parallel
- class aitoolbox.torchtrain.parallel.TTParallelBase(module, default_model_methods=('get_loss', 'get_loss_eval', 'get_predictions'))[source]
Bases:
object
torchtrain parallel base class used for transferring TTModel functions to the PyTorch Parallel wrappers level
- Parameters:
module (aitoolbox.torchtrain.model.TTModel) – neural network model
default_model_methods (list or tuple) – list of core methods which are present also in TTModel abstract class
- class aitoolbox.torchtrain.parallel.TTDataParallel(module, default_model_methods=('get_loss', 'get_loss_eval', 'get_predictions'), **kwargs)[source]
Bases:
DataParallel
,TTParallelBase
torchtrain-enabled DataParallel
This DataParallel wrapper works in the same way as the original PyTorch
torch.nn.DataParallel
. Furthermore, it exposesTTModel
batch data feeding definitions (additional abstract methods) to the TrainLoop while still enabling multi GPU training.- Parameters:
module (aitoolbox.torchtrain.model.TTModel) – neural network model
default_model_methods (list or tuple) – list of core methods which are present also in TTModel abstract class
**kwargs – additional parameters for underlying nn.DataParallel
- class aitoolbox.torchtrain.parallel.TTDistributedDataParallel(module, default_model_methods=('get_loss', 'get_loss_eval', 'get_predictions'), **kwargs)[source]
Bases:
TTParallelBase
,DistributedDataParallel
torchtrain-enabled DistributedDataParallel
This DistributedDataParallel wrapper works in the same way as the original PyTorch
torch.nn.parallel.DistributedDataParallel
. Furthermore, it exposesTTModel
batch data feeding definitions (additional abstract methods) to the TrainLoop while still enabling multi GPU training.- Parameters:
module (aitoolbox.torchtrain.model.TTModel) – neural network model
default_model_methods (list or tuple) – list of core methods which are present also in TTModel abstract class
**kwargs – additional parameters for underlying nn.parallel.DistributedDataParallel