import torch
from aitoolbox.torchtrain.callbacks.abstract import AbstractCallback
from aitoolbox.utils.util import is_empty_function
[docs]class CallbacksHandler:
def __init__(self, train_loop_obj):
"""Callback handler used for the callback orchestration inside the TrainLoop
The use of this handler is to call specified callback methods inside the TrainLoop at different stages of
the training process. This executes desired callbacks' functionality at the desired point of the training
process.
The ``CallbacksHandler`` handler will at certain TrainLoop stage only execute those
callback methods which have implemented the functionality intended to be executed at this particular stage.
Thus, `CallbacksHandler` doesn't unnecessarily execute callbacks at stages they are not implemented at -
their respective callback methods are left as ``pass`` and aren't overridden with some desired code logic.
Args:
train_loop_obj (aitoolbox.torchtrain.train_loop.train_loop.TrainLoop): reference to the encapsulating
TrainLoop
"""
self.train_loop_obj = train_loop_obj
self.callbacks_cache = []
self.cbs_on_epoch_begin = []
self.cbs_on_epoch_end = []
self.cbs_on_train_begin = []
self.cbs_on_train_end = []
self.cbs_on_batch_begin = []
self.cbs_on_batch_end = []
self.cbs_on_after_gradient_update = []
self.cbs_on_after_optimizer_step = []
self.cbs_on_multiprocess_start = []
self.cbs_on_after_batch_prediction = []
self.registered_cbs = [
self.cbs_on_epoch_begin, self.cbs_on_epoch_end,
self.cbs_on_train_begin, self.cbs_on_train_end,
self.cbs_on_batch_begin, self.cbs_on_batch_end,
self.cbs_on_after_gradient_update, self.cbs_on_after_optimizer_step,
self.cbs_on_multiprocess_start,
self.cbs_on_after_batch_prediction
]
[docs] def register_callbacks(self, callbacks, cache_callbacks=False, print_callbacks=False):
"""Register TrainLoop object reference inside the listed callbacks when the TrainLoop is created
Normally, this is called from inside the train loop by the TrainLoop itself. Basically train loop "registers"
itself with each of the provided callbacks.
Add via append new provided callbacks to the existing ones.
Args:
callbacks (list or None): list of new callbacks to be added (appended)
cache_callbacks (bool): should the provided callbacks be cached and not yet registered. First subsequent
time this method is called without ``cache_callbacks`` enabled all the previously cached callbacks
are added and also registered with the current list of callbacks.
print_callbacks (bool): after registering the provided callbacks also print the list of registered callbacks
which will be executed during the run of the train loop
Returns:
None
"""
if cache_callbacks:
# Just filling the self.callbacks_cache list with callbacks
self.callbacks_cache += callbacks if callbacks is not None else []
else:
# Combine any previously cached callbacks with new callbacks
# If there aren't any callbacks cached then the callback cache is just an empty list
callbacks = self.callbacks_cache + (callbacks if callbacks is not None else [])
# Empty the callbacks cache
self.callbacks_cache = []
if callbacks is not None and len(callbacks) > 0:
self.enforce_callbacks_quality(callbacks)
self.train_loop_obj.callbacks += [
cb.register_train_loop_object(self.train_loop_obj) for cb in callbacks
if self.should_enable_callback(cb)
]
if not all(0 == cb.execution_order for cb in self.train_loop_obj.callbacks):
self.train_loop_obj.callbacks = sorted(self.train_loop_obj.callbacks, key=lambda cb: cb.execution_order)
# Note: using `callbacks` here instead of `self.train_loop_obj.callbacks` is correct.
# Provide original input `callbacks` to this method instead of `self.train_loop_obj.callbacks`
# which we added new callbacks to above. In case some callbacks were already registered at some earlier
# time this prevents their duplication int the execution-position-split self.registered_cbs.
self.split_on_execution_position(callbacks, register_train_loop=False)
if print_callbacks:
self.print_registered_callback_names()
[docs] def should_enable_callback(self, callback):
"""Determine if callback should be enabled and executed to be in accordance with the GPU device setting
Always true in case of training on single device (CPU or one GPU).
In case of multi (GPU) device training such as DDP, this function checks if a callback should be executed on
the particular GPU device. If the callback doesn't have any ``device_idx_execution`` set than it is executed
on all the GPUs. In case the parameter is set in the callback than this function will only be True when the set
``device_idx_execution`` in the callback and the train loop's GPU device index match. In other words
the callback will be executed only in the DDP process which sits on the matching GPU.
Args:
callback (AbstractCallback): callback which will be checked if it should be enabled during the particular
train loop run
Returns:
bool: if the provided callback should be enabled or disabled based on (GPU) device index matching.
"""
return self.train_loop_obj.device.index is None or \
callback.device_idx_execution is None or \
(
callback.device_idx_execution is not None and
callback.device_idx_execution == self.train_loop_obj.device.index
)
[docs] def execute_epoch_begin(self):
for callback in self.cbs_on_epoch_begin:
callback.on_epoch_begin()
[docs] def execute_epoch_end(self):
for callback in self.cbs_on_epoch_end:
callback.on_epoch_end()
[docs] def execute_train_begin(self):
for callback in self.cbs_on_train_begin:
callback.on_train_begin()
[docs] def execute_train_end(self):
for callback in self.cbs_on_train_end:
callback.on_train_end()
[docs] def execute_batch_begin(self):
for callback in self.cbs_on_batch_begin:
callback.on_batch_begin()
[docs] def execute_batch_end(self):
for callback in self.cbs_on_batch_end:
callback.on_batch_end()
[docs] def execute_gradient_update(self, optimizer_idx=0):
for callback in self.cbs_on_after_gradient_update:
callback.on_after_gradient_update(optimizer_idx)
[docs] def execute_optimizer_step(self):
for callback in self.cbs_on_after_optimizer_step:
callback.on_after_optimizer_step()
[docs] def execute_multiprocess_start(self):
for callback in self.cbs_on_multiprocess_start:
callback.on_multiprocess_start()
[docs] def execute_after_batch_prediction(self, y_pred_batch, y_test_batch, metadata_batch, dataset_info):
for callback in self.cbs_on_after_batch_prediction:
callback.on_after_batch_prediction(y_pred_batch, y_test_batch, metadata_batch, dataset_info)
[docs] def split_on_execution_position(self, callbacks, register_train_loop=False):
if callbacks is not None and len(callbacks) > 0:
for callback in callbacks:
if self.should_enable_callback(callback):
if register_train_loop:
callback = callback.register_train_loop_object(self.train_loop_obj)
if not is_empty_function(callback.on_epoch_begin):
self.cbs_on_epoch_begin.append(callback)
if not is_empty_function(callback.on_epoch_end):
self.cbs_on_epoch_end.append(callback)
if not is_empty_function(callback.on_train_begin):
self.cbs_on_train_begin.append(callback)
if not is_empty_function(callback.on_train_end):
self.cbs_on_train_end.append(callback)
if not is_empty_function(callback.on_batch_begin):
self.cbs_on_batch_begin.append(callback)
if not is_empty_function(callback.on_batch_end):
self.cbs_on_batch_end.append(callback)
if not is_empty_function(callback.on_after_gradient_update):
self.cbs_on_after_gradient_update.append(callback)
if not is_empty_function(callback.on_after_optimizer_step):
self.cbs_on_after_optimizer_step.append(callback)
if not is_empty_function(callback.on_multiprocess_start):
self.cbs_on_multiprocess_start.append(callback)
if not is_empty_function(callback.on_after_batch_prediction):
self.cbs_on_after_batch_prediction.append(callback)
for cbs_at_position in self.registered_cbs:
if not all(0 == cb.execution_order for cb in cbs_at_position):
cbs_at_position.sort(key=lambda cb: cb.execution_order)
[docs] def mp_filter_callbacks(self):
self.train_loop_obj.callbacks = self._mp_filter_cb_list(self.train_loop_obj.callbacks)
self.cbs_on_epoch_begin = self._mp_filter_cb_list(self.cbs_on_epoch_begin)
self.cbs_on_epoch_end = self._mp_filter_cb_list(self.cbs_on_epoch_end)
self.cbs_on_train_begin = self._mp_filter_cb_list(self.cbs_on_train_begin)
self.cbs_on_train_end = self._mp_filter_cb_list(self.cbs_on_train_end)
self.cbs_on_batch_begin = self._mp_filter_cb_list(self.cbs_on_batch_begin)
self.cbs_on_batch_end = self._mp_filter_cb_list(self.cbs_on_batch_end)
self.cbs_on_after_gradient_update = self._mp_filter_cb_list(self.cbs_on_after_gradient_update)
self.cbs_on_after_optimizer_step = self._mp_filter_cb_list(self.cbs_on_after_optimizer_step)
self.cbs_on_multiprocess_start = self._mp_filter_cb_list(self.cbs_on_multiprocess_start)
self.cbs_on_after_batch_prediction = self._mp_filter_cb_list(self.cbs_on_after_batch_prediction)
self.registered_cbs = [
self.cbs_on_epoch_begin, self.cbs_on_epoch_end,
self.cbs_on_train_begin, self.cbs_on_train_end,
self.cbs_on_batch_begin, self.cbs_on_batch_end,
self.cbs_on_after_gradient_update, self.cbs_on_after_optimizer_step,
self.cbs_on_multiprocess_start,
self.cbs_on_after_batch_prediction
]
def _mp_filter_cb_list(self, callbacks_list):
return [cb for cb in callbacks_list if self.should_enable_callback(cb)]
[docs] def enforce_callbacks_quality(self, callbacks):
for cb in callbacks:
if not isinstance(cb, AbstractCallback):
raise TypeError(f'Callback {cb} is not inherited from the AbstractCallback')
if cb.device_idx_execution is not None and self.train_loop_obj.device.index is not None:
if cb.device_idx_execution >= torch.cuda.device_count():
raise ValueError(f'Selected device_idx_execution of {cb.device_idx_execution} is too high. '
f'There are only {torch.cuda.device_count()} available GPU devices. '
f'Select index ranging from 0 to {torch.cuda.device_count() - 1}')
def __str__(self):
return 'CALLBACKS\n' \
f'At on_epoch_begin:\n{self.print_callback_info(self.cbs_on_epoch_begin)}\n' \
f'At on_epoch_end:\n{self.print_callback_info(self.cbs_on_epoch_end)}\n' \
f'At on_train_begin:\n{self.print_callback_info(self.cbs_on_train_begin)}\n' \
f'At on_train_end:\n{self.print_callback_info(self.cbs_on_train_end)}\n' \
f'At on_batch_begin:\n{self.print_callback_info(self.cbs_on_batch_begin)}\n' \
f'At on_batch_end:\n{self.print_callback_info(self.cbs_on_batch_end)}\n' \
f'At on_after_gradient_update:\n{self.print_callback_info(self.cbs_on_after_gradient_update)}\n' \
f'At on_after_optimizer_step:\n{self.print_callback_info(self.cbs_on_after_optimizer_step)}\n' \
f'At on_multiprocess_start:\n{self.print_callback_info(self.cbs_on_multiprocess_start)}\n' \
f'At cbs_on_after_batch_prediction:\n{self.print_callback_info(self.cbs_on_after_batch_prediction)}\n'
[docs] @staticmethod
def print_callback_info(callback_list):
return '\n'.join([f'\t{callback.callback_name}: {type(callback)}, execution_order: {callback.execution_order}'
for callback in callback_list])
[docs] def print_registered_callback_names(self):
if self.train_loop_obj.ddp_training_mode:
print(f'*** On device {self.train_loop_obj.device.index} ({self.train_loop_obj.device}) ***')
print(self)
def __len__(self):
return len(self.train_loop_obj.callbacks)
[docs] def __add__(self, other):
"""
Args:
other (list): callbacks list
Returns:
CallbacksHandler:
"""
self.register_callbacks(other)
return self
[docs] def __iadd__(self, other):
"""
Args:
other (list): callbacks list
Returns:
CallbacksHandler:
"""
self.register_callbacks(other)
return self
[docs] def __contains__(self, item):
"""
Args:
item:
Returns:
bool:
"""
if type(item) == str:
for cb in self.train_loop_obj.callbacks:
if cb.callback_name == item:
return True
else:
for cb in self.train_loop_obj.callbacks:
if type(cb) == item:
return True
return False