Skip to content

model

ls_mlkit.model

ModelForPipeline

Bases: Module

Wraps a model whose forward(**batch) returns a dict (e.g. {"x": ...}).

Source code in src/ls_mlkit/model/model_for_pipeline.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class ModelForPipeline(Module):
    """Wraps a model whose ``forward(**batch)`` returns a dict (e.g. ``{"x": ...}``)."""

    def __init__(self, model: Module):
        super().__init__()
        self.model = model
        self.hook_manager = ModelHookManager()

    def get_model_device(self) -> torch.device:
        model_device = next(self.model.parameters()).device
        return model_device

    def forward(
        self,
        **batch: Any,
    ) -> dict[str, Any]:
        model = self.model
        self.hook_manager.run_hooks(stage=ModelHookStageType.PRE_COMPUTE_LOSS, model=model, batch=batch)
        model_output = model(**batch)
        self.hook_manager.run_hooks(
            stage=ModelHookStageType.POST_COMPUTE_LOSS,
            model=model,
            batch=batch,
            model_output=model_output,
        )
        return model_output

    def register_hooks(self, hooks: list[ModelHook]) -> list[HookHandler[ModelHookStageType]]:
        typed_hooks = cast(list[Hook[ModelHookStageType]], hooks)
        return self.hook_manager.register_hooks(typed_hooks)