test_pytorch_writer.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import pytest
  4. from draugr import PROJECT_APP_PATH
  5. from draugr.torch_utilities import TensorBoardPytorchWriter
  6. __author__ = "Christian Heider Nielsen"
  7. __doc__ = r"""
  8. """
  9. @pytest.mark.parametrize(
  10. ["tag", "val", "step"],
  11. (("signal", 0, 0), ("signal", 20, 1), ("signal", -1, 6)),
  12. ids=["signal_first", "signal_second", "signal_sixth"],
  13. )
  14. def test_valid_scalars(tag, val, step):
  15. with TensorBoardPytorchWriter(path=PROJECT_APP_PATH.user_log) as w:
  16. w.scalar(tag, val, step)
  17. @pytest.mark.parametrize(
  18. ["tag", "val", "step"],
  19. (("signal", "", 0), ("signal", None, 1), ("signal", object(), 6)),
  20. ids=["str_scalar", "None_scalar", "object_scalar"],
  21. )
  22. def test_invalid_val_type_scalars(tag, val, step):
  23. try:
  24. with TensorBoardPytorchWriter(path=PROJECT_APP_PATH.user_log) as w:
  25. w.scalar(tag, val, step)
  26. assert False
  27. except Exception as e:
  28. assert True
  29. @pytest.mark.parametrize(
  30. ["tag", "val", "step"],
  31. ((1, 0, 0), (None, 20, 1), (object(), -1, 6)),
  32. ids=["numeral_tag", "None_tag", "object_tag"],
  33. )
  34. def test_invalid_tag_scalars(tag, val, step):
  35. try:
  36. with TensorBoardPytorchWriter(path=PROJECT_APP_PATH.user_log) as w:
  37. w.scalar(tag, val, step)
  38. assert False
  39. except Exception as e:
  40. print(e)
  41. assert True
  42. @pytest.mark.parametrize(
  43. ["tag", "val", "step"],
  44. (("signal", 0, ""), ("signal", 20, None), ("tag1", -0, object())),
  45. ids=["str_step", "None_step", "object_step"],
  46. )
  47. def test_invalid_step_type_scalars(tag, val, step):
  48. try:
  49. with TensorBoardPytorchWriter(path=PROJECT_APP_PATH.user_log) as w:
  50. w.scalar(tag, val, step)
  51. assert False
  52. except Exception as e:
  53. print(e)
  54. assert True
  55. @pytest.mark.parametrize(
  56. ["tag", "val", "truth", "step"],
  57. (
  58. ("signal", range(9), range(9), 0),
  59. ("signal", range(9), range(9), None),
  60. ("tag1", range(9), range(9), object()),
  61. ),
  62. ids=["str_step", "None_step", "object_step"],
  63. )
  64. def test_precision_recall(tag, val, truth, step):
  65. try:
  66. with TensorBoardPytorchWriter(path=PROJECT_APP_PATH.user_log) as w:
  67. w.precision_recall_curve(tag, val, truth, step)
  68. assert False
  69. except Exception as e:
  70. print(e)
  71. assert True