Source code for aitoolbox.torchtrain.model

from abc import ABC, abstractmethod
import torch.nn as nn

from aitoolbox.torchtrain.data.batch_model_feed_defs import AbstractModelFeedDefinition


[docs]class TTModel(nn.Module, ABC): """ TTModel is an extension of core PyTorch nn.Module *TT in TTModel --> TorchTrain Model* In addition to the common :meth:`~torch.nn.Module.forward` method required by the base :class:`torch.nn.Module`, the user also needs to implement the additional AIToolbox specific :meth:`~aitoolbox.torchtrain.model.TTModel.get_loss` and :meth:`~aitoolbox.torchtrain.model.TTModel.get_predictions` methods. Optionally, the user can also implement a desired :meth:`~aitoolbox.torchtrain.model.TTModel.get_loss_eval` method for specific loss calculation when in evaluation mode. ``transfer_model_attributes`` (list or tuple): additional TTModel attributes which need to be transferred to the TTDataParallel level to enable their use in the transferred/exposed class methods. When coding the model's __init__() method user should also fill in the string names of attributes that should be transferred in case the model is wrapped for DP/DDP. """ def __init__(self): super().__init__() self.transfer_model_attributes = []
[docs] @abstractmethod def get_loss(self, batch_data, criterion, device): """Get loss during training stage Called from fit() in TrainLoop Executed during training stage where model weights are updated based on the loss returned from this function. Args: batch_data (torch.Tensor or list or tuple or dict): model input data batch criterion (torch.nn.Module): loss criterion device (torch.device): device on which the model is being trained Returns: torch.Tensor or :class:`~aitoolbox.torchtrain.multi_loss_optim.MultiLoss`: loss """ pass
[docs] def get_loss_eval(self, batch_data, criterion, device): """Get loss during evaluation stage Called from evaluate_model_loss() in TrainLoop. The difference compared with get_loss() is that here the backprop weight update is not done. This function is executed in the evaluation stage not training. For simple examples this function can just call the :meth:`~aitoolbox.torchtrain.model.TTModel.get_loss` and return its result. Args: batch_data (torch.Tensor or list or tuple or dict): model input data batch criterion (torch.nn.Module): loss criterion device (torch.device): device on which the model is being trained Returns: torch.Tensor or :class:`~aitoolbox.torchtrain.multi_loss_optim.MultiLoss`: loss """ return self.get_loss(batch_data, criterion, device)
[docs] @abstractmethod def get_predictions(self, batch_data, device): """Get predictions during evaluation stage Args: batch_data (torch.Tensor or list or tuple or dict): model input data batch device (torch.device): device on which the model is making the prediction Returns: (torch.Tensor, torch.Tensor, dict or None): y_pred, y_test, metadata in the form of dict of lists/torch.Tensors/np.arrays """ pass
[docs]class TTBasicModel(TTModel): """Extension of the TTModel abstract class with already implemented simple loss and prediction calculation functions The pre-implemented get_loss() and get_predictions() will take all the provided data sources from the data loader except the last one as an input to the model. The last data source from the data loader will be treated as the target variable. (*batch_input_data, targets = batch_data) This base class is mainly meant to be used for simple models. TTBasicModel removes the need to constantly duplicate code in get_loss and get_predictions. """
[docs] def get_loss(self, batch_data, criterion, device): *batch_input_data, targets = [data.to(device) for data in batch_data] predictions = self(*batch_input_data) loss = criterion(predictions, targets) return loss
[docs] def get_predictions(self, batch_data, device): *batch_input_data, targets = batch_data batch_input_data = [data.to(device) for data in batch_input_data] predictions = self(*batch_input_data) return predictions, targets, {}
[docs]class TTBasicMultiGPUModel(TTBasicModel): """Extension of the TTModel abstract class with already implemented simple loss and prediction calculation functions which support leveled utilization when training on multi-GPU. The pre-implemented get_loss() and get_predictions() will take all the provided data sources from the data loader except the last one as an input to the model. The last data source from the data loader will be treated as the target variable. (*batch_input_data, targets = batch_data) In the case of the :meth:`~aitoolbox.torchtrain.model.TTModel.get_loss` the input into the model's :meth:`~torch.nn.Module.forward` function will also provide `targets` and `criterion` arguments in order to enable calculation of the loss inside :meth:`~torch.nn.Module.forward` function. The forward() function should have the following parameter signature and should finish with:: def forward(*batch_input_data, targets=None, criterion=None): ... predictions calculation via the computational graph ... if criterion is not None: return criterion(predictions, targets) else: return predictions This base class is mainly meant to be used for simple models. TTBasicMultiGPUModel removes the need to constantly duplicate code in get_loss and get_predictions. """
[docs] def get_loss(self, batch_data, criterion, device): *batch_input_data, targets = [data.to(device) for data in batch_data] loss = self(*batch_input_data, targets=targets, criterion=criterion) return loss
[docs]class MultiGPUModelWrap(TTBasicMultiGPUModel): def __init__(self, model): """Model wrapper optimizing the model for multi-GPU training by moving the loss calculation to the GPUs Args: model (torch.nn.Module or TTModel): neural network model. The model should follow the basic PyTorch model definition where the ``forward()`` function returns predictions """ TTBasicMultiGPUModel.__init__(self) if not isinstance(model, nn.Module): raise TypeError(f'Provided model not inherited from nn.Module') self.model = model
[docs] def forward(self, *input_data, targets=None, criterion=None): """DP friendly forward abstraction on top of the wrapped model's usual forward() function Args: *input_data: whatever input data should be passed into the wrapped model's forward() function targets: target variables which the model is training to fit criterion: loss function Returns: PyTorch loss or model output predictions. If loss function criterion is provided this function returns the calculated loss, otherwise the model output predictions are returned """ predictions = self.model(*input_data) if criterion is not None: return criterion(predictions, targets) return predictions
[docs]class ModelWrap: def __init__(self, model, batch_model_feed_def): """TrainLoop model wrapper combining PyTorch model and model feed definition Note: Especially useful in the case when you want to train on multi-GPU where TTModel abstract functions can't be used. ModelWrap can be used as a replacement of TTModel when using the TrainLoop. Args: model (torch.nn.Module): neural network model batch_model_feed_def (AbstractModelFeedDefinition or None): data prep definition for batched data. This definition prepares the data for each batch that gets than fed into the neural network. """ if not isinstance(model, nn.Module): raise TypeError('Provided model is not inherited base PyTorch Module') if not isinstance(batch_model_feed_def, AbstractModelFeedDefinition): raise TypeError('Provided the base PyTorch model but did not give ' 'the batch_model_feed_def inherited from AbstractModelFeedDefinition') self.model = model self.batch_model_feed_def = batch_model_feed_def