ls_mlkit.util.sde package

Submodules

Module contents

class ls_mlkit.util.sde.Corrector(sde: SDE, score_fn: object, snr: float, n_steps: int)[source]

Bases: ABC

The abstract class for a corrector algorithm.

abstractmethod update_fn(x: Tensor, t: Tensor, mask=None)[source]

One update of the corrector.

Parameters:
  • x – A PyTorch tensor representing the current state

  • t – A PyTorch tensor representing the current time step.

Returns:

A PyTorch tensor of the next state. x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.

Return type:

x

class ls_mlkit.util.sde.LangevinCorrector(sde: SDE, score_fn: object, snr: float, n_steps: int, ndim_micro_shape: int = 3)[source]

Bases: Corrector

update_fn(x: Tensor, t: Tensor, mask=None)[source]

One update of the corrector.

Parameters:
  • x – A PyTorch tensor representing the current state

  • t – A PyTorch tensor representing the current time step.

Returns:

A PyTorch tensor of the next state. x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.

Return type:

x

class ls_mlkit.util.sde.NoneCorrector(sde, score_fn, snr, n_steps)[source]

Bases: Corrector

An empty corrector that does nothing.

update_fn(x, t, mask=None)[source]

One update of the corrector.

Parameters:
  • x – A PyTorch tensor representing the current state

  • t – A PyTorch tensor representing the current time step.

Returns:

A PyTorch tensor of the next state. x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.

Return type:

x

class ls_mlkit.util.sde.NonePredictor(sde, score_fn, use_probability_flow=False)[source]

Bases: Predictor

update_fn(x, t, mask=None)[source]

One update of the predictor.

Parameters:
  • x – A PyTorch tensor representing the current state

  • t – A Pytorch tensor representing the current time step.

Returns:

A PyTorch tensor of the next state. x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.

Return type:

x

class ls_mlkit.util.sde.Predictor(sde: SDE, score_fn: object, use_probability_flow=False)[source]

Bases: ABC

abstractmethod update_fn(x: Tensor, t: Tensor, mask=None) Tuple[Tensor, Tensor][source]

One update of the predictor.

Parameters:
  • x – A PyTorch tensor representing the current state

  • t – A Pytorch tensor representing the current time step.

Returns:

A PyTorch tensor of the next state. x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.

Return type:

x

class ls_mlkit.util.sde.ReverseDiffusionPredictor(sde: SDE, score_fn, use_probability_flow=False, n_dim: int = 3)[source]

Bases: Predictor

update_fn(x: Tensor, t: Tensor, mask=None) Tuple[Tensor, Tensor][source]
\[ \begin{align}\begin{aligned}x_{t+\Delta t} &= x_t + f(x_t, t)(\Delta t) + g(x_t, t) \epsilon, \epsilon \sim \mathcal{N}(0,\sqrt{\Delta t}))\\f &= f(x_t, t)|\Delta t|\\g &= g(x_t, t)\sqrt{|\Delta t|}\end{aligned}\end{align} \]
class ls_mlkit.util.sde.SDE(ndim_micro_shape: int = 2)[source]

Bases: ABC

SDE abstract class. Functions are designed for a mini-batch of inputs.

abstract property T: float

End time of the SDE.

abstractmethod get_drift_and_diffusion(x: Tensor, t: Tensor, mask=None) Tuple[Tensor, Tensor][source]

Get the drift and diffusion of the SDE.

Parameters:
  • x (Tensor) – the sample.

  • t (Tensor) – the time step.

  • mask (Tensor, optional) – the mask of the sample. Defaults to None.

Returns:

the drift and diffusion of the SDE.

Return type:

Tuple[Tensor, Tensor]

get_reverse_sde(score=None, score_fn: object = None, use_probability_flow=False)[source]

Create the reverse-time SDE/ODE.

Parameters:
  • score_fn – A time-dependent score-based model that takes (x ,t, mask) and returns the score.

  • use_probability_flow – If True, create the reverse-time ODE used for probability flow sampling.

class ls_mlkit.util.sde.SubVPSDE(beta_min: float = 0.1, beta_max: float = 20, ndim_micro_shape: int = 2)[source]

Bases: SDE

property T: float

End time of the SDE.

get_drift_and_diffusion(x: Tensor, t: Tensor, mask=None) Tuple[Tensor, Tensor][source]

Get the drift and diffusion of the SDE.

Parameters:
  • x (Tensor) – the sample.

  • t (Tensor) – the time step.

  • mask (Tensor, optional) – the mask of the sample. Defaults to None.

Returns:

the drift and diffusion of the SDE.

Return type:

Tuple[Tensor, Tensor]

marginal_prob(x, t, mask=None)[source]
prior_logp(z)[source]
prior_sampling(shape)[source]
class ls_mlkit.util.sde.VESDE(sigma_min=0.01, sigma_max=50, n_discretization_steps=1000, ndim_micro_shape=2, drop_first_step=False)[source]

Bases: SDE

property T: float

End time of the SDE.

get_discretized_drift_and_diffusion(x: Tensor, t: Tensor, mask=None) Tuple[Tensor, Tensor][source]

SMLD(NCSN) discretization. .. math:

x_t &= x_0 + g \epsilon

x_t &\sim \mathcal{N}(x_0, \sigma_t^2)

\sigma_t^2 &= \sigma_{t-1}^2 + g^2

g &= \sqrt{\sigma_t^2 - \sigma_{t-1}^2}
get_drift_and_diffusion(x: Tensor, t: Tensor, mask=None) Tuple[Tensor, Tensor][source]
\[ \begin{align}\begin{aligned}dx = 0 dt + \sigma_{min} \left(\frac{\sigma_{max}}{\sigma_{min}}\right)^t \sqrt{2 \log(\frac{\sigma_{max}}{\sigma_{min}})} dw \sigma_t = \sigma_{min} \left(\frac{\sigma_{max}}{\sigma_{min}}\right)^t\\diffusion = \sigma_t * \sqrt{2 \log(\frac{\sigma_{max}}{\sigma_{min}})}\end{aligned}\end{align} \]
marginal_prob(x, t, mask=None)[source]
prior_logp(z)[source]
prior_sampling(shape)[source]
class ls_mlkit.util.sde.VPSDE(beta_min: float = 0.1, beta_max: float = 20, ndim_micro_shape: int = 2)[source]

Bases: SDE

property T: float

End time of the SDE.

forward_from_t1_to_t2(x_t1: Tensor, t1: Tensor, t2: Tensor) Tensor[source]
forward_process(x_0: Tensor, t: Tensor, mask: Tensor = None) Tuple[Tensor, Tensor][source]
\[p_{0t} (x_t|x_0)\]
\[ \begin{align}\begin{aligned}\gamma = -\frac{1}{4}t^2 (\beta_1 - \beta_0) - \frac{1}{2} t \beta_0\\mean = e^{\gamma} * x\\std = \sqrt{1 - e^{2 \gamma }}\end{aligned}\end{align} \]
get_a_b(t: Tensor) Tuple[Tensor, Tensor][source]

x_t = a * x_0 + b * epsilon, epsilon ~ N(0, 1)

Parameters:

t (Tensor) – continuous time

Returns:

a, b

Return type:

Tuple[Tensor, Tensor]

get_drift_and_diffusion(x: Tensor, t: Tensor, mask=None) Tuple[Tensor, Tensor][source]

continuous DDPM SDE

\[dx &= -\frac{1}{2}\beta_t x dt + \sqrt{\beta_t} dw\]
Parameters:
  • x

  • t – (macro_shape)

  • mask

Returns:

shape = x.shape diffusion: shape=x.macro_shape

Return type:

drift

get_score(x_t, mean, std) Tensor[source]
\[p_{0t} (x_t|x_0) = \nabla_{x_t} \ln p_{0t} (x_t|x_0)\]
prior_logp(z: Tensor) Tensor[source]
\[(2\pi)^{-k/2} \det(\Sigma)^{-1/2} \exp\left( -\frac{1}{2} (\mathbf{x} - \boldsymbol{\mu})^\mathrm{T} \Sigma^{-1} (\mathbf{x} - \boldsymbol{\mu}) \right)\]

where \(\Sigma = I\) and \(\mathbf{\mu} = 0\)

prior_sampling(shape: Tuple) Tensor[source]
\[\epsilon \sim \mathbfcal{N}(0,1)\]
ls_mlkit.util.sde.get_model_fn(model, train=False)[source]

Create a function to give the output of the score-based model.

Parameters:
  • model – The score model.

  • trainTrue for training and False for evaluation.

Returns:

A model function.

ls_mlkit.util.sde.get_pc_sampler(sde: SDE, shape: Tuple[int, ...], predictor_class: Predictor, corrector_class: Corrector, inverse_scaler: Callable, snr: float, n_correct_steps: int = 1, use_probability_flow: bool = False, denoise_at_final: bool = True, eps: float = 0.001, device: str = 'cuda')[source]

Create a Predictor-Corrector (PC) sampler.

Parameters:
  • sde – An SDE object representing the forward SDE.

  • shape – A sequence of integers. The expected shape of a single sample. First dimension is batch size.

  • predictor_class – A subclass of Predictor representing the predictor algorithm.

  • corrector_class – A subclass of Corrector representing the corrector algorithm.

  • inverse_scaler – The inverse data normalizer.

  • snr – A float number. The signal-to-noise ratio for configuring correctors.

  • n_correct_steps – An integer. The number of corrector steps per predictor update.

  • use_probability_flow – If True, solve the reverse-time probability flow ODE when running the predictor.

  • denoise_at_final – If True, add one-step denoising to the final samples.

  • eps – A float number. The reverse-time SDE and ODE are integrated to epsilon to avoid numerical issues.

  • device – PyTorch device.

Returns:

A sampling function that returns samples and the number of function evaluations during sampling.

ls_mlkit.util.sde.get_score_fn(sde, model, train=False, continuous=False)[source]

Wraps score_fn so that the model output corresponds to a real time-dependent score function.

Parameters:
  • sde – An sde_lib.SDE object that represents the forward SDE.

  • model – A score model.

  • trainTrue for training and False for evaluation.

  • continuous – If True, the score-based model is expected to directly take continuous time steps.

Returns:

A score function.