training_history

class aitoolbox.experiment.training_history.TrainingHistory(has_validation=True, strict_content_check=False)[source]

Bases: object

Training history abstraction adding specific functionality to the simple dict

In many ways the object can be used with the same API as a normal python dict. However, for the need of tracking performance in the TrainLoop TrainingHistory offers additional functions handling the input, output and quality assurance of the stored results.

Parameters:
  • has_validation – if train history should by default include ‘val_loss’. This is needed when train loops by default evaluate loss on validation set when such a set is available.

  • strict_content_check (bool) – should just print warning or raise the error and crash in case of found (quality) problems

insert_single_result_into_history(metric_name, metric_result)[source]

Insert a key-value formatted result into the training history

Parameters:
  • metric_name (str) – name of the metric to be stored.

  • metric_result (float or dict) – metric performance result to be stored.

get_train_history()[source]

Returns the whole train history dict in its original form without any transformations

Returns:

training history dict

Return type:

dict

get_train_history_dict(flatten_dict=False)[source]

Returns QA-ed and optionally flattened training history dict

Parameters:

flatten_dict (bool) – should the returned training history dict be flattened. So no nested dicts of dicts. The keys of the nested dicts will we “_” concatenated and moved into the single level dict.

Returns:

training history dict

Return type:

dict

wrap_pre_prepared_history(history)[source]

Wrap existing history dict into the TrainingHistory object

Parameters:

history (dict) – training history base dict

Returns:

self

Return type:

TrainingHistory

Examples

Expected history dict to be wrapped:

history = {
    'val_loss': [2.2513437271118164, 2.1482439041137695, 2.0187528133392334, 1.7953970432281494,
                 1.5492324829101562, 1.715561032295227, 1.631982684135437, 1.3721977472305298,
                 1.039527416229248, 0.9796673059463501],
    'val_acc': [0.25999999046325684, 0.36000001430511475, 0.5, 0.5400000214576721, 0.5400000214576721,
                0.5799999833106995,  0.46000000834465027, 0.699999988079071, 0.7599999904632568,
                0.7200000286102295],
    'loss': [2.3088033199310303, 2.2141530513763428, 2.113713264465332, 1.912109375, 1.666761875152588,
             1.460097312927246, 1.6031768321990967, 1.534214973449707, 1.1710081100463867,
             0.8969314098358154],
    'acc': [0.07999999821186066, 0.33000001311302185, 0.3100000023841858, 0.5299999713897705,
            0.5799999833106995, 0.6200000047683716, 0.4300000071525574, 0.5099999904632568,
            0.6700000166893005, 0.7599999904632568]
}
qa_check_history_records()[source]

Quality check history

Returns:

None

warn_about_result_data_problem(msg)[source]
keys()[source]
items()[source]
add_history_dict(other)[source]

Add another training history dict to this training history

Parameters:

other (dict) – another training history dict

Returns:

None