#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Created on 20/07/2020
"""
import datetime
import os
from typing import Optional, Tuple, Union
import torch
from torch.nn.modules.module import Module
from torch.optim import Optimizer
from draugr.torch_utilities.persistence.config import (
ensure_directory_exist,
save_config,
)
from warg.decorators.kw_passing import drop_unused_kws
parameter_extension = ".parameters"
config_extension = ".py"
optimiser_extension = ".optimiser"
__all__ = [
"load_model_parameters",
"load_latest_model_parameters",
"save_parameters_and_configuration",
"save_model_parameters",
]
from pathlib import Path
[docs]@drop_unused_kws
def load_latest_model_parameters(
model: torch.nn.Module,
*,
optimiser: Optimizer = None,
model_name: str,
model_directory: Path,
) -> Tuple[Union[torch.nn.Module, Tuple[torch.nn.Module, Optimizer]], bool]:
"""
inplace but returns model
:param optimiser:
:param model:
:type model:
:param model_directory:
:param model_name:
:return:"""
model_loaded = False
optimiser_loaded = False
if model:
model_path = model_directory / model_name
list_of_files = list(model_path.glob(f"*{parameter_extension}"))
if len(list_of_files) == 0:
print(
f"Found no previous models with extension {parameter_extension} in {model_path}"
)
else:
latest_model_parameter_file = max(list_of_files, key=os.path.getctime)
print(f"loading previous model parameters: {latest_model_parameter_file}")
model.load_state_dict(torch.load(str(latest_model_parameter_file)))
model_loaded = True
if optimiser:
opt_st_d_file = latest_model_parameter_file.with_suffix(
optimiser_extension
)
if opt_st_d_file.exists():
optimiser.load_state_dict(torch.load(str(opt_st_d_file)))
print(f"loading previous optimiser state: {opt_st_d_file}")
optimiser_loaded = True
if optimiser:
return (model, optimiser), (model_loaded, optimiser_loaded)
return model, model_loaded
load_model_parameters = load_latest_model_parameters
# @passes_kws_to(save_config)
[docs]def save_parameters_and_configuration(
*,
model: Module,
model_save_path: Path,
optimiser: Optional[Optimizer] = None,
optimiser_save_path: Optional[Path] = None,
config_save_path: Optional[Path] = None,
loaded_config_file_path: Optional[Path] = None,
) -> None:
"""
:param optimiser:
:type optimiser:
:param optimiser_save_path:
:type optimiser_save_path:
:param model:
:param model_save_path:
:param config_save_path:
:param loaded_config_file_path:
:return:"""
torch.save(model.state_dict(), str(model_save_path))
if optimiser:
torch.save(optimiser.state_dict(), str(optimiser_save_path))
if loaded_config_file_path:
save_config(config_save_path, loaded_config_file_path)
[docs]@drop_unused_kws
def save_model_parameters(
model: Module,
*,
model_name: str,
save_directory: Path,
optimiser: Optional[Optimizer] = None,
config_file_path: Optional[Path] = None,
verbose: bool = False,
) -> None:
"""
:param optimiser:
:param model:
:param save_directory:
:param config_file_path:
:param model_name:
:return:"""
model_date = datetime.datetime.now()
model_time_rep = model_date.strftime("%Y%m%d%H%M%S")
model_save_path = save_directory / model_name / f"{model_time_rep}"
ensure_directory_exist(model_save_path.parent)
saved = False
try:
save_parameters_and_configuration(
model=model,
model_save_path=model_save_path.with_suffix(parameter_extension),
optimiser=optimiser,
optimiser_save_path=(
model_save_path.parent / f"{model_time_rep}"
).with_suffix(optimiser_extension),
loaded_config_file_path=config_file_path,
config_save_path=(model_save_path.parent / f"{model_time_rep}").with_suffix(
config_extension
),
)
saved = True
except FileNotFoundError as e:
print(e)
while not saved:
model_save_path = (
Path(input("Enter another file path: ")).expanduser().resolve()
)
ensure_directory_exist(model_save_path.parent)
try:
save_parameters_and_configuration(
model=model,
model_save_path=model_save_path.endswith(parameter_extension),
optimiser=optimiser,
optimiser_save_path=(
model_save_path.parent / f"{model_time_rep}"
).with_suffix(optimiser_extension),
loaded_config_file_path=config_file_path,
config_save_path=(
model_save_path.parent / f"{model_time_rep}"
).with_suffix(config_extension),
)
saved = True
except FileNotFoundError as e:
print(e)
saved = False
if verbose:
if saved:
print(
f"Successfully saved model parameters, optimiser state and configuration at names {[model_save_path.with_suffix(a) for a in (parameter_extension, optimiser_extension, config_extension)]}"
)
else:
print(f"Was unsuccessful at saving model or configuration")