Source code for ls_mlkit.diffusion.euclidean_vpsde_diffuser

from typing import Any, Callable, Tuple, cast

import torch
from torch import Tensor
from torch.nn import Module

from ..util.base_class.base_gm_class import GMHook, GMHookStageType
from ..util.decorators import inherit_docstrings
from ..util.mask.masker_interface import MaskerInterface
from ..util.sde.corrector import LangevinCorrector
from ..util.sde.sde_lib import VPSDE
from .conditioner import Conditioner
from .conditioner.utils import get_accumulated_conditional_score
from .euclidean_diffuser import EuclideanDiffuser, EuclideanDiffuserConfig
from .time_scheduler import DiffusionTimeScheduler


[docs] @inherit_docstrings class EuclideanVPSDEConfig(EuclideanDiffuserConfig): def __init__( self, n_discretization_steps: int = 1000, ndim_micro_shape: int = 2, use_probability_flow=False, beta_min: float = 0.1, beta_max: float = 20, n_correct_steps: int = 1, snr: float = 1.0, *args, **kwargs, ): super().__init__( n_discretization_steps=n_discretization_steps, ndim_micro_shape=ndim_micro_shape, ) self.beta_min = beta_min self.beta_max = beta_max self.sde = VPSDE( beta_min=beta_min, beta_max=beta_max, ndim_micro_shape=ndim_micro_shape, ) self.use_probability_flow = use_probability_flow self.n_correct_steps = n_correct_steps self.snr = snr
[docs] @inherit_docstrings class EuclideanVPSDEDiffuser(EuclideanDiffuser): def __init__( self, config: EuclideanVPSDEConfig, time_scheduler: DiffusionTimeScheduler, masker: MaskerInterface, model: Module, loss_fn: Callable[[Tensor, Tensor, Tensor], Tensor], # (predicted, ground_true, padding_mask) ): """Initialize the EuclideanVPSDEDiffuser Args: config (EuclideanVPSDEConfig): the config of the diffuser time_scheduler (DiffusionTimeScheduler): the time scheduler of the diffuser masker (MaskerInterface): the masker of the diffuser model (Module): the model of the diffuser loss_fn (Callable[[Tensor, Tensor, Tensor], Tensor]): the loss function of the diffuser Returns: None """ super().__init__(config=config, time_scheduler=time_scheduler, masker=masker) self.config: EuclideanVPSDEConfig = config self.sde = config.sde self.model = model self.loss_fn = loss_fn def score_fn(x: Tensor, t: Tensor, mask: Tensor) -> Tensor: return self.model(x, t.long(), mask)["x"] self.corrector = LangevinCorrector( sde=self.sde, score_fn=score_fn, snr=self.config.snr, n_steps=self.config.n_correct_steps, ndim_micro_shape=self.config.ndim_micro_shape, )
[docs] def prior_sampling(self, shape: Tuple[int, ...]) -> Tensor: return self.sde.prior_sampling(shape)
[docs] def forward_process( self, x_0: Tensor, discrete_t: Tensor, mask: Tensor, *args: list[Any], **kwargs: dict[Any, Any] ) -> dict: t = self.time_scheduler.discrete_time_to_continuous_time(discrete_t) forward_result = self.sde.forward_process(x_0, t, mask) return { "x_t": forward_result["x_t"], "mean": forward_result["mean"], "std": forward_result["std"], "a": forward_result["a"], "b": forward_result["b"], }
[docs] def compute_loss(self, batch: dict[str, Any], *args: Any, **kwargs: Any) -> dict: batch = self.model.prepare_batch_data_for_input(batch) assert isinstance(batch, dict), "batch must be a dictionary" x_0 = batch["gt_data"] padding_mask = batch["padding_mask"] device = x_0.device macro_shape = self.get_macro_shape(x_0) t = batch.get("t", None) if t is None: t = self.time_scheduler.sample_a_discrete_time_step_uniformly(macro_shape).to(device) self.config = self.config.to(t) forward_result = self.forward_process(x_0, t, padding_mask) x_t = forward_result["x_t"] mean = forward_result["mean"] std = forward_result["std"] a = forward_result["a"] b = forward_result["b"] gt_uc_score = self.sde.get_score(x_t=x_t, mean=mean, std=std) model_input_dict = batch model_input_dict.pop("gt_data") model_input_dict.pop("padding_mask") model_input_dict.pop("t", None) model_output = self.model(x_t, t, padding_mask, **model_input_dict) p_uc_score = model_output["x"] gt_uc_score = b * gt_uc_score p_uc_score = b * p_uc_score loss = self.loss_fn(p_uc_score, gt_uc_score, padding_mask) return { "loss": loss, "gt_data": x_0, "t": t, "x_t": x_t, "padding_mask": padding_mask, "gt_uc_score": gt_uc_score, "p_uc_score": p_uc_score, "a": a, "b": b, "loss_fn": self.loss_fn, "config": self.config, }
[docs] def forward_process_n_step( self, x: Tensor, t: Tensor, next_t: Tensor, padding_mask: Tensor, *args: Any, **kwargs: Any ) -> Tensor: assert (next_t > t).all() assert (t >= 0).all() assert (next_t < self.config.n_discretization_steps).all() continuous_t1 = self.time_scheduler.discrete_time_to_continuous_time(t) continuous_t2 = self.time_scheduler.discrete_time_to_continuous_time(next_t) x_t2 = self.sde.forward_from_t1_to_t2(x, continuous_t1, continuous_t2) return x_t2
[docs] def step(self, x_t: Tensor, t: Tensor, padding_mask: Tensor, *args: Any, **kwargs: Any) -> dict: r""" Args: x_t (Tensor): the sample at timestep t t (Tensor): the timestep padding_mask (Tensor): the padding mask Returns: Tensor: the sample at timestep t-1 """ assert torch.all(t == t.view(-1)[0]).item() device = x_t.device idx = kwargs.get("idx") ones = torch.ones_like(t) t_start = self.time_scheduler.get_continuous_timesteps_schedule().to(device)[idx] * ones t_end = self.time_scheduler.get_continuous_timesteps_schedule().to(device)[idx + 1] * ones config = cast(EuclideanVPSDEConfig, self.config.to(device)) model_output = self.model(x_t, t.long(), padding_mask, *args, **kwargs) p_uc_score = model_output["x"] # score hook start===================================================== hook_input = { "p_uc_score": p_uc_score, "x_t": x_t, "t": t, "padding_mask": padding_mask, "config": config, "sampling_condition": kwargs.get("sampling_condition"), } hook_output = self.hook_manager.run_hooks(GMHookStageType.PRE_UPDATE_IN_STEP_FN, **hook_input) if hook_output is not None: p_uc_score = hook_output # score hook start end ================================================================= rsde = self.sde.get_reverse_sde( score=p_uc_score, score_fn=None, use_probability_flow=self.config.use_probability_flow ) delta_t = t_end - t_start delta_t = self.complete_micro_shape(delta_t) f, g = rsde.get_drift_and_diffusion(x_t, t_start, mask=padding_mask) g = self.complete_micro_shape(g) z = torch.randn_like(x_t) x_mean = x_t + f * delta_t if (t > 0).all(): x = x_mean + g * z * torch.sqrt(delta_t.abs()) else: x = x_mean if (t > 0).all(): x, _ = self.corrector.update_fn(x, t - 1, padding_mask) return { "x": x, }
[docs] def get_posterior_mean_fn(self, score: Tensor = None, score_fn: Callable = None): r"""Get the posterior mean function Args: score (Tensor, optional): the score of the sample score_fn (Callable, optional): the function to compute score Returns: Callable: the posterior mean function """ def _posterior_mean_fn( x_t: Tensor, t: Tensor, padding_mask: Tensor, ): r""" Args: x_t: shape=(..., n_nodes, 3) t: shape=(...), dtype=torch.long For the case of VPSDE sampling, the posterior mean is given by .. math:: E[x_0|x_t] = \frac{b^2}{a} \nabla_{x_t}\log p_t(x_t) - \frac{x_t}{a} """ nonlocal score, score_fn assert score is not None or score_fn is not None, "either score or score_fn must be provided" if score is None: score = score_fn(x_t, t, padding_mask) sde = cast(EuclideanVPSDEConfig, self.config.to(t)).sde t = self.time_scheduler.discrete_time_to_continuous_time(t) a, b = sde.get_a_b(t) E_x0_xt = b**2 / a * score + x_t / a return E_x0_xt return _posterior_mean_fn
[docs] def get_condition_post_compute_loss_hook(self, conditioner_list: list[Conditioner]): def _hook_fn(**kwargs): nonlocal conditioner_list loss = kwargs.get("loss") x_0 = kwargs.get("gt_data") x_t = kwargs.get("x_t") t = kwargs.get("t", None) padding_mask = kwargs.get("padding_mask") loss_fn = kwargs.get("loss_fn") config = kwargs.get("config") p_uc_score = kwargs.get("p_uc_score") gt_uc_score = kwargs.get("gt_uc_score") a = kwargs.get("a") b = kwargs.get("b") tgt_mask = padding_mask for conditioner in conditioner_list: if not conditioner.is_enabled(): continue conditioner.set_condition( **{ **conditioner.prepare_condition_dict( train=True, **{ "tgt_mask": tgt_mask, "gt_data": x_0, "padding_mask": padding_mask, "posterior_mean_fn": self.get_posterior_mean_fn(score=p_uc_score, score_fn=None), }, ), } ) acc_c_score = get_accumulated_conditional_score(conditioner_list, x_t, t, padding_mask) gt_score = gt_uc_score + acc_c_score # Scale and compute conditioned loss p_uc_score = b * p_uc_score gt_score = b * gt_score total_loss = loss_fn(p_uc_score, gt_score, padding_mask) kwargs["loss"] = total_loss return kwargs return GMHook( name="VPSDE_condition_post_compute_loss_hook", stage=GMHookStageType.POST_COMPUTE_LOSS, fn=_hook_fn, priority=0, enabled=True, )
[docs] def get_condition_pre_update_in_step_fn_hook(self, conditioner_list: list[Conditioner]): def _hook_fn(**kwargs): nonlocal conditioner_list p_uc_score = kwargs.get("p_uc_score") x_t = kwargs.get("x_t") t = kwargs.get("t", None) padding_mask = kwargs.get("padding_mask") config = kwargs.get("config") sampling_condition = kwargs.get("sampling_condition") tgt_mask = padding_mask for conditioner in conditioner_list: if not conditioner.is_enabled(): continue conditioner.set_condition( **{ **conditioner.prepare_condition_dict( train=False, **{ "tgt_mask": tgt_mask, "sampling_condition": sampling_condition, "padding_mask": padding_mask, "posterior_mean_fn": self.get_posterior_mean_fn(score=p_uc_score, score_fn=None), }, ), } ) acc_c_score = get_accumulated_conditional_score(conditioner_list, x_t, t, padding_mask) p_score = p_uc_score + acc_c_score return p_score return GMHook( name="VPSDE_condition_pre_update_in_step_fn_hook", stage=GMHookStageType.PRE_UPDATE_IN_STEP_FN, fn=_hook_fn, priority=0, enabled=True, )