Source code for draugr.torch_utilities.optimisation.parameters.freezing.parameters

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "Christian Heider Nielsen"
__doc__ = r"""

           Created on 15/02/2020
           """

__all__ = ["frozen_parameters", "freeze_parameters"]

from contextlib import contextmanager
from itertools import tee
from typing import Iterator

from torch.nn import Parameter


[docs]def freeze_parameters(params: Iterator[Parameter], value: bool = None) -> None: """ :param params: :param value: :return:""" if isinstance(value, bool): for p in params: p.requires_grad = not value else: for p in params: p.requires_grad = not p.requires_grad
[docs]@contextmanager def frozen_parameters(params: Iterator[Parameter], enabled=True) -> None: """ :param enabled: :type enabled: :param params: :return:""" params_1, params_2 = tee(params) if enabled: freeze_parameters(params_1, True) yield True if enabled: freeze_parameters(params_2, False)
if __name__ == "__main__": from torch import nn def asd21312a() -> None: """ :rtype: None """ a = nn.Linear(10, 5) print(a.weight.requires_grad) with frozen_parameters(a.parameters()): print(a.weight.requires_grad) print(a.weight.requires_grad) def afsda32() -> None: """ :rtype: None """ a = nn.Linear(10, 5) print(a.weight.requires_grad) with frozen_parameters(a.parameters()): print(a.weight.requires_grad) print(a.weight.requires_grad) def afsda12332_toogle() -> None: """ :rtype: None """ a = nn.Linear(10, 5) print(a.weight.requires_grad) freeze_parameters(a.parameters()) print(a.weight.requires_grad) freeze_parameters(a.parameters()) print(a.weight.requires_grad) def afsda12332_explicit() -> None: """ :rtype: None """ a = nn.Linear(10, 5) print(a.weight.requires_grad) freeze_parameters(a.parameters(), True) print(a.weight.requires_grad) freeze_parameters(a.parameters(), False) print(a.weight.requires_grad) def seq_no_context() -> None: """ :rtype: None """ a = nn.Sequential(nn.Linear(10, 5), nn.Linear(5, 5)) print(next(a.parameters()).requires_grad) freeze_parameters(a.parameters(), True) print(next(a.parameters()).requires_grad) freeze_parameters(a.parameters(), False) print(next(a.parameters()).requires_grad) def seq_context() -> None: """ :rtype: None """ a = nn.Sequential(nn.Linear(10, 5), nn.Linear(5, 5)) print(next(a.parameters()).requires_grad) freeze_parameters(a.parameters(), True) print(next(a.parameters()).requires_grad) freeze_parameters(a.parameters(), False) print(next(a.parameters()).requires_grad) print("\n") asd21312a() print("\n") afsda32() print("\n") afsda12332_toogle() print("\n") afsda12332_explicit() print("\n") seq_no_context() print("\n") seq_context()