Source code for draugr.torch_utilities.persistence.model
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import datetime
import sys
from typing import Optional
import torch
from torch.nn.modules.module import Module
from draugr.torch_utilities.persistence.config import (
ensure_directory_exist,
save_config,
)
from warg import latest_file
from warg import passes_kws_to
from warg.decorators.kw_passing import drop_unused_kws
__author__ = "Christian Heider Nielsen"
model_extension = ".model"
config_extension = ".py"
__all__ = [
"load_model",
"load_latest_model",
"save_model_and_configuration",
"save_model",
"convert_saved_model_to_cpu",
]
from pathlib import Path
[docs]@drop_unused_kws
def load_latest_model(
*, model_name: str, model_directory: Path, raise_on_failure: bool = True
) -> Optional[torch.nn.Module]:
"""
load model with the lastest time appendix or in this case creation time
:param raise_on_failure:
:param model_directory:
:param model_name:
:return:"""
model_path = model_directory / model_name
latest_model_ = latest_file(
model_path,
extension=model_extension,
raise_on_failure=raise_on_failure,
)
print(f"loading previous model: {latest_model_}")
if latest_model_:
return torch.load(str(latest_model_))
load_model = load_latest_model
[docs]@passes_kws_to(save_config)
def save_model_and_configuration(
*,
model: Module,
model_save_path: Path,
config_save_path: Path = None,
loaded_config_file_path: Path = None,
raise_on_existing: bool = False,
) -> None:
"""
:param raise_on_existing:
:param model:
:param model_save_path:
:param config_save_path:
:param loaded_config_file_path:
:return:"""
if raise_on_existing and model_save_path.exists():
raise FileExistsError(f"{model_save_path} exists!")
torch.save(model, str(model_save_path))
if loaded_config_file_path:
save_config(config_save_path, loaded_config_file_path)
[docs]@drop_unused_kws
@passes_kws_to(save_model_and_configuration)
def save_model(
model: Module,
*,
model_name: str,
save_directory: Path, # TODO: RENAME to model directory for consistency
config_file_path: Path = None,
prompt_on_failure: bool = True,
verbose: bool = False,
) -> None:
"""
save a model with a timestamp appendix to later to loaded
:param prompt_on_failure:
:param verbose:
:param model:
:param save_directory:
:param config_file_path:
:param model_name:
:return:"""
model_date = datetime.datetime.now()
# config_name = config_name.replace(".", "_")
model_time_rep = model_date.strftime("%Y%m%d%H%M%S")
model_save_path = save_directory / model_name / f"{model_time_rep}{model_extension}"
config_save_path = (
save_directory / model_name / f"{model_time_rep}{config_extension}"
)
ensure_directory_exist(model_save_path.parent)
saved = False
try:
save_model_and_configuration(
model=model,
model_save_path=model_save_path,
loaded_config_file_path=config_file_path,
config_save_path=config_save_path,
)
saved = True
except FileNotFoundError as e:
if prompt_on_failure:
print(e)
while not saved:
file_path = input("Enter another file path: ")
model_save_path = Path(file_path).expanduser().resolve()
parent = model_save_path.parent
ensure_directory_exist(parent)
config_save_path = parent / f"{model_save_path.name}{config_extension}"
try:
save_model_and_configuration(
model=model,
model_save_path=model_save_path,
loaded_config_file_path=config_file_path,
config_save_path=config_save_path,
)
saved = True
except FileNotFoundError as e:
print(e)
saved = False
else:
raise e
if verbose:
if saved:
print(
f"Successfully saved model and configuration respectively at {model_save_path} and {config_save_path}"
)
else:
print(f"Was unsuccessful at saving model or configuration")
[docs]def convert_saved_model_to_cpu(path: Path) -> None:
"""
:param path:
:return:"""
model = torch.load(path, map_location=lambda storage, loc: storage)
torch.save(model, f"{path}.cpu{model_extension}")
if __name__ == "__main__":
convert_saved_model_to_cpu(Path(sys.argv[1]))