ls_mlkit.util.sde.base_sde module¶
Abstract SDE classes
Note
t is always continous time step in this module.
- class ls_mlkit.util.sde.base_sde.SDE(ndim_micro_shape: int = 2)[source]¶
Bases:
ABCSDE abstract class. Functions are designed for a mini-batch of inputs.
- abstract property T: float¶
End time of the SDE.
- abstractmethod get_drift_and_diffusion(x: Tensor, t: Tensor, mask=None) Tuple[Tensor, Tensor][source]¶
Get the drift and diffusion of the SDE.
- Parameters:
x (
Tensor) – the sample.t (
Tensor) – the time step.mask (
Tensor, optional) – the mask of the sample. Defaults to None.
- Returns:
the drift and diffusion of the SDE.
- Return type:
Tuple[Tensor, Tensor]
- get_reverse_sde(score=None, score_fn: object = None, use_probability_flow=False)[source]¶
Create the reverse-time SDE/ODE.
- Parameters:
score_fn – A time-dependent score-based model that takes (x ,t, mask) and returns the score.
use_probability_flow – If True, create the reverse-time ODE used for probability flow sampling.