Source code for draugr.torch_utilities.tensors.info

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "Christian Heider Nielsen"
__doc__ = r"""

           Created on 28/06/2020
           """

import itertools
import sys

import torch

__all__ = ["size_of_tensor"]


[docs]def size_of_tensor(obj: torch.Tensor) -> int: r"""**Get size in bytes of Tensor, torch.nn.Module or standard object.** Specific routines are defined for torch.tensor objects and torch.nn.Module objects. They will calculate how much memory in bytes those object consume. If another object is passed, `sys.getsizeof` will be called on it. This function works similarly to C++'s sizeof operator. Parameters ---------- obj Object whose size will be measured. Returns ------- int Size in bytes of the object""" if torch.is_tensor(obj): return obj.element_size() * obj.numel() elif isinstance(obj, torch.nn.Module): return sum( size_of_tensor(tensor) for tensor in itertools.chain(obj.buffers(), obj.parameters()) ) else: return sys.getsizeof(obj)
if __name__ == "__main__": module = torch.nn.Linear(20, 20) bias = 20 * 4 # in bytes weights = 20 * 20 * 4 # in bytes assert size_of_tensor(module) == bias + weights