ls_mlkit.diffusion.so3_diffuser module

class ls_mlkit.diffusion.so3_diffuser.SO3Diffuser(config: SO3DiffuserConfig, time_scheduler: DiffusionTimeScheduler, masker: Masker, sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor], Tensor], loss_fn: Callable[[Tensor, Tensor, Tensor], Tensor])[source]

Bases: LieGroupDiffuser

compute_loss(batch: dict[str, Any], *args: list[Any], **kwargs: dict[Any, Any]) Tensor[source]

Compute loss

Parameters:

batch (dict[str, Any]) – the batch of data

Returns:

a dictionary that must contain the key “loss” or a tensor of loss

Return type:

dict``|``Tensor

forward_process(x_0: Tensor, discrete_t: Tensor, mask: Tensor, *args: list[Any], **kwargs: dict[Any, Any]) dict[source]

Forward process

\[\text{IG}_{\text{SO}(3)} (\mathbf{x}; \mathbf{\mu}, \sigma^2) = f_{\sigma} (\arccos((\text{tr}(\mathbf{\mu}^T \mathbf{x}) - 1)/2)) \quad \forall \mathbf{x} \in \text{SO}(3)\]
Parameters:
  • x_0 (Tensor) – the initial sample

  • discrete_t (Tensor) – the discrete timestep

  • mask (Tensor) – the mask

  • *args – additional arguments

  • **kwargs – additional keyword arguments

Returns:

a dictionary that must contain the key “x_t”

Return type:

dict

get_ground_truth_score(x_0: Tensor, x_t: Tensor, discrete_t: Tensor, padding_mask: Tensor) Tensor[source]

Denoise Score Matching

\[\]

abla_x log p_{0t} (x_t | x_0)

Args:

x_0 (Tensor): _description_ x_t (Tensor): _description_ discrete_t (Tensor): _description_ padding_mask (Tensor): _description_

Returns:

Tensor: _description_

property igso3_cdf: Tensor
property igso3_discrete_omega: Tensor
property igso3_discrete_sigma: Tensor
property igso3_exp_score_norms: Tensor
property igso3_score_norm: Tensor
inpainting(x, padding_mask, inpainting_mask, device, x_init_posterior=None, inpainting_mask_key='inpainting_mask', *args, **kwargs)[source]

_summary_

Parameters:
  • x (_type_) – _description_

  • padding_mask (_type_) – _description_

  • inpainting_mask (_type_) – _description_

  • device (_type_) – _description_

  • x_init_posterior (_type_, optional) – _description_. Defaults to None.

  • inpainting_mask_key (str, optional) – _description_. Defaults to “inpainting_mask”.

  • sapmling_condition_key (str, optional) – _description_. Defaults to “sapmling_condition”.

  • return_all (bool, optional) – _description_. Defaults to False.

  • sampling_condition (_type_, optional) – _description_. Defaults to None.

Returns:

_description_

Return type:

dict

prior_sampling(shape: Tuple[int, ...]) Tensor[source]

Sample initial noise used for reverse process

\[\mathcal{U}_{SO(3)}\]
Parameters:

shape (Tuple[int, ...]) – the shape of the sample

Returns:

the initial noise

Return type:

Tensor

sample_noise_in_lie_algebra(macro_shape: Tuple[int, ...]) Tensor[source]

Sample noise in Lie algebra, Skew-symmetric matrix

Parameters:

macro_shape (Tuple[int, ...]) – the macro shape of the noise

Returns:

the noise in Lie algebra of shape \((*macro_shape, 3, 3)\)

Return type:

Tensor

sampling(shape, device, x_init_posterior=None, *args, **kwargs)[source]

_summary_

Parameters:
  • shape (_type_) – _description_

  • device (_type_) – _description_

  • x_init_posterior (_type_, optional) – _description_. Defaults to None.

  • return_all (bool, optional) – _description_. Defaults to False.

  • sampling_condition (_type_, optional) – _description_. Defaults to None.

  • sapmling_condition_key (str, optional) – _description_. Defaults to “sapmling_condition”.

Returns:

_description_

Return type:

dict

step(x_t: Tensor, discrete_t: Tensor, padding_mask: Tensor, *args: list[Any], **kwargs: dict[Any, Any]) dict[source]
\[\begin{split}dx &= \exp_{x_t}(f_{rev} dt + g_{rev} dw)\\ x_{t+\Delta_t} &= \exp_{x_t}(- f_{rev} |\Delta_t| + g_{rev} \Delta w)\\ f_{rev} &= (f - g^2 \nabla_x \ln p_t(x))\\ g_{rev} &= g\\\end{split}\]
class ls_mlkit.diffusion.so3_diffuser.SO3DiffuserConfig(ndim_micro_shape: int, n_discretization_steps: int, n_inference_steps: int, igso3_num_sigma: int, igso3_num_omega: int, igso3_min_sigma: float, igso3_max_sigma: float, *args: list[Any], **kwargs: dict[Any, Any])[source]

Bases: LieGroupDiffuserConfig