Source code for draugr.torch_utilities.system.data_type
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Created on 15/11/2019
"""
GLOBAL_DTYPE = None
__all__ = ["global_torch_dtype", "set_global_torch_dtype"]
[docs]def global_torch_dtype(
override: torch.dtype = None, verbose: bool = False
) -> torch.dtype:
"""
first time call stores to dtype for global reference, later call must manually override
:param verbose:
:type verbose:
:param override:
:type override:
:return:
:rtype:"""
global GLOBAL_DTYPE
if override is not None:
GLOBAL_DTYPE = override
set_global_torch_dtype(GLOBAL_DTYPE)
if verbose:
print(f"Overriding global torch device to {override}")
elif GLOBAL_DTYPE is None:
GLOBAL_DTYPE = torch.get_default_dtype()
return GLOBAL_DTYPE
[docs]def set_global_torch_dtype(dtype: torch.dtype) -> None:
"""description"""
global GLOBAL_DTYPE
GLOBAL_DTYPE = dtype
torch.set_default_dtype(GLOBAL_DTYPE)
if __name__ == "__main__":
def stest_override() -> None:
"""
:rtype: None
"""
print(global_torch_dtype(verbose=True))
print(global_torch_dtype(override=torch.double, verbose=True))
print(global_torch_dtype(verbose=True))
print(global_torch_dtype())
print(global_torch_dtype())
print(global_torch_dtype(override=torch.half))
print(global_torch_dtype())
stest_override()