Bases: Module
Wraps a model whose forward(**batch) returns a dict (e.g. {"x": ...}).
Source code in src/ls_mlkit/model/model_for_pipeline.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43 | class ModelForPipeline(Module):
"""Wraps a model whose ``forward(**batch)`` returns a dict (e.g. ``{"x": ...}``)."""
def __init__(self, model: Module):
super().__init__()
self.model = model
self.hook_manager = ModelHookManager()
def get_model_device(self) -> torch.device:
model_device = next(self.model.parameters()).device
return model_device
def forward(
self,
**batch: Any,
) -> dict[str, Any]:
model = self.model
self.hook_manager.run_hooks(stage=ModelHookStageType.PRE_COMPUTE_LOSS, model=model, batch=batch)
model_output = model(**batch)
self.hook_manager.run_hooks(
stage=ModelHookStageType.POST_COMPUTE_LOSS,
model=model,
batch=batch,
model_output=model_output,
)
return model_output
def register_hooks(self, hooks: list[ModelHook]) -> list[HookHandler[ModelHookStageType]]:
typed_hooks = cast(list[Hook[ModelHookStageType]], hooks)
return self.hook_manager.register_hooks(typed_hooks)
|