from enum import Enum
[docs]class MessageHandling(Enum):
KEEP_FOREVER = 'keep_forever'
UNTIL_END_OF_EPOCH = 'until_end_of_epoch'
UNTIL_READ = 'until_read'
OVERWRITE = 'overwrite'
[docs]class Message:
def __init__(self, key, value, msg_handling_settings):
"""Wrapper object to represent the messages in the MessageService together with their handling settings
Args:
key (str): message key
value: message value
msg_handling_settings (MessageHandling or list[MessageHandling]): selected message handling settings
for this particular message
"""
self.key = key
self.value = value
self.msg_handling_settings = msg_handling_settings \
if type(msg_handling_settings) is list else [msg_handling_settings]
[docs]class MessageService:
def __init__(self):
"""Message Passing Service
Primarily intended for passing the messages in the TrainLoop, especially for communication or data sharing
between different callbacks.
"""
self.message_store = {}
[docs] def read_messages(self, key):
"""Read messages by key from the TrainLoop message service
Args:
key (str): message key
Returns:
list or None: if message key present return content, otherwise return None
"""
if key in self.message_store:
messages = [msg.value for msg in self.message_store[key]]
self.message_store[key] = [
msg for msg in self.message_store[key]
if MessageHandling.UNTIL_READ not in msg.msg_handling_settings
]
return messages
else:
return None
[docs] def write_message(self, key, value, msg_handling_settings=MessageHandling.UNTIL_END_OF_EPOCH):
"""Write a new message to the message service
Args:
key (str): message key
value: message content
msg_handling_settings (MessageHandling or list[MessageHandling]): setting how to handle the lifespan of
the message. Can use one of the following message lifecycle handling settings which are variables
imported from this script file and can be found defined at the beginning of the script:
* ``KEEP_FOREVER``
* ``UNTIL_END_OF_EPOCH``
* ``UNTIL_READ``
* ``OVERWRITE``
Returns:
None
"""
msg_handling_settings = self.validate_msg_handling_settings(msg_handling_settings)
if key not in self.message_store:
self.message_store[key] = []
message = Message(key, value, msg_handling_settings)
if MessageHandling.OVERWRITE in msg_handling_settings:
self.message_store[key] = [message]
else:
self.message_store[key].append(message)
[docs] def end_of_epoch_trigger(self):
"""Purging of the message service at the end of the epoch
Normally executed by the TrainLoop automatically after all the callbacks were executed at the end of every epoch
Returns:
None
"""
for key, msgs_list in list(self.message_store.items()):
self.message_store[key] = [msg for msg in self.message_store[key]
if MessageHandling.UNTIL_END_OF_EPOCH not in msg.msg_handling_settings]
if len(self.message_store[key]) == 0:
del self.message_store[key]
[docs] @staticmethod
def validate_msg_handling_settings(msg_handling_settings):
if type(msg_handling_settings) == list:
for msg_setting in msg_handling_settings:
if type(msg_setting) != MessageHandling:
raise TypeError('msg_setting is not of the correct MessageHandling type. '
f'It is {type(msg_setting)}.')
if len(msg_handling_settings) > 1 and MessageHandling.OVERWRITE not in msg_handling_settings:
raise ValueError(f'Provided two incompatible msg_handling_settings {msg_handling_settings}. '
'Only OVERRIDE setting can currently be combined with another available setting.')
elif type(msg_handling_settings) != MessageHandling:
raise TypeError(f'Provided msg_handling_settings {msg_handling_settings} type not of the supported '
'MessageHandling or list of MessageHandling.')
return msg_handling_settings if type(msg_handling_settings) is list else [msg_handling_settings]