basic
- class aitoolbox.torchtrain.schedulers.basic.AbstractScheduler[source]
Bases:
object
Scheduler (callback) base class
All the scheduler callbacks should in addition to
AbstractCallback
also inherit from this base class. This class serves to indicate to torchtrain components which used callbacks are schedulers and which are just normal callbacks which have nothing to do with learning rate scheduling.In addition to the above, this scheduler base class also implements the interface methods needed for saving and loading the scheduler state_dict’s when checkpointing and reloading the scheduler.
When implementing the actual scheduler callback make sure to assign the created learning rate scheduler to the
self.scheduler
class member.
- class aitoolbox.torchtrain.schedulers.basic.GeneralLRSchedulerCallback(scheduler_class, optimizer_idx=None, **kwargs)[source]
Bases:
AbstractScheduler
,AbstractCallback
Learning rate scheduler base class
- Parameters:
scheduler_class – PyTorch learning rate scheduler class
optimizer_idx (int or torch.optim.optimizer.Optimizer or None) – index or the actual object reference of the paired optimizer when using multiple optimizers
**kwargs – learning rate scheduler additional parameters
- register_train_loop_object(train_loop_obj)[source]
Modified register_train_loop_object method to support scheduler creation
- Parameters:
train_loop_obj (aitoolbox.torchtrain.train_loop.TrainLoop) – reference to the encapsulating TrainLoop
- Returns:
return the reference to the callback after it is registered
- Return type:
- class aitoolbox.torchtrain.schedulers.basic.ReduceLROnPlateauScheduler(main_multi_loss=None, **kwargs)[source]
Bases:
GeneralLRSchedulerCallback
Learning rate scheduler which reduces the rate if the loss performance stops improving
- Parameters:
main_multi_loss (str or None) – name of the main loss to follow in the scheduler when using MultiLoss setup. If
None
is provided, then the mean of all the MultiLosses is taken.**kwargs – learning rate scheduler additional parameters
- class aitoolbox.torchtrain.schedulers.basic.ReduceLROnPlateauMetricScheduler(metric_name, **kwargs)[source]
Bases:
GeneralLRSchedulerCallback
Learning rate scheduler which reduces the rate if the performance of the selected metric stops improving
Needs to be used in combination with ModelPerformanceEvaluation to calculate the metric and fill it in the TrainLoop history.
- Parameters:
metric_name (str) – monitored metric based on which the learning rate scheduler modifies the learning rate
**kwargs – learning rate scheduler additional parameters
- class aitoolbox.torchtrain.schedulers.basic.LambdaLRScheduler(lr_lambda, execute_epoch_end=True, execute_batch_end=False, **kwargs)[source]
Bases:
GeneralLRSchedulerCallback
Sets the learning rate of each parameter group to the initial lr times a given function
When last_epoch=-1, sets initial lr as lr.
- Parameters:
lr_lambda (callable or list) – A function or a list of functions which computes a multiplicative factor given an integer parameter epoch, or a list of such functions, one for each group in optimizer.param_groups.
execute_epoch_end (bool) – should scheduler step be executed at the end of the epoch
execute_batch_end (bool) – should scheduler step be executed at the end of each batch
**kwargs – learning rate scheduler additional parameters
- class aitoolbox.torchtrain.schedulers.basic.StepLRScheduler(step_size, **kwargs)[source]
Bases:
GeneralLRSchedulerCallback
Sets the learning rate of each parameter group to the initial lr decayed by gamma every step_size epochs
When last_epoch=-1, sets initial lr as lr.
- Parameters:
step_size (int) – period of learning rate decay
**kwargs – learning rate scheduler additional parameters
- class aitoolbox.torchtrain.schedulers.basic.MultiStepLRScheduler(milestones_list, **kwargs)[source]
Bases:
GeneralLRSchedulerCallback
- Set the learning rate of each parameter group to the initial lr decayed by gamma once the number of epoch
reaches one of the milestones.
When last_epoch=-1, sets initial lr as lr.
- Parameters:
milestones_list (list) – list of epoch indices. Must be increasing
**kwargs – learning rate scheduler additional parameters