Source code for draugr.torch_utilities.architectures.experimental.recurrent
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "Christian Heider Nielsen"
__doc__ = ""
import torch
from torch import nn
from torch.nn import functional
from draugr.torch_utilities.architectures.mlp import MLP
[docs]class RecurrentCategoricalMLP(MLP):
"""description"""
[docs] def __init__(self, r_hidden_layers=10, **kwargs):
super().__init__(**kwargs)
self._r_hidden_layers = r_hidden_layers
self._r_input_shape = self._output_shape + r_hidden_layers
self.hidden = nn.Linear(
self._r_input_shape, r_hidden_layers, bias=self._use_bias
)
self.out = nn.Linear(self._r_input_shape, r_hidden_layers, bias=self._use_bias)
self._prev_hidden_x = torch.zeros(r_hidden_layers)
[docs] def forward(self, x, **kwargs):
"""
:param x:
:type x:
:param kwargs:
:type kwargs:
:return:
:rtype:
"""
x = super().forward(x, **kwargs)
combined = torch.cat((x, self._prev_hidden_x), 1)
out_x = self.out(combined)
hidden_x = self.hidden(combined)
self._prev_hidden_x = hidden_x
return functional.log_softmax(out_x, dim=-1)
[docs]class ExposedRecurrentCategoricalMLP(RecurrentCategoricalMLP):
[docs] def forward(self, x, hidden_x, **kwargs):
"""
:param x:
:type x:
:param hidden_x:
:type hidden_x:
:param kwargs:
:type kwargs:
:return:
:rtype:
"""
self._prev_hidden_x = hidden_x
out_x = super().forward(x, **kwargs)
return functional.log_softmax(out_x, dim=-1), self._prev_hidden_x
[docs]class RecurrentBase(nn.Module):
"""description"""
[docs] def __init__(self, recurrent, recurrent_input_size, hidden_size):
super().__init__()
self._hidden_size = hidden_size
self._recurrent = recurrent
if recurrent:
self.gru = nn.GRUCell(recurrent_input_size, hidden_size)
nn.init.orthogonal_(self.gru.weight_ih.data)
nn.init.orthogonal_(self.gru.weight_hh.data)
self.gru.bias_ih.data.fill_(0)
self.gru.bias_hh.data.fill_(0)
def _forward_gru(self, x, hxs, masks):
if x.size(0) == hxs.size(0):
x = hxs = self.gru(x, hxs * masks)
else:
# x is a (T, N, -1) tensor that has been flatten to (T * N, -1)
N = hxs.size(0)
T = int(x.size(0) / N)
# unflatten
x = x.reshape(T, N, x.size(1))
# Same deal with masks
masks = masks.reshape(T, N, 1)
outputs = []
for i in range(T):
hx = hxs = self.gru(x[i], hxs * masks[i])
outputs.append(hx)
# assert len(outputs) == T
# x is a (T, N, -1) tensor
x = torch.stack(outputs, dim=0)
# flatten
x = x.reshape(T * N, -1)
return x, hxs