model_prediction_store
- class aitoolbox.torchtrain.train_loop.components.model_prediction_store.ModelPredictionStore(auto_purge=False)[source]
Bases:
object
Service for TrainLoop enabling the prediction caching
Prediction calculation can be costly and it can have severe performance implications if the same predictions would be calculated repeatedly. This store caches already made predictions in the current iteration of the TrainLoop which takes the cached values if they are available instead of recalculating.
- Parameters:
auto_purge (bool) – should the prediction service cache be automatically purged at the end of each iteration
- insert_train_predictions(predictions, iteration_idx, force_prediction=False)[source]
Insert training dataset predictions into the cache
- Parameters:
- Returns:
None
- insert_val_predictions(predictions, iteration_idx, force_prediction=False)[source]
Insert validation dataset predictions into the cache
- Parameters:
- Returns:
None
- insert_test_predictions(predictions, iteration_idx, force_prediction=False)[source]
Insert test dataset predictions into the cache
- Parameters:
- Returns:
None
- insert_train_loss(loss, iteration_idx, force_prediction=False)[source]
Insert training dataset loss into the cache
- Parameters:
loss (float or aitoolbox.torchtrain.multi_loss_optim.MultiLoss) – model train dataset loss
iteration_idx (int) – current iteration index of the TrainLoop
force_prediction (bool) – insert the loss value even if it is available in the loss cache. This causes the old cached loss value to be overwritten.
- Returns:
None
- insert_val_loss(loss, iteration_idx, force_prediction=False)[source]
Insert validation dataset loss into the cache
- Parameters:
loss (float or aitoolbox.torchtrain.multi_loss_optim.MultiLoss) – model validation dataset loss
iteration_idx (int) – current iteration index of the TrainLoop
force_prediction (bool) – insert the loss value even if it is available in the loss cache. This causes the old cached loss value to be overwritten.
- Returns:
None
- insert_test_loss(loss, iteration_idx, force_prediction=False)[source]
Insert test dataset loss into the cache
- Parameters:
loss (float or aitoolbox.torchtrain.multi_loss_optim.MultiLoss) – model test dataset loss
iteration_idx (int) – current iteration index of the TrainLoop
force_prediction (bool) – insert the loss value even if it is available in the loss cache. This causes the old cached loss value to be overwritten.
- Returns:
None
- get_train_loss(iteration_idx)[source]
Get training dataset model loss out of the cache
- Parameters:
iteration_idx (int) – current iteration index of the TrainLoop
- Returns:
cached model train dataset loss
- Return type:
- get_val_loss(iteration_idx)[source]
Get validation dataset model loss out of the cache
- Parameters:
iteration_idx (int) – current iteration index of the TrainLoop
- Returns:
cached model validation dataset loss
- Return type:
- get_test_loss(iteration_idx)[source]
Get test dataset model loss out of the cache
- Parameters:
iteration_idx (int) – current iteration index of the TrainLoop
- Returns:
cached model test dataset loss
- Return type:
- _insert_data(source_name, data, iteration_idx, force_prediction=False)[source]
Insert a general value into the prediction / loss cache
- Parameters:
- Returns:
None