import os
from aitoolbox.cloud.AWS.data_access import BaseDataLoader
from aitoolbox.experiment.local_load.local_model_load import AbstractLocalModelLoader, PyTorchLocalModelLoader
from aitoolbox.experiment.local_save.folder_create import ExperimentFolder
[docs]class BaseModelLoader(BaseDataLoader):
def __init__(self, local_model_loader, local_model_result_folder_path='~/project/model_result',
bucket_name='model-result', cloud_dir_prefix=''):
"""Base saved model loading from S3 storage
Args:
local_model_loader (AbstractLocalModelLoader): model loader implementing the loading of the saved model for
the selected deep learning framework
local_model_result_folder_path (str): root local path where project folder will be created
bucket_name (str): name of the bucket in the cloud storage from which the model will be downloaded
cloud_dir_prefix (str): path to the folder inside the bucket where the experiments are going to be saved
"""
BaseDataLoader.__init__(self, bucket_name, local_model_result_folder_path)
self.local_model_result_folder_path = self.local_base_data_folder_path
self.cloud_dir_prefix = cloud_dir_prefix
self.local_model_loader = local_model_loader
if not isinstance(local_model_loader, AbstractLocalModelLoader):
raise TypeError('Provided local_model_loader is not inherited from AbstractLocalModelLoader as required.')
[docs] def load_model(self, project_name, experiment_name, experiment_timestamp,
model_save_dir='checkpoint_model', epoch_num=None,
**kwargs):
"""Download and read/load the model
Args:
project_name (str): root name of the project
experiment_name (str): name of the particular experiment
experiment_timestamp (str): time stamp at the start of training
model_save_dir (str): name of the folder inside experiment folder where the model is saved
epoch_num (int or None): epoch number of the model checkpoint or none if loading final model
**kwargs: additional local_model_loader parameters
Returns:
dict: model representation. (currently only returning dicts as only PyTorch model loading is supported)
"""
cloud_model_folder_path = os.path.join(self.cloud_dir_prefix,
project_name,
experiment_name + '_' + experiment_timestamp,
model_save_dir)
experiment_dir_path = ExperimentFolder.create_base_folder(project_name, experiment_name, experiment_timestamp,
self.local_model_result_folder_path)
local_model_folder_path = os.path.join(experiment_dir_path, model_save_dir)
if not os.path.exists(local_model_folder_path):
os.mkdir(local_model_folder_path)
if epoch_num is None:
model_name = f'model_{experiment_name}_{experiment_timestamp}.pth'
else:
model_name = f'model_{experiment_name}_{experiment_timestamp}_E{epoch_num}.pth'
# Loads the model save file from S3 to the local folder
cloud_model_file_path = os.path.join(cloud_model_folder_path, model_name)
local_model_file_path = os.path.join(local_model_folder_path, model_name)
# Will only download from S3 if file not present on local drive
self.load_file(cloud_model_file_path, local_model_file_path)
return self.local_model_loader.load_model(project_name, experiment_name, experiment_timestamp,
model_save_dir, epoch_num, **kwargs)
[docs]class PyTorchS3ModelLoader(BaseModelLoader):
def __init__(self, local_model_result_folder_path='~/project/model_result',
bucket_name='model-result', cloud_dir_prefix=''):
"""PyTorch S3 model downloader & loader
Args:
local_model_result_folder_path (str): root local path where project folder will be created
bucket_name (str): name of the bucket in the cloud storage from which the model will be downloaded
cloud_dir_prefix (str): path to the folder inside the bucket where the experiments are going to be saved
"""
local_model_loader = PyTorchLocalModelLoader(local_model_result_folder_path)
BaseModelLoader.__init__(self, local_model_loader, local_model_result_folder_path,
bucket_name, cloud_dir_prefix)
[docs] def init_model(self, model, used_data_parallel=False):
"""Initialize provided PyTorch model with the loaded model weights
For this function to work, load_model() must be first called to read the model representation into memory.
Args:
model: PyTorch model
used_data_parallel (bool): if the saved model was nn.DataParallel or normal model
Returns:
initialized model
"""
return self.local_model_loader.init_model(model, used_data_parallel)
[docs] def init_optimizer(self, optimizer, device='cuda'):
"""Initialize PyTorch optimizer
Args:
optimizer:
device (str):
Returns:
initialized optimizer
"""
return self.local_model_loader.init_optimizer(optimizer, device)
[docs] def init_scheduler(self, scheduler_callbacks_list, ignore_saved=False, ignore_missing_saved=False):
"""Initialize the list of schedulers based on saved model/optimizer/scheduler checkpoint
Args:
scheduler_callbacks_list (list): list of scheduler (callbacks)
ignore_saved (bool): if exception should be raised in the case there are found scheduler snapshots
in the checkpoint, but not schedulers are provided to this method
ignore_missing_saved (bool): if exception should be raised in the case schedulers are provided to
this method but no saved scheduler snapshots can be found in the checkpoint
Returns:
list: list of initialized scheduler (callbacks)
"""
return self.local_model_loader.init_scheduler(scheduler_callbacks_list, ignore_saved, ignore_missing_saved)
[docs] def init_amp(self, amp_scaler):
"""Initialize AMP GradScaler
Args:
amp_scaler (torch.cuda.amp.GradScaler): AMP GradScaler
Returns:
torch.cuda.amp.GradScaler: initialized AMP GradScaler
"""
return self.local_model_loader.init_amp(amp_scaler)