Multi-Loss and Multi-Optimizer
TrainLoop supports training using multiple separate losses and/or multiple different optimizers at the same time.
The multi loss/optimizer functionality is achieved by wrapping multiple loss or
optimizer objects into the MultiLoss
and MultiOptimizer
wrappers respectively
provided in aitoolbox.torchtrain.multi_loss_optim
.
Multi-Loss Training
To implement training with multiple losses use aitoolbox.torchtrain.multi_loss_optim.MultiLoss
to wrap different calculated losses together and return them from model’s
get_loss()
function.
Train loop will then automatically know to correctly execute backprop through each of the losses.
Multiple losses need to be provided to the MultiLoss as a dict:
MultiLoss({'main_loss': main_loss, 'aux_loss': aux_loss})
In case of more elaborate backprop logic is needed one can override MultiLoss’
aitoolbox.torchtrain.multi_loss_optim.MultiLoss.backward()
method with the desired advanced logic.
Multi-Optimizer Training
To use multiple optimizers, for example each one optimizing a different part of the model, define multiple
optimizers each with access to different parameters of the model. These separate optimizers need to be provided
in a list to the aitoolbox.torchtrain.multi_loss_optim.MultiOptimizer
wrapper.
The MultiOptimizer
can subsequently be given to the TrainLoop the same way as the normal single optimizer.
MultiOptimizer
definition example:
MultiOptimizer([optimizer_1, optimizer_2])
When more advanced multi-optimizer training logic is required the user can override the
aitoolbox.torchtrain.multi_loss_optim.MultiOptimizer.step()
and/or the
aitoolbox.torchtrain.multi_loss_optim.MultiOptimizer.zero_grad()
methods as needed.
Lastly, when using the MultiOptimizer
the training state checkpoint saving is also automatically
handled by the train loop. As part of this the train loop automatically stores the state of
each of the optimizers wrapped inside of the MultiOptimizer
. The same functionality is provided
when loading the saved model.