Source code for ls_mlkit.util.sde.score_fn_utils
import torch
from .sde_lib import VESDE, VPSDE, SubVPSDE
[docs]
def get_model_fn(model, train=False):
"""Create a function to give the output of the score-based model.
Args:
model: The score model.
train: `True` for training and `False` for evaluation.
Returns:
A model function.
"""
def model_fn(x, labels):
"""Compute the output of the score-based model.
Args:
x: A mini-batch of input data.
labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
for different models.
Returns:
A tuple of (model output, new mutable states)
"""
if not train:
model.eval()
return model(x, labels)
else:
model.train()
return model(x, labels)
return model_fn
[docs]
def get_score_fn(sde, model, train=False, continuous=False):
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
model: A score model.
train: `True` for training and `False` for evaluation.
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
Returns:
A score function.
"""
model_fn = get_model_fn(model, train=train)
if isinstance(sde, VPSDE) or isinstance(sde, SubVPSDE):
def score_fn(x, t):
# Scale neural network output by standard deviation and flip sign
if continuous or isinstance(sde, SubVPSDE):
# For VP-trained models, t=0 corresponds to the lowest noise level
# The maximum value of time embedding is assumed to 999 for
# continuously-trained models.
labels = t * 999
score = model_fn(x, labels)
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
else:
# For VP-trained models, t=0 corresponds to the lowest noise level
labels = t * (sde.N - 1)
score = model_fn(x, labels)
std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()]
score = -score / std[:, None, None, None]
return score
elif isinstance(sde, VESDE):
def score_fn(x, t):
if continuous:
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
else:
# For VE-trained models, t=0 corresponds to the highest noise level
labels = sde.T - t
labels *= sde.N - 1
labels = torch.round(labels).long()
score = model_fn(x, labels)
return score
else:
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
return score_fn