cross_valid_example.py 639 B

1234567891011121314151617181920212223242526272829
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. __author__ = "Christian Heider Nielsen"
  4. __doc__ = r"""
  5. Created on 31-10-2020
  6. """
  7. import torch
  8. from torch.utils.data import TensorDataset
  9. from draugr.torch_utilities import cross_validation_generator, to_tensor
  10. def asdasidoj() -> None:
  11. """
  12. :rtype: None
  13. """
  14. X = to_tensor([torch.diag(torch.arange(i, i + 2)) for i in range(200)])
  15. x_train = TensorDataset(X[:100])
  16. x_val = TensorDataset(X[100:])
  17. for train, val in cross_validation_generator(x_train, x_val):
  18. print(len(train), len(val))
  19. print(train[0], val[0])
  20. asdasidoj()