test_sessions.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. __author__ = "Christian Heider Nielsen"
  4. __doc__ = r"""
  5. Created on 11/05/2020
  6. """
  7. import torch
  8. from draugr.torch_utilities import (
  9. TorchCpuSession,
  10. TorchCudaSession,
  11. TorchEvalSession,
  12. TorchTrainSession,
  13. global_torch_device,
  14. )
  15. def test_cpu():
  16. print(global_torch_device(override=global_torch_device(device_preference=True)))
  17. print(global_torch_device())
  18. with TorchCpuSession():
  19. print(global_torch_device())
  20. print(global_torch_device())
  21. def test_nested_model_sessions():
  22. model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.Dropout(0.1))
  23. print(model.training)
  24. with TorchEvalSession(model):
  25. print(model.training)
  26. with TorchTrainSession(model):
  27. print(model.training)
  28. with TorchEvalSession(model):
  29. print(model.training)
  30. with TorchTrainSession(model):
  31. print(model.training)
  32. with TorchEvalSession(model):
  33. print(model.training)
  34. print(model.training)
  35. def test_nested_device_sessions():
  36. print(global_torch_device(override=global_torch_device(device_preference=True)))
  37. print(global_torch_device())
  38. with TorchCpuSession():
  39. print(global_torch_device())
  40. with TorchCudaSession():
  41. print(global_torch_device())
  42. with TorchCpuSession():
  43. print(global_torch_device())
  44. print(global_torch_device())