model_prediction_store

class aitoolbox.torchtrain.train_loop.components.model_prediction_store.ModelPredictionStore(auto_purge=False)[source]

Bases: object

Service for TrainLoop enabling the prediction caching

Prediction calculation can be costly and it can have severe performance implications if the same predictions would be calculated repeatedly. This store caches already made predictions in the current iteration of the TrainLoop which takes the cached values if they are available instead of recalculating.

Parameters:

auto_purge (bool) – should the prediction service cache be automatically purged at the end of each iteration

insert_train_predictions(predictions, iteration_idx, force_prediction=False)[source]

Insert training dataset predictions into the cache

Parameters:
  • predictions (tuple) – model training dataset predictions

  • iteration_idx (int) – current iteration index of the TrainLoop

  • force_prediction (bool) – insert the predicted values even if they are available in the prediction cache. This causes the old cached predictions to be overwritten.

Returns:

None

insert_val_predictions(predictions, iteration_idx, force_prediction=False)[source]

Insert validation dataset predictions into the cache

Parameters:
  • predictions (tuple) – model validation dataset predictions

  • iteration_idx (int) – current iteration index of the TrainLoop

  • force_prediction (bool) – insert the predicted values even if they are available in the prediction cache. This causes the old cached predictions to be overwritten.

Returns:

None

insert_test_predictions(predictions, iteration_idx, force_prediction=False)[source]

Insert test dataset predictions into the cache

Parameters:
  • predictions (tuple) – model test dataset predictions

  • iteration_idx (int) – current iteration index of the TrainLoop

  • force_prediction (bool) – insert the predicted values even if they are available in the prediction cache. This causes the old cached predictions to be overwritten.

Returns:

None

get_train_predictions(iteration_idx)[source]

Get training dataset predictions out of the cache

Parameters:

iteration_idx (int) – current iterating index of the TrainLoop

Returns:

cached model train dataset predictions

Return type:

tuple

get_val_predictions(iteration_idx)[source]

Get validation dataset predictions out of the cache

Parameters:

iteration_idx (int) – current iteration index of the TrainLoop

Returns:

cached model validation dataset predictions

Return type:

tuple

get_test_predictions(iteration_idx)[source]

Get test dataset predictions out of the cache

Parameters:

iteration_idx (int) – current iteration index of the TrainLoop

Returns:

cached model test dataset predictions

Return type:

tuple

has_train_predictions(iteration_idx)[source]

Are there training dataset predictions in the cache

Parameters:

iteration_idx (int) – current iteration index of the TrainLoop

Returns:

if predictions are in the cache

Return type:

bool

has_val_predictions(iteration_idx)[source]

Are there validation dataset predictions in the cache

Parameters:

iteration_idx (int) – current iteration index of the TrainLoop

Returns:

if predictions are in the cache

Return type:

bool

has_test_predictions(iteration_idx)[source]

Are there test dataset predictions in the cache

Parameters:

iteration_idx (int) – current iteration index of the TrainLoop

Returns:

if predictions are in the cache

Return type:

bool

insert_train_loss(loss, iteration_idx, force_prediction=False)[source]

Insert training dataset loss into the cache

Parameters:
  • loss (float or aitoolbox.torchtrain.multi_loss_optim.MultiLoss) – model train dataset loss

  • iteration_idx (int) – current iteration index of the TrainLoop

  • force_prediction (bool) – insert the loss value even if it is available in the loss cache. This causes the old cached loss value to be overwritten.

Returns:

None

insert_val_loss(loss, iteration_idx, force_prediction=False)[source]

Insert validation dataset loss into the cache

Parameters:
  • loss (float or aitoolbox.torchtrain.multi_loss_optim.MultiLoss) – model validation dataset loss

  • iteration_idx (int) – current iteration index of the TrainLoop

  • force_prediction (bool) – insert the loss value even if it is available in the loss cache. This causes the old cached loss value to be overwritten.

Returns:

None

insert_test_loss(loss, iteration_idx, force_prediction=False)[source]

Insert test dataset loss into the cache

Parameters:
  • loss (float or aitoolbox.torchtrain.multi_loss_optim.MultiLoss) – model test dataset loss

  • iteration_idx (int) – current iteration index of the TrainLoop

  • force_prediction (bool) – insert the loss value even if it is available in the loss cache. This causes the old cached loss value to be overwritten.

Returns:

None

get_train_loss(iteration_idx)[source]

Get training dataset model loss out of the cache

Parameters:

iteration_idx (int) – current iteration index of the TrainLoop

Returns:

cached model train dataset loss

Return type:

float or aitoolbox.torchtrain.multi_loss_optim.MultiLoss

get_val_loss(iteration_idx)[source]

Get validation dataset model loss out of the cache

Parameters:

iteration_idx (int) – current iteration index of the TrainLoop

Returns:

cached model validation dataset loss

Return type:

float or aitoolbox.torchtrain.multi_loss_optim.MultiLoss

get_test_loss(iteration_idx)[source]

Get test dataset model loss out of the cache

Parameters:

iteration_idx (int) – current iteration index of the TrainLoop

Returns:

cached model test dataset loss

Return type:

float or aitoolbox.torchtrain.multi_loss_optim.MultiLoss

has_train_loss(iteration_idx)[source]

Is there training dataset model loss in the cache

Parameters:

iteration_idx (int) – current iteration index of the TrainLoop

Returns:

if loss value is in the cache

Return type:

bool

has_val_loss(iteration_idx)[source]

Is there validation dataset model loss in the cache

Parameters:

iteration_idx (int) – current iteration index of the TrainLoop

Returns:

if loss value is in the cache

Return type:

bool

has_test_loss(iteration_idx)[source]

Is there test dataset model loss in the cache

Parameters:

iteration_idx (int) – current epoch of the TrainLoop

Returns:

if loss value is in the cache

Return type:

bool

_insert_data(source_name, data, iteration_idx, force_prediction=False)[source]

Insert a general value into the prediction / loss cache

Parameters:
  • source_name (str) – data source name

  • data (tuple or float or dict) – data to be cached

  • iteration_idx (int) – current iteration index of the TrainLoop

  • force_prediction (bool) – insert the data into the cache even if it is already available in the cache. This causes the old cached data under the same source_name to be overwritten.

Returns:

None

_get_data(source_name, iteration_idx)[source]

Get data based on the source name from the cache

Parameters:
  • source_name (str) – data source name

  • iteration_idx (int) – current iteration index of the TrainLoop

Returns:

cached data

Return type:

tuple or float or dict

_has_data(source_name, iteration_idx)[source]

Check if data under the specified source name is currently available in the cache

Parameters:
  • source_name (str) – data source name

  • iteration_idx (int) – current iteration index of the TrainLoop

Returns:

if the requested data is available in the cache

Return type:

bool

auto_purge(iteration_idx)[source]

Automatically purge the current cache if the given iteration index had moved past the last cached iteration

Parameters:

iteration_idx (int) – current iteration index of the TrainLoop

Returns:

None