training_history
- class aitoolbox.experiment.training_history.TrainingHistory(has_validation=True, strict_content_check=False)[source]
Bases:
object
Training history abstraction adding specific functionality to the simple dict
In many ways the object can be used with the same API as a normal python dict. However, for the need of tracking performance in the TrainLoop TrainingHistory offers additional functions handling the input, output and quality assurance of the stored results.
- Parameters:
has_validation – if train history should by default include ‘val_loss’. This is needed when train loops by default evaluate loss on validation set when such a set is available.
strict_content_check (bool) – should just print warning or raise the error and crash in case of found (quality) problems
- insert_single_result_into_history(metric_name, metric_result)[source]
Insert a key-value formatted result into the training history
- get_train_history()[source]
Returns the whole train history dict in its original form without any transformations
- Returns:
training history dict
- Return type:
- get_train_history_dict(flatten_dict=False)[source]
Returns QA-ed and optionally flattened training history dict
- wrap_pre_prepared_history(history)[source]
Wrap existing history dict into the TrainingHistory object
- Parameters:
history (dict) – training history base dict
- Returns:
self
- Return type:
Examples
Expected history dict to be wrapped:
history = { 'val_loss': [2.2513437271118164, 2.1482439041137695, 2.0187528133392334, 1.7953970432281494, 1.5492324829101562, 1.715561032295227, 1.631982684135437, 1.3721977472305298, 1.039527416229248, 0.9796673059463501], 'val_acc': [0.25999999046325684, 0.36000001430511475, 0.5, 0.5400000214576721, 0.5400000214576721, 0.5799999833106995, 0.46000000834465027, 0.699999988079071, 0.7599999904632568, 0.7200000286102295], 'loss': [2.3088033199310303, 2.2141530513763428, 2.113713264465332, 1.912109375, 1.666761875152588, 1.460097312927246, 1.6031768321990967, 1.534214973449707, 1.1710081100463867, 0.8969314098358154], 'acc': [0.07999999821186066, 0.33000001311302185, 0.3100000023841858, 0.5299999713897705, 0.5799999833106995, 0.6200000047683716, 0.4300000071525574, 0.5099999904632568, 0.6700000166893005, 0.7599999904632568] }