Source code for draugr.torch_utilities.architectures.distributional.categorical

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import List, Tuple

import numpy
import torch
from torch.distributions import Categorical
from torch.nn import functional

from draugr.torch_utilities.architectures.mlp import MLP
from draugr.torch_utilities.tensors.to_tensor import to_tensor

__author__ = "Christian Heider Nielsen"
__doc__ = r"""
"""
__all__ = ["MultipleCategoricalMLP", "CategoricalMLP"]


[docs]class MultipleCategoricalMLP(MLP):
[docs] @staticmethod def sample(distributions) -> Tuple: """ :param distributions: :type distributions: :return: :rtype: """ actions = [d.sample() for d in distributions][0] log_prob = [d.log_prob(action) for d, action in zip(distributions, actions)][0] actions = [a.to("cpu").numpy().tolist() for a in actions] return actions, log_prob
[docs] @staticmethod def entropy(distributions) -> torch.tensor: """ :param distributions: :type distributions: :return: :rtype: """ return torch.mean(to_tensor([d.entropy() for d in distributions]))
[docs] def forward(self, *x, **kwargs) -> List: """ :param x: :type x: :param kwargs: :type kwargs: :return: :rtype: """ out = super().forward(*x, **kwargs) outs = [] for o in out: outs.append(Categorical(logits=functional.log_softmax(o, dim=-1))) return outs
[docs]class CategoricalMLP(MLP):
[docs] def forward(self, *x, **kwargs) -> Categorical: """ :param x: :type x: :param kwargs: :type kwargs: :return: :rtype: """ return Categorical( logits=functional.log_softmax(super().forward(*x, **kwargs), dim=-1) )
if __name__ == "__main__": def multi_cat(): """description""" s = (2, 2) a = (2, 2) model = MultipleCategoricalMLP(input_shape=s, output_shape=a) inp = to_tensor(numpy.random.rand(64, s[0]), device="cpu") print(model.sample(model(inp, inp))) def single_cat(): """description""" s = (1, 2) a = (2,) model = CategoricalMLP(input_shape=s, output_shape=a) inp = to_tensor(numpy.random.rand(64, s[0]), device="cpu") inp2 = to_tensor(numpy.random.rand(64, s[1]), device="cpu") print(model(inp, inp2).sample()) def single_cat2(): """description""" s = (4,) a = (2,) model = CategoricalMLP(input_shape=s, output_shape=a) inp = to_tensor(numpy.random.rand(64, s[0]), device="cpu") print(model(inp).sample()) multi_cat() single_cat() single_cat2()