Source code for draugr.torch_utilities.architectures.mlp_variants.disjunction
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Fission variants of MLPs
Created on 24/02/2020
"""
from typing import Sequence, Tuple
import torch
from torch import nn
__all__ = ["DisjunctMLP"]
from draugr.torch_utilities.architectures.mlp import MLP
[docs]class DisjunctMLP(MLP):
"""description"""
[docs] def __init__(
self,
output_shape: Sequence = (2,),
disjunction_size=256,
subnet_size=128,
hidden_layer_activation=nn.ReLU(),
**kwargs
):
super().__init__(
output_shape=(disjunction_size,),
hidden_layer_activation=hidden_layer_activation,
output_activation=nn.Identity(),
**kwargs
)
self.subnet_1 = torch.nn.Sequential(
torch.nn.Linear(disjunction_size, subnet_size),
hidden_layer_activation,
torch.nn.Linear(subnet_size, output_shape[-1]),
)
self.subnet_2 = torch.nn.Sequential(
torch.nn.Linear(disjunction_size, subnet_size),
hidden_layer_activation,
torch.nn.Linear(subnet_size, 1),
)
[docs] def forward(self, *act, **kwargs) -> Tuple[torch.tensor, torch.tensor]:
"""
:param act:
:type act:
:param kwargs:
:type kwargs:
:return:
:rtype:
"""
x = super().forward(*act, **kwargs)
return self.subnet_1(x), self.subnet_2(x)