ls_mlkit.diffusion.so3_diffuser module¶
- class ls_mlkit.diffusion.so3_diffuser.SO3Diffuser(config: SO3DiffuserConfig, time_scheduler: DiffusionTimeScheduler, masker: Masker, sde: SDE, score_fn: Callable[[Tensor, Tensor, Tensor], Tensor], loss_fn: Callable[[Tensor, Tensor, Tensor], Tensor])[source]¶
Bases:
LieGroupDiffuser- compute_loss(batch: dict[str, Any], *args: list[Any], **kwargs: dict[Any, Any]) Tensor[source]¶
Compute loss
- Parameters:
batch (
dict[str, Any]) – the batch of data- Returns:
a dictionary that must contain the key “loss” or a tensor of loss
- Return type:
dict``|``Tensor
- forward_process(x_0: Tensor, discrete_t: Tensor, mask: Tensor, *args: list[Any], **kwargs: dict[Any, Any]) dict[source]¶
Forward process
\[\text{IG}_{\text{SO}(3)} (\mathbf{x}; \mathbf{\mu}, \sigma^2) = f_{\sigma} (\arccos((\text{tr}(\mathbf{\mu}^T \mathbf{x}) - 1)/2)) \quad \forall \mathbf{x} \in \text{SO}(3)\]- Parameters:
x_0 (Tensor) – the initial sample
discrete_t (Tensor) – the discrete timestep
mask (Tensor) – the mask
*args – additional arguments
**kwargs – additional keyword arguments
- Returns:
a dictionary that must contain the key “x_t”
- Return type:
dict
- get_ground_truth_score(x_0: Tensor, x_t: Tensor, discrete_t: Tensor, padding_mask: Tensor) Tensor[source]¶
Denoise Score Matching
\[\]abla_x log p_{0t} (x_t | x_0)
- Args:
x_0 (Tensor): _description_ x_t (Tensor): _description_ discrete_t (Tensor): _description_ padding_mask (Tensor): _description_
- Returns:
Tensor: _description_
- property igso3_cdf: Tensor¶
- property igso3_discrete_omega: Tensor¶
- property igso3_discrete_sigma: Tensor¶
- property igso3_exp_score_norms: Tensor¶
- property igso3_score_norm: Tensor¶
- inpainting(x, padding_mask, inpainting_mask, device, x_init_posterior=None, inpainting_mask_key='inpainting_mask', *args, **kwargs)[source]¶
_summary_
- Parameters:
x (
_type_) – _description_padding_mask (
_type_) – _description_inpainting_mask (
_type_) – _description_device (
_type_) – _description_x_init_posterior (
_type_, optional) – _description_. Defaults to None.inpainting_mask_key (
str, optional) – _description_. Defaults to “inpainting_mask”.sapmling_condition_key (
str, optional) – _description_. Defaults to “sapmling_condition”.return_all (
bool, optional) – _description_. Defaults to False.sampling_condition (
_type_, optional) – _description_. Defaults to None.
- Returns:
_description_
- Return type:
dict
- prior_sampling(shape: Tuple[int, ...]) Tensor[source]¶
Sample initial noise used for reverse process
\[\mathcal{U}_{SO(3)}\]- Parameters:
shape (Tuple[int, ...]) – the shape of the sample
- Returns:
the initial noise
- Return type:
Tensor
- sample_noise_in_lie_algebra(macro_shape: Tuple[int, ...]) Tensor[source]¶
Sample noise in Lie algebra, Skew-symmetric matrix
- Parameters:
macro_shape (Tuple[int, ...]) – the macro shape of the noise
- Returns:
the noise in Lie algebra of shape \((*macro_shape, 3, 3)\)
- Return type:
Tensor
- sampling(shape, device, x_init_posterior=None, *args, **kwargs)[source]¶
_summary_
- Parameters:
shape (
_type_) – _description_device (
_type_) – _description_x_init_posterior (
_type_, optional) – _description_. Defaults to None.return_all (
bool, optional) – _description_. Defaults to False.sampling_condition (
_type_, optional) – _description_. Defaults to None.sapmling_condition_key (
str, optional) – _description_. Defaults to “sapmling_condition”.
- Returns:
_description_
- Return type:
dict
- step(x_t: Tensor, discrete_t: Tensor, padding_mask: Tensor, *args: list[Any], **kwargs: dict[Any, Any]) dict[source]¶
- \[\begin{split}dx &= \exp_{x_t}(f_{rev} dt + g_{rev} dw)\\ x_{t+\Delta_t} &= \exp_{x_t}(- f_{rev} |\Delta_t| + g_{rev} \Delta w)\\ f_{rev} &= (f - g^2 \nabla_x \ln p_t(x))\\ g_{rev} &= g\\\end{split}\]
- class ls_mlkit.diffusion.so3_diffuser.SO3DiffuserConfig(ndim_micro_shape: int, n_discretization_steps: int, n_inference_steps: int, igso3_num_sigma: int, igso3_num_omega: int, igso3_min_sigma: float, igso3_max_sigma: float, *args: list[Any], **kwargs: dict[Any, Any])[source]¶
Bases:
LieGroupDiffuserConfig