batch_model_feed_defs

class aitoolbox.torchtrain.data.batch_model_feed_defs.AbstractModelFeedDefinition[source]

Bases: ABC

Model Feed Definition

Note

The primary way of defining the model for TrainLoop training is to utilize: aitoolbox.torchtrain.model.TTModel

Use of the AbstractModelFeedDefinition is the legacy way of defining the model. However, in certain scenarios where the TTModel might prove to increase complexity, ModelFeedDefinition still is useful for augmenting the torch.nn.Module with the logic to calculate loss and predictions.

abstract get_loss(model, batch_data, criterion, device)[source]

Get loss during training stage

Called from fit() in TrainLoop

Executed during training stage where model weights are updated based on the loss returned from this function.

Parameters:
  • model (torch.nn.Module) – neural network model

  • batch_data (torch.Tensor) – model input data batch

  • criterion – loss criterion

  • device (torch.device) – device on which the model is being trained

Returns:

PyTorch loss

get_loss_eval(model, batch_data, criterion, device)[source]

Get loss during evaluation stage

Called from evaluate_model_loss() in TrainLoop.

The difference compared with get_loss() is that here the backprop weight update is not done. This function is executed in the evaluation stage not training.

For simple examples this function can just call the get_loss() and return its result.

Parameters:
  • model (torch.nn.Module) – neural network model

  • batch_data (torch.Tensor) – model input data batch

  • criterion – loss criterion

  • device (torch.device) – device on which the model is being trained

Returns:

PyTorch loss

abstract get_predictions(model, batch_data, device)[source]

Get predictions during evaluation stage

Parameters:
Returns:

y_pred, y_test, metadata

Return type:

(torch.Tensor, torch.Tensor, dict or None)