Source code for draugr.torch_utilities.tensors.types

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "Christian Heider Nielsen"
__doc__ = r"""

           Created on 02-02-2021
           """

import numpy
import torch

__all__ = [
    "numpy_to_torch_dtype_dict",
    "torch_to_numpy_dtype_dict",
    "numpy_to_torch_dtype",
    "torch_to_numpy_dtype",
]

numpy_to_torch_dtype_dict = (
    {  # Dict of NumPy dtype -> torch dtype (when the correspondence exists)
        bool: torch.bool,
        numpy.uint8: torch.uint8,
        numpy.int8: torch.int8,
        numpy.int16: torch.int16,
        numpy.int32: torch.int32,
        numpy.int64: torch.int64,
        numpy.float16: torch.float16,
        numpy.float32: torch.float32,
        numpy.float64: torch.float64,
        numpy.complex64: torch.complex64,
        numpy.complex128: torch.complex128,
    }
)

torch_to_numpy_dtype_dict = {
    value: key for (key, value) in numpy_to_torch_dtype_dict.items()
}  # Dict of torch dtype -> NumPy dtype


[docs]def numpy_to_torch_dtype(numpy_dtype: numpy.dtype) -> torch.dtype: """description""" return numpy_to_torch_dtype_dict[numpy_dtype.type]
[docs]def torch_to_numpy_dtype(torch_dtype: torch.dtype) -> numpy.dtype: """description""" return torch_to_numpy_dtype_dict[torch_dtype]
if __name__ == "__main__": def iusahdu() -> None: """ :rtype: None """ a = numpy.zeros((1, 1)) print(a.dtype) b = numpy_to_torch_dtype(a.dtype) print(b) print(type(b)) c = torch_to_numpy_dtype(b) print(c) print(type(c)) iusahdu()