Skip to content

flow_matching

ls_mlkit.flow_matching

BaseFlow

Bases: BaseGenerativeModel

abstract method: prior_sampling, compute_loss, step, sampling, inpainting

Source code in src/ls_mlkit/flow_matching/base_fm.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
@inherit_docstrings
class BaseFlow(BaseGenerativeModel):
    """
    abstract method: prior_sampling, compute_loss, step, sampling, inpainting
    """

    def __init__(
        self,
        config: BaseFlowConfig,
        time_scheduler: FlowMatchingTimeScheduler,
    ) -> None:
        super().__init__(config=config)
        self.config: BaseFlowConfig = config
        self.time_scheduler: FlowMatchingTimeScheduler = time_scheduler