import math
from typing import Any, Callable, Tuple, cast
from scipy.constants import sigma
import torch
from torch import Tensor
from torch.nn import Module
from ..util.base_class.base_gm_class import GMHook, GMHookStageType
from ..util.context.temp_remove import TemporaryKeyRemover
from ..util.decorators import inherit_docstrings
from ..util.mask.masker_interface import MaskerInterface
from .conditioner import Conditioner
from .conditioner.utils import get_accumulated_conditional_score
from .euclidean_diffuser import EuclideanDiffuser, EuclideanDiffuserConfig
from .time_scheduler import DiffusionTimeScheduler
[docs]
@inherit_docstrings
class EuclideanEDMConfig(EuclideanDiffuserConfig):
"""
Config Class for Euclidean EDM Diffuser
"""
def __init__(
self,
n_discretization_steps: int = 200,
ndim_micro_shape: int = 2,
P_mean: float = -1.2,
P_std: float = 1.2,
sigma_data: float = 0.5,
sigma_min: float = 0.002,
sigma_max: float = 80.0,
rho: float = 7.0,
use_2nd_order_correction: bool = True,
use_ode_flow: bool = False,
S_churn: float = 0.0,
S_min: float = 0.0,
S_max: float = float("inf"),
S_noise: float = 1.0,
use_clip: bool = False,
clip_sample_range: float = 1.0,
use_dyn_thresholding: bool = False,
dynamic_thresholding_ratio=0.995,
sample_max_value: float = 1.0,
sigma_multiply_by_sigma_data: bool = False,
*args,
**kwargs,
):
r"""
Args:
n_discretization_steps: the number of discretization steps
ndim_micro_shape: the number of dimensions of the micro shape
P_mean: mean of the log-normal distribution for sampling sigma during training
P_std: standard deviation of the log-normal distribution for sampling sigma during training
sigma_data: expected standard deviation of the training data
sigma_min: minimum supported noise level
sigma_max: maximum supported noise level
rho: time step exponent for sampling schedule
Returns:
None
"""
super().__init__(
n_discretization_steps=n_discretization_steps,
ndim_micro_shape=ndim_micro_shape,
)
self.P_mean = P_mean
self.P_std = P_std
self.sigma_data = sigma_data
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.rho = rho
self.use_ode_flow = use_ode_flow
self.use_2nd_order_correction = use_2nd_order_correction
self.S_churn = S_churn
self.S_min = S_min
self.S_max = S_max
self.S_noise = S_noise
self.use_clip = use_clip
self.clip_sample_range = clip_sample_range
self.use_dyn_thresholding = use_dyn_thresholding
self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
self.sample_max_value = sample_max_value
self.sigma_multiply_by_sigma_data = sigma_multiply_by_sigma_data
step_indices = torch.arange(n_discretization_steps + 1, dtype=torch.float32)
self.sigma_schedule: Tensor = (
sigma_min ** (1 / rho)
+ (step_indices - 1) / (n_discretization_steps - 1) * (sigma_max ** (1 / rho) - sigma_min ** (1 / rho))
) ** rho
self.sigma_schedule[0] = 0.0
[docs]
def c_in(self, sigma: Tensor) -> Tensor:
return 1 / torch.sqrt(sigma**2 + self.sigma_data**2)
[docs]
def c_noise(self, sigma: Tensor) -> Tensor:
return 1 / 4 * torch.log(sigma)
[docs]
def c_skip(self, sigma: Tensor) -> Tensor:
return self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
[docs]
def c_out(self, sigma: Tensor) -> Tensor:
return sigma * self.sigma_data / torch.sqrt(sigma**2 + self.sigma_data**2)
[docs]
def sigma(self, t: Tensor, is_continuous_time: bool = True) -> Tensor:
if is_continuous_time:
return t
else:
return self.timestep_index_to_sigma(t)
[docs]
def timestep_index_to_sigma(self, timestep_index: Tensor) -> Tensor:
"""Convert discrete timesteps to sigma values.
Args:
discrete_t: discrete timesteps, shape=(...)
Returns:
sigma: noise levels, shape=(...)
"""
timestep_index = timestep_index.clamp(1, self.n_discretization_steps).long()
return self.sigma_schedule[timestep_index].to(timestep_index.device)
[docs]
def compute_loss_weight(self, sigma: Tensor) -> Tensor:
"""Compute EDM loss weight: (sigma² + sigma_data²) / (sigma * sigma_data)².
Args:
sigma: noise level, shape=(...)
Returns:
weight: the loss weight, shape=(...)
"""
return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
[docs]
def sampling_timestep_for_training(self, macro_shape: Tensor):
rnd_normal = torch.randn(macro_shape)
t = (self.P_mean + self.P_std * rnd_normal).exp()
if self.sigma_multiply_by_sigma_data:
t = t * self.sigma_data
return t
[docs]
@inherit_docstrings
class EuclideanEDMDiffuser(EuclideanDiffuser):
def __init__(
self,
config: EuclideanEDMConfig,
time_scheduler: DiffusionTimeScheduler,
masker: MaskerInterface,
model: Module,
loss_fn: Callable[[Tensor, Tensor, Tensor], Tensor], # (predicted, ground_true, padding_mask)
):
super().__init__(config=config, time_scheduler=time_scheduler, masker=masker)
self.config: EuclideanEDMConfig = config
self.model = model
self.loss_fn = loss_fn
[docs]
def prior_sampling(self, shape: Tuple[int, ...]) -> Tensor:
return torch.randn(shape) * self.config.sigma_max
[docs]
def compute_loss(self, **batch) -> dict:
"""Compute the EDM loss.
Args:
**batch: batch dictionary containing:
- gt_data: ground truth data x_0
- padding_mask: padding mask
Returns:
dict: A dictionary containing the loss and other information
"""
x_0 = batch["gt_data"]
padding_mask = batch["padding_mask"]
device = x_0.device
macro_shape = self.get_macro_shape(x_0) # (b, )
macro_shape = self.hook_manager.run_hooks(
stage=GMHookStageType.POST_GET_MACRO_SHAPE, tgt_key_name="macro_shape", macro_shape=macro_shape, batch=batch
)
t = self.config.sampling_timestep_for_training(macro_shape=macro_shape).to(device)
t = self.hook_manager.run_hooks(
stage=GMHookStageType.POST_SAMPLING_TIME_STEP, tgt_key_name="t", t=t, batch=batch
)
t = self.complete_micro_shape(t)
# Forward process: add noise
forward_result = self.forward_process(x_0, torch.zeros_like(t), t, padding_mask, is_continuous_time=True)
x_t, noise, sigma_diff = (forward_result["x_t"], forward_result["noise"], forward_result["sigma_diff"])
sigma = sigma_diff
batch["t"] = t
batch["x_t"] = self.config.c_in(sigma) * x_t
batch["gm_kwargs"] = {"c_in": self.config.c_in(sigma)}
with TemporaryKeyRemover(mapping=batch, keys=["gt_data"]):
model_output = self.model(**batch)
# Compute EDM loss
p_raw = model_output["x"]
D_yn = self._compute_denoised(x_t, p_raw, sigma)
# EDM loss weight: lambda(sigma) = (sigma^2 + sigma_data^2) / (sigma * sigma_data)^2
weight = self.config.compute_loss_weight(sigma)
sqrt_weight = weight.sqrt()
loss = self.loss_fn(sqrt_weight * D_yn, sqrt_weight * x_0, padding_mask)
p_x_0 = D_yn
return {
"loss": loss,
"gt_data": x_0,
"t": t,
"sigma": sigma,
"x_t": x_t,
"noise": noise,
"p_raw": p_raw,
"p_x_0": p_x_0,
"padding_mask": padding_mask,
"loss_fn": self.loss_fn,
"config": self.config,
"base_model_output": model_output,
}
[docs]
def forward_process(
self,
x_0: Tensor,
t_a: Tensor,
t_b: Tensor,
mask: Tensor,
is_continuous_time: bool = True,
*args: list[Any],
**kwargs: dict[Any, Any],
) -> dict:
assert (t_b >= t_a).all()
sigma_a = self.config.sigma(t_a, is_continuous_time)
sigma_b = self.config.sigma(t_b, is_continuous_time)
sigma_diff = (sigma_b**2 - sigma_a**2).clamp(min=0).sqrt()
noise = torch.randn_like(x_0)
x_t = x_0 + sigma_diff * noise
return {"x_t": x_t, "noise": noise, "sigma_diff": sigma_diff}
def _compute_denoised(self, x: Tensor, F_x: Tensor, sigma_expanded: Tensor) -> Tensor:
"""Compute denoised prediction using EDM preconditioning.
Args:
x: noisy input
F_x: raw network output
sigma_expanded: sigma value expanded to micro shape
Returns:
Denoised prediction D_x = c_skip * x + c_out * F_x
"""
return self.config.c_skip(sigma_expanded) * x + self.config.c_out(sigma_expanded) * F_x
[docs]
def step(self, x_t: Tensor, t: Tensor, padding_mask: Tensor, *args: Any, **kwargs: Any) -> dict:
r"""EDM sampling step (Euler or Heun's method).
Args:
x_t: the sample at timestep t
t: the timestep (all elements must be the same)
padding_mask: the padding mask
Returns:
dict:
- x: the sample at timestep t-1
- E_x0_xt: the predicted original sample
"""
assert torch.all(t == t.view(-1)[0]).item(), "All timesteps in batch must be the same for EDM step"
assert t.ndim == x_t.ndim, "Timestep and sample must have the same number of dimensions"
config = cast(EuclideanEDMConfig, self.config.to(t))
t = t.long()
t_next = t - 1
is_final_step = (t_next == 0).all()
use_heun = not is_final_step and self.config.use_2nd_order_correction
# Get sigma values and preconditioning coefficients with batch dimension
sigma_cur = config.sigma(t, is_continuous_time=False)
if not self.config.use_ode_flow:
episilon = self.config.S_noise * torch.randn_like(x_t)
gamma = (
min(self.config.S_churn / self.config.n_discretization_steps, math.sqrt(2) - 1)
if ((self.config.S_min <= sigma_cur).all() and (sigma_cur <= self.config.S_max).all())
else 0.0
)
sigma_cur_hat = sigma_cur + gamma * sigma_cur
x_t = x_t + torch.sqrt(sigma_cur_hat**2 - sigma_cur**2) * episilon
# p_x_0 prediction
c_in_cur = config.c_in(sigma_cur)
scaled_x_t = c_in_cur * x_t
batch_dict = {
"x_t": scaled_x_t,
"t": sigma_cur,
"padding_mask": padding_mask,
**kwargs,
"gm_kwargs": {"c_in": c_in_cur},
}
F_x = self.model(**batch_dict)["x"]
p_x_0 = self._compute_denoised(x_t, F_x, sigma_cur)
# Clip predicted x_0 (following standard DDPM implementation)
# 3. Clip or threshold "predicted x_0"
if self.config.use_dyn_thresholding:
p_x_0 = self._threshold_sample(p_x_0)
elif self.config.use_clip:
p_x_0 = p_x_0.clamp(-self.config.clip_sample_range, self.config.clip_sample_range)
# Run PRE_UPDATE_IN_STEP_FN hooks for conditional sampling
hook_input = {
"x_t": x_t,
"t": sigma_cur,
"p_x_0": p_x_0,
"p_raw": F_x,
"padding_mask": padding_mask,
**kwargs,
}
hook_output = self.hook_manager.run_hooks(
GMHookStageType.PRE_UPDATE_IN_STEP_FN, tgt_key_name="p_x_0", **hook_input
)
if hook_output is not None:
p_x_0 = hook_output
# Final step: return denoised directly
if is_final_step:
return {"x": p_x_0, "E_x0_xt": p_x_0}
# Euler step
sigma_next = config.sigma(t_next, is_continuous_time=False)
d_cur = (x_t - p_x_0) / sigma_cur.clamp(min=1e-8)
delta_sigma = sigma_next - sigma_cur
x_next = x_t + delta_sigma * d_cur
# Apply Heun's 2nd order correction
if use_heun:
c_in_next = config.c_in(sigma_next)
scaled_x_next = c_in_next * x_next
batch_dict_next = {
"x_t": scaled_x_next, # Apply c_in scaling to match training
"t": sigma_next,
"padding_mask": padding_mask,
**kwargs,
"gm_kwargs": {"c_in": c_in_next},
}
F_x_next = self.model(**batch_dict_next)["x"]
p_x_0_next = self._compute_denoised(x_next, F_x_next, sigma_next)
hook_input = {
"x_t": x_next,
"t": sigma_next,
"p_x_0": p_x_0_next,
"p_raw": F_x_next,
"padding_mask": padding_mask,
**kwargs,
}
hook_output = self.hook_manager.run_hooks(
GMHookStageType.PRE_UPDATE_IN_STEP_FN, tgt_key_name="p_x_0", **hook_input
)
if hook_output is not None:
p_x_0_next = hook_output
d_prime = (x_next - p_x_0_next) / sigma_next.clamp(min=1e-8)
x_next = x_t + 0.5 * (d_cur + d_prime) * delta_sigma
return {"x": x_next, "E_x0_xt": p_x_0}
[docs]
def get_posterior_mean_fn(self, score: Tensor = None, score_fn: Callable = None):
r"""Get the posterior mean function for EDM.
For EDM, the posterior mean is:
.. math::
E[x_0|x_t] = D_\theta(x_t, \sigma_t)
where D_\theta is the denoised prediction.
Args:
score (Tensor, optional): the score of the sample
score_fn (Callable, optional): the function to compute score
Returns:
Callable: the posterior mean function
"""
def _edm_posterior_mean_fn(
x_t: Tensor,
t: Tensor,
padding_mask: Tensor,
is_continuous_time: bool = True,
):
r"""
Args:
x_t: shape=(..., n_nodes, 3)
t: shape=(...), dtype=torch.long
For EDM, the posterior mean is the denoised prediction D_\theta(x_t, \sigma_t).
"""
# TODO: get x0 by score function
nonlocal score, score_fn
sigma = self.config.sigma(t, is_continuous_time=True)
c_in = self.config.c_in(sigma)
batch_dict = {
"x_t": c_in * x_t,
"t": t,
"sigma": sigma,
"padding_mask": padding_mask,
"gm_kwargs": {"c_in": c_in},
}
F_x = self.model(**batch_dict)["x"]
return self._compute_denoised(x_t, F_x, sigma)
return _edm_posterior_mean_fn
def _compute_edm_score(self, x_t: Tensor, x_0: Tensor, sigma: Tensor) -> Tensor:
"""Compute EDM score function: -(x_t - x_0) / sigma².
Args:
x_t: noisy sample at time t
x_0: clean sample (predicted or ground truth)
sigma: noise level
Returns:
score: the score function value
"""
sigma_squared = (sigma**2).clamp(min=1e-8)
return -(x_t - x_0) / sigma_squared
def _setup_conditioners(
self,
conditioner_list: list[Conditioner],
*,
train: bool,
tgt_mask: Tensor,
padding_mask: Tensor,
p_uc_score: Tensor,
gt_data: Tensor = None,
sampling_condition: Tensor = None,
) -> None:
"""Setup conditioners with common parameters.
Args:
conditioner_list: list of conditioners to setup
train: whether in training mode
tgt_mask: target mask
padding_mask: padding mask
p_uc_score: unconditional predicted score
gt_data: ground truth data (for training)
sampling_condition: sampling condition (for inference)
"""
posterior_mean_fn = self.get_posterior_mean_fn(score=p_uc_score, score_fn=None)
for conditioner in conditioner_list:
if not conditioner.is_enabled():
continue
if train:
condition_dict = conditioner.prepare_condition_dict(
train=True,
tgt_mask=tgt_mask,
gt_data=gt_data,
padding_mask=padding_mask,
posterior_mean_fn=posterior_mean_fn,
)
else:
condition_dict = conditioner.prepare_condition_dict(
train=False,
tgt_mask=tgt_mask,
sampling_condition=sampling_condition,
padding_mask=padding_mask,
posterior_mean_fn=posterior_mean_fn,
)
conditioner.set_condition(**condition_dict)
[docs]
def get_condition_post_compute_loss_hook(self, conditioner_list: list[Conditioner]):
"""Get hook for conditioning after loss computation (training).
This hook modifies the loss to include conditional guidance during training.
It computes the conditional score and updates the loss accordingly.
Args:
conditioner_list: list of conditioners
Returns:
GMHook: the hook for POST_COMPUTE_LOSS stage
"""
def _hook_fn(**kwargs):
x_0 = kwargs["gt_data"]
x_t = kwargs["x_t"]
t = kwargs["t"]
padding_mask = kwargs["padding_mask"]
loss_fn = kwargs["loss_fn"]
# Use p_x_0 if available, otherwise compute from raw output
p_x_0 = kwargs.get("p_x_0")
# Compute scores
sigma = self.config.sigma(t, is_continuous_time=True)
p_uc_score = self._compute_edm_score(x_t, p_x_0, sigma)
gt_uc_score = self._compute_edm_score(x_t, x_0, sigma)
# Setup conditioners and get accumulated conditional score
self._setup_conditioners(
conditioner_list,
train=True,
tgt_mask=padding_mask,
padding_mask=padding_mask,
p_uc_score=p_uc_score,
gt_data=x_0,
)
acc_c_score = get_accumulated_conditional_score(
conditioner_list, x_t, t, padding_mask, is_continuous_time=True
)
# Compute conditioned loss with EDM weighting
gt_score = gt_uc_score + acc_c_score
gt_x_0 = x_t + sigma**2 * gt_score
weight = self.config.compute_loss_weight(sigma)
sqrt_weight = weight.sqrt()
kwargs["loss"] = loss_fn(sqrt_weight * gt_x_0, sqrt_weight * p_x_0, padding_mask)
return kwargs
return GMHook(
name="EDM_condition_post_compute_loss_hook",
stage=GMHookStageType.POST_COMPUTE_LOSS,
fn=_hook_fn,
priority=0,
enabled=True,
)
[docs]
def get_condition_pre_update_in_step_fn_hook(self, conditioner_list: list[Conditioner]):
"""Get hook for conditioning before update in step function (sampling).
This hook applies conditional guidance during sampling by modifying
the predicted denoised sample based on the conditional score.
Args:
conditioner_list: list of conditioners
Returns:
GMHook: the hook for PRE_UPDATE_IN_STEP_FN stage
"""
def _hook_fn(**kwargs):
x_t = kwargs["x_t"]
t = kwargs["t"]
padding_mask = kwargs["padding_mask"]
sampling_condition = kwargs.get("sampling_condition")
# Use p_x_0 if available, otherwise compute from raw output
p_x_0 = kwargs.get("p_x_0")
# Compute unconditional score
sigma = self.config.sigma(t, is_continuous_time=True)
p_uc_score = self._compute_edm_score(x_t, p_x_0, sigma)
# Setup conditioners and get accumulated conditional score
self._setup_conditioners(
conditioner_list,
train=False,
tgt_mask=padding_mask,
padding_mask=padding_mask,
p_uc_score=p_uc_score,
sampling_condition=sampling_condition,
)
acc_c_score = get_accumulated_conditional_score(
conditioner_list, x_t, t, padding_mask, is_continuous_time=True
)
# Compute conditioned denoised prediction: x_0 = x_t + sigma² * score
# From: score = -(x_t - x_0) / sigma² => x_0 = x_t + sigma² * score
sigma_squared = sigma**2
p_c_x_0 = x_t + sigma_squared * (p_uc_score + acc_c_score)
# Return p_c_x_0 directly (hook manager expects target value when tgt_key_name is set)
return p_c_x_0
return GMHook(
name="EDM_condition_pre_update_in_step_fn_hook",
stage=GMHookStageType.PRE_UPDATE_IN_STEP_FN,
fn=_hook_fn,
priority=0,
enabled=True,
)