Source code for draugr.torch_utilities.sessions.device_sessions
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Created on 11/05/2020
"""
import torch
from torch.nn import Module
from draugr.torch_utilities import global_torch_device
from warg import AlsoDecorator
__all__ = ["TorchCpuSession", "TorchCudaSession", "TorchDeviceSession"]
[docs]class TorchCudaSession(AlsoDecorator):
"""
Sets global torch devices to cuda if available"""
[docs] def __init__(self, model: Module = None, no_side_effect: bool = True):
self._model = model
self._no_side_effect = no_side_effect
if no_side_effect:
self.prev_dev = global_torch_device()
def __enter__(self):
device = global_torch_device(override=torch.device("cuda"))
if self._model:
self._model.to(device)
return True
def __exit__(self, exc_type, exc_val, exc_tb):
if self._no_side_effect:
device = global_torch_device(override=self.prev_dev)
else:
device = global_torch_device(override=torch.device("cpu"))
if self._model:
self._model.to(device)
return False
[docs]class TorchCpuSession(AlsoDecorator):
"""
Sets global torch devices to cpu"""
[docs] def __init__(self, model: Module = None, no_side_effect: bool = True):
self._model = model
self._no_side_effect = no_side_effect
if no_side_effect:
self.prev_dev = global_torch_device()
def __enter__(self):
device = global_torch_device(override=torch.device("cpu"))
if self._model:
self._model.to(device)
return True
def __exit__(self, exc_type, exc_val, exc_tb):
if self._no_side_effect:
device = global_torch_device(override=self.prev_dev)
else:
device = global_torch_device(override=torch.device("cuda"))
if self._model:
self._model.to(device)
return False
[docs]class TorchDeviceSession(AlsoDecorator):
"""
Sets global torch devices to cpu"""
[docs] def __init__(
self, device: torch.device, model: Module = None, no_side_effect: bool = True
):
self._model = model
self._no_side_effect = no_side_effect
self._device = device
if no_side_effect:
self.prev_dev = global_torch_device()
def __enter__(self):
device = global_torch_device(override=self._device)
if self._model:
self._model.to(device)
return True
def __exit__(self, exc_type, exc_val, exc_tb):
if self._no_side_effect:
device = global_torch_device(override=self.prev_dev)
if self._model:
self._model.to(device)
return False
if __name__ == "__main__":
with TorchDeviceSession(global_torch_device()):
pass