parallel

class aitoolbox.torchtrain.parallel.TTParallelBase(module, default_model_methods=('get_loss', 'get_loss_eval', 'get_predictions'))[source]

Bases: object

torchtrain parallel base class used for transferring TTModel functions to the PyTorch Parallel wrappers level

Parameters
  • module (aitoolbox.torchtrain.model.TTModel) – neural network model

  • default_model_methods (list or tuple) – list of core methods which are present also in TTModel abstract class

get_loss(batch_data, criterion, device)[source]
get_loss_eval(batch_data, criterion, device)[source]
get_predictions(batch_data, device)[source]
class aitoolbox.torchtrain.parallel.TTDataParallel(module, default_model_methods=('get_loss', 'get_loss_eval', 'get_predictions'), **kwargs)[source]

Bases: torch.nn.parallel.data_parallel.DataParallel, aitoolbox.torchtrain.parallel.TTParallelBase

torchtrain enabled DataParallel

This DataParallel wrapper works in the same way as the original PyTorch nn.DataParallel. Furthermore it exposes TTModel batch data feeding definitions (additional abstract methods) to the TrainLoop while still enabling multi GPU training.

Parameters
  • module (aitoolbox.torchtrain.model.TTModel) – neural network model

  • default_model_methods (list or tuple) – list of core methods which are present also in TTModel abstract class

  • **kwargs – additional parameters for underlying nn.DataParallel

training: bool
class aitoolbox.torchtrain.parallel.TTDistributedDataParallel(module, default_model_methods=('get_loss', 'get_loss_eval', 'get_predictions'), **kwargs)[source]

Bases: torch.nn.parallel.distributed.DistributedDataParallel, aitoolbox.torchtrain.parallel.TTParallelBase

torchtrain enabled DistributedDataParallel

Parameters
  • module (aitoolbox.torchtrain.model.TTModel) – neural network model

  • default_model_methods (list or tuple) – list of core methods which are present also in TTModel abstract class

  • **kwargs – additional parameters for underlying nn.parallel.DistributedDataParallel

training: bool