Source code for draugr.python_utilities.torch_like_channel_transformation

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "Christian Heider Nielsen"
__doc__ = r"""

           Created on 15/04/2020
           """

__all__ = [
    "rgb_drop_alpha_hwc",
    "rgb_drop_alpha_batch_nhwc",
    "torch_vision_normalize_batch_nchw",
    "reverse_torch_vision_normalize_batch_nchw",
]

# from numba import jit

from warg.typing_extension import StrictNumbers


# @jit(nopython=True, fastmath=True)
[docs]def rgb_drop_alpha_hwc(inp: StrictNumbers) -> StrictNumbers: """ :param inp: :type inp: :return: :rtype:""" assert len(inp[-1, -1]) >= 3, f"not enough channels, only had {len(inp[-1, -1])}" return inp[..., :3]
# @jit(nopython=True, fastmath=True)
[docs]def rgb_drop_alpha_batch_nhwc(inp: StrictNumbers) -> StrictNumbers: """ :param inp: :type inp: :return: :rtype:""" assert ( len(inp[-1, -1, -1]) >= 3 ), f"not enough channels, only had {len(inp[-1, -1, -1])}" return inp[..., :3]
# @jit(nopython=True, fastmath=True)
[docs]def torch_vision_normalize_batch_nchw(inp: StrictNumbers) -> StrictNumbers: """ warning INPLACE! :param inp: :type inp: :return: :rtype:""" mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) assert len(inp[-1]) == 3, f"was {len(inp[-1])}" inp[:, 0] = (inp[:, 0] - mean[0]) / std[0] inp[:, 1] = (inp[:, 1] - mean[1]) / std[1] inp[:, 2] = (inp[:, 2] - mean[2]) / std[2] return inp
# @jit(nopython=True, fastmath=True)
[docs]def reverse_torch_vision_normalize_batch_nchw(inp: StrictNumbers) -> StrictNumbers: """ :param inp: :type inp: :return: :rtype:""" mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) assert len(inp[-1]) == 3, f"was {len(inp[-1])}" inp[:, 0] = inp[:, 0] * std[0] + mean[0] inp[:, 1] = inp[:, 1] * std[1] + mean[1] inp[:, 2] = inp[:, 2] * std[2] + mean[2] return inp
if __name__ == "__main__": import numpy def asda() -> None: """ :rtype: None """ a = numpy.ones((1, 4, 4, 4)) b = numpy.ones((1, 4, 4, 3)) c = numpy.ones((4, 4, 3)) d = numpy.ones((1, 4, 4, 2)) rgb_drop_alpha_batch_nhwc(a) rgb_drop_alpha_batch_nhwc(b) try: rgb_drop_alpha_batch_nhwc(c) except: pass rgb_drop_alpha_hwc(c) try: rgb_drop_alpha_batch_nhwc(d) except: pass try: rgb_drop_alpha_hwc(d) except: pass def asbsdfdsa() -> None: """ :rtype: None """ a = numpy.ones((1, 3, 4, 4)) ba = torch_vision_normalize_batch_nchw(a) print(ba) ca = reverse_torch_vision_normalize_batch_nchw(ba) print(ca) asda() asbsdfdsa()