test_batch_generator.py 771 B

123456789101112131415161718192021222324252627282930313233343536
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. from draugr.python_utilities import batched_recycle
  4. __author__ = "Christian Heider Nielsen"
  5. __doc__ = r"""
  6. Created on 28/10/2019
  7. """
  8. def test_batch_generator1():
  9. a = range(9)
  10. batch_size = 3
  11. for i, b in zip(range(18), batched_recycle(a, batch_size)):
  12. assert [b_ in a for b_ in b]
  13. assert i == 17
  14. def test_batch_with_label():
  15. import numpy
  16. channels_in = 3
  17. batches = 3
  18. batch_size = 32
  19. data_shape = (batches * batch_size, 256, 256, channels_in)
  20. generator = batched_recycle(
  21. zip(numpy.random.sample(data_shape), numpy.random.sample(data_shape[0])),
  22. batch_size,
  23. )
  24. for i, a in enumerate(generator):
  25. print(a)
  26. break