cross_valid_example.py 458 B

12345678910111213141516171819
  1. import torch
  2. from torch.utils.data import TensorDataset
  3. from draugr.torch_utilities import cross_validation_generator, to_tensor
  4. def asdasidoj():
  5. """ """
  6. X = to_tensor([torch.diag(torch.arange(i, i + 2)) for i in range(200)])
  7. x_train = TensorDataset(X[:100])
  8. x_val = TensorDataset(X[100:])
  9. for train, val in cross_validation_generator(x_train, x_val):
  10. print(len(train), len(val))
  11. print(train[0], val[0])
  12. asdasidoj()