ls_mlkit.diffusion.euclidean_edm_diffuser module¶
- class ls_mlkit.diffusion.euclidean_edm_diffuser.EuclideanEDMConfig(n_discretization_steps: int = 200, ndim_micro_shape: int = 2, P_mean: float = -1.2, P_std: float = 1.2, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80.0, rho: float = 7.0, use_2nd_order_correction: bool = True, use_ode_flow: bool = False, S_churn: float = 0.0, S_min: float = 0.0, S_max: float = inf, S_noise: float = 1.0, use_clip: bool = False, clip_sample_range: float = 1.0, use_dyn_thresholding: bool = False, dynamic_thresholding_ratio=0.995, sample_max_value: float = 1.0, sigma_multiply_by_sigma_data: bool = False, *args, **kwargs)[source]¶
Bases:
EuclideanDiffuserConfigConfig Class for Euclidean EDM Diffuser
- class ls_mlkit.diffusion.euclidean_edm_diffuser.EuclideanEDMDiffuser(config: EuclideanEDMConfig, time_scheduler: DiffusionTimeScheduler, masker: MaskerInterface, model: Module, loss_fn: Callable[[Tensor, Tensor, Tensor], Tensor])[source]¶
Bases:
EuclideanDiffuser- compute_loss(**batch) dict[source]¶
Compute the EDM loss.
- Parameters:
**batch – batch dictionary containing: - gt_data: ground truth data x_0 - padding_mask: padding mask
- Returns:
A dictionary containing the loss and other information
- Return type:
dict
- forward_process(x_0: Tensor, t_a: Tensor, t_b: Tensor, mask: Tensor, is_continuous_time: bool = True, *args: list[Any], **kwargs: dict[Any, Any]) dict[source]¶
- get_condition_post_compute_loss_hook(conditioner_list: list[Conditioner])[source]¶
Get hook for conditioning after loss computation (training).
This hook modifies the loss to include conditional guidance during training. It computes the conditional score and updates the loss accordingly.
- Parameters:
conditioner_list – list of conditioners
- Returns:
the hook for POST_COMPUTE_LOSS stage
- Return type:
- get_condition_pre_update_in_step_fn_hook(conditioner_list: list[Conditioner])[source]¶
Get hook for conditioning before update in step function (sampling).
This hook applies conditional guidance during sampling by modifying the predicted denoised sample based on the conditional score.
- Parameters:
conditioner_list – list of conditioners
- Returns:
the hook for PRE_UPDATE_IN_STEP_FN stage
- Return type:
- get_posterior_mean_fn(score: Tensor = None, score_fn: Callable = None)[source]¶
Get the posterior mean function for EDM.
For EDM, the posterior mean is: .. math:
E[x_0|x_t] = D_\theta(x_t, \sigma_t)
where D_theta is the denoised prediction.
- Parameters:
score (Tensor, optional) – the score of the sample
score_fn (Callable, optional) – the function to compute score
- Returns:
the posterior mean function
- Return type:
Callable
- prior_sampling(shape: Tuple[int, ...]) Tensor[source]¶
prior sampling
- Parameters:
shape (
tuple[int, ...]) – the shape of the sample- Returns:
data from prior distribution
- Return type:
Tensor
- step(x_t: Tensor, t: Tensor, padding_mask: Tensor, *args: Any, **kwargs: Any) dict[source]¶
EDM sampling step (Euler or Heun’s method).
- Parameters:
x_t – the sample at timestep t
t – the timestep (all elements must be the same)
padding_mask – the padding mask
- Returns:
x: the sample at timestep t-1
E_x0_xt: the predicted original sample
- Return type:
dict