ddp_handler

class aitoolbox.torchtrain.train_loop.components.ddp_handler.DDPHandler(train_loop_obj)[source]

Bases: object

Distributed Data Parallel process handler for the TrainLoop

Parameters:

train_loop_obj (aitoolbox.torchtrain.train_loop.train_loop.TrainLoop) – reference to the encapsulating TrainLoop

add_distributed_samplers(world_size, rank)[source]

Add Distributed Samplers needed for DDP to the normal single process DataLoader provided to the TrainLoop

Parameters:
  • world_size (int) – world size of for the distributed training

  • rank (int) – rank of the current process

static build_loader_sampler(data_loader, shuffle, world_size, rank)[source]

Replicate given data loader with added distributed sampler

Parameters:
  • data_loader (torch.utils.data.DataLoader) – original single process data loader without the distributed sampler

  • shuffle (bool) – should the added sampler be returning examples in the shuffled order

  • world_size (int) – world size of for the distributed training

  • rank (int) – rank of the current process

Returns:

new data loader with the sampler, reference to the distributed sampler

included in the new data loader

Return type:

DataLoader, DistributedSampler

mp_sync(data, double_precision=False, concat_mp_data=True, return_tensor=True)[source]

Multiprocess data sync

Share input data between all the active processes so that every process has all the values from all the processes. This way we can achieve the same state of the data across all the parallel processes.

Parameters:
  • data (torch.Tensor or numpy.ndarray or list or float or int or bool) – data to be synchronized between processes. In case this is torch.Tensor, resulting output the device location will be preserved.

  • double_precision (bool) – in case the data parameter is not already a Tensor, the function wraps given data into a Tensor. By default, it uses PyTorch default 32 bit precision float tensor. If this parameter is set to True however, the double precision 64 bit tensor will be created. This is useful for example if input data is in 64 bit, and we want to prevent precision reduction when syncing the data across the workers.

  • concat_mp_data (bool) – should the returned list of collected tensors be concatenated into a single list of values

  • return_tensor (bool) – should the synced data be returned as a tensor or should it be converted back to the same data type as type of the input data

Returns:

data variable values synced across all the active processes

Return type:

torch.Tensor or numpy.ndarray or list

mp_sync_dict(dict_data)[source]

Multiprocess sync of a dict

Convenience wrapper around the mp_sync() for the specific case of dict syncing.

Parameters:

dict_data (dict) – dict to be synchronized across the processes

Returns:

synchronized dict of tensors with combined values gathered from all the active processes

Return type:

dict