train_loop

class aitoolbox.torchtrain.train_loop.train_loop.TrainLoop(model, train_loader, validation_loader, test_loader, optimizer, criterion, collate_batch_pred_fn=<function append_predictions>, pred_transform_fn=<function torch_cat_transf>, end_auto_eval=True, lazy_experiment_save=False, gpu_mode='single', cuda_device_idx=None, use_amp=False)[source]

Bases: object

Core PyTorch TrainLoop supporting the model training and target prediction

Implements core training procedures: batch feeding into the network as part of (multi)epoch train loop, calculation of the loss & gradients. Apart from training related functionality the TrainLoop also implements the logic needed for prediction of target variables.

Parameters
  • model (TTModel or ModelWrap or TTDataParallel) – neural network model

  • train_loader (torch.utils.data.DataLoader) – data loader for train data set

  • validation_loader (torch.utils.data.DataLoader or None) – data loader for validation data set

  • test_loader (torch.utils.data.DataLoader or None) – data loader for test data set

  • optimizer (torch.optim.optimizer.Optimizer or MultiOptimizer) – optimizer algorithm.

  • criterion (torch.nn.modules.loss._Loss or MultiLoss or None) – criterion during the training procedure

  • collate_batch_pred_fn (callable) – collate function transforming batch predictions as they come out from the model

  • pred_transform_fn (callable) – function transforming all the produced predictions after all the batches have been run through the model

  • end_auto_eval (bool or int) – used to optionally disable otherwise automatic end of epoch/training val/test loss calculations. This is useful when conducting very costly experiments to save on compute time. Specify either True/False boolean to always run or never run after each epoch or specify an int to execute only every specified number of epochs.

  • lazy_experiment_save (bool) – when in lazy mode experiment tracking components will create the experiment folder only after some training results are available (possibly at the end of the first epoch) instead of at the beginning of training.

  • gpu_mode (str) –

    GPU training mode selection. TrainLoop supports different GPU training modes by specifying one of the following:

    • 'single': single GPU training

    • 'dp': multi-GPU training via DataParallel

    • 'ddp': multi-GPU training via DistributedDataParallel

  • cuda_device_idx (int or None) – CUDA device index used when training on multiple GPUs

  • use_amp (bool or dict) –

    use 16-bit Automatic Mixed Precision (AMP)

    To switch to AMP mode either:

    • set this parameter to True to use default AMP torch.cuda.amp.GradScaler initialization params

    • provide custom AMP torch.cuda.amp.GradScaler initialization parameters as a dict as this parameter

fit(num_epochs=0, num_iterations=0, callbacks=None, grad_accumulation=1, **kwargs)[source]

Train the model using the train loop

This is the general API method which starts the model training. By calling this method and depending on the selected training mode provided as the TrainLoop’s gpu_mode parameter the training will start in one of the following training modes:

  • Basic (CPU or single GPU) mode

  • DataParallel mode

  • DistributedDataParallel mode

Parameters
  • num_epochs (int) – how many epochs the network will be trained

  • num_iterations (int) – how many iterations (batches) the network will be trained. This enables more granular specification of the training length than the num_epochs parameter.

  • callbacks (list or None) – callbacks that are executed during the training run

  • grad_accumulation (int) – number of batches the gradients are accumulated before updating weights

  • **kwargs

    additional parameters for training methods:

    • aitoolbox.torchtrain.train_loop.TrainLoop._train_dp()

    • aitoolbox.torchtrain.train_loop.TrainLoop._train_ddp()

    These training methods are called by the TrainLoop depending on the specified setting of the TrainLoop’s gpu_mode parameter.

Returns

trained model

Return type

TTModel or torch.nn.modules.Module or TTDataParallel

_train(num_epochs, num_iterations, callbacks=None, grad_accumulation=1)[source]

Train the model using the train loop

Parameters
  • num_epochs (int) – how many epochs the network will be trained

  • num_iterations (int) – how many iterations (batches) the network will be trained. This enables more granular specification of the training length than the num_epochs parameter.

  • callbacks (list or None) – callbacks that are executed during the training run

  • grad_accumulation (int) – number of batches the gradients are accumulated before updating weights

Returns

trained model

Return type

TTModel or torch.nn.modules.Module or TTDataParallel

_calculate_batch_loss(batch_data)[source]

Push batch data through the model and calculate the batch loss

Parameters

batch_data – input data batch

Returns

loss calculated on current batch

Return type

loss

_backward_pass(loss_batch, optimizer_idx)[source]

Execute backward pass from the current batch loss

Parameters
  • loss_batch – loss calculated on current batch

  • optimizer_idx (int) – index of the current optimizer. Mostly useful when using multiple optimizers. When only a single optimizer is used this parameter can be ignored.

Returns

None

_optimizer_step(optimizer_idx)[source]

Execute the optimizer step

Parameters

optimizer_idx (int) – index of the current optimizer. Mostly useful when using multiple optimizers. When only a single optimizer is used this parameter can be ignored.

Returns

None

_optimizer_zero_grad(optimizer_idx)[source]

Execute optimizer zero grad

Parameters

optimizer_idx (int) – index of the current optimizer. Mostly useful when using multiple optimizers. When only a single optimizer is used this parameter can be ignored.

Returns

None

auto_execute_end_of_epoch()[source]

Basic performance evaluation executed by default at the end of each epoch

Mainly evaluation of the loss functions which are always present as part of the training loop.

Returns

None

auto_execute_end_of_training()[source]

Basic performance evaluation executed by default at the end of the training process

Returns

None

parse_loss(loss_record)[source]

Helper function to process different possible loss formats

Primarily useful for parsing between single loss representation and the multi-loss representation.

Parameters

loss_record (list) – list losses from each processed batch

Returns

in the case of single loss numpy array is returned, otherwise the dict of multiple losses

is returned

Return type

np.array or dict

_print_save_loss(loss_parsed, loss_type_name, loss_print_description)[source]

Helper function which prints information about parsed loss and saves the loss results into the history

Parameters
  • loss_parsed (np.array or dict) – parsed loss result either as a single value or as a dict of multiple losses

  • loss_type_name (str) – type of the provided loss result

  • loss_print_description (str) – presentation description text of the provided loss result

Returns

None

evaluate_loss_on_train_set(force_prediction=False)[source]

Run train dataset through the network without updating the weights and return the loss

Parameters

force_prediction (bool) – recompute the loss even if it is available in the prediction cache. This causes the old cached value to be overwritten.

Returns

loss, in the case of multi loss, the dict gets returned

Return type

float or dict

evaluate_loss_on_validation_set(force_prediction=False)[source]

Run validation dataset through the network without updating the weights and return the loss

Parameters

force_prediction (bool) – recompute the loss even if it is available in the prediction cache. This causes the old cached value to be overwritten.

Returns

loss, in the case of multi loss, the dict gets returned

Return type

float or dict

evaluate_loss_on_test_set(force_prediction=False)[source]

Run test dataset through the network without updating the weights and return the loss

Parameters

force_prediction (bool) – recompute the loss even if it is available in the prediction cache. This causes the old cached value to be overwritten.

Returns

loss, in the case of multi loss, the dict gets returned

Return type

float or dict

evaluate_model_loss(data_loader)[source]

Run given dataset through the network without updating the weights and return the loss

Parameters

data_loader (torch.utils.data.DataLoader) – dataloader containing the data on which the loss is calculated

Returns

loss, in the case of multi loss, the dict gets returned

Return type

float or dict

predict_on_train_set(force_prediction=False)[source]

Run train dataset through the network and return true target values, target predictions and metadata

Parameters

force_prediction (bool) – recompute the output prediction even if it is available in the prediction cache. This causes the old cached predictions to be overwritten.

Returns

y_pred, y_true, metadata

Return type

(torch.Tensor, torch.Tensor, dict)

predict_on_validation_set(force_prediction=False)[source]

Run validation dataset through the network and return true target values, target predictions and metadata

Parameters

force_prediction (bool) – recompute the output prediction even if it is available in the prediction cache. This causes the old cached predictions to be overwritten.

Returns

y_pred, y_true, metadata

Return type

(torch.Tensor, torch.Tensor, dict)

predict_on_test_set(force_prediction=False)[source]

Run test dataset through the network and return true target values, target predictions and metadata

Parameters

force_prediction (bool) – recompute the output prediction even if it is available in the prediction cache. This causes the old cached predictions to be overwritten.

Returns

y_pred, y_true, metadata

Return type

(torch.Tensor, torch.Tensor, dict)

predict_with_model(data_loader)[source]

Run given dataset through the network and return true target values, target predictions and metadata

Parameters

data_loader (torch.utils.data.DataLoader) – dataloader containing the data on which the output predictions are calculated

Returns

y_pred, y_true, metadata

Return type

(torch.Tensor, torch.Tensor, dict)

insert_metric_result_into_history(metric_name, metric_result)[source]

Insert a metric result into the train history

This is the main and preferred API function for metric insertion as part of the train loop.

Parameters
  • metric_name (str) – name of the metric to be inserted

  • metric_result (float or dict) – new result for the corresponding metric

get_schedulers()[source]

Get the registered schedulers

Schedulers in TrainLoop training are implemented as callbacks under the hood.

Returns

list of scheduler (callbacks)

Return type

list

get_num_training_steps()[source]

Get the number of actual training steps

Useful in case of gradient accumulation to learn the number of steps where the gradient is actually updated in between the accumulation steps.

Returns

number of training steps / iterations

Return type

int

_train_dp(num_epochs, num_iterations, callbacks=None, grad_accumulation=1, dp_model_args=None)[source]

Train the model on multi-GPU with DataParallel auto wrapping

Parameters
  • num_epochs (int) – how many epochs the network will be trained

  • num_iterations (int) – how many iterations (batches) the network will be trained. This enables more granular specification of the training length than the num_epochs parameter.

  • callbacks (list or None) – callbacks that are executed during the training run

  • grad_accumulation (int) – number of batches the gradients are accumulated before updating weights

  • dp_model_args (dict or None) – parameters for aitoolbox.torchtrain.parallel.TTDataParallel / nn.DataParallel DP model wrap.

Returns

trained model

Return type

TTDataParallel or nn.DataParallel

_train_ddp(num_epochs, num_iterations, callbacks=None, grad_accumulation=1, ddp_model_args=None, in_process_data_load=None, num_nodes=1, node_rank=0, num_gpus=0)[source]

Train the model using the train loop in the Distributed Data Parallel setting

During the training, multiple processes will be spawned, one for each of the available GPUs.

Parameters
  • num_epochs (int) – how many epochs the network will be trained

  • num_iterations (int) – how many iterations (batches) the network will be trained. This enables more granular specification of the training length than the num_epochs parameter.

  • callbacks (list or None) – callbacks that are executed during the training run

  • grad_accumulation (int) – number of batches the gradients are accumulated before updating weights

  • ddp_model_args (dict or None) –

    parameters for DistributedDataParallel model Available parameters for DistributedDataParallel:

  • in_process_data_load (AbstractCallback or list or None) – in-process data loading logic implemented as a torchtrain callback. The logic should be placed inside the on_multiprocess_start() callback function. When using this data loading option bare in mind that loaded dataset will be replicated in memory for every spawned training process. This can in turn in cause extensive overall memory consumption.

  • num_nodes (int) – number of nodes in the cluster

  • node_rank (int) – rank of the current node

  • num_gpus (int) – number of GPUs in the node

_spawn_fit(gpu, ddp_args, num_epochs, num_iterations, callbacks, grad_accumulation, in_process_data_load)[source]

Helper function that prepares the TrainLoop state inside each of the spawned processes and initiates training

Parameters
  • gpu (int) – provided by the mp.spawn(); index of the GPU allocated to the current process

  • ddp_args (dict) – parameters dict needed for the distributed training setup

  • num_epochs (int) – how many epochs the network will be trained

  • num_iterations (int) – how many iterations (batches) the network will be trained. This enables more granular specification of the training length than the num_epochs parameter.

  • callbacks (list or None) – callbacks that are executed during the training run

  • grad_accumulation (int) – number of batches the gradients are accumulated before updating weights

  • in_process_data_load (list or None) – in-process data loading logic implemented as a torchtrain callback. The logic should be placed inside the on_multiprocess_start() callback function. When using this data loading option bare in mind that loaded dataset will be replicated in memory for every spawned training process. This can in turn in cause extensive overall memory consumption.

__call__(num_epochs=0, num_iterations=0, callbacks=None, grad_accumulation=1, **kwargs)[source]

Train the model using the train loop

This is a convenience function which calls the main TrainLoop model training method fit().

Parameters
  • num_epochs (int) – how many epochs the network will be trained

  • num_iterations (int) – how many iterations (batches) the network will be trained. This enables more granular specification of the training length than the num_epochs parameter.

  • callbacks (list) – callbacks that are executed during the training run

  • grad_accumulation (int) – number of batches the gradients are accumulated before updating weights

  • **kwargs – additional parameters for _train_dp() and _train_ddp() methods.

Returns

trained model

Return type

TTModel or torch.nn.modules.Module or TTDataParallel