Source code for ls_mlkit.util.sde.score_fn_utils

import torch

from .sde_lib import VESDE, VPSDE, SubVPSDE


[docs] def get_model_fn(model, train=False): """Create a function to give the output of the score-based model. Args: model: The score model. train: `True` for training and `False` for evaluation. Returns: A model function. """ def model_fn(x, labels): """Compute the output of the score-based model. Args: x: A mini-batch of input data. labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently for different models. Returns: A tuple of (model output, new mutable states) """ if not train: model.eval() return model(x, labels) else: model.train() return model(x, labels) return model_fn
[docs] def get_score_fn(sde, model, train=False, continuous=False): """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function. Args: sde: An `sde_lib.SDE` object that represents the forward SDE. model: A score model. train: `True` for training and `False` for evaluation. continuous: If `True`, the score-based model is expected to directly take continuous time steps. Returns: A score function. """ model_fn = get_model_fn(model, train=train) if isinstance(sde, VPSDE) or isinstance(sde, SubVPSDE): def score_fn(x, t): # Scale neural network output by standard deviation and flip sign if continuous or isinstance(sde, SubVPSDE): # For VP-trained models, t=0 corresponds to the lowest noise level # The maximum value of time embedding is assumed to 999 for # continuously-trained models. labels = t * 999 score = model_fn(x, labels) std = sde.marginal_prob(torch.zeros_like(x), t)[1] else: # For VP-trained models, t=0 corresponds to the lowest noise level labels = t * (sde.N - 1) score = model_fn(x, labels) std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()] score = -score / std[:, None, None, None] return score elif isinstance(sde, VESDE): def score_fn(x, t): if continuous: labels = sde.marginal_prob(torch.zeros_like(x), t)[1] else: # For VE-trained models, t=0 corresponds to the highest noise level labels = sde.T - t labels *= sde.N - 1 labels = torch.round(labels).long() score = model_fn(x, labels) return score else: raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") return score_fn