ls_mlkit.util.sde.base_sde module

Abstract SDE classes

Note

t is always continous time step in this module.

class ls_mlkit.util.sde.base_sde.SDE(ndim_micro_shape: int = 2)[source]

Bases: ABC

SDE abstract class. Functions are designed for a mini-batch of inputs.

abstract property T: float

End time of the SDE.

abstractmethod get_drift_and_diffusion(x: Tensor, t: Tensor, mask=None) Tuple[Tensor, Tensor][source]

Get the drift and diffusion of the SDE.

Parameters:
  • x (Tensor) – the sample.

  • t (Tensor) – the time step.

  • mask (Tensor, optional) – the mask of the sample. Defaults to None.

Returns:

the drift and diffusion of the SDE.

Return type:

Tuple[Tensor, Tensor]

get_reverse_sde(score=None, score_fn: object = None, use_probability_flow=False)[source]

Create the reverse-time SDE/ODE.

Parameters:
  • score_fn – A time-dependent score-based model that takes (x ,t, mask) and returns the score.

  • use_probability_flow – If True, create the reverse-time ODE used for probability flow sampling.