Source code for aitoolbox.torchtrain.callbacks.model_load

from aitoolbox.torchtrain.callbacks.abstract import AbstractExperimentCallback
from aitoolbox.cloud.AWS.model_load import PyTorchS3ModelLoader
from aitoolbox.cloud.GoogleCloud.model_load import PyTorchGoogleStorageModelLoader
from aitoolbox.experiment.local_load.local_model_load import PyTorchLocalModelLoader
from aitoolbox.cloud import s3_available_options, gcs_available_options


[docs]class ModelLoadContinueTraining(AbstractExperimentCallback): def __init__(self, saved_experiment_timestamp, saved_model_dir='checkpoint_model', epoch_num=None, ignore_saved_schedulers=False, ignore_missing_saved_schedulers=False, used_data_parallel=False, custom_local_loader_class=None, project_name=None, experiment_name=None, local_model_result_folder_path=None, cloud_save_mode=None, bucket_name=None, cloud_dir_prefix=None, **kwargs): """(Down)load previously trained and saved model and continue training from this snapshot instead from beginning Args: saved_experiment_timestamp (str): timestamp of the saved model experiment saved_model_dir (str): folder where saved model file is inside main experiment folder epoch_num (int or None): if loading checkpoint model instead of final model this parameter indicates from which epoch of training the model will be loaded ignore_saved_schedulers (bool): if exception should be raised in the case there are found scheduler snapshots in the checkpoint, but not schedulers are provided to this method ignore_missing_saved_schedulers (bool): if exception should be raised in the case schedulers are provided to this method but no saved scheduler snapshots can be found in the checkpoint used_data_parallel (bool): if the saved model was nn.DataParallel or normal model custom_local_loader_class (AbstractLocalModelLoader class or None): provide a custom local PyTorch model loader definition in case the default one is not suitable for particular use case. For example, in the case of complex custom optimizer initialization. project_name (str or None): root name of the project experiment_name (str or None): name of the particular experiment local_model_result_folder_path (str or None): root local path where project folder will be created cloud_save_mode (str or None): Storage destination selector. For AWS S3: 's3' / 'aws_s3' / 'aws' For Google Cloud Storage: 'gcs' / 'google_storage' / 'google storage' Everything else results just in local storage to disk bucket_name (str): name of the bucket in the cloud storage cloud_dir_prefix (str): path to the folder inside the bucket where the experiments are going to be saved **kwargs: additional parameters for the local model loader load_model() function """ AbstractExperimentCallback.__init__(self, 'Model loading and initialization from checkpoint before training', project_name, experiment_name, local_model_result_folder_path, cloud_save_mode, bucket_name, cloud_dir_prefix, execution_order=-10) self.saved_experiment_timestamp = saved_experiment_timestamp self.saved_model_dir = saved_model_dir self.epoch_num = epoch_num self.ignore_saved_schedulers = ignore_saved_schedulers self.ignore_missing_saved_schedulers = ignore_missing_saved_schedulers self.used_data_parallel = used_data_parallel self.custom_local_loader_class = custom_local_loader_class self.local_loader_kwargs = kwargs self.model_loader = None
[docs] def on_train_loop_registration(self): self.try_infer_experiment_details(infer_cloud_details=True) self.init_model_loader() model_representation = self.model_loader.load_model(self.project_name, self.experiment_name, self.saved_experiment_timestamp, self.saved_model_dir, self.epoch_num, **self.local_loader_kwargs) self.train_loop_obj.model = self.model_loader.init_model(self.train_loop_obj.model, self.used_data_parallel) self.train_loop_obj.optimizer = self.model_loader.init_optimizer(self.train_loop_obj.optimizer) if self.train_loop_obj.use_amp: self.train_loop_obj.amp_scaler = self.model_loader.init_amp(self.train_loop_obj.amp_scaler) self.train_loop_obj.epoch = model_representation['epoch'] + 1
[docs] def on_train_begin(self): # Not doing in on_train_loop_registration() in order to ensure # schedulers are initialised inside the scheduler callbacks schedulers = self.train_loop_obj.get_schedulers() self.model_loader.init_scheduler( schedulers, ignore_saved=self.ignore_saved_schedulers, ignore_missing_saved=self.ignore_missing_saved_schedulers )
[docs] def init_model_loader(self): """Initialize model loader object based on provided arguments to the callback object Returns: None """ if self.cloud_save_mode in s3_available_options: self.model_loader = PyTorchS3ModelLoader(self.local_model_result_folder_path, self.bucket_name, self.cloud_dir_prefix) if self.custom_local_loader_class is not None: self.model_loader.local_model_loader = self.custom_local_loader_class(self.local_model_result_folder_path) elif self.cloud_save_mode in gcs_available_options: self.model_loader = PyTorchGoogleStorageModelLoader(self.local_model_result_folder_path, self.bucket_name, self.cloud_dir_prefix) if self.custom_local_loader_class is not None: self.model_loader.local_model_loader = self.custom_local_loader_class(self.local_model_result_folder_path) else: if self.custom_local_loader_class is None: self.model_loader = PyTorchLocalModelLoader(self.local_model_result_folder_path) else: self.model_loader = self.custom_local_loader_class(self.local_model_result_folder_path)