Source code for draugr.torch_utilities.architectures.experimental.heads

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from torch import nn

__author__ = "Christian Heider Nielsen"
__doc__ = ""

from draugr.torch_utilities.architectures.mlp import MLP


[docs]class MultiHeadedMLP(MLP): """description"""
[docs] def __init__(self, *, heads_hidden_sizes=(32, 64), heads=(2, 1), **kwargs): super().__init__(**kwargs) assert len(heads_hidden_sizes) == len(heads) self._heads_hidden_sizes = heads_hidden_sizes self._heads = heads self.num_of_heads = len(self._heads) if self.num_of_heads > 0: for i in range(1, self.num_of_heads + 1): head_hidden = nn.Linear( self._output_shape, self._heads_hidden_sizes[i - 1], bias=self._use_bias, ) setattr(self, f"subhead{str(i)}_hidden", head_hidden) head = nn.Linear( self._heads_hidden_sizes[i - 1], self._heads[i - 1], bias=self._use_bias, ) setattr(self, f"subhead{str(i)}", head) else: raise ValueError("Number of heads must be >0")
[docs] def forward(self, x, **kwargs): """ :param x: :type x: :param kwargs: :type kwargs: :return: :rtype: """ x = super().forward(x, **kwargs) output = [] for i in range(1, self.num_of_heads + 1): head_hidden = getattr(self, f"subhead{str(i)}_hidden") x_s = head_hidden(x) head = getattr(self, f"subhead{str(i)}") sub_res = head(x_s) # if not isinstance(sub_res, list): # sub_res = [sub_res] output.append(sub_res) return output