batch_model_feed_defs
- class aitoolbox.torchtrain.data.batch_model_feed_defs.AbstractModelFeedDefinition[source]
Bases:
ABC
Model Feed Definition
Note
The primary way of defining the model for TrainLoop training is to utilize:
aitoolbox.torchtrain.model.TTModel
Use of the
AbstractModelFeedDefinition
is the legacy way of defining the model. However, in certain scenarios where theTTModel
might prove to increase complexity, ModelFeedDefinition still is useful for augmenting thetorch.nn.Module
with the logic to calculate loss and predictions.- abstract get_loss(model, batch_data, criterion, device)[source]
Get loss during training stage
Called from fit() in TrainLoop
Executed during training stage where model weights are updated based on the loss returned from this function.
- Parameters:
model (torch.nn.Module) – neural network model
batch_data (torch.Tensor) – model input data batch
criterion – loss criterion
device (torch.device) – device on which the model is being trained
- Returns:
PyTorch loss
- get_loss_eval(model, batch_data, criterion, device)[source]
Get loss during evaluation stage
Called from evaluate_model_loss() in TrainLoop.
The difference compared with get_loss() is that here the backprop weight update is not done. This function is executed in the evaluation stage not training.
For simple examples this function can just call the get_loss() and return its result.
- Parameters:
model (torch.nn.Module) – neural network model
batch_data (torch.Tensor) – model input data batch
criterion – loss criterion
device (torch.device) – device on which the model is being trained
- Returns:
PyTorch loss
- abstract get_predictions(model, batch_data, device)[source]
Get predictions during evaluation stage
- Parameters:
model (torch.nn.Module) – neural network model
batch_data (torch.Tensor) – model input data batch
device (torch.device) – device on which the model is being trained
- Returns:
y_pred, y_test, metadata
- Return type:
(torch.Tensor, torch.Tensor, dict or None)