pred_collate_fns

aitoolbox.torchtrain.train_loop.components.pred_collate_fns.append_predictions(y_batch, predictions)[source]
Parameters:
  • y_batch (torch.Tensor) – predictions for the new batch

  • predictions (list) – accumulation list where all the batched predictions are appended

Returns:

predictions list with the new tensor appended

Return type:

list

aitoolbox.torchtrain.train_loop.components.pred_collate_fns.append_concat_predictions(y_batch, predictions)[source]
Parameters:
  • y_batch (torch.Tensor or list) –

  • predictions (list) – accumulation list where all the batched predictions are added

Returns:

predictions list with the new tensor appended

Return type:

list

aitoolbox.torchtrain.train_loop.components.pred_collate_fns.torch_cat_transf(predictions)[source]

PyTorch concatenation of the given list of tensors

Parameters:

predictions (list) – expects a list of torch.Tensor

Returns:

concatenated tensor made up of provided smaller tensors

Return type:

torch.Tensor

aitoolbox.torchtrain.train_loop.components.pred_collate_fns.keep_list_transf(predictions)[source]

Identity transformation of the predictions keeping them as they were

Parameters:

predictions (list) – list of predictions

Returns:

returns unaltered list of predictions

Return type:

list