Source code for ls_mlkit.flow_matching.base_fm
from typing import Any
from ..util.base_class.base_gm_class import BaseGenerativeModel, BaseGenerativeModelConfig
from ..util.decorators import inherit_docstrings
from .time_scheduler import FlowMatchingTimeScheduler
[docs]
@inherit_docstrings
class BaseFlowConfig(BaseGenerativeModelConfig):
def __init__(
self,
ndim_micro_shape: int,
n_discretization_steps: int,
n_inference_steps: int = None,
*args: list[Any],
**kwargs: dict[Any, Any],
) -> None:
super().__init__(
ndim_micro_shape=ndim_micro_shape,
n_discretization_steps=n_discretization_steps,
n_inference_steps=n_inference_steps,
*args,
**kwargs,
)
[docs]
@inherit_docstrings
class BaseFlow(BaseGenerativeModel):
"""
abstract method: prior_sampling, compute_loss, step, sampling, inpainting
"""
def __init__(
self,
config: BaseFlowConfig,
time_scheduler: FlowMatchingTimeScheduler,
) -> None:
super().__init__(config=config)
self.config: BaseFlowConfig = config
self.time_scheduler: FlowMatchingTimeScheduler = time_scheduler