Source code for draugr.torch_utilities.sessions.model_sessions

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "Christian Heider Nielsen"
__doc__ = r"""

           Created on 11/05/2020
           """

from collections import OrderedDict
from itertools import tee

import torch

from draugr.torch_utilities.optimisation.parameters.freezing import freeze_parameters
from warg import AlsoDecorator

__all__ = [
    "TorchEvalSession",
    "TorchTrainSession",
    "TorchFrozenModelSession",
    "TorchTrainingSession",
]


[docs]class TorchEvalSession(AlsoDecorator): """ # speed up evaluating after training finished"""
[docs] def __init__(self, model: torch.nn.Module, no_side_effect: bool = True): self.model = model self._no_side_effect = no_side_effect if no_side_effect: self.prev_state = model.training
def __enter__(self): # self.model.eval() self.model.train(False) return True def __exit__(self, exc_type, exc_val, exc_tb): if self._no_side_effect: self.model.train(self.prev_state) else: self.model.train(True)
[docs]class TorchTrainSession(AlsoDecorator): """ # speed up evaluating after training finished"""
[docs] def __init__(self, model: torch.nn.Module, no_side_effect: bool = True): self.model = model self._no_side_effect = no_side_effect if no_side_effect: self.prev_state = model.training
def __enter__(self): self.model.train(True) return True def __exit__(self, exc_type, exc_val, exc_tb): if self._no_side_effect: self.model.train(self.prev_state) else: self.model.train(False)
TorchTrainingSession = TorchTrainSession
[docs]class TorchFrozenModelSession(AlsoDecorator): """description"""
[docs] def __init__(self, model: torch.nn.Module, no_side_effect: bool = True): self.model = model self._no_side_effect = no_side_effect self.params_1, self.params_2, self.params_3 = tee(model.parameters(True), 3) if no_side_effect: self.previous_states = [a.requires_grad for a in self.params_3]
def __enter__(self): freeze_parameters(self.params_1, True) return True def __exit__(self, exc_type, exc_val, exc_tb): if self._no_side_effect: [p.requires_grad_(rg) for p, rg in zip(self.params_2, self.previous_states)] else: freeze_parameters(self.params_2, False)
if __name__ == "__main__": def main() -> None: """ :rtype: None """ a = torch.nn.Sequential( OrderedDict(l1=torch.nn.Linear(3, 5), l2=torch.nn.Linear(5, 2)) ) p_iter = iter(a.parameters(True)) l1_w = next(p_iter) l1_bias = next(p_iter) l1_bias.requires_grad_(False) def initial(): """description""" for p in a.parameters(True): print(p.requires_grad) @TorchFrozenModelSession(a) def frozen(): """description""" for p in a.parameters(True): print(p.requires_grad) def frozen_session(): """description""" with TorchFrozenModelSession(a): for p in a.parameters(True): print(p.requires_grad) initial() print() frozen() print() initial() print() frozen_session() print() initial() main()