train_loop
- class aitoolbox.torchtrain.train_loop.train_loop.TrainLoop(model, train_loader, validation_loader, test_loader, optimizer, criterion, collate_batch_pred_fn=<function append_predictions>, pred_transform_fn=<function torch_cat_transf>, end_auto_eval=True, lazy_experiment_save=False, print_callbacks=False, gpu_mode='single', cuda_device_idx=None, use_amp=False)[source]
Bases:
object
Core PyTorch TrainLoop supporting the model training and target prediction
Implements core training procedures: batch feeding into the network as part of (multi)epoch train loop, calculation of the loss & gradients. Apart from training related functionality the TrainLoop also implements the logic needed for prediction of target variables.
- Parameters:
model (TTModel or ModelWrap or TTDataParallel) – neural network model
train_loader (torch.utils.data.DataLoader) – data loader for train data set
validation_loader (torch.utils.data.DataLoader or None) – data loader for validation data set
test_loader (torch.utils.data.DataLoader or None) – data loader for test data set
optimizer (torch.optim.Optimizer or MultiOptimizer) – optimizer algorithm.
criterion (torch.nn.Module or MultiLoss or None) – criterion during the training procedure
collate_batch_pred_fn (callable) – collate function transforming batch predictions as they come out from the model
pred_transform_fn (callable) – function transforming all the produced predictions after all the batches have been run through the model
end_auto_eval (bool or int) – used to optionally disable otherwise automatic end of epoch/training val/test loss calculations. This is useful when conducting very costly experiments to save on compute time. Specify either True/False boolean to always run or never run after each epoch or specify an int to execute only every specified number of epochs.
lazy_experiment_save (bool) – when in lazy mode experiment tracking components will create the experiment folder only after some training results are available (possibly at the end of the first epoch) instead of at the beginning of training.
print_callbacks (bool) – at the start of training print the list of registered callbacks which will be executed during the run of the train loop
gpu_mode (str) –
GPU training mode selection. TrainLoop supports different GPU training modes by specifying one of the following:
'single'
: single GPU training'dp'
: multi-GPU training via DataParallel'ddp'
: multi-GPU training via DistributedDataParallel
cuda_device_idx (int or None) – CUDA device index used when training on multiple GPUs
Use 16-bit Automatic Mixed Precision (AMP).
To switch to AMP mode either:
set this parameter to
True
to use default AMPGradScaler
initialization paramsprovide custom AMP
GradScaler
initialization parameters as a dict as this parameter
- fit(num_epochs=0, num_iterations=0, callbacks=None, grad_accumulation=1, **kwargs)[source]
Train the model using the train loop
This is the general API method which starts the model training. By calling this method and depending on the selected training mode provided as the TrainLoop’s
gpu_mode
parameter the training will start in one of the following training modes:Basic (CPU or single GPU) mode
DataParallel mode
DistributedDataParallel mode
- Parameters:
num_epochs (int) – how many epochs the network will be trained
num_iterations (int) – how many iterations (batches) the network will be trained. This enables more granular specification of the training length than the
num_epochs
parameter.callbacks (list or None) – callbacks that are executed during the training run
grad_accumulation (int) – number of batches the gradients are accumulated before updating weights
**kwargs –
additional parameters for training methods:
These training methods are called by the TrainLoop depending on the specified setting of the TrainLoop’s
gpu_mode
parameter.
- Returns:
trained model
- Return type:
- _train(num_epochs, num_iterations, callbacks=None, grad_accumulation=1)[source]
Train the model using the train loop
- Parameters:
num_epochs (int) – how many epochs the network will be trained
num_iterations (int) – how many iterations (batches) the network will be trained. This enables more granular specification of the training length than the
num_epochs
parameter.callbacks (list or None) – callbacks that are executed during the training run
grad_accumulation (int) – number of batches the gradients are accumulated before updating weights
- Returns:
trained model
- Return type:
- _calculate_batch_loss(batch_data)[source]
Push batch data through the model and calculate the batch loss
- Parameters:
batch_data (torch.Tensor) – input data batch
- Returns:
loss calculated on current batch
- Return type:
loss (torch.Tensor or MultiLoss)
- _backward_pass(loss_batch, optimizer_idx)[source]
Execute backward pass from the current batch loss
- Parameters:
loss_batch (torch.Tensor or MultiLoss) – loss calculated on current batch
optimizer_idx (int) – index of the current optimizer. Mostly useful when using multiple optimizers. When only a single optimizer is used this parameter can be ignored.
- Returns:
None
- _optimizer_step(optimizer_idx)[source]
Execute the optimizer step
- Parameters:
optimizer_idx (int) – index of the current optimizer. Mostly useful when using multiple optimizers. When only a single optimizer is used this parameter can be ignored.
- Returns:
None
- _optimizer_zero_grad(optimizer_idx)[source]
Execute optimizer zero grad
- Parameters:
optimizer_idx (int) – index of the current optimizer. Mostly useful when using multiple optimizers. When only a single optimizer is used this parameter can be ignored.
- Returns:
None
- should_execute_optimizer_update()[source]
Determine if optimizer update based on calculated gradients should be done at the current iteration
Combined with optimizer update we normally also execute zero_grad as well as different gradient clipping operations.
This method is especially important in the case when gradient accumulation is used in training. It provides knowledge when model parameter updates via the optimizer are made based on accumulated gradients.
Note
Switched from a simple condition to better a condition to also cover the final non-complete batch:
if (self.iteration + 1) % self.grad_accumulation == 0
if (self.iteration + 1) % self.grad_accumulation == 0 or self.iteration == len(self.train_loader) - 1
- Returns:
if in current iteration a model parameter update via the optimizer should be done
- Return type:
- auto_execute_end_of_epoch()[source]
Basic performance evaluation executed by default at the end of each epoch
Mainly evaluation of the loss functions which are always present as part of the training loop.
- Returns:
None
- auto_execute_end_of_training()[source]
Basic performance evaluation executed by default at the end of the training process
- Returns:
None
- parse_loss(loss_record)[source]
Helper function to process different possible loss formats
Primarily useful for parsing between single loss representation and the multi-loss representation.
- Parameters:
loss_record (list) –
list of Tensor losses from each processed batch.
If we used single loss than the
loss_record
is a list of Tensors where each element Tensor is loss for a single batch.If we used multiple losses wrapped inside MultiLoss(), these behave the same way as normal dicts as MultiLoss subclasses a dict and thus implements dict protocols. Consequently, loss_record can be thought as a list of (MultiLoss) dicts, where each dict represents a loss for a single batch:
[MultiLoss({'loss_1': Tensor(1.), 'loss_2': Tensor(33.)}), MultiLoss({ ... })]
- Returns:
in the case of single loss torch Tensor is returned, otherwise the dict of multiple losses is returned where each value is again a torch Tensor
Note
Important to note: all the returned loss Tensors are left on the original device (e.g. a GPU).
- Return type:
torch.DoubleTensor or MultiLoss
- _print_save_loss(loss_parsed, loss_type_name, loss_print_description)[source]
Helper function which prints information about parsed loss and saves the loss results into the history
- Parameters:
loss_parsed (torch.Tensor or MultiLoss) – parsed loss result either as a single value or as MultiLoss in case of multiple losses
loss_type_name (str) – type of the provided loss result
loss_print_description (str) – presentation description text of the provided loss result
- Returns:
None
- evaluate_loss_on_train_set(force_prediction=False, float_dict_format=False)[source]
Run train dataset through the network without updating the weights and return the loss
- Parameters:
force_prediction (bool) – recompute the loss even if it is available in the prediction cache. This causes the old cached value to be overwritten.
float_dict_format (bool) – if true, simplified loss representation is returned. In case of single loss, a float is returned, while in case of multi-loss a dict extracted from MultiLoss wrapper is returned. If false, the standard
torch.Tensor
orMultiLoss
get returned.
- Returns:
train set loss. Returned tensors are on the CPU. Depending on the set
float_dict_format
parameter either a standard or simplified loss representation is returned:torch.Tensor
/MultiLoss
vs.float
/dict
- Return type:
torch.Tensor or MultiLoss or float or dict
- evaluate_loss_on_validation_set(force_prediction=False, float_dict_format=False)[source]
Run validation dataset through the network without updating the weights and return the loss
- Parameters:
force_prediction (bool) – recompute the loss even if it is available in the prediction cache. This causes the old cached value to be overwritten.
float_dict_format (bool) – if true, simplified loss representation is returned. In case of single loss, a float is returned, while in case of multi-loss a dict extracted from MultiLoss wrapper is returned. If false, the standard
torch.Tensor
orMultiLoss
get returned.
- Returns:
validation set loss. Returned tensors are on the CPU. Depending on the set
float_dict_format
parameter either a standard or simplified loss representation is returned:torch.Tensor
/MultiLoss
vs.float
/dict
- Return type:
torch.Tensor or MultiLoss or float or dict
- evaluate_loss_on_test_set(force_prediction=False, float_dict_format=False)[source]
Run test dataset through the network without updating the weights and return the loss
- Parameters:
force_prediction (bool) – recompute the loss even if it is available in the prediction cache. This causes the old cached value to be overwritten.
float_dict_format (bool) – if true, simplified loss representation is returned. In case of single loss, a float is returned, while in case of multi-loss a dict extracted from MultiLoss wrapper is returned. If false, the standard
torch.Tensor
orMultiLoss
get returned.
- Returns:
test set loss. Returned tensors are on the CPU. Depending on the set
float_dict_format
parameter either a standard or simplified loss representation is returned:torch.Tensor
/MultiLoss
vs.float
/dict
- Return type:
torch.Tensor or MultiLoss or float or dict
- evaluate_model_loss(data_loader, move_to_cpu=False, dataset_info=None)[source]
Run given dataset through the network without updating the weights and return the loss
- Parameters:
data_loader (torch.utils.data.DataLoader) – dataloader containing the data on which the loss is calculated
move_to_cpu (bool) – should the loss result be moved to the CPU. Otherwise, the returned loss is kept on the original device (which can also be a GPU).
dataset_info (dict or None) – additional information describing the dataset inside the provided dataloader. One such dataset info is the dataset
type
("train"
,"validation"
, or"test"
) set byevaluate_loss_on_train_set()
,evaluate_loss_on_validation_set()
andevaluate_loss_on_test_set()
methods.
- Returns:
Calculated average loss over all the batches. In the case of multi loss, the MultiLoss wrapper gets returned.
Note
Important to note: by default the returned loss tensors are left on the same device as they are computed. Meaning, that the returned values can potentially still be on the GPU.
- Return type:
- predict_on_train_set(force_prediction=False, execute_callbacks=False)[source]
Run train dataset through the network and return true target values, target predictions and metadata
- Parameters:
force_prediction (bool) – recompute the output prediction even if it is available in the prediction cache. This causes the old cached predictions to be overwritten.
execute_callbacks (bool) – If true, prediction loop will execute provided callbacks after prediction for each batch has been made. Otherwise, callbacks at this position are ignored.
- Returns:
y_pred, y_true, metadata in the form of dict of lists/torch.Tensors/np.arrays
- Return type:
- predict_on_validation_set(force_prediction=False, execute_callbacks=False)[source]
Run validation dataset through the network and return true target values, target predictions and metadata
- Parameters:
force_prediction (bool) – recompute the output prediction even if it is available in the prediction cache. This causes the old cached predictions to be overwritten.
execute_callbacks (bool) – If true, prediction loop will execute provided callbacks after prediction for each batch has been made. Otherwise, callbacks at this position are ignored.
- Returns:
y_pred, y_true, metadata in the form of dict of lists/torch.Tensors/np.arrays
- Return type:
- predict_on_test_set(force_prediction=False, execute_callbacks=False)[source]
Run test dataset through the network and return true target values, target predictions and metadata
- Parameters:
force_prediction (bool) – recompute the output prediction even if it is available in the prediction cache. This causes the old cached predictions to be overwritten.
execute_callbacks (bool) – If true, prediction loop will execute provided callbacks after prediction for each batch has been made. Otherwise, callbacks at this position are ignored.
- Returns:
y_pred, y_true, metadata in the form of dict of lists/torch.Tensors/np.arrays
- Return type:
- predict_with_model(data_loader, execute_callbacks=False, move_to_cpu=False, dataset_info=None)[source]
Run given dataset through the network and return true target values, target predictions and metadata
- Parameters:
data_loader (torch.utils.data.DataLoader) – dataloader containing the data on which the output predictions are calculated
execute_callbacks (bool) – If true, prediction loop will execute provided callbacks after prediction for each batch has been made. Otherwise, callbacks at this position are ignored.
move_to_cpu (bool) – should the predicted returned results be moved to the CPU. Otherwise, the returned results are kept on the original device (which can also be a GPU).
dataset_info (dict or None) – additional information describing the dataset inside the provided dataloader. One such dataset info is the dataset
type
(train, validation, or test) set bypredict_on_train_set()
,predict_on_validation_set()
andpredict_on_test_set()
methods.
- Returns:
y_pred, y_true, metadata in the form of dict of lists/torch.Tensors/np.arrays
- Return type:
- insert_metric_result_into_history(metric_name, metric_result)[source]
Insert a metric result into the train history
This is the main and preferred API function for metric insertion as part of the train loop.
- get_schedulers()[source]
Get the registered schedulers
Schedulers in TrainLoop training are implemented as callbacks under the hood.
- Returns:
list of scheduler (callbacks)
- Return type:
- get_num_training_steps()[source]
Get the number of actual training steps
Useful in case of gradient accumulation to learn the number of steps where the gradient is actually updated in between the accumulation steps.
- Returns:
number of training steps / iterations
- Return type:
- is_main_process()[source]
Is current process the main training process
In case of single GPU/CPU we have single process so this function is always True. However, for DDP training main process is treated as that which is at rank 0.
- Returns:
if current process is the main training process. In case of DDP it is process at rank 0
- Return type:
- static convert_loss_to_float_dict_format(loss)[source]
Util method for converting loss records in Tensor/MultiLoss format into simpler float/dict format
- Parameters:
loss (torch.Tensor or MultiLoss) – more complex loss representation. In case of single loss it is torch Tensor. In case of multi-loss it is MultiLoss wrapper.
- Returns:
simplified loss representation. In case of single loss it is a single float value. In case of multi-loss it is a dict extracted out from the given MultiLoss wrapper.
- Return type:
- _train_dp(num_epochs, num_iterations, callbacks=None, grad_accumulation=1, dp_model_args=None)[source]
Train the model on multi-GPU with DataParallel auto wrapping
- Parameters:
num_epochs (int) – how many epochs the network will be trained
num_iterations (int) – how many iterations (batches) the network will be trained. This enables more granular specification of the training length than the
num_epochs
parameter.callbacks (list or None) – callbacks that are executed during the training run
grad_accumulation (int) – number of batches the gradients are accumulated before updating weights
dp_model_args (dict or None) – parameters for
aitoolbox.torchtrain.parallel.TTDataParallel
/torch.nn.DataParallel
DP model wrap.
- Returns:
trained model
- Return type:
- _train_ddp(num_epochs, num_iterations, callbacks=None, grad_accumulation=1, ddp_model_args=None, in_process_data_load=None, num_nodes=1, node_rank=0, num_gpus=0, backend='nccl', init_method='env://', on_gpu=True)[source]
Train the model using the train loop in the Distributed Data Parallel setting
During the training, multiple processes will be spawned, one for each of the available GPUs.
- Parameters:
num_epochs (int) – how many epochs the network will be trained
num_iterations (int) – how many iterations (batches) the network will be trained. This enables more granular specification of the training length than the
num_epochs
parameter.callbacks (list or None) – callbacks that are executed during the training run
grad_accumulation (int) – number of batches the gradients are accumulated before updating weights
ddp_model_args (dict or None) – parameters for underlying PyTorch
DistributedDataParallel
modelin_process_data_load (AbstractCallback or list or None) – in-process data loading logic implemented as a torchtrain callback. The logic should be placed inside the on_multiprocess_start() callback function. When using this data loading option bear in mind that loaded dataset will be replicated in memory for every spawned training process. This can in turn in cause extensive overall memory consumption.
num_nodes (int) – number of nodes in the cluster
node_rank (int) – rank of the current node
num_gpus (int) – number of GPUs in the node
backend (str) – The backend to use. For more information look up the documentation for
torch.distributed.init_process_group()
. Valid values includempi
,gloo
, andnccl
.init_method (str) – URL specifying how to initialize the process group. For more information look up the documentation for
torch.distributed.init_process_group()
.on_gpu (bool) – if the DDP training is executed on the GPU or on the CPU
- _spawn_fit(gpu, ddp_args, num_epochs, num_iterations, callbacks, grad_accumulation, in_process_data_load)[source]
Helper function that prepares the TrainLoop state inside each of the spawned processes and initiates training
- Parameters:
gpu (int) – provided by the mp.spawn(); index of the GPU allocated to the current process
ddp_args (dict) – parameters dict needed for the distributed training setup
num_epochs (int) – how many epochs the network will be trained
num_iterations (int) – how many iterations (batches) the network will be trained. This enables more granular specification of the training length than the
num_epochs
parameter.callbacks (list or None) – callbacks that are executed during the training run
grad_accumulation (int) – number of batches the gradients are accumulated before updating weights
in_process_data_load (list or None) – in-process data loading logic implemented as a torchtrain callback. The logic should be placed inside the
on_multiprocess_start()
callback function. When using this data loading option bear in mind that loaded dataset will be replicated in memory for every spawned training process. This can in turn in cause extensive overall memory consumption.
- __call__(num_epochs=0, num_iterations=0, callbacks=None, grad_accumulation=1, **kwargs)[source]
Train the model using the train loop
This is a convenience function which calls the main TrainLoop model training method fit().
- Parameters:
num_epochs (int) – how many epochs the network will be trained
num_iterations (int) – how many iterations (batches) the network will be trained. This enables more granular specification of the training length than the
num_epochs
parameter.callbacks (list) – callbacks that are executed during the training run
grad_accumulation (int) – number of batches the gradients are accumulated before updating weights
**kwargs – additional parameters for
_train_dp()
and_train_ddp()
methods.
- Returns:
trained model
- Return type: