Source code for draugr.torch_utilities.tensors.dimension_order

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "Christian Heider Nielsen"
__doc__ = r"""

          A few low dimensional orders common in raster-grids

           Created on 28-03-2021
           """

import torch

__all__ = [
    "hwc_to_chw_tensor",
    "chw_to_hwc_tensor",
    "nhwc_to_nchw_tensor",
    "nchw_to_nhwc_tensor",
    "nthwc_to_ntchw_tensor",
    "ntchw_to_nthwc_tensor",
]


[docs]def hwc_to_chw_tensor(tensor: torch.Tensor) -> torch.Tensor: """ :param tensor: :return: :rtype:""" assert len(tensor.shape) == 3 return tensor.permute(2, 0, 1)
[docs]def chw_to_hwc_tensor(tensor: torch.Tensor) -> torch.Tensor: """ :param tensor: :return: :rtype:""" assert len(tensor.shape) == 3 return tensor.permute(1, 2, 0)
[docs]def nhwc_to_nchw_tensor(tensor: torch.Tensor) -> torch.Tensor: """ :param tensor: :return: :rtype:""" assert len(tensor.shape) == 4 return tensor.permute(0, 3, 1, 2)
[docs]def nchw_to_nhwc_tensor(tensor: torch.Tensor) -> torch.Tensor: """ :param tensor: :return: :rtype:""" assert len(tensor.shape) == 4 return tensor.permute(0, 2, 3, 1)
[docs]def nthwc_to_ntchw_tensor(tensor: torch.Tensor) -> torch.Tensor: """ :param tensor: :return: :rtype:""" assert len(tensor.shape) == 5 return tensor.permute(0, 1, 4, 2, 3)
[docs]def ntchw_to_nthwc_tensor(tensor: torch.Tensor) -> torch.Tensor: """ :param tensor: :return: :rtype:""" assert len(tensor.shape) == 5 return tensor.permute(0, 1, 3, 4, 2)