Source code for draugr.torch_utilities.tensors.reshaping
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Created on 23/02/2020
"""
__all__ = ["flatten_tn_dim", "flatten_keep_batch", "safe_concat"]
import torch
[docs]def flatten_tn_dim(_tensor: torch.Tensor) -> torch.tensor:
"""
:param _tensor:
:return:"""
t, n, *r = _tensor.size()
return _tensor.reshape(t * n, *r)
[docs]def flatten_keep_batch(t: torch.Tensor) -> torch.Tensor:
"""
:param t:
:return:"""
return t.reshape(t.shape[0], -1)
[docs]def safe_concat(arr: torch.Tensor, el: torch.Tensor, dim: int = 0) -> torch.Tensor:
"""
:param arr:
:param el:
:param dim:
:return:"""
if arr is None:
return el
return torch.cat((arr, el), dim=dim)
if __name__ == "__main__":
def a() -> None:
"""
:rtype: None
"""
shape = (2, 3, 4, 5)
from warg import prod
t = torch.reshape(torch.arange(0, prod(shape)), shape)
f = flatten_tn_dim(t)
tf = t.flatten(0, 1)
print(t, f, tf)
print(t.shape, f.shape, tf.shape)
a()