Source code for draugr.python_utilities.generators.batching_generator
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Any, Iterable, Sequence
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Created on 21/10/2019
"""
import numpy
__all__ = ["sized_batch", "shuffled_batches", "random_batches", "batch_generator"]
[docs]def sized_batch(sized: Iterable, n: int = 32, drop_not_full: bool = True) -> Any:
r"""
:param sized:
:param n:
:param drop_not_full:
:return:"""
if not isinstance(sized, Sequence):
sized = list(sized)
l = len(sized)
for ndx in range(0, l, n):
if drop_not_full and ndx + n > l - 1:
return
yield sized[ndx : min(ndx + n, l)]
[docs]def random_batches(*args, size: int, batch_size: int) -> Sequence:
r"""
:param args:
:type args:
:param size:
:type size:
:param batch_size:
:type batch_size:"""
for _ in range(size // batch_size):
rand_ids = numpy.random.randint(0, size, batch_size)
yield [a[rand_ids] for a in args]
[docs]def shuffled_batches(*args, size: int, batch_size: int) -> Sequence:
r"""
:param args:
:type args:
:param size:
:type size:
:param batch_size:
:type batch_size:"""
permutation = numpy.random.permutation(size)
r = size // batch_size
assert r > 0, f"{size}/{batch_size}={r}"
for i in range(r):
perm = permutation[i * batch_size : (i + 1) * batch_size]
yield [a[perm] for a in args]
[docs]def batch_generator(iterable: Iterable, n: int = 32, drop_not_full: bool = True) -> Any:
r"""
:param iterable:
:param n:
:param drop_not_full:
:return:"""
b = []
i = 0
for a in iterable:
b.append(a)
i += 1
if i >= n:
yield b
b.clear()
i = 0
if drop_not_full:
return
return b
if __name__ == "__main__":
def asda() -> None:
"""
:rtype: None
"""
arg_num = 4
size = 12
mini_batch_size = 5
b = numpy.random.random((arg_num, size))
for a in shuffled_batches(*b, size=size, batch_size=mini_batch_size):
print(list(a))
asda()