parallel
- class aitoolbox.torchtrain.parallel.TTParallelBase(module, default_model_methods=('get_loss', 'get_loss_eval', 'get_predictions'))[source]
Bases:
objecttorchtrain parallel base class used for transferring TTModel functions to the PyTorch Parallel wrappers level
- Parameters:
module (aitoolbox.torchtrain.model.TTModel) – neural network model
default_model_methods (list or tuple) – list of core methods which are present also in TTModel abstract class
- class aitoolbox.torchtrain.parallel.TTDataParallel(module, default_model_methods=('get_loss', 'get_loss_eval', 'get_predictions'), **kwargs)[source]
Bases:
DataParallel,TTParallelBasetorchtrain-enabled DataParallel
This DataParallel wrapper works in the same way as the original PyTorch
torch.nn.DataParallel. Furthermore, it exposesTTModelbatch data feeding definitions (additional abstract methods) to the TrainLoop while still enabling multi GPU training.- Parameters:
module (aitoolbox.torchtrain.model.TTModel) – neural network model
default_model_methods (list or tuple) – list of core methods which are present also in TTModel abstract class
**kwargs – additional parameters for underlying nn.DataParallel
- class aitoolbox.torchtrain.parallel.TTDistributedDataParallel(module, default_model_methods=('get_loss', 'get_loss_eval', 'get_predictions'), **kwargs)[source]
Bases:
TTParallelBase,DistributedDataParalleltorchtrain-enabled DistributedDataParallel
This DistributedDataParallel wrapper works in the same way as the original PyTorch
torch.nn.parallel.DistributedDataParallel. Furthermore, it exposesTTModelbatch data feeding definitions (additional abstract methods) to the TrainLoop while still enabling multi GPU training.- Parameters:
module (aitoolbox.torchtrain.model.TTModel) – neural network model
default_model_methods (list or tuple) – list of core methods which are present also in TTModel abstract class
**kwargs – additional parameters for underlying nn.parallel.DistributedDataParallel