test_merged.py 698 B

12345678910111213141516171819202122232425262728293031323334
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. __author__ = "Christian Heider Nielsen"
  4. __doc__ = r"""
  5. Created on 04/12/2019
  6. """
  7. from draugr.torch_utilities import to_tensor, PreConcatInputMLP
  8. def test_normal():
  9. s = (10,)
  10. a = (10,)
  11. model = PreConcatInputMLP(input_shape=s, output_shape=a)
  12. inp = to_tensor(range(s[0]), device="cpu")
  13. print(model.forward(inp))
  14. def test_multi_dim_normal():
  15. s = (10, 2, 3)
  16. a = (2, 10)
  17. model = PreConcatInputMLP(input_shape=s, output_shape=a)
  18. inp = [to_tensor(range(s_), device="cpu") for s_ in s]
  19. print(model.forward(*inp))
  20. if __name__ == "__main__":
  21. test_normal()
  22. test_multi_dim_normal()