Source code for ls_mlkit.diffusion.base_diffuser
from abc import abstractmethod
from typing import Any
from torch import Tensor
from ..util.base_class.base_gm_class import BaseGenerativeModel, BaseGenerativeModelConfig
from ..util.decorators import inherit_docstrings
from .time_scheduler import DiffusionTimeScheduler
[docs]
@inherit_docstrings
class BaseDiffuserConfig(BaseGenerativeModelConfig):
def __init__(
self,
ndim_micro_shape: int,
n_discretization_steps: int,
n_inference_steps: int = None,
use_batch_flattening: bool = False,
*args: list[Any],
**kwargs: dict[Any, Any],
) -> None:
super().__init__(
ndim_micro_shape=ndim_micro_shape,
n_discretization_steps=n_discretization_steps,
n_inference_steps=n_inference_steps,
use_batch_flattening=use_batch_flattening,
*args,
**kwargs,
)
[docs]
@inherit_docstrings
class BaseDiffuser(BaseGenerativeModel):
"""
abstract method:
"""
def __init__(
self,
config: BaseDiffuserConfig,
time_scheduler: DiffusionTimeScheduler,
) -> None:
r"""Initialize the BaseDiffuser
Args:
config (``BaseDiffuserConfig``): the config of the diffuser
time_scheduler (``DiffusionTimeScheduler``): the time scheduler of the diffuser
"""
super().__init__(config=config)
self.config: BaseDiffuserConfig = config
self.time_scheduler: DiffusionTimeScheduler = time_scheduler
[docs]
@abstractmethod
def forward_process(
self, x_0: Tensor, discrete_t: Tensor, mask: Tensor, *args: list[Any], **kwargs: dict[Any, Any]
) -> dict:
r"""Forward process, from :math:`x_0` to :math:`x_t`
Args:
x_0 (``Tensor``): :math:`x_0`
discrete_t (``Tensor``): the discrete time steps :math:`t`
mask (``Tensor``): the mask of the sample
Returns:
``dict``: a dictionary that must contain the key "x_t"
"""