ddp_handler

class aitoolbox.torchtrain.train_loop.components.ddp_handler.DDPHandler(train_loop_obj)[source]

Bases: object

Distributed Data Parallel process handler for the TrainLoop

Parameters

train_loop_obj (aitoolbox.torchtrain.train_loop.TrainLoop) – reference to the encapsulating TrainLoop

add_distributed_samplers(world_size, rank)[source]

Add Distributed Samplers needed for DDP to the normal single process DataLoader provided to the TrainLoop

Parameters
  • world_size (int) – world size of for the distributed training

  • rank (int) – rank of the current process

static build_loader_sampler(data_loader, shuffle, world_size, rank)[source]

Replicate given data loader with added distributed sampler

Parameters
  • data_loader (DataLoader) – original single process data loader without the distributed sampler

  • shuffle (bool) – should the added sampler be returning examples in the shuffled order

  • world_size (int) – world size of for the distributed training

  • rank (int) – rank of the current process

Returns

new data loader with the sampler, reference to the distributed sampler

included in the new data loader

Return type

DataLoader, DistributedSampler

mp_sync(data, concat_mp_data=True)[source]

Multiprocess data sync

Share input data between all the active processes so that every process has all the values from all the processes. This way we can achieve the same state of the data across all the parallel processes.

Parameters
  • data (torch.Tensor, list, float, int) – data to be synchronized between processes. In case this is torch.Tensor, resulting output the device location will be preserved.

  • concat_mp_data (bool) – should the returned list of collected tensors be concatenated into a single list of values

Returns

list of data variable values synced across all the active processes

Return type

torch.Tensor

mp_sync_dict_of_lists(dict_list_data)[source]

Multiprocess dict of lists sync

Convenience wrapper around the mp_sync() for the specific case of dict of lists syncing.

Parameters

dict_list_data (dict) – dict of lists to be synchronized across the processes

Returns

synchronized dict of lists with combined values gathered from all the active processes

Return type

dict