Source code for ls_mlkit.pipeline.callback

from abc import ABCMeta, abstractmethod
from enum import Enum
from typing import List


[docs] class CallbackEvent(Enum): # training TRAINING_START = "training_start" TRAINING_END = "training_end" EPOCH_START = "epoch_start" EPOCH_END = "epoch_end" STEP_START = "step_start" STEP_END = "step_end" # save, load PRE_SAVE = "pre_save" POST_SAVE = "post_save" PRE_LOAD = "pre_load" POST_LOAD = "post_load" # optimize PRE_COMPUTE_LOSS = "pre_compute_loss" POST_COMPUTE_LOSS = "post_compute_loss" PRE_BACKWARD = "pre_backward" POST_BACKWARD = "post_backward" PRE_OPTIMIZER_STEP = "pre_optimizer_step" POST_OPTIMIZER_STEP = "post_optimizer_step" # eval PRE_EVAL = "pre_eval" POST_EVAL = "post_eval" PRE_EVAL_STEP = "pre_eval_step" POST_EVAL_STEP = "post_eval_step"
[docs] class BaseCallback(metaclass=ABCMeta):
[docs] @abstractmethod def on_event(self, event: CallbackEvent, *args, **kwargs): """On event Args: event (CallbackEvent): the event to trigger *args: the arguments to pass to the callback **kwargs: the keyword arguments to pass to the callback """
[docs] class CallbackManager: def __init__(self): self.callbacks: List[BaseCallback] = []
[docs] def add_callback(self, callback: BaseCallback): """Add a callback Args: callback (BaseCallback): the callback to add """ if callback is not None: self.callbacks.append(callback)
[docs] def add_callbacks(self, callbacks: List[BaseCallback]): """Add a list of callbacks Args: callbacks (List[BaseCallback]): the callbacks to add """ if callbacks is not None and len(callbacks) > 0: self.callbacks.extend(callbacks)
[docs] def trigger(self, event: CallbackEvent, *args, **kwargs): """Trigger all callbacks for a given event Args: event (CallbackEvent): the event to trigger *args: the arguments to pass to the callback **kwargs: the keyword arguments to pass to the callback """ for callback in self.callbacks: callback.on_event(event, *args, **kwargs)