Multi-GPU Training

All TrainLoop versions in addition to single GPU also support multi-GPU training to achieve even faster training. Following the core PyTorch setup, two multi-GPU training approaches are available:


To use DataParallel-like multiGPU training with TrainLoop just switch the TrainLoop’s gpu_mode parameter to 'dp':

from aitoolbox.torchtrain.train_loop import *
from aitoolbox.torchtrain.parallel import TTDataParallel

model = CNNModel()  # TTModel based neural model

train_loader = DataLoader(...)
val_loader = DataLoader(...)
test_loader = DataLoader(...)

optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
criterion = nn.NLLLoss()

tl = TrainLoop(model,
               train_loader, val_loader, test_loader,
               optimizer, criterion,

model =

Check out a full DataParallel training example.


Distributed training on multiple GPUs via DistributedDataParallel is enabled by the TrainLoop itself under the hood by wrapping the TTModel-based model into TTDistributedDataParallel. TrainLoop also automatically spawns multiple processes and initializes them. Inside each spawned process the model and all other necessary training components are moved to the correct GPU belonging to a specific process. Lastly, TrainLoop also automatically adds the PyTorch DistributedSampler to each of the provided data loaders in order to ensure different data batches go to different GPUs and there is no overlap.

To enable distributed training via DistributedDataParallel, the user has to set the TrainLoop’s gpu_mode parameter to 'ddp'.

from aitoolbox.torchtrain.train_loop import *

model = CNNModel()  # TTModel based neural model

train_loader = DataLoader(...)
val_loader = DataLoader(...)
test_loader = DataLoader(...)

optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
criterion = nn.NLLLoss()

tl = TrainLoop(
    train_loader, val_loader, test_loader,
    optimizer, criterion,

model =,
               num_nodes=1, node_rank=0, num_gpus=torch.cuda.device_count())

Check out a full DistributedDataParallel training example.