ls_mlkit.diffusion.euclidean_edm_diffuser module

class ls_mlkit.diffusion.euclidean_edm_diffuser.EuclideanEDMConfig(n_discretization_steps: int = 200, ndim_micro_shape: int = 2, P_mean: float = -1.2, P_std: float = 1.2, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80.0, rho: float = 7.0, use_2nd_order_correction: bool = True, use_ode_flow: bool = False, S_churn: float = 0.0, S_min: float = 0.0, S_max: float = inf, S_noise: float = 1.0, use_clip: bool = False, clip_sample_range: float = 1.0, use_dyn_thresholding: bool = False, dynamic_thresholding_ratio=0.995, sample_max_value: float = 1.0, sigma_multiply_by_sigma_data: bool = False, *args, **kwargs)[source]

Bases: EuclideanDiffuserConfig

Config Class for Euclidean EDM Diffuser

c_in(sigma: Tensor) Tensor[source]
c_noise(sigma: Tensor) Tensor[source]
c_out(sigma: Tensor) Tensor[source]
c_skip(sigma: Tensor) Tensor[source]
compute_loss_weight(sigma: Tensor) Tensor[source]

Compute EDM loss weight: (sigma² + sigma_data²) / (sigma * sigma_data)².

Parameters:

sigma – noise level, shape=(…)

Returns:

the loss weight, shape=(…)

Return type:

weight

sampling_timestep_for_training(macro_shape: Tensor)[source]
sigma(t: Tensor, is_continuous_time: bool = True) Tensor[source]
timestep_index_to_sigma(timestep_index: Tensor) Tensor[source]

Convert discrete timesteps to sigma values.

Parameters:

discrete_t – discrete timesteps, shape=(…)

Returns:

noise levels, shape=(…)

Return type:

sigma

class ls_mlkit.diffusion.euclidean_edm_diffuser.EuclideanEDMDiffuser(config: EuclideanEDMConfig, time_scheduler: DiffusionTimeScheduler, masker: MaskerInterface, model: Module, loss_fn: Callable[[Tensor, Tensor, Tensor], Tensor])[source]

Bases: EuclideanDiffuser

compute_loss(**batch) dict[source]

Compute the EDM loss.

Parameters:

**batch – batch dictionary containing: - gt_data: ground truth data x_0 - padding_mask: padding mask

Returns:

A dictionary containing the loss and other information

Return type:

dict

forward_process(x_0: Tensor, t_a: Tensor, t_b: Tensor, mask: Tensor, is_continuous_time: bool = True, *args: list[Any], **kwargs: dict[Any, Any]) dict[source]
get_condition_post_compute_loss_hook(conditioner_list: list[Conditioner])[source]

Get hook for conditioning after loss computation (training).

This hook modifies the loss to include conditional guidance during training. It computes the conditional score and updates the loss accordingly.

Parameters:

conditioner_list – list of conditioners

Returns:

the hook for POST_COMPUTE_LOSS stage

Return type:

GMHook

get_condition_pre_update_in_step_fn_hook(conditioner_list: list[Conditioner])[source]

Get hook for conditioning before update in step function (sampling).

This hook applies conditional guidance during sampling by modifying the predicted denoised sample based on the conditional score.

Parameters:

conditioner_list – list of conditioners

Returns:

the hook for PRE_UPDATE_IN_STEP_FN stage

Return type:

GMHook

get_posterior_mean_fn(score: Tensor = None, score_fn: Callable = None)[source]

Get the posterior mean function for EDM.

For EDM, the posterior mean is: .. math:

E[x_0|x_t] = D_\theta(x_t, \sigma_t)

where D_theta is the denoised prediction.

Parameters:
  • score (Tensor, optional) – the score of the sample

  • score_fn (Callable, optional) – the function to compute score

Returns:

the posterior mean function

Return type:

Callable

prior_sampling(shape: Tuple[int, ...]) Tensor[source]

prior sampling

Parameters:

shape (tuple[int, ...]) – the shape of the sample

Returns:

data from prior distribution

Return type:

Tensor

step(x_t: Tensor, t: Tensor, padding_mask: Tensor, *args: Any, **kwargs: Any) dict[source]

EDM sampling step (Euler or Heun’s method).

Parameters:
  • x_t – the sample at timestep t

  • t – the timestep (all elements must be the same)

  • padding_mask – the padding mask

Returns:

  • x: the sample at timestep t-1

  • E_x0_xt: the predicted original sample

Return type:

dict