Source code for ls_mlkit.util.sde.sde_lib

from typing import Tuple

import numpy as np
import torch
from torch import Tensor

from .base_sde import SDE


[docs] class VPSDE(SDE): def __init__(self, beta_min: float = 0.1, beta_max: float = 20, ndim_micro_shape: int = 2): r"""Construct a Variance Preserving SDE. Args: beta_min: value of beta(0) beta_max: value of beta(1) ndim_micro_shape: number of dimensions of a sample """ super().__init__(ndim_micro_shape=ndim_micro_shape) self.beta_0 = beta_min self.beta_1 = beta_max @property def T(self) -> float: return 1
[docs] def get_drift_and_diffusion(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]: r"""continuous DDPM SDE .. math:: dx &= -\frac{1}{2}\beta_t x dt + \sqrt{\beta_t} dw Args: x: t: (macro_shape) mask: Returns: drift: shape = x.shape diffusion: shape=x.macro_shape """ macro_shape = x.shape[: -self.ndim_micro_shape] beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0) drift = -0.5 * beta_t.view(*macro_shape, *[1 for _ in range(self.ndim_micro_shape)]) * x diffusion = torch.sqrt(beta_t) return drift, diffusion
[docs] def get_score(self, x_t, mean, std) -> Tensor: r""" .. math:: p_{0t} (x_t|x_0) = \nabla_{x_t} \ln p_{0t} (x_t|x_0) """ score = -(x_t - mean) / std**2 return score
[docs] def get_a_b(self, t: Tensor) -> Tuple[Tensor, Tensor]: """x_t = a * x_0 + b * epsilon, epsilon ~ N(0, 1) Args: t (``Tensor``): continuous time Returns: ``Tuple[Tensor, Tensor]``: a, b """ macro_shape = t.shape log_mean_coeff = -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 # mcro_shape log_mean_coeff = log_mean_coeff.view(*macro_shape, *[1 for _ in range(self.ndim_micro_shape)]) a = torch.exp(log_mean_coeff) b = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) return a, b
[docs] def forward_process(self, x_0: Tensor, t: Tensor, mask: Tensor = None) -> Tuple[Tensor, Tensor]: r""" .. math:: p_{0t} (x_t|x_0) .. math:: \gamma = -\frac{1}{4}t^2 (\beta_1 - \beta_0) - \frac{1}{2} t \beta_0 mean = e^{\gamma} * x std = \sqrt{1 - e^{2 \gamma }} """ a, b = self.get_a_b(t) mean = a * x_0 x_t = mean + b * torch.randn_like(x_0) return { "x_t": x_t, "mean": mean, "std": b, "a": a, "b": b, }
[docs] def forward_from_t1_to_t2(self, x_t1: Tensor, t1: Tensor, t2: Tensor) -> Tensor: assert (t1 <= t2).all(), "t1 must be less than or equal to t2" a1, b1 = self.get_a_b(t1) a2, b2 = self.get_a_b(t2) a12 = a2 / a1 b12 = a2 * torch.sqrt((b2 / a2) ** 2 - (b1 / a1) ** 2) x_t2 = a12 * x_t1 + b12 * torch.randn_like(x_t1) return x_t2
[docs] def prior_sampling(self, shape: Tuple) -> Tensor: r""" .. math:: \epsilon \sim \mathbfcal{N}(0,1) """ return torch.randn(*shape)
[docs] def prior_logp(self, z: torch.Tensor) -> Tensor: r""" .. math:: (2\pi)^{-k/2} \det(\Sigma)^{-1/2} \exp\left( -\frac{1}{2} (\mathbf{x} - \boldsymbol{\mu})^\mathrm{T} \Sigma^{-1} (\mathbf{x} - \boldsymbol{\mu}) \right) where :math:`\Sigma = I` and :math:`\mathbf{\mu} = 0` """ shape = z.shape N = np.prod(shape[1:]) logps = -N / 2.0 * np.log(2 * np.pi) - torch.sum(z**2, dim=(1, 2, 3)) / 2.0 return logps
[docs] class SubVPSDE(SDE): def __init__(self, beta_min: float = 0.1, beta_max: float = 20, ndim_micro_shape: int = 2): """Construct the sub-VP SDE that excels at likelihoods. Args: beta_min: value of beta(0) beta_max: value of beta(1) n_discretization_steps: number of discretization steps ndim_micro_shape: number of dimensions of a sample """ super().__init__(ndim_micro_shape=ndim_micro_shape) self.beta_0 = beta_min self.beta_1 = beta_max @property def T(self) -> float: return 1
[docs] def get_drift_and_diffusion(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]: beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0) macro_shape = x.shape[: -self.ndim_micro_shape] beta_t = beta_t.view(*macro_shape, *[1 for _ in range(self.ndim_micro_shape)]) drift = -0.5 * beta_t * x discount = 1.0 - torch.exp(-2 * self.beta_0 * t - (self.beta_1 - self.beta_0) * t**2) diffusion = torch.sqrt(beta_t * discount) return drift, diffusion
[docs] def marginal_prob(self, x, t, mask=None): log_mean_coeff = -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 macro_shape = x.shape[: -self.ndim_micro_shape] log_mean_coeff = log_mean_coeff.view(*macro_shape, *[1 for _ in range(self.ndim_micro_shape)]) mean = torch.exp(log_mean_coeff) * x std = 1 - torch.exp(2.0 * log_mean_coeff) return mean, std
[docs] def prior_sampling(self, shape): return torch.randn(*shape)
[docs] def prior_logp(self, z): shape = z.shape N = np.prod(shape[1:]) return -N / 2.0 * np.log(2 * np.pi) - torch.sum(z**2, dim=(1, 2, 3)) / 2.0
[docs] class VESDE(SDE): def __init__( self, sigma_min=0.01, sigma_max=50, n_discretization_steps=1000, ndim_micro_shape=2, drop_first_step=False ): """Construct a Variance Exploding SDE. Args: sigma_min: smallest sigma. sigma_max: largest sigma. n_discretization_steps: number of discretization steps ndim_micro_shape: number of dimensions of a sample """ super().__init__(n_discretization_steps=n_discretization_steps, ndim_micro_shape=ndim_micro_shape) self.sigma_min = sigma_min self.sigma_max = sigma_max self.drop_first_step = drop_first_step sigma_min = torch.tensor(sigma_min) sigma_max = torch.tensor(sigma_max) if drop_first_step: self.discrete_sigmas = ( 10 ** torch.linspace(torch.log10(sigma_min), torch.log10(sigma_max), n_discretization_steps + 1)[1:] ) else: self.discrete_sigmas = torch.exp( torch.linspace(torch.log(sigma_min), torch.log(sigma_max), n_discretization_steps) ) @property def T(self) -> float: return 1
[docs] def get_drift_and_diffusion(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]: r""" .. math:: dx = 0 dt + \sigma_{min} \left(\frac{\sigma_{max}}{\sigma_{min}}\right)^t \sqrt{2 \log(\frac{\sigma_{max}}{\sigma_{min}})} dw \sigma_t = \sigma_{min} \left(\frac{\sigma_{max}}{\sigma_{min}}\right)^t diffusion = \sigma_t * \sqrt{2 \log(\frac{\sigma_{max}}{\sigma_{min}})} """ sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t drift = torch.zeros_like(x) diffusion = sigma * torch.sqrt( torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)), device=t.device) ) return drift, diffusion
[docs] def get_discretized_drift_and_diffusion(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]: r"""SMLD(NCSN) discretization. .. math:: x_t &= x_0 + g \epsilon x_t &\sim \mathcal{N}(x_0, \sigma_t^2) \sigma_t^2 &= \sigma_{t-1}^2 + g^2 g &= \sqrt{\sigma_t^2 - \sigma_{t-1}^2} """ timestep = (t * (self.n_discretization_steps - 1) / self.T).long() sigma = self.discrete_sigmas.to(t.device)[timestep] adjacent_sigma = torch.where( timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(t.device) ) f = torch.zeros_like(x) g = torch.sqrt(sigma**2 - adjacent_sigma**2) return f, g
[docs] def marginal_prob(self, x, t, mask=None): std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t mean = x return mean, std
[docs] def prior_sampling(self, shape): return torch.randn(*shape) * self.sigma_max
[docs] def prior_logp(self, z): shape = z.shape N = np.prod(shape[1:]) return -N / 2.0 * np.log(2 * np.pi * self.sigma_max**2) - torch.sum(z**2, dim=(1, 2, 3)) / ( 2 * self.sigma_max**2 )