Source code for draugr.torch_utilities.architectures.mlp_variants.concatination
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Iterable, List, Sequence
import numpy
import torch
from draugr.torch_utilities.architectures.mlp import MLP
from draugr.torch_utilities.tensors.to_tensor import to_tensor
__author__ = "Christian Heider Nielsen"
__doc__ = "Fusion variant of MLPs"
__all__ = ["PreConcatInputMLP", "LateConcatInputMLP"]
from warg import passes_kws_to
[docs]class PreConcatInputMLP(MLP):
"""
Early fusion
"""
[docs] def __init__(self, input_shape: Sequence = (2,), **kwargs):
if isinstance(input_shape, Iterable):
input_shape = sum(input_shape)
super().__init__(input_shape=input_shape, **kwargs)
[docs] @passes_kws_to(MLP.forward)
def forward(self, *x, **kwargs) -> List:
"""
:param x:
:type x:
:param kwargs:
:type kwargs:
:return:
:rtype:
"""
return super().forward(torch.cat(x, dim=-1), **kwargs)
[docs]class LateConcatInputMLP(MLP):
"""
Late fusion, quite a botch job, only a single addition block fusion supported for now
You have been warned! ;)
"""
[docs] def __init__(
self,
input_shape: Sequence = (2, 2),
output_shape: Sequence = (2,),
fusion_hidden_multiplier: int = 10,
**kwargs
):
forward_shape, *res = input_shape
self._residual_shape = res
if not isinstance(self._residual_shape, Iterable):
self._residual_shape = (self._residual_shape,)
if not isinstance(output_shape, Iterable):
output_shape = (*self._residual_shape, output_shape)
assert len(output_shape) == 2
super().__init__(
input_shape=(forward_shape,), output_shape=output_shape, **kwargs
)
s = sum((*output_shape, *self._residual_shape))
t = s * fusion_hidden_multiplier # Hidden
self.post_concat_layer = torch.nn.Sequential(
torch.nn.Linear(s, t), torch.nn.ReLU(), torch.nn.Linear(t, output_shape[-1])
)
[docs] @passes_kws_to(MLP.forward)
def forward(self, *x, **kwargs) -> torch.tensor:
"""
:param x:
:type x:
:param kwargs:
:type kwargs:
:return:
:rtype:
"""
forward_x, *residual_x = x
return self.post_concat_layer(
torch.cat((*(super().forward(forward_x, **kwargs)), *residual_x), dim=-1)
)
if __name__ == "__main__":
def stest_normal():
"""description"""
s = (10,)
a = (10,)
model = PreConcatInputMLP(input_shape=s, output_shape=a)
inp = to_tensor(range(s[0]), device="cpu")
print(model.forward(inp))
def stest_multi_dim_normal():
"""description"""
s = (19,)
s1 = (4,)
batch_size = (100,)
a = (2, 10)
model = PreConcatInputMLP(input_shape=s + s1, output_shape=a)
inp = to_tensor(numpy.random.random((*batch_size, *s)), device="cpu")
late_input = to_tensor(numpy.random.random((*batch_size, *s1)), device="cpu")
print(model.forward(inp, late_input))
def stest_multi_dim_normal21():
"""description"""
s = (19,)
s1 = (4,)
batch_size = (100,)
a = (2, 10)
model = LateConcatInputMLP(input_shape=s + s1, output_shape=a)
inp = to_tensor(numpy.random.random((*batch_size, *s)), device="cpu")
late_input = to_tensor(numpy.random.random((*batch_size, *s1)), device="cpu")
print(model.forward(inp, late_input))
def stest_multi_dim_normal23121():
"""description"""
s = (19,)
s1 = (4,)
batch_size = (100,)
output_shape = (1, 2)
model = LateConcatInputMLP(input_shape=s + s1, output_shape=output_shape)
inp = to_tensor(numpy.random.random((*batch_size, *s)), device="cpu")
late_input = to_tensor(numpy.random.random((*batch_size, *s1)), device="cpu")
print(model.forward(inp, late_input))
def stest_multi_dim_normal2321412121():
"""description"""
s = (19,)
s1 = (4,)
batch_size = (100,)
output_shape = 2
model = LateConcatInputMLP(input_shape=s + s1, output_shape=output_shape)
inp = to_tensor(
numpy.random.random((*batch_size, *s)), device="cpu", dtype=float
)
late_input = to_tensor(
numpy.random.random((*batch_size, *s1)), device="cpu", dtype=float
)
print(model.forward(inp, late_input).shape)
# stest_normal()
# stest_multi_dim_normal()
# stest_multi_dim_normal21()
# stest_multi_dim_normal23121()
stest_multi_dim_normal2321412121()