Source code for ls_mlkit.model.model_for_pipeline

from typing import Any

import torch
from torch.nn import Module

from ..util.hook.model_hook import ModelHook, ModelHookHandler, ModelHookManager, ModelHookStageType


[docs] class ModelForPipeline(Module): def __init__(self, model: Module): super().__init__() self.model = model self.hook_manager = ModelHookManager()
[docs] def get_model_device(self) -> torch.device: model_device = next(self.model.parameters()).device return model_device
[docs] def forward( self, **batch: dict[str, Any], ) -> dict: model = self.model self.hook_manager.run_hooks(stage=ModelHookStageType.PRE_COMPUTE_LOSS, model=model, batch=batch) model_output = model(**batch) self.hook_manager.run_hooks( stage=ModelHookStageType.POST_COMPUTE_LOSS, model=model, batch=batch, model_output=model_output ) return model_output
[docs] def register_hooks(self, hooks: list[ModelHook]) -> list[ModelHookHandler]: return self.hook_manager.register_hooks(hooks)