Source code for draugr.torch_utilities.distributions.entropy
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import math
import torch
from draugr.torch_utilities.system.constants import torch_pi
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Created on 09/10/2019
"""
__all__ = [
"shannon_entropy",
"log_shannon_entropy",
"normal_entropy",
"differential_entropy_gaussian",
"normal_log_density",
]
[docs]def shannon_entropy(prob: torch.Tensor) -> torch.Tensor:
"""
:param prob:
:type prob:
:return:
:rtype:"""
return -torch.sum(prob * torch.log2(prob), -1)
[docs]def log_shannon_entropy(log_prob: torch.Tensor) -> torch.Tensor:
"""
:param log_prob:
:type log_prob:
:return:
:rtype:"""
return -torch.sum(torch.pow(2, log_prob) * log_prob, -1)
# return - torch.sum(torch.exp(log_prob) * log_prob, -1)
[docs]def normal_entropy(std: torch.Tensor) -> torch.Tensor:
"""
:param std:
:type std:
:return:
:rtype:"""
var = std.pow(2)
ent = 0.5 + 0.5 * torch.log(2 * var * math.pi)
return ent.sum(dim=-1, keepdim=True)
[docs]def differential_entropy_gaussian(std: torch.Tensor) -> torch.Tensor:
"""
:param std:
:type std:
:return:
:rtype:"""
return torch.log(std * torch.sqrt(2 * torch_pi())) + 0.5
[docs]def normal_log_density(x: torch.Tensor, mean, log_std, std) -> torch.Tensor:
"""
:param x:
:type x:
:param mean:
:type mean:
:param log_std:
:type log_std:
:param std:
:type std:
:return:
:rtype:"""
var = std.pow(2)
log_density = -(x - mean).pow(2) / (2 * var) - 0.5 * math.log(2 * math.pi) - log_std
return log_density.sum(1, keepdim=True)