"""
Utility file for model checkpoints file operations
:author: Julian M. Kleber
"""
import os
import torch
from amarium.utils import (
prepare_file_name_saving,
load_json_from_file,
make_full_filename,
save_json_to_file,
check_make_dir,
)
from typing import Tuple, Type, Dict, Any
from carate.models.base_model import Model
[docs]
def load_model(
model_path: str, model_net: Type[torch.nn.Module]
) -> Type[torch.nn.Module]:
"""
The load_model function takes in a model_path, model_params_path and the type of network to be loaded.
It then loads the parameters from the params file into a dictionary and uses that to create an instance of
the specified network. It then loads in the state dict from PATH and sets it as eval mode.
:param model_path:str: Used to specify the path to the model file.
:param model_params_path:str: Used to load the model parameters from a file.
:param model_net:Type[torch.nn.Module]: Used to specify the type of model that is being loaded.
:return: A model that is loaded with the parameters in the path.
:doc-author: Julian M. Kleber
"""
model = model_net()
model.load_state_dict(torch.load(model_path))
model.eval()
return model
[docs]
def load_model_training_checkpoint(
checkpoint_path: str,
model_net: Type[torch.nn.Module],
optimizer: Type[torch.optim.Optimizer],
) -> Tuple[Model, torch.optim.Optimizer]:
# For any bug fixing please consult the PyTorch documentation: https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training
model = model_net
optimizer = optimizer
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]
loss = checkpoint["loss"]
model.train()
return model, optimizer
[docs]
def save_model_training_checkpoint(
result_save_dir: str,
dataset_name: str,
num_cv: int,
num_epoch: int,
model_net: Type[torch.nn.Module],
optimizer: Type[torch.optim.Optimizer],
loss: float,
override: bool,
) -> None:
# For any bug fixing please refer to https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training
prefix = result_save_dir + "/checkpoints/CV_" + str(num_cv)
check_make_dir(prefix)
if override is True:
save_path = prepare_file_name_saving(
prefix=prefix,
file_name=dataset_name,
suffix=".tar",
)
if override is False:
save_path = prepare_file_name_saving(
prefix=prefix,
file_name=dataset_name + "_Epoch-" + str(num_epoch),
suffix=".tar",
)
torch.save(
{
"epoch": num_epoch,
"cv": num_cv,
"model_state_dict": model_net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": loss,
},
save_path,
)
[docs]
def save_model(
result_save_dir: str,
dataset_name: str,
num_cv: int,
num_epoch: int,
model_net: Type[torch.nn.Module],
) -> None:
"""
The save_model function saves the model to a file.
The save_model function saves the model to a file. The filename is constructed from the dataset name, number of cross-validation folds, and number of epochs trained on.
:param result_save_dir:str: Used to specify the directory where the model will be saved.
:param dataset_name:str: Used to name the file.
:param num_cv:int: Used to make the filename unique.
:param num_epoch: Used to save the model after a certain number of epochs.
:param model_net:Type[torch.nn.Module]: Used to save the model.
:param : Used to specify the directory where the model will be saved.
:return: The path of the saved model.
:doc-author: Julian M. Kleber
"""
prefix = result_save_dir + "/checkpoints/CV_" + str(num_cv)
check_make_dir(prefix)
save_path = prepare_file_name_saving(
prefix=prefix,
file_name=dataset_name + "_Epoch-" + str(num_epoch),
suffix=".pt",
)
torch.save(model_net.state_dict(), save_path)
[docs]
def load_model_parameters(model_params_file_path: str) -> Dict[Any, Any]:
"""
The load_model_parameters function loads the model parameters from a JSON file.
Parameters:
model_params_file_path (str): The path to the JSON file containing the model parameters.
Returns:
dict: A dictionary of all of the loaded model parameters.
:param model_params_file_path:str: Used to Specify the file path of the model parameters.
:return: A dictionary of model parameters.
:doc-author: Julian M. Kleber
"""
return load_json_from_file(model_params_file_path)
[docs]
def save_model_parameters(model_net: Model, save_dir: str) -> None:
"""
The save_model_parameters function saves the model architecture to a csv file.
Args:
model_net (torch.nn.Module): The neural network that is being used for training and testing, e.g., CNN() or RNN().
save_path (str): The path where the json file will be saved to, e.g., "./model/".
Returns: None
:param model_net:Type[torch.nn.Module]: Used to specify the type of model that is being used.
:param save_path:str: Used to save the model architecture in a json file.
:return: A dictionary of the model architecture (model_architecture).
:doc-author: Julian M. Kleber
"""
prefix = save_dir + "/model_parameters/"
model_architecture = model_net.__dict__
file_name = prepare_file_name_saving(
prefix=prefix, file_name="model_architecture", suffix=".json"
)
save_json_to_file(model_architecture, file_name=file_name)
[docs]
def get_latest_checkpoint(search_dir: str, num_cv: int, epoch: int) -> str:
if not search_dir.endswith("/"):
search_dir += "/"
search_dir += "checkpoints"
checkpoint_dirs = os.listdir(search_dir)
correct_sub_dir = checkpoint_dirs[checkpoint_dirs.index("CV_" + str(num_cv))]
search_dir += "/" + correct_sub_dir
checkpoints = os.listdir(search_dir)
checkpoints = sorted(checkpoints)
return checkpoints[:-1]