model_prediction_store

class aitoolbox.torchtrain.train_loop.components.model_prediction_store.ModelPredictionStore(auto_purge=False)[source]

Bases: object

Service for TrainLoop enabling the prediction caching

Prediction calculation can be costly and it can have severe performance implications if the same predictions would be calculated repeatedly. This store caches already made predictions in the current iteration of the TrainLoop which takes the cached values if they are available instead of recalculating.

Parameters

auto_purge (bool) – should the prediction service cache be automatically purged at the end of each iteration

insert_train_predictions(predictions, iteration_idx, force_prediction=False)[source]

Insert training dataset predictions into the cache

Parameters
  • predictions (tuple) – model training dataset predictions

  • iteration_idx (int) – current iteration index of the TrainLoop

  • force_prediction (bool) – insert the predicted values even if they are available in the prediction cache. This causes the old cached predictions to be overwritten.

Returns

None

insert_val_predictions(predictions, iteration_idx, force_prediction=False)[source]

Insert validation dataset predictions into the cache

Parameters
  • predictions (tuple) – model validation dataset predictions

  • iteration_idx (int) – current iteration index of the TrainLoop

  • force_prediction (bool) – insert the predicted values even if they are available in the prediction cache. This causes the old cached predictions to be overwritten.

Returns

None

insert_test_predictions(predictions, iteration_idx, force_prediction=False)[source]

Insert test dataset predictions into the cache

Parameters
  • predictions (tuple) – model test dataset predictions

  • iteration_idx (int) – current iteration index of the TrainLoop

  • force_prediction (bool) – insert the predicted values even if they are available in the prediction cache. This causes the old cached predictions to be overwritten.

Returns

None

get_train_predictions(iteration_idx)[source]

Get training dataset predictions out of the cache

Parameters

iteration_idx (int) – current iterating index of the TrainLoop

Returns

cached model train dataset predictions

Return type

tuple

get_val_predictions(iteration_idx)[source]

Get validation dataset predictions out of the cache

Parameters

iteration_idx (int) – current iteration index of the TrainLoop

Returns

cached model validation dataset predictions

Return type

tuple

get_test_predictions(iteration_idx)[source]

Get test dataset predictions out of the cache

Parameters

iteration_idx (int) – current iteration index of the TrainLoop

Returns

cached model test dataset predictions

Return type

tuple

has_train_predictions(iteration_idx)[source]

Are there training dataset predictions in the cache

Parameters

iteration_idx (int) – current iteration index of the TrainLoop

Returns

if predictions are in the cache

Return type

bool

has_val_predictions(iteration_idx)[source]

Are there validation dataset predictions in the cache

Parameters

iteration_idx (int) – current iteration index of the TrainLoop

Returns

if predictions are in the cache

Return type

bool

has_test_predictions(iteration_idx)[source]

Are there test dataset predictions in the cache

Parameters

iteration_idx (int) – current iteration index of the TrainLoop

Returns

if predictions are in the cache

Return type

bool

insert_train_loss(loss, iteration_idx, force_prediction=False)[source]

Insert training dataset loss into the cache

Parameters
  • loss (float or dict) – model train dataset loss

  • iteration_idx (int) – current iteration index of the TrainLoop

  • force_prediction (bool) – insert the loss value even if it is available in the loss cache. This causes the old cached loss value to be overwritten.

Returns

None

insert_val_loss(loss, iteration_idx, force_prediction=False)[source]

Insert validation dataset loss into the cache

Parameters
  • loss (float or dict) – model validation dataset loss

  • iteration_idx (int) – current iteration index of the TrainLoop

  • force_prediction (bool) – insert the loss value even if it is available in the loss cache. This causes the old cached loss value to be overwritten.

Returns

None

insert_test_loss(loss, iteration_idx, force_prediction=False)[source]

Insert test dataset loss into the cache

Parameters
  • loss (float or dict) – model test dataset loss

  • iteration_idx (int) – current iteration index of the TrainLoop

  • force_prediction (bool) – insert the loss value even if it is available in the loss cache. This causes the old cached loss value to be overwritten.

Returns

None

get_train_loss(iteration_idx)[source]

Get training dataset model loss out of the cache

Parameters

iteration_idx (int) – current iteration index of the TrainLoop

Returns

cached model train dataset loss

Return type

float or dict

get_val_loss(iteration_idx)[source]

Get validation dataset model loss out of the cache

Parameters

iteration_idx (int) – current iteration index of the TrainLoop

Returns

cached model validation dataset loss

Return type

float or dict

get_test_loss(iteration_idx)[source]

Get test dataset model loss out of the cache

Parameters

iteration_idx (int) – current iteration index of the TrainLoop

Returns

cached model test dataset loss

Return type

float or dict

has_train_loss(iteration_idx)[source]

Is there training dataset model loss in the cache

Parameters

iteration_idx (int) – current iteration index of the TrainLoop

Returns

if loss value is in the cache

Return type

bool

has_val_loss(iteration_idx)[source]

iteration index

has_test_loss(iteration_idx)[source]

Is there test dataset model loss in the cache

Parameters

iteration_idx (int) – current epoch of the TrainLoop

Returns

if loss value is in the cache

Return type

bool

_insert_data(source_name, data, iteration_idx, force_prediction=False)[source]

Insert a general value into the prediction / loss cache

Parameters
  • source_name (str) – data source name

  • data (tuple or float or dict) – data to be cached

  • iteration_idx (int) – current iteration index of the TrainLoop

  • force_prediction (bool) – insert the data into the cache even if it is already available in the cache. This causes the old cached data under the same source_name to be overwritten.

Returns

None

_get_data(source_name, iteration_idx)[source]

Get data based on the source name from the cache

Parameters
  • source_name (str) – data source name

  • iteration_idx (int) – current iteration index of the TrainLoop

Returns

cached data

Return type

tuple or float or dict

_has_data(source_name, iteration_idx)[source]

Check if data under the specified source name is currently available in the cache

Parameters
  • source_name (str) – data source name

  • iteration_idx (int) – current iteration index of the TrainLoop

Returns

if the requested data is available in the cache

Return type

bool

auto_purge(iteration_idx)[source]

Automatically purge the current cache if the given iteration index had moved past the last cached iteration

Parameters

iteration_idx (int) – current iteration index of the TrainLoop

Returns

None