Source code for aitoolbox.torchtrain.schedulers.basic

import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau, LambdaLR, StepLR, MultiStepLR

from aitoolbox.torchtrain.callbacks.abstract import AbstractCallback
from aitoolbox.torchtrain.multi_loss_optim import MultiLoss


[docs]class AbstractScheduler: def __init__(self): """Scheduler (callback) base class All the scheduler callbacks should in addition to :class:`~aitoolbox.torchtrain.callbacks.abstract.AbstractCallback` also inherit from this base class. This class serves to indicate to torchtrain components which used callbacks are schedulers and which are just normal callbacks which have nothing to do with learning rate scheduling. In addition to the above, this scheduler base class also implements the interface methods needed for saving and loading the scheduler state_dict's when checkpointing and reloading the scheduler. When implementing the actual scheduler callback make sure to assign the created learning rate scheduler to the ``self.scheduler`` class member. """ self.scheduler = None
[docs] def state_dict(self): return self.scheduler.state_dict()
[docs] def load_state_dict(self, state_dict): self.scheduler.load_state_dict(state_dict)
[docs]class GeneralLRSchedulerCallback(AbstractScheduler, AbstractCallback): def __init__(self, scheduler_class, optimizer_idx=None, **kwargs): """Learning rate scheduler base class Args: scheduler_class: PyTorch learning rate scheduler class optimizer_idx (int or torch.optim.optimizer.Optimizer or None): index or the actual object reference of the paired optimizer when using multiple optimizers **kwargs: learning rate scheduler additional parameters """ AbstractScheduler.__init__(self) AbstractCallback.__init__(self, 'General learn rate scheduler') self.scheduler_args = kwargs self.scheduler_class = scheduler_class self.optimizer_idx = optimizer_idx
[docs] def register_train_loop_object(self, train_loop_obj): """Modified register_train_loop_object method to support scheduler creation Args: train_loop_obj (aitoolbox.torchtrain.train_loop.TrainLoop): reference to the encapsulating TrainLoop Returns: AbstractCallback: return the reference to the callback after it is registered """ self.train_loop_obj = train_loop_obj self.message_service = train_loop_obj.message_service if self.optimizer_idx is None: optimizer = self.train_loop_obj.optimizer elif type(self.optimizer_idx) == int: optimizer = self.train_loop_obj.optimizer[self.optimizer_idx] else: optimizer = self.optimizer_idx self.scheduler = self.scheduler_class(optimizer, **self.scheduler_args) self.on_train_loop_registration() return self
[docs] def on_epoch_end(self): self.scheduler.step()
[docs]class ReduceLROnPlateauScheduler(GeneralLRSchedulerCallback): def __init__(self, main_multi_loss=None, **kwargs): """Learning rate scheduler which reduces the rate if the loss performance stops improving Args: main_multi_loss (str or None): name of the main loss to follow in the scheduler when using MultiLoss setup. If ``None`` is provided, then the mean of all the MultiLosses is taken. **kwargs: learning rate scheduler additional parameters """ GeneralLRSchedulerCallback.__init__(self, ReduceLROnPlateau, **kwargs) self.callback_name = 'Reduce learn rate if the model hits the plateau' self.main_multi_loss = main_multi_loss
[docs] def on_epoch_end(self): val_loss_avg = self.train_loop_obj.evaluate_loss_on_validation_set() if isinstance(val_loss_avg, MultiLoss): if self.main_multi_loss is None: val_loss_avg = torch.mean(torch.Tensor(list(val_loss_avg.values()))) else: val_loss_avg = val_loss_avg[self.main_multi_loss] self.scheduler.step(val_loss_avg)
[docs]class ReduceLROnPlateauMetricScheduler(GeneralLRSchedulerCallback): def __init__(self, metric_name, **kwargs): """Learning rate scheduler which reduces the rate if the performance of the selected metric stops improving Needs to be used in combination with ModelPerformanceEvaluation to calculate the metric and fill it in the TrainLoop history. Args: metric_name (str): monitored metric based on which the learning rate scheduler modifies the learning rate **kwargs: learning rate scheduler additional parameters """ GeneralLRSchedulerCallback.__init__(self, ReduceLROnPlateau, **kwargs) self.metric_name = metric_name self.callback_name = 'Reduce learn rate if the model hits the plateau based on metric in TrainLoop history'
[docs] def on_epoch_end(self): if self.metric_name not in self.train_loop_obj.train_history: raise ValueError( f'Metric {self.metric_name} expected for the report missing from TrainLoop.train_history. ' f'Found only the following: {self.train_loop_obj.train_history.keys()}') val_metric_result = self.train_loop_obj.train_history[self.metric_name][-1] self.scheduler.step(val_metric_result)
[docs]class LambdaLRScheduler(GeneralLRSchedulerCallback): def __init__(self, lr_lambda, execute_epoch_end=True, execute_batch_end=False, **kwargs): """Sets the learning rate of each parameter group to the initial lr times a given function When last_epoch=-1, sets initial lr as lr. Args: lr_lambda (callable or list): A function or a list of functions which computes a multiplicative factor given an integer parameter epoch, or a list of such functions, one for each group in optimizer.param_groups. execute_epoch_end (bool): should scheduler step be executed at the end of the epoch execute_batch_end (bool): should scheduler step be executed at the end of each batch **kwargs: learning rate scheduler additional parameters """ GeneralLRSchedulerCallback.__init__(self, LambdaLR, **dict(kwargs, lr_lambda=lr_lambda)) self.callback_name = '' self.execute_epoch_end = execute_epoch_end self.execute_batch_end = execute_batch_end
[docs] def on_epoch_end(self): if self.execute_epoch_end: self.scheduler.step()
[docs] def on_batch_end(self): if self.execute_batch_end: if self.train_loop_obj.should_execute_optimizer_update(): self.scheduler.step()
[docs]class StepLRScheduler(GeneralLRSchedulerCallback): def __init__(self, step_size, **kwargs): """Sets the learning rate of each parameter group to the initial lr decayed by gamma every step_size epochs When last_epoch=-1, sets initial lr as lr. Args: step_size (int): period of learning rate decay **kwargs: learning rate scheduler additional parameters """ GeneralLRSchedulerCallback.__init__(self, StepLR, **dict(kwargs, step_size=step_size)) self.callback_name = ''
[docs]class MultiStepLRScheduler(GeneralLRSchedulerCallback): def __init__(self, milestones_list, **kwargs): """Set the learning rate of each parameter group to the initial lr decayed by gamma once the number of epoch reaches one of the milestones. When last_epoch=-1, sets initial lr as lr. Args: milestones_list (list): list of epoch indices. Must be increasing **kwargs: learning rate scheduler additional parameters """ GeneralLRSchedulerCallback.__init__(self, MultiStepLR, **dict(kwargs, milestones=milestones_list)) self.callback_name = ''