batch_model_feed_defs

class aitoolbox.torchtrain.data.batch_model_feed_defs.AbstractModelFeedDefinition[source]

Bases: abc.ABC

Model Feed Definition

The primary way of defining the model for TrainLoop training is to utilize:

aitoolbox.torchtrain.model.TTModel

Use of the ModelFeedDefinition is the legacy way of defining the model. However in certain scenarios where the TTModel might prove to increase complexity, ModelFeedDefinition still is useful for augmenting the nn.Module with the logic to calculate loss and predictions.

abstract get_loss(model, batch_data, criterion, device)[source]

Get loss during training stage

Called from fit() in TrainLoop

Executed during training stage where model weights are updated based on the loss returned from this function.

Parameters
  • model (nn.Module) – neural network model

  • batch_data – model input data batch

  • criterion – loss criterion

  • device – device on which the model is being trained

Returns

PyTorch loss

get_loss_eval(model, batch_data, criterion, device)[source]

Get loss during evaluation stage

Called from evaluate_model_loss() in TrainLoop.

The difference compared with get_loss() is that here the backprop weight update is not done. This function is executed in the evaluation stage not training.

For simple examples this function can just call the get_loss() and return its result.

Parameters
  • model (nn.Module) – neural network model

  • batch_data – model input data batch

  • criterion – loss criterion

  • device – device on which the model is being trained

Returns

PyTorch loss

abstract get_predictions(model, batch_data, device)[source]

Get predictions during evaluation stage

Parameters
  • model (nn.Module) – neural network model

  • batch_data – model input data batch

  • device – device on which the model is being trained

Returns

y_pred.cpu(), y_test.cpu(), metadata

Return type

np.array, np.array, dict