Source code for draugr.torch_utilities.system.seeding

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import functools
from typing import Any, Callable

import numpy
import torch

__author__ = "Christian Heider Nielsen"
__doc__ = r"""

           Created on 09/10/2019
           """
__all__ = ["torch_seed"]

from torch import Tensor


[docs]def torch_seed(s: int = 72163) -> torch.Generator: """ seeding for reproducibility""" generator = torch.manual_seed(s) if False: # Disabled for now torch.set_deterministic(True) if torch.cuda.is_available(): torch.cuda.manual_seed_all(s) torch.backends.cudnn.deterministic = True return generator
class Seed: r"""**Seed PyTorch and numpy.** This code is based on PyTorch's reproducibility guide: https://pytorch.org/docs/stable/notes/randomness.html Can be used as standard seeding procedure, context manager (seed will be changed only within block) or function decorator. **Standard seed**:: torchfunc.Seed(0) # no surprises I guess **Used as context manager**:: with Seed(1): ... # your operations print(torch.initial_seed()) # Should be back to seed pre block **Used as function decorator**:: @Seed(1) # Seed only within function def foo(): return 42 **Important:** It's impossible to put original `numpy` seed after context manager or decorator, hence it will be set to original PyTorch's seed. Parameters ---------- value: int Seed value used in numpy.random_seed and torch.manual_seed. Usually int is provided cuda: bool, optional Whether to set PyTorch's cuda backend into deterministic mode (setting cudnn.benchmark to `False` and cudnn.deterministic to `True`). If `False`, consecutive runs may be slightly different. If `True`, automatic autotuning for convolutions layers with consistent input shape will be turned off. Default: `False` """ def __init__(self, value: int, cuda: bool = False): self.value = value self.cuda = cuda self.no_side_effect = False if self.no_side_effect: self._last_seed = torch.initial_seed() numpy.random.seed(self.value) torch.manual_seed(self.value) if False: # Disabled for now torch.set_deterministic(True) if self.cuda: torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def __enter__(self): return self def __exit__(self, *_, **__): if self.no_side_effect: torch.manual_seed(self._last_seed) numpy.random.seed(self._last_seed) return False def __call__(self, function: Callable) -> callable: @functools.wraps(function) def decorated(*args, **kwargs) -> Any: """description""" value = function(*args, **kwargs) self.__exit__() return value return decorated if __name__ == "__main__": @Seed(1) # Seed only within function def foo() -> Tensor: """ :rtype: None """ return torch.randint(5, (2, 2)) def bar() -> Tensor: """ :rtype: None """ with Seed(1): return torch.randint(5, (2, 2)) def buzz() -> Tensor: """ :rtype: None """ Seed(1) return torch.randint(5, (2, 2)) for f in [foo, bar, buzz]: print(f())