TTModel
Torchtrain Model - TTModel for short
To take advantage of the TrainLoop abstraction the user has to define their model as a class which is a standard way
in core PyTorch as well. The only difference is that for TrainLoop supported training the model class has to be
inherited from the AIToolbox specific aitoolbox.torchtrain.model.TTModel
base class instead of
PyTorch torch.nn.Module
.
TTModel
itself inherits from the normally used nn.Module
class thus our models still retain all the expected
PyTorch enabled functionality. The reason for using the TTModel super class is that TrainLoop requires users to
implement two additional methods which describe how each batch of data is fed into the model when calculating the loss
in the training mode and when making the predictions in the evaluation mode.
In total the user has to implement the following three methods when building a new model inherited from TTModel
:
forward()
(inherited fromtorch.nn.Module.forward()
)
The code below shows the general skeleton all the TTModels have to follow to enable them to be trained with the TrainLoop:
from aitoolbox.torchtrain.model import TTModel
class MyNeuralModel(TTModel):
def __init__(self):
# model layers, etc.
def forward(self, x_data_batch):
# The same method as required in the base PyTorch nn.Module
...
# return prediction
def get_loss(self, batch_data, criterion, device):
# Get loss during training stage, called from fit() in TrainLoop
...
# return batch loss
def get_loss_eval(self, batch_data, criterion, device):
# Get loss during evaluation stage. Normally just calls get_loss()
return self.get_loss(batch_data, criterion, device)
def get_predictions(self, batch_data, device):
# Get predictions during evaluation stage
# + return any metadata potentially needed for evaluation
...
# return predictions, true_targets, metadata
For a full working example of the TTModel
based model definition, check out this
model example script.