Source code for draugr.torch_utilities.persistence.checkpoint
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Created on 06-04-2021
"""
from pathlib import Path
import torch
from torch.optim.optimizer import Optimizer
__all__ = ["save_checkpoint", "load_checkpoint"]
def save_optimiser(
*,
optimiser: Optimizer,
optimiser_save_path: Path,
raise_on_existing: bool = False,
) -> None:
"""
:param optimiser:
:param optimiser_save_path:
:param raise_on_existing:"""
if raise_on_existing and optimiser_save_path.exists():
raise FileExistsError(f"{optimiser_save_path} exists!")
torch.save(optimiser, str(optimiser_save_path))
[docs]def save_checkpoint(PATH: Path, epoch, model, optimiser, loss):
"""
:param PATH:
:param epoch:
:param model:
:param optimiser:
:param loss:
"""
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimiser_state_dict": optimiser.state_dict(),
"value": loss,
},
PATH,
)
PATH.with_suffix(".tar")
[docs]def load_checkpoint(PATH: Path, model, optimizer):
"""
:param PATH:
:param model:
:param optimizer:
:return:
"""
checkpoint = torch.load(PATH)
epoch = checkpoint["epoch"]
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimiser_state_dict"])
loss = checkpoint["value"]
return epoch, model, optimizer, loss
if __name__ == "__main__":
def main() -> None:
"""
:rtype: None
"""
pass
# model = TheModelClass(args, **kwargs)
# optimizer = TheOptimizerClass(args, **kwargs)
def multi() -> None:
"""
:rtype: None
"""
pass
# checkpoint = torch.load(PATH)
# modelA.load_state_dict(checkpoint['modelA_state_dict'])
# modelB.load_state_dict(checkpoint['modelB_state_dict'])
# optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
# optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])