Source code for draugr.torch_utilities.optimisation.updates.copying
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "Christian Heider Nielsen"
import torch
__all__ = ["copy_parameters", "copy_state"]
[docs]def copy_parameters(
target: torch.nn.Module, source: torch.nn.Module
) -> torch.nn.Module:
"""description"""
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(param.data)
return target
[docs]def copy_state(*, target: torch.nn.Module, source: torch.nn.Module) -> torch.nn.Module:
"""description"""
target.load_state_dict(source.state_dict())
return target