Source code for draugr.torch_utilities.tensors.tensor_container

__all__ = ["NamedTensorTuple"]


[docs]class NamedTensorTuple: """ Help class for manage boxes, labels, etc... Not inherit dict due to `default_collate` will change dict's subclass to dict."""
[docs] def __init__(self, **kwargs): self._data_dict = kwargs
def __setattr__(self, key, value): object.__setattr__(self, key, value) def __getitem__(self, key): return self._data_dict[key] def __iter__(self): return self._data_dict.__iter__() def __setitem__(self, key, value): self._data_dict[key] = value def _call(self, name, *args, **kwargs): keys = list(self._data_dict.keys()) for key in keys: value = self._data_dict[key] if hasattr(value, name): self._data_dict[key] = getattr(value, name)(*args, **kwargs) return self
[docs] def to(self, *args, **kwargs): """ :param args: :type args: :param kwargs: :type kwargs: :return: :rtype:""" return self._call("to", *args, **kwargs)
[docs] def numpy(self): """ :return: :rtype:""" return self._call("numpy")
def __repr__(self): return self._data_dict.__repr__()