Source code for ls_mlkit.diffusion.lie_group_diffuser

r"""
Lie Group Diffuser
"""

from typing import Any

from ..util.decorators import inherit_docstrings
from ..util.manifold.lie_group import LieGroup
from .manifold_diffuser import RiemannianManifoldDiffuser, RiemannianManifoldDiffuserConfig
from .time_scheduler import DiffusionTimeScheduler


[docs] @inherit_docstrings class LieGroupDiffuserConfig(RiemannianManifoldDiffuserConfig): def __init__( self, ndim_micro_shape: int, n_discretization_steps: int, n_inference_steps: int, *args: list[Any], **kwargs: dict[Any, Any], ): super().__init__( ndim_micro_shape=ndim_micro_shape, n_discretization_steps=n_discretization_steps, n_inference_steps=n_inference_steps, *args, **kwargs, )
[docs] @inherit_docstrings class LieGroupDiffuser(RiemannianManifoldDiffuser): def __init__( self, config: LieGroupDiffuserConfig, time_scheduler: DiffusionTimeScheduler, lie_group: LieGroup, ): super().__init__( config=config, time_scheduler=time_scheduler, riemannian_manifold=lie_group, ) self.lie_group = lie_group