Source code for ls_mlkit.util.hook.base_hook
from enum import Enum
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
HookStageType = TypeVar("HookStageType", bound=Enum)
[docs]
class Hook(Generic[HookStageType]):
def __init__(
self, name: str, stage: HookStageType, fn: Callable[..., Optional[Any]], priority: int = 0, enabled: bool = True
):
self.name = name
self.stage = stage
self.fn = fn
self.priority = priority
self.enabled = enabled
def __call__(self, **kwargs) -> Optional[Any]:
if not self.enabled:
return None
return self.fn(**kwargs)
def __repr__(self):
return f"Hook(name={self.name}, stage={self.stage}, priority={self.priority})"
[docs]
class HookHandler(Generic[HookStageType]):
def __init__(self, manager: "HookManager[HookStageType]", hook: Hook[HookStageType]):
self._manager = manager
self._hook = hook
@property
def hook(self) -> Hook[HookStageType]:
return self._hook
[docs]
def enable(self) -> None:
self._hook.enabled = True
[docs]
def disable(self) -> None:
self._hook.enabled = False
[docs]
def remove(self) -> None:
self._manager.unregister_hook(name=self._hook.name, stage=self._hook.stage)
def __repr__(self):
state = "enabled" if self._hook.enabled else "disabled"
return f"<HookHandler {self._hook.name} ({state})>"
[docs]
class HookManager(Generic[HookStageType]):
def __init__(self) -> None:
self._hooks: Dict[HookStageType, list[Hook[HookStageType]]] = {}
[docs]
def register_hook(self, hook: Hook[HookStageType]) -> HookHandler[HookStageType]:
self._hooks.setdefault(hook.stage, []).append(hook)
self._hooks[hook.stage].sort(key=lambda h: h.priority, reverse=False)
return HookHandler(self, hook=hook)
[docs]
def register_hooks(self, hooks: list[Hook[HookStageType]]) -> list[HookHandler[HookStageType]]:
return [self.register_hook(hook) for hook in hooks]
[docs]
def unregister_hook(self, name: str, stage: Optional[HookStageType] = None) -> None:
if stage:
if stage in self._hooks:
self._hooks[stage] = [h for h in self._hooks[stage] if h.name != name]
else:
for s in self._hooks:
self._hooks[s] = [h for h in self._hooks[s] if h.name != name]
[docs]
def enable_hook(self, name: str = None, stage: HookStageType = None, enabled: bool = True) -> None:
hook_found = False
for stage_key, hooks in self._hooks.items():
if stage is not None and stage_key != stage:
continue
for h in hooks:
if name is None or h.name == name:
h.enabled = enabled
hook_found = True
if not hook_found:
raise ValueError(f"Hook with name {name} not found.")
[docs]
def disable_hook(self, name: str = None, stage: HookStageType = None) -> None:
self.enable_hook(name=name, stage=stage, enabled=False)
[docs]
def run_hooks(self, stage: HookStageType, tgt_key_name=None, **kwargs) -> Optional[Any]:
"""Executes all enabled hooks for a given stage, optionally updating or collecting results in kwargs,
and returns either the final modified kwargs or a specific key's value.
Args:
stage (``HookStageType``): _description_
tgt_key_name (``_type_``, *optional*): target key name. Defaults to None.
"""
hook_output = None
if stage is not None and stage in self._hooks:
for hook in self._hooks[stage]:
if not hook.enabled:
continue
hook_output = hook(**kwargs)
if tgt_key_name is not None:
kwargs[tgt_key_name] = hook_output
else:
kwargs = hook_output
if tgt_key_name is not None:
return kwargs[tgt_key_name]
elif tgt_key_name is None:
return kwargs
[docs]
def list_hooks(self) -> None:
for stage, hooks in self._hooks.items():
print(f"[{stage}]")
for h in hooks:
print(f" - {h} {'(enabled)' if h.enabled else '(disabled)'}")