Source code for ls_mlkit.util.sde.base_sde
r"""Abstract SDE classes
Note:
``t`` is always continous time step in this module.
"""
import abc
from typing import Tuple
from torch import Tensor
[docs]
class SDE(abc.ABC):
r"""
SDE abstract class. Functions are designed for a mini-batch of inputs.
"""
def __init__(self, ndim_micro_shape: int = 2):
r"""Initialize the SDE
Args:
ndim_micro_shape (``int``, *optional*): number of dimensions of a sample.
e.g. for image with shape ``[b, c, h, w]``, ndim_micro_shape = 3
e.g. for protein with shape ``[b, n_res, 3]``, ndim_micro_shape = 2
"""
super().__init__()
self.ndim_micro_shape = ndim_micro_shape
@property
@abc.abstractmethod
def T(self) -> float:
r"""End time of the SDE."""
[docs]
@abc.abstractmethod
def get_drift_and_diffusion(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]:
r"""Get the drift and diffusion of the SDE.
Args:
x (``Tensor``): the sample.
t (``Tensor``): the time step.
mask (``Tensor``, *optional*): the mask of the sample. Defaults to None.
Returns:
``Tuple[Tensor, Tensor]``: the drift and diffusion of the SDE.
"""
[docs]
def get_reverse_sde(self, score=None, score_fn: object = None, use_probability_flow=False):
r"""Create the reverse-time SDE/ODE.
Args:
score_fn: A time-dependent score-based model that takes (x ,t, mask) and returns the score.
use_probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
"""
T = self.T
ndim_micro_shape = self.ndim_micro_shape
get_forward_drift_and_diffusion = self.get_drift_and_diffusion
# get_forward_discretized_drift_and_diffusion = self.get_discretized_drift_and_diffusion
class RSDE(self.__class__):
def __init__(self):
self.use_probability_flow = use_probability_flow
self.ndim_micro_shape = ndim_micro_shape
def get_drift_and_diffusion(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]:
r"""
Create the drift and diffusion functions for the reverse SDE/ODE.
$$
\begin{align*}
dx = (f(x,t) - g(x,t)^2 \nabla_x \log p_t(x)) dt + g(x,t) dw
\end{align*}
$$
if use ODE probability flow:
$$
\begin{align*}
dx = (f(x,t) - \frac{1}{2} g(x,t)^2 \nabla_x \log p_t(x)) dt
\end{align*}
$$
"""
nonlocal score, score_fn
assert score is not None or score_fn is not None, "either score or score_fn must be provided"
if score is None:
score = score_fn(x, t, mask)
drift, diffusion = get_forward_drift_and_diffusion(x, t, mask=mask)
rev_diffusion = 0.0 if self.use_probability_flow else diffusion
diffusion = diffusion.view(
*x.shape[: -self.ndim_micro_shape], *[1 for _ in range(self.ndim_micro_shape)]
)
rev_drift = drift - diffusion**2 * score * (0.5 if self.use_probability_flow else 1.0)
# Set the diffusion function to zero for ODEs.
return rev_drift, rev_diffusion
return RSDE()