parallel

class aitoolbox.torchtrain.parallel.TTParallelBase(module, default_model_methods=('get_loss', 'get_loss_eval', 'get_predictions'))[source]

Bases: object

torchtrain parallel base class used for transferring TTModel functions to the PyTorch Parallel wrappers level

Parameters:
get_loss(batch_data, criterion, device)[source]
get_loss_eval(batch_data, criterion, device)[source]
get_predictions(batch_data, device)[source]
class aitoolbox.torchtrain.parallel.TTDataParallel(module, default_model_methods=('get_loss', 'get_loss_eval', 'get_predictions'), **kwargs)[source]

Bases: DataParallel, TTParallelBase

torchtrain-enabled DataParallel

This DataParallel wrapper works in the same way as the original PyTorch torch.nn.DataParallel. Furthermore, it exposes TTModel batch data feeding definitions (additional abstract methods) to the TrainLoop while still enabling multi GPU training.

Parameters:
  • module (aitoolbox.torchtrain.model.TTModel) – neural network model

  • default_model_methods (list or tuple) – list of core methods which are present also in TTModel abstract class

  • **kwargs – additional parameters for underlying nn.DataParallel

training: bool
class aitoolbox.torchtrain.parallel.TTDistributedDataParallel(module, default_model_methods=('get_loss', 'get_loss_eval', 'get_predictions'), **kwargs)[source]

Bases: TTParallelBase, DistributedDataParallel

torchtrain-enabled DistributedDataParallel

This DistributedDataParallel wrapper works in the same way as the original PyTorch torch.nn.parallel.DistributedDataParallel. Furthermore, it exposes TTModel batch data feeding definitions (additional abstract methods) to the TrainLoop while still enabling multi GPU training.

Parameters:
  • module (aitoolbox.torchtrain.model.TTModel) – neural network model

  • default_model_methods (list or tuple) – list of core methods which are present also in TTModel abstract class

  • **kwargs – additional parameters for underlying nn.parallel.DistributedDataParallel

training: bool