Source code for draugr.torch_utilities.sessions.type_sessions

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "Christian Heider Nielsen"
__doc__ = r"""

           Created on 15/06/2020
           """

import torch

from warg import AlsoDecorator

__all__ = ["DefaultTypeSession"]


# torch.set_default_tensor_type('torch.cuda.FloatTensor') # Legacy


[docs]class DefaultTypeSession(AlsoDecorator): """ # speed up evaluating after training finished"""
[docs] def __init__(self, dtype: torch.dtype = torch.float32, no_side_effect: bool = True): self._dtype = dtype self._no_side_effect = no_side_effect if no_side_effect: self.prev_state = torch.get_default_dtype()
def __enter__(self): torch.set_default_dtype(self._dtype) return True def __exit__(self, exc_type, exc_val, exc_tb): if self._no_side_effect: torch.set_default_dtype(self.prev_state) else: torch.set_default_dtype(torch.float32)
if __name__ == "__main__": pass