Source code for draugr.torch_utilities.datasets.non_sequential_dataset

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import List

import numpy
from torch.utils.data import Dataset

__author__ = "Christian Heider Nielsen"
__doc__ = r"""

           Created on 09/10/2019
           """
__all__ = ["NonSequentialDataset"]


[docs]class NonSequentialDataset(Dataset): """ * ``N`` - number of parallel environments * ``T`` - number of time steps explored in environments Dataset that flattens ``N*T*...`` arrays into ``B*...`` (where ``B`` is equal to ``N*T``) and returns such rows one by one. So basically we loose information about sequence order and we return for example one state, action and reward per row. It can be used for ``Model``'s that does not need to keep the order of events like MLP models. For ``LSTM`` use another implementation that will slice the dataset differently"""
[docs] def __init__(self, *arrays: numpy.ndarray) -> None: """ :param arrays: arrays to be flattened from ``N*T*...`` to ``B*...`` and returned in each call to get item""" super().__init__() self.arrays = [array.reshape(-1, *array.shape[2:]) for array in arrays]
def __getitem__(self, index: int) -> List: return [array[index] for array in self.arrays] def __len__(self): return len(self.arrays[0])