Source code for draugr.torch_utilities.architectures.mlp_variants.fourier

__all__ = ["BaseFourierFeatureMLP"]
__doc__ = (
    r""""Module containing models based upon the Fourier Feature Networks template."""
)

import math
from typing import List, Optional

import torch
from torch import nn


[docs]class BaseFourierFeatureMLP(nn.Module): """MLP which uses Fourier features as a preprocessing step."""
[docs] def __init__( self, num_inputs: int, num_outputs: int, a_values: Optional[torch.Tensor], b_values: Optional[torch.Tensor], layer_channels: List[int], ): """Constructor. Args: num_inputs (int): Number of dimensions in the input num_outputs (int): Number of dimensions in the output a_values (torch.Tensor): a values for encoding b_values (torch.Tensor): b values for encoding layer_channels (List[int]): Number of channels per layer. """ nn.Module.__init__(self) self.params = { "num_inputs": num_inputs, "num_outputs": num_outputs, "a_values": None if a_values is None else a_values.tolist(), "b_values": None if b_values is None else b_values.tolist(), "layer_channels": layer_channels, } self.num_inputs = num_inputs if b_values is None: self.a_values = None self.b_values = None num_inputs = num_inputs else: assert b_values.shape[0] == num_inputs assert a_values.shape[0] == b_values.shape[1] self.a_values = nn.Parameter(a_values, requires_grad=False) self.b_values = nn.Parameter(b_values, requires_grad=False) num_inputs = b_values.shape[1] * 2 self.layers = nn.ModuleList() for num_channels in layer_channels: self.layers.append(nn.Linear(num_inputs, num_channels)) num_inputs = num_channels self.layers.append(nn.Linear(num_inputs, num_outputs)) self.use_view = False self.keep_activations = False self.activations = []
[docs] def forward(self, inputs: torch.Tensor) -> torch.Tensor: """Predicts outputs from the provided uv input.""" if self.b_values is None: output = inputs else: # NB: the below should be 2*math.pi, but the values # coming in are already in the range of -1 to 1 or # 0 to 2, so we want to keep the range so that it does # not exceed 2pi encoded = (math.pi * inputs) @ self.b_values output = torch.cat( [self.a_values * encoded.cos(), self.a_values * encoded.sin()], dim=-1 ) self.activations.clear() for layer in self.layers[:-1]: output = torch.relu(layer(output)) if self.keep_activations: self.activations.append(output.detach().cpu().numpy()) output = self.layers[-1](output) return output
[docs] def save(self, path: str): """Saves the model to the specified path. Args: path (str): Path to the model file on disk """ state_dict = self.state_dict() state_dict["type"] = "fourier" state_dict["params"] = self.params torch.save(state_dict, path)
class MLP(BaseFourierFeatureMLP): """Unencoded FFN, essentially a standard MLP.""" def __init__( self, num_inputs: int, num_outputs: int, num_layers=3, num_channels=256 ): """Constructor. Args: num_inputs (int): Number of dimensions in the input num_outputs (int): Number of dimensions in the output num_layers (int, optional): Number of layers in the MLP. Defaults to 4. num_channels (int, optional): Number of channels in the MLP. Defaults to 256. """ super().__init__( num_inputs, num_outputs, None, None, [num_channels] * num_layers ) class BasicFMLP(BaseFourierFeatureMLP): """Basic version of FFN in which inputs are projected onto the unit circle.""" def __init__( self, num_inputs: int, num_outputs: int, num_layers=3, num_channels=256 ): """Constructor. Args: num_inputs (int): Number of dimensions in the input num_outputs (int): Number of dimensions in the output num_layers (int, optional): Number of layers in the MLP. Defaults to 4. num_channels (int, optional): Number of channels in the MLP. Defaults to 256. """ a_values = torch.ones(num_inputs) b_values = torch.eye(num_inputs) super().__init__( num_inputs, num_outputs, a_values, b_values, [num_channels] * num_layers, ) class PositionalFMLP(BaseFourierFeatureMLP): """Version of FFN with positional encoding.""" def __init__( self, num_inputs: int, num_outputs: int, max_log_scale: float, num_layers=3, num_channels=256, embedding_size=256, ): """Constructor. Args: num_inputs (int): Number of dimensions in the input num_outputs (int): Number of dimensions in the output max_log_scale (float): Maximum log scale for embedding num_layers (int, optional): Number of layers in the MLP. Defaults to 4. num_channels (int, optional): Number of channels in the MLP. Defaults to 256. embedding_size (int, optional): The size of the feature embedding. Defaults to 256. """ b_values = self._encoding(max_log_scale, embedding_size, num_inputs) a_values = torch.ones(b_values.shape[1]) super().__init__( num_inputs, num_outputs, a_values, b_values, [num_channels] * num_layers, ) @staticmethod def _encoding(max_log_scale: float, embedding_size: int, num_inputs: int): """Produces the encoding b_values matrix.""" embedding_size = embedding_size // num_inputs frequencies_matrix = 2.0 ** torch.linspace(0, max_log_scale, embedding_size) frequencies_matrix = frequencies_matrix.reshape(-1, 1, 1) frequencies_matrix = torch.eye(num_inputs) * frequencies_matrix frequencies_matrix = frequencies_matrix.reshape(-1, num_inputs) frequencies_matrix = frequencies_matrix.transpose(0, 1) return frequencies_matrix class GaussianFMLP(BaseFourierFeatureMLP): """Version of a FFN using a full Gaussian matrix for encoding.""" def __init__( self, num_inputs: int, num_outputs: int, sigma: float, num_layers=3, num_channels=256, embedding_size=256, ): """Constructor. Args: num_inputs (int): Number of dimensions in the input num_outputs (int): Number of dimensions in the output sigma (float): Standard deviation of the Gaussian distribution num_layers (int, optional): Number of layers in the MLP. Defaults to 4. num_channels (int, optional): Number of channels in the MLP. Defaults to 256. embedding_size (int, optional): Number of frequencies to use for the encoding. Defaults to 256. """ b_values = torch.normal(0, sigma, size=(num_inputs, embedding_size)) a_values = torch.ones(b_values.shape[1]) super().__init__( num_inputs, num_outputs, a_values, b_values, [num_channels] * num_layers, )