ls_mlkit.diffusion.euclidean_vpsde_diffuser module

class ls_mlkit.diffusion.euclidean_vpsde_diffuser.EuclideanVPSDEConfig(n_discretization_steps: int = 1000, ndim_micro_shape: int = 2, use_probability_flow=False, beta_min: float = 0.1, beta_max: float = 20, n_correct_steps: int = 1, snr: float = 1.0, *args, **kwargs)[source]

Bases: EuclideanDiffuserConfig

class ls_mlkit.diffusion.euclidean_vpsde_diffuser.EuclideanVPSDEDiffuser(config: EuclideanVPSDEConfig, time_scheduler: DiffusionTimeScheduler, masker: MaskerInterface, model: Module, loss_fn: Callable[[Tensor, Tensor, Tensor], Tensor])[source]

Bases: EuclideanDiffuser

compute_loss(batch: dict[str, Any], *args: Any, **kwargs: Any) dict[source]

Compute loss

Parameters:

batch (dict[str, Any]) – the batch of data

Returns:

a dictionary that must contain the key “loss” or a tensor of loss

Return type:

dict``|``Tensor

forward_process(x_0: Tensor, discrete_t: Tensor, mask: Tensor, *args: list[Any], **kwargs: dict[Any, Any]) dict[source]
forward_process_n_step(x: Tensor, t: Tensor, next_t: Tensor, padding_mask: Tensor, *args: Any, **kwargs: Any) Tensor[source]
get_condition_post_compute_loss_hook(conditioner_list: list[Conditioner])[source]
get_condition_pre_update_in_step_fn_hook(conditioner_list: list[Conditioner])[source]
get_posterior_mean_fn(score: Tensor = None, score_fn: Callable = None)[source]

Get the posterior mean function

Parameters:
  • score (Tensor, optional) – the score of the sample

  • score_fn (Callable, optional) – the function to compute score

Returns:

the posterior mean function

Return type:

Callable

prior_sampling(shape: Tuple[int, ...]) Tensor[source]

prior sampling

Parameters:

shape (tuple[int, ...]) – the shape of the sample

Returns:

data from prior distribution

Return type:

Tensor

step(x_t: Tensor, t: Tensor, padding_mask: Tensor, *args: Any, **kwargs: Any) dict[source]
Parameters:
  • x_t (Tensor) – the sample at timestep t

  • t (Tensor) – the timestep

  • padding_mask (Tensor) – the padding mask

Returns:

the sample at timestep t-1

Return type:

Tensor