Source code for draugr.torch_utilities.generators.batching
from torch.utils.data.sampler import BatchSampler
[docs]class LimitedBatchResampler(BatchSampler):
"""
Wraps a BatchSampler, re-sampling from it until
a specified number of iterations have been sampled"""
[docs] def __init__(self, batch_sampler, num_iterations, start_iter: int = 0):
self.batch_sampler = batch_sampler
self.num_iterations = num_iterations
self.start_iter = start_iter
def __iter__(self):
iteration = self.start_iter
while iteration <= self.num_iterations:
# if the underlying sampler has a set_epoch method, like
# DistributedSampler, used for making each process see
# a different split of the dataset, then set it
if hasattr(self.batch_sampler.sampler, "set_epoch"):
self.batch_sampler.sampler.set_epoch(iteration)
for batch in self.batch_sampler:
iteration += 1
if iteration > self.num_iterations:
break
yield batch
def __len__(self):
return self.num_iterations