test_to_tensor.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import numpy
  4. import torch
  5. from draugr.torch_utilities.tensors.to_tensor import to_tensor
  6. __author__ = "Christian Heider Nielsen"
  7. __doc__ = ""
  8. def test_to_tensor_none():
  9. try:
  10. tensor = to_tensor(None)
  11. except:
  12. return
  13. assert False
  14. def test_to_tensor_empty_list():
  15. try:
  16. tensor = to_tensor([])
  17. except:
  18. return
  19. assert False
  20. def test_to_tensor_empty_tuple():
  21. try:
  22. tensor = to_tensor(())
  23. except:
  24. return
  25. assert False
  26. def test_to_tensor_list():
  27. ref = [0]
  28. tensor = to_tensor(ref, device="cpu")
  29. assert tensor.equal(torch.FloatTensor([0]))
  30. def test_to_tensor_multi_list():
  31. ref = [[0], [1]]
  32. tensor = to_tensor(ref, device="cpu")
  33. assert tensor.equal(torch.FloatTensor([[0], [1]]))
  34. def test_to_tensor_tuple():
  35. ref = (0,)
  36. tensor = to_tensor(ref, device="cpu")
  37. assert tensor.equal(torch.FloatTensor([0]))
  38. def test_to_tensor_multi_tuple():
  39. ref = ([0], [1])
  40. tensor = to_tensor(ref, device="cpu")
  41. assert tensor.equal(torch.FloatTensor([[0], [1]]))
  42. def test_to_tensor_from_numpy_tensor():
  43. ref = torch.from_numpy(numpy.random.sample((1, 2)))
  44. tensor = to_tensor(ref, dtype=torch.double, device="cpu")
  45. assert tensor.equal(ref)
  46. def test_to_tensor_float_tensor():
  47. ref = torch.FloatTensor([0])
  48. tensor = to_tensor(ref, device="cpu")
  49. assert tensor.equal(ref)
  50. def test_generator_to_float_tensor():
  51. s = range(9)
  52. ref = torch.FloatTensor([*s])
  53. tensor = to_tensor(s, device="cpu")
  54. assert tensor.equal(ref)
  55. if __name__ == "__main__":
  56. pass