model

class aitoolbox.torchtrain.model.TTModel[source]

Bases: Module, ABC

TTModel is an extension of core PyTorch nn.Module

TT in TTModel –> TorchTrain Model

In addition to the common forward() method required by the base torch.nn.Module, the user also needs to implement the additional AIToolbox specific get_loss() and get_predictions() methods. Optionally, the user can also implement a desired get_loss_eval() method for specific loss calculation when in evaluation mode.

transfer_model_attributes (list or tuple): additional TTModel attributes which need to be transferred to the TTDataParallel level to enable their use in the transferred/exposed class methods. When coding the model’s __init__() method user should also fill in the string names of attributes that should be transferred in case the model is wrapped for DP/DDP.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

abstract get_loss(batch_data, criterion, device)[source]

Get loss during training stage

Called from fit() in TrainLoop

Executed during training stage where model weights are updated based on the loss returned from this function.

Parameters:
Returns:

loss

Return type:

torch.Tensor or MultiLoss

get_loss_eval(batch_data, criterion, device)[source]

Get loss during evaluation stage

Called from evaluate_model_loss() in TrainLoop.

The difference compared with get_loss() is that here the backprop weight update is not done. This function is executed in the evaluation stage not training.

For simple examples this function can just call the get_loss() and return its result.

Parameters:
Returns:

loss

Return type:

torch.Tensor or MultiLoss

abstract get_predictions(batch_data, device)[source]

Get predictions during evaluation stage

Parameters:
Returns:

y_pred, y_test, metadata in the form of dict of lists/torch.Tensors/np.arrays

Return type:

(torch.Tensor, torch.Tensor, dict or None)

training: bool
class aitoolbox.torchtrain.model.TTBasicModel[source]

Bases: TTModel

Extension of the TTModel abstract class with already implemented simple loss and prediction calculation functions

The pre-implemented get_loss() and get_predictions() will take all the provided data sources from the data loader except the last one as an input to the model. The last data source from the data loader will be treated as the target variable. (*batch_input_data, targets = batch_data)

This base class is mainly meant to be used for simple models. TTBasicModel removes the need to constantly duplicate code in get_loss and get_predictions.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

get_loss(batch_data, criterion, device)[source]

Get loss during training stage

Called from fit() in TrainLoop

Executed during training stage where model weights are updated based on the loss returned from this function.

Parameters:
Returns:

loss

Return type:

torch.Tensor or MultiLoss

get_predictions(batch_data, device)[source]

Get predictions during evaluation stage

Parameters:
Returns:

y_pred, y_test, metadata in the form of dict of lists/torch.Tensors/np.arrays

Return type:

(torch.Tensor, torch.Tensor, dict or None)

training: bool
class aitoolbox.torchtrain.model.TTBasicMultiGPUModel[source]

Bases: TTBasicModel

Extension of the TTModel abstract class with already implemented simple loss and prediction calculation functions

which support leveled utilization when training on multi-GPU.

The pre-implemented get_loss() and get_predictions() will take all the provided data sources from the data loader except the last one as an input to the model. The last data source from the data loader will be treated as the target variable. (*batch_input_data, targets = batch_data)

In the case of the get_loss() the input into the model’s forward() function will also provide targets and criterion arguments in order to enable calculation of the loss inside forward() function.

The forward() function should have the following parameter signature and should finish with:

def forward(*batch_input_data, targets=None, criterion=None):
    ... predictions calculation via the computational graph ...

    if criterion is not None:
        return criterion(predictions, targets)
    else:
        return predictions

This base class is mainly meant to be used for simple models. TTBasicMultiGPUModel removes the need to constantly duplicate code in get_loss and get_predictions.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

get_loss(batch_data, criterion, device)[source]

Get loss during training stage

Called from fit() in TrainLoop

Executed during training stage where model weights are updated based on the loss returned from this function.

Parameters:
Returns:

loss

Return type:

torch.Tensor or MultiLoss

training: bool
class aitoolbox.torchtrain.model.MultiGPUModelWrap(model)[source]

Bases: TTBasicMultiGPUModel

Model wrapper optimizing the model for multi-GPU training by moving the loss calculation to the GPUs

Parameters:

model (torch.nn.Module or TTModel) – neural network model. The model should follow the basic PyTorch model definition where the forward() function returns predictions

forward(*input_data, targets=None, criterion=None)[source]

DP friendly forward abstraction on top of the wrapped model’s usual forward() function

Parameters:
  • *input_data – whatever input data should be passed into the wrapped model’s forward() function

  • targets – target variables which the model is training to fit

  • criterion – loss function

Returns:

PyTorch loss or model output predictions. If loss function criterion is provided this function returns the calculated loss, otherwise the model output predictions are returned

training: bool
class aitoolbox.torchtrain.model.ModelWrap(model, batch_model_feed_def)[source]

Bases: object

TrainLoop model wrapper combining PyTorch model and model feed definition

Note

Especially useful in the case when you want to train on multi-GPU where TTModel abstract functions can’t be used.

ModelWrap can be used as a replacement of TTModel when using the TrainLoop.

Parameters:
  • model (torch.nn.Module) – neural network model

  • batch_model_feed_def (AbstractModelFeedDefinition or None) – data prep definition for batched data. This definition prepares the data for each batch that gets than fed into the neural network.