Source code for draugr.torch_utilities.sessions.cache_sessions
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Created on 20/03/2020
"""
import torch
from draugr.torch_utilities.sessions.device_sessions import (
TorchCpuSession,
TorchCudaSession,
global_torch_device,
)
from draugr.torch_utilities.sessions.model_sessions import (
TorchEvalSession,
TorchTrainSession,
)
__all__ = ["TorchCacheSession"]
from warg import AlsoDecorator
[docs]class TorchCacheSession(AlsoDecorator):
"""
# speed up evaluating after training finished
# NOTE: HAS THE SIDE EFFECT OF CLEARING CACHE, NON RECOVERABLE"""
[docs] def __init__(self, using_cuda: bool = global_torch_device().type == "cuda"):
self.using_cuda = using_cuda
def __enter__(self):
if self.using_cuda:
torch.cuda.empty_cache()
return True
def __exit__(self, exc_type, exc_val, exc_tb):
if self.using_cuda:
torch.cuda.empty_cache()
if __name__ == "__main__":
def a() -> None:
"""
:rtype: None
"""
print(torch.cuda.memory_cached(global_torch_device()))
with TorchCacheSession():
torch.tensor([0.0], device=global_torch_device())
print(torch.cuda.memory_cached(global_torch_device()))
print(torch.cuda.memory_cached(global_torch_device()))
def b() -> None:
"""
:rtype: None
"""
model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.Dropout(0.1))
print(model.training)
with TorchEvalSession(model):
print(model.training)
print(model.training)
def c() -> None:
"""
:rtype: None
"""
model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.Dropout(0.1))
model.eval()
print(model.training)
with TorchTrainSession(model):
print(model.training)
print(model.training)
def d() -> None:
"""
:rtype: None
"""
print(
global_torch_device(override=global_torch_device(device_preference=False))
)
print(global_torch_device())
with TorchCudaSession():
print(global_torch_device())
print(global_torch_device())
def e() -> None:
"""
:rtype: None
"""
print(global_torch_device(override=global_torch_device(device_preference=True)))
print(global_torch_device())
with TorchCpuSession():
print(global_torch_device())
print(global_torch_device())
# a()
# b()
# c()
d()
e()