Skip to content

diffusion

ls_mlkit.diffusion

BaseDiffuser

Bases: BaseGenerativeModel

abstract method:

Source code in src/ls_mlkit/diffusion/base_diffuser.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
@inherit_docstrings
class BaseDiffuser(BaseGenerativeModel):
    """
    abstract method:
    """

    def __init__(
        self,
        config: BaseDiffuserConfig,
        time_scheduler: DiffusionTimeScheduler,
    ) -> None:
        r"""Initialize the BaseDiffuser

        Args:
            config (``BaseDiffuserConfig``): the config of the diffuser
            time_scheduler (``DiffusionTimeScheduler``): the time scheduler of the diffuser
        """
        super().__init__(config=config)
        self.config: BaseDiffuserConfig = config
        self.time_scheduler: DiffusionTimeScheduler = time_scheduler

    @abstractmethod
    def forward_process(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> dict:
        return {}

__init__(config, time_scheduler)

Initialize the BaseDiffuser

Parameters:

Name Type Description Default
config ``BaseDiffuserConfig``

the config of the diffuser

required
time_scheduler ``DiffusionTimeScheduler``

the time scheduler of the diffuser

required
Source code in src/ls_mlkit/diffusion/base_diffuser.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def __init__(
    self,
    config: BaseDiffuserConfig,
    time_scheduler: DiffusionTimeScheduler,
) -> None:
    r"""Initialize the BaseDiffuser

    Args:
        config (``BaseDiffuserConfig``): the config of the diffuser
        time_scheduler (``DiffusionTimeScheduler``): the time scheduler of the diffuser
    """
    super().__init__(config=config)
    self.config: BaseDiffuserConfig = config
    self.time_scheduler: DiffusionTimeScheduler = time_scheduler

EuclideanDDIMConfig

Bases: EuclideanDDPMConfig

Source code in src/ls_mlkit/diffusion/euclidean_ddim_diffuser.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
@inherit_docstrings
class EuclideanDDIMConfig(EuclideanDDPMConfig):
    def __init__(
        self,
        n_discretization_steps: int = 1000,
        ndim_micro_shape: int = 2,
        use_probability_flow=False,
        use_clip: bool = True,
        clip_sample_range: float = 1.0,
        use_dyn_thresholding: bool = False,
        dynamic_thresholding_ratio=0.995,
        sample_max_value: float = 1.0,
        betas=None,
        n_inference_steps: int = 1000,
        eta: float = 0.0,
        *args,
        **kwargs,
    ):
        """Initialize the EuclideanDDIMConfig

        Args:
            n_discretization_steps (int): the number of discretization steps
            ndim_micro_shape (int): the number of dimensions of the micro shape
            use_probability_flow (bool): whether to use probability flow
            use_clip (bool): whether to use clip
            clip_sample_range (float): the range of the clip
            use_dyn_thresholding (bool): whether to use dynamic thresholding
            dynamic_thresholding_ratio (float): the ratio of the dynamic thresholding
            sample_max_value (float): the maximum value of the sample used in thresholding
            betas (Tensor): the betas
            n_inference_steps (int): the number of inference steps
            eta (float): the eta

        Returns:
            None
        """
        super().__init__(
            n_discretization_steps=n_discretization_steps,
            ndim_micro_shape=ndim_micro_shape,
            use_probability_flow=use_probability_flow,
            use_clip=use_clip,
            clip_sample_range=clip_sample_range,
            use_dyn_thresholding=use_dyn_thresholding,
            dynamic_thresholding_ratio=dynamic_thresholding_ratio,
            sample_max_value=sample_max_value,
            betas=betas,
        )
        self.n_inference_steps = n_inference_steps
        self.eta: float = eta

__init__(n_discretization_steps=1000, ndim_micro_shape=2, use_probability_flow=False, use_clip=True, clip_sample_range=1.0, use_dyn_thresholding=False, dynamic_thresholding_ratio=0.995, sample_max_value=1.0, betas=None, n_inference_steps=1000, eta=0.0, *args, **kwargs)

Initialize the EuclideanDDIMConfig

Parameters:

Name Type Description Default
n_discretization_steps int

the number of discretization steps

1000
ndim_micro_shape int

the number of dimensions of the micro shape

2
use_probability_flow bool

whether to use probability flow

False
use_clip bool

whether to use clip

True
clip_sample_range float

the range of the clip

1.0
use_dyn_thresholding bool

whether to use dynamic thresholding

False
dynamic_thresholding_ratio float

the ratio of the dynamic thresholding

0.995
sample_max_value float

the maximum value of the sample used in thresholding

1.0
betas Tensor

the betas

None
n_inference_steps int

the number of inference steps

1000
eta float

the eta

0.0

Returns:

Type Description

None

Source code in src/ls_mlkit/diffusion/euclidean_ddim_diffuser.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def __init__(
    self,
    n_discretization_steps: int = 1000,
    ndim_micro_shape: int = 2,
    use_probability_flow=False,
    use_clip: bool = True,
    clip_sample_range: float = 1.0,
    use_dyn_thresholding: bool = False,
    dynamic_thresholding_ratio=0.995,
    sample_max_value: float = 1.0,
    betas=None,
    n_inference_steps: int = 1000,
    eta: float = 0.0,
    *args,
    **kwargs,
):
    """Initialize the EuclideanDDIMConfig

    Args:
        n_discretization_steps (int): the number of discretization steps
        ndim_micro_shape (int): the number of dimensions of the micro shape
        use_probability_flow (bool): whether to use probability flow
        use_clip (bool): whether to use clip
        clip_sample_range (float): the range of the clip
        use_dyn_thresholding (bool): whether to use dynamic thresholding
        dynamic_thresholding_ratio (float): the ratio of the dynamic thresholding
        sample_max_value (float): the maximum value of the sample used in thresholding
        betas (Tensor): the betas
        n_inference_steps (int): the number of inference steps
        eta (float): the eta

    Returns:
        None
    """
    super().__init__(
        n_discretization_steps=n_discretization_steps,
        ndim_micro_shape=ndim_micro_shape,
        use_probability_flow=use_probability_flow,
        use_clip=use_clip,
        clip_sample_range=clip_sample_range,
        use_dyn_thresholding=use_dyn_thresholding,
        dynamic_thresholding_ratio=dynamic_thresholding_ratio,
        sample_max_value=sample_max_value,
        betas=betas,
    )
    self.n_inference_steps = n_inference_steps
    self.eta: float = eta

EuclideanDDIMDiffuser

Bases: EuclideanDDPMDiffuser

Source code in src/ls_mlkit/diffusion/euclidean_ddim_diffuser.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
@inherit_docstrings
class EuclideanDDIMDiffuser(EuclideanDDPMDiffuser):
    def __init__(
        self,
        config: EuclideanDDPMConfig,
        time_scheduler: DiffusionTimeScheduler,
        masker: MaskerInterface,
        model: Module,
        loss_fn: Callable[[Tensor, Tensor, Tensor], Tensor],  # (predicted, ground_true, padding_mask)
    ):
        super().__init__(
            config=config,
            time_scheduler=time_scheduler,
            masker=masker,
            model=model,
            loss_fn=loss_fn,
        )

    def get_sigma2(self, t: Tensor, prev_t: Tensor) -> Tensor:
        r"""Compute DDIM variance term

        .. math::
            \sigma^2 = (\frac{1 - \bar{\alpha}_{pre}}{1 - \bar{\alpha}_{t}}) \cdot ( 1- \frac{\bar{\alpha}_{t}}{\bar{\alpha}_{pre}})

        Args:
            t (Tensor): timestep
            prev_t (Tensor): previous timestep

        Returns:
            Tensor: :math:`\sigma^2`
        """
        config = cast(EuclideanDDIMConfig, self.config)
        alpha_prod_t = config.alphas_cumprod[t]
        alpha_prod_t_prev = config.alphas_cumprod[prev_t] if prev_t >= 0 else torch.ones(1).to(t.device)
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

        # DDIM variance formula
        variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
        return variance

    def step(
        self,
        x_t: Tensor,
        t: Tensor,
        padding_mask: Tensor | None = None,
        *args: Any,
        **kwargs: Any,
    ) -> dict:
        r"""DDIM sampling algorithm:

        .. math::

            \hat{x}_0 = \frac{x_t - \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon_\theta(x_t, t)}{\sqrt{\bar{\alpha}_t}}

            \text{direction} = \sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \cdot \epsilon_\theta(x_t, t)

            x_{t-1} = \sqrt{\bar{\alpha}_{t-1}} \cdot \hat{x}_0 + \text{direction} + \sigma_t \cdot z

        Args:
            x_t (Tensor): the sample at timestep t
            t (Tensor): the timestep
            padding_mask (Tensor): the padding mask

        Returns:
            Tensor: the sample at timestep t-1
        """
        assert torch.all(t == t.view(-1)[0]).item()
        config = cast(EuclideanDDIMConfig, self.config.to(t))
        t = t.long()
        t = t.view(-1)[0]
        # DDIM requires proper timestep scaling for inference
        # When using fewer inference steps than training steps, we need to scale the timestep difference
        step_ratio = config.n_discretization_steps // config.n_inference_steps
        prev_t = t - step_ratio
        alpha_prod_t = config.alphas_cumprod[t]
        alpha_prod_t_prev = config.alphas_cumprod[prev_t] if prev_t >= 0 else torch.ones(1).to(t.device)
        beta_prod_t = 1 - alpha_prod_t

        mode: Literal["epsilon", "x_0", "score"] = kwargs.get("mode", "epsilon")
        model_batch = {"x_t": x_t, "t": t, "padding_mask": padding_mask, **kwargs}
        # print(f"mode: {mode}, t={t}, prev_t={prev_t}")
        if mode == "epsilon":
            epsilon_predicted = self.model(**model_batch)["x"]
        elif mode == "x_0":
            p_x_0 = self.model(**model_batch)["x"]
            epsilon_predicted = (x_t - alpha_prod_t ** (0.5) * p_x_0) / beta_prod_t ** (0.5)
        elif mode == "score":
            raise ValueError(f"Currently not supported mode: {mode}")
        else:
            raise ValueError(f"Invalid mode: {mode}")

        r"""
        $$\hat{x_0} = \frac{x_t - \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon_\theta(x_t, t)}{\sqrt{\bar{\alpha}_t}}$$
        """
        pred_original_sample = None
        if mode in ["epsilon"]:
            pred_original_sample = (x_t - beta_prod_t ** (0.5) * epsilon_predicted) / alpha_prod_t ** (0.5)
        elif mode in ["x_0"]:
            pred_original_sample = p_x_0

        r"""
        $$\sigma = \eta \cdot \sqrt{(\frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_{t}}) \cdot ( 1- \frac{\bar{\alpha}_{t}}{\bar{\alpha}_{t-1}})}$$
        """
        sigma2 = self.get_sigma2(t, prev_t)
        sigma = config.eta * torch.sqrt(sigma2)

        r"""
        $$direction = \sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \cdot \epsilon_\theta(x_t, t)$$
        """
        direction = torch.sqrt(1 - alpha_prod_t_prev - sigma**2) * epsilon_predicted

        r"""
        $$x_{t-1} = \sqrt{\bar{\alpha}_{t-1}} \cdot \hat{x}_0 + \text{direction} + \sigma_t \cdot z$$
        """
        pred_original_sample = cast(Tensor, pred_original_sample)
        pred_prev_sample = torch.sqrt(alpha_prod_t_prev) * pred_original_sample + direction

        epsilon_t = torch.randn_like(x_t)
        if t > 0:
            pred_prev_sample = pred_prev_sample + sigma * epsilon_t

        return {"x": pred_prev_sample, "E_x0_xt": pred_original_sample}

get_sigma2(t, prev_t)

Compute DDIM variance term

.. math:: \sigma^2 = (\frac{1 - \bar{\alpha}{pre}}{1 - \bar{\alpha}{t}}) \cdot ( 1- \frac{\bar{\alpha}{t}}{\bar{\alpha}{pre}})

Parameters:

Name Type Description Default
t Tensor

timestep

required
prev_t Tensor

previous timestep

required

Returns:

Name Type Description
Tensor Tensor

:math:\sigma^2

Source code in src/ls_mlkit/diffusion/euclidean_ddim_diffuser.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def get_sigma2(self, t: Tensor, prev_t: Tensor) -> Tensor:
    r"""Compute DDIM variance term

    .. math::
        \sigma^2 = (\frac{1 - \bar{\alpha}_{pre}}{1 - \bar{\alpha}_{t}}) \cdot ( 1- \frac{\bar{\alpha}_{t}}{\bar{\alpha}_{pre}})

    Args:
        t (Tensor): timestep
        prev_t (Tensor): previous timestep

    Returns:
        Tensor: :math:`\sigma^2`
    """
    config = cast(EuclideanDDIMConfig, self.config)
    alpha_prod_t = config.alphas_cumprod[t]
    alpha_prod_t_prev = config.alphas_cumprod[prev_t] if prev_t >= 0 else torch.ones(1).to(t.device)
    beta_prod_t = 1 - alpha_prod_t
    beta_prod_t_prev = 1 - alpha_prod_t_prev

    # DDIM variance formula
    variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
    return variance

step(x_t, t, padding_mask=None, *args, **kwargs)

DDIM sampling algorithm:

.. math::

\hat{x}_0 = \frac{x_t - \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon_\theta(x_t, t)}{\sqrt{\bar{\alpha}_t}}

\text{direction} = \sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \cdot \epsilon_\theta(x_t, t)

x_{t-1} = \sqrt{\bar{\alpha}_{t-1}} \cdot \hat{x}_0 + \text{direction} + \sigma_t \cdot z

Parameters:

Name Type Description Default
x_t Tensor

the sample at timestep t

required
t Tensor

the timestep

required
padding_mask Tensor

the padding mask

None

Returns:

Name Type Description
Tensor dict

the sample at timestep t-1

Source code in src/ls_mlkit/diffusion/euclidean_ddim_diffuser.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def step(
    self,
    x_t: Tensor,
    t: Tensor,
    padding_mask: Tensor | None = None,
    *args: Any,
    **kwargs: Any,
) -> dict:
    r"""DDIM sampling algorithm:

    .. math::

        \hat{x}_0 = \frac{x_t - \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon_\theta(x_t, t)}{\sqrt{\bar{\alpha}_t}}

        \text{direction} = \sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \cdot \epsilon_\theta(x_t, t)

        x_{t-1} = \sqrt{\bar{\alpha}_{t-1}} \cdot \hat{x}_0 + \text{direction} + \sigma_t \cdot z

    Args:
        x_t (Tensor): the sample at timestep t
        t (Tensor): the timestep
        padding_mask (Tensor): the padding mask

    Returns:
        Tensor: the sample at timestep t-1
    """
    assert torch.all(t == t.view(-1)[0]).item()
    config = cast(EuclideanDDIMConfig, self.config.to(t))
    t = t.long()
    t = t.view(-1)[0]
    # DDIM requires proper timestep scaling for inference
    # When using fewer inference steps than training steps, we need to scale the timestep difference
    step_ratio = config.n_discretization_steps // config.n_inference_steps
    prev_t = t - step_ratio
    alpha_prod_t = config.alphas_cumprod[t]
    alpha_prod_t_prev = config.alphas_cumprod[prev_t] if prev_t >= 0 else torch.ones(1).to(t.device)
    beta_prod_t = 1 - alpha_prod_t

    mode: Literal["epsilon", "x_0", "score"] = kwargs.get("mode", "epsilon")
    model_batch = {"x_t": x_t, "t": t, "padding_mask": padding_mask, **kwargs}
    # print(f"mode: {mode}, t={t}, prev_t={prev_t}")
    if mode == "epsilon":
        epsilon_predicted = self.model(**model_batch)["x"]
    elif mode == "x_0":
        p_x_0 = self.model(**model_batch)["x"]
        epsilon_predicted = (x_t - alpha_prod_t ** (0.5) * p_x_0) / beta_prod_t ** (0.5)
    elif mode == "score":
        raise ValueError(f"Currently not supported mode: {mode}")
    else:
        raise ValueError(f"Invalid mode: {mode}")

    r"""
    $$\hat{x_0} = \frac{x_t - \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon_\theta(x_t, t)}{\sqrt{\bar{\alpha}_t}}$$
    """
    pred_original_sample = None
    if mode in ["epsilon"]:
        pred_original_sample = (x_t - beta_prod_t ** (0.5) * epsilon_predicted) / alpha_prod_t ** (0.5)
    elif mode in ["x_0"]:
        pred_original_sample = p_x_0

    r"""
    $$\sigma = \eta \cdot \sqrt{(\frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_{t}}) \cdot ( 1- \frac{\bar{\alpha}_{t}}{\bar{\alpha}_{t-1}})}$$
    """
    sigma2 = self.get_sigma2(t, prev_t)
    sigma = config.eta * torch.sqrt(sigma2)

    r"""
    $$direction = \sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \cdot \epsilon_\theta(x_t, t)$$
    """
    direction = torch.sqrt(1 - alpha_prod_t_prev - sigma**2) * epsilon_predicted

    r"""
    $$x_{t-1} = \sqrt{\bar{\alpha}_{t-1}} \cdot \hat{x}_0 + \text{direction} + \sigma_t \cdot z$$
    """
    pred_original_sample = cast(Tensor, pred_original_sample)
    pred_prev_sample = torch.sqrt(alpha_prod_t_prev) * pred_original_sample + direction

    epsilon_t = torch.randn_like(x_t)
    if t > 0:
        pred_prev_sample = pred_prev_sample + sigma * epsilon_t

    return {"x": pred_prev_sample, "E_x0_xt": pred_original_sample}

EuclideanDDPMConfig

Bases: EuclideanDiffuserConfig

Config Class for Euclidean DDPM Diffuser

Source code in src/ls_mlkit/diffusion/euclidean_ddpm_diffuser.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
@inherit_docstrings
class EuclideanDDPMConfig(EuclideanDiffuserConfig):
    """
    Config Class for Euclidean DDPM Diffuser
    """

    def __init__(
        self,
        n_discretization_steps: int = 1000,
        ndim_micro_shape: int = 2,
        use_probability_flow=False,
        use_clip: bool = False,
        clip_sample_range: float = 1.0,
        use_dyn_thresholding: bool = False,
        dynamic_thresholding_ratio=0.995,
        sample_max_value: float = 1.0,
        betas=None,
        *args,
        **kwargs,
    ):
        r"""
        Args:
            n_discretization_steps: the number of discretization steps
            ndim_micro_shape: the number of dimensions of the micro shape
            use_probability_flow: whether to use probability flow
            use_clip: whether to use clip
            clip_sample_range: the range of the clip
            use_dyn_thresholding: whether to use dynamic thresholding
            dynamic_thresholding_ratio: the ratio of the dynamic thresholding
            sample_max_value: the maximum value of the sample used in thresholding
            betas: the betas
        Returns:
            None
        """
        super().__init__(
            n_discretization_steps=n_discretization_steps,
            ndim_micro_shape=ndim_micro_shape,
        )
        self.betas: Tensor
        if betas is None:
            # Use the same beta schedule as standard DDPMScheduler
            # Linear schedule from beta_start=0.0001 to beta_end=0.02
            self.betas = torch.linspace(
                0.0001,
                0.02,
                steps=self.n_discretization_steps,
                dtype=torch.float32,
            )
        else:
            self.betas = betas
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(cast(Tensor, self.alphas), dim=0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)  # expectation
        self.sqrt_1m_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)  # std
        self.use_clip = use_clip
        self.clip_sample_range = clip_sample_range
        self.use_dyn_thresholding = use_dyn_thresholding
        self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
        self.sample_max_value = sample_max_value

__init__(n_discretization_steps=1000, ndim_micro_shape=2, use_probability_flow=False, use_clip=False, clip_sample_range=1.0, use_dyn_thresholding=False, dynamic_thresholding_ratio=0.995, sample_max_value=1.0, betas=None, *args, **kwargs)

Parameters:

Name Type Description Default
n_discretization_steps int

the number of discretization steps

1000
ndim_micro_shape int

the number of dimensions of the micro shape

2
use_probability_flow

whether to use probability flow

False
use_clip bool

whether to use clip

False
clip_sample_range float

the range of the clip

1.0
use_dyn_thresholding bool

whether to use dynamic thresholding

False
dynamic_thresholding_ratio

the ratio of the dynamic thresholding

0.995
sample_max_value float

the maximum value of the sample used in thresholding

1.0
betas

the betas

None

Returns: None

Source code in src/ls_mlkit/diffusion/euclidean_ddpm_diffuser.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def __init__(
    self,
    n_discretization_steps: int = 1000,
    ndim_micro_shape: int = 2,
    use_probability_flow=False,
    use_clip: bool = False,
    clip_sample_range: float = 1.0,
    use_dyn_thresholding: bool = False,
    dynamic_thresholding_ratio=0.995,
    sample_max_value: float = 1.0,
    betas=None,
    *args,
    **kwargs,
):
    r"""
    Args:
        n_discretization_steps: the number of discretization steps
        ndim_micro_shape: the number of dimensions of the micro shape
        use_probability_flow: whether to use probability flow
        use_clip: whether to use clip
        clip_sample_range: the range of the clip
        use_dyn_thresholding: whether to use dynamic thresholding
        dynamic_thresholding_ratio: the ratio of the dynamic thresholding
        sample_max_value: the maximum value of the sample used in thresholding
        betas: the betas
    Returns:
        None
    """
    super().__init__(
        n_discretization_steps=n_discretization_steps,
        ndim_micro_shape=ndim_micro_shape,
    )
    self.betas: Tensor
    if betas is None:
        # Use the same beta schedule as standard DDPMScheduler
        # Linear schedule from beta_start=0.0001 to beta_end=0.02
        self.betas = torch.linspace(
            0.0001,
            0.02,
            steps=self.n_discretization_steps,
            dtype=torch.float32,
        )
    else:
        self.betas = betas
    self.alphas = 1.0 - self.betas
    self.alphas_cumprod = torch.cumprod(cast(Tensor, self.alphas), dim=0)
    self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)  # expectation
    self.sqrt_1m_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)  # std
    self.use_clip = use_clip
    self.clip_sample_range = clip_sample_range
    self.use_dyn_thresholding = use_dyn_thresholding
    self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
    self.sample_max_value = sample_max_value

EuclideanDDPMDiffuser

Bases: EuclideanDiffuser

Source code in src/ls_mlkit/diffusion/euclidean_ddpm_diffuser.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
@inherit_docstrings
class EuclideanDDPMDiffuser(EuclideanDiffuser):
    def __init__(
        self,
        config: EuclideanDDPMConfig,
        time_scheduler: DiffusionTimeScheduler,
        masker: MaskerInterface,
        model: Module,
        loss_fn: Callable[[Tensor, Tensor, Tensor], Tensor],  # (predicted, ground_true, padding_mask)
    ):
        """Initialize the EuclideanDDPMDiffuser

        Args:
            config (EuclideanDDPMConfig): the config of the diffuser
            time_scheduler (DiffusionTimeScheduler): the time scheduler of the diffuser
            masker (MaskerInterface): the masker of the diffuser
            model (Module): the model of the diffuser
            loss_fn (Callable[[Tensor, Tensor, Tensor], Tensor]): the loss function of the diffuser

        Returns:
            None
        """
        super().__init__(config=config, time_scheduler=time_scheduler, masker=masker)
        self.config: EuclideanDDPMConfig = config
        self.model = model
        self.loss_fn = loss_fn

    def prior_sampling(self, shape: Tuple[int, ...]) -> Tensor:
        return torch.randn(shape)

    def compute_loss(self, **batch) -> dict:
        mode: Literal["epsilon", "x_0", "score"] = batch.get("mode", "epsilon")
        x_0 = batch["gt_data"]
        padding_mask = batch["padding_mask"]
        device = x_0.device

        macro_shape = self.get_macro_shape(x_0)  # (b, )
        macro_shape = self.hook_manager.run_hooks(
            stage=GMHookStageType.POST_GET_MACRO_SHAPE,
            tgt_key_name="macro_shape",
            macro_shape=macro_shape,
            batch=batch,
        )
        macro_shape = cast(tuple[int, ...], macro_shape)
        t = self.time_scheduler.sample_timestep_index_uniformly(macro_shape).to(device)  # (b, )
        t = self.hook_manager.run_hooks(
            stage=GMHookStageType.POST_SAMPLING_TIME_STEP,
            tgt_key_name="t",
            t=t,
            batch=batch,
        )
        t = cast(Tensor, t)
        sqrt_1m_alphas_cumprod = self.complete_micro_shape(self.config.sqrt_1m_alphas_cumprod[t])
        sqrt_alphas_cumprod = self.complete_micro_shape(self.config.sqrt_alphas_cumprod[t])
        b = sqrt_1m_alphas_cumprod
        a = sqrt_alphas_cumprod

        forward_result = self.forward_process(x_0, t, padding_mask)
        x_t, noise = (forward_result["x_t"], forward_result["noise"])
        batch["t"] = t
        batch["x_t"] = x_t
        with TemporaryKeyRemover(mapping=batch, keys=["gt_data", "mode"]):
            model_output = self.model(**batch)

        # Simplified loss calculation following standard DDPM
        if mode == "epsilon":
            p_noise = model_output["x"]
            # Standard DDPM loss: MSE between predicted and actual noise
            loss = self.loss_fn(p_noise, noise, padding_mask)
            p_x_0 = (x_t - b * p_noise) / a
        elif mode == "x_0":
            p_x_0 = model_output["x"]
            # Convert to noise prediction for consistent loss calculation
            p_noise = (x_t - a * p_x_0) / b
            loss = self.loss_fn(p_noise, noise, padding_mask)
        elif mode == "score":
            raise ValueError(f"Currently not supported mode: {mode}")
        else:
            raise ValueError(f"Invalid mode: {mode}")

        return {
            "loss": loss,
            # ======================================
            "gt_data": x_0,
            "t": t,
            "x_t": x_t,
            "noise": noise,
            "p_noise": p_noise,
            "p_x_0": p_x_0,
            "padding_mask": padding_mask,
            "a": a,
            "b": b,
            "loss_fn": self.loss_fn,
            "mode": mode,
            "config": self.config,
            # ======================================
            "base_model_output": model_output,
        }

    def q_xt_x_0(self, x_0: Tensor, t: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
        r"""Forward process

        .. math::

            q(x_t|x_0) = \mathcal{N}(\sqrt{\alpha_t} x_0, \sqrt{1-\alpha_t} I)

        Args:
            x_0 (Tensor): :math:`x_0`
            t (Tensor): :math:`t`
            mask (Tensor): the mask of the sample

        Returns:
            Tuple[Tensor, Tensor]: the expectation and standard deviation of the sample
        """
        config = cast(EuclideanDDPMConfig, self.config.to(t))
        expectation = self.complete_micro_shape(config.sqrt_alphas_cumprod[t]) * x_0
        standard_deviation = self.complete_micro_shape(config.sqrt_1m_alphas_cumprod[t])
        return expectation, standard_deviation

    def forward_process_n_step(
        self,
        x: Tensor,
        t: Tensor,
        next_t: Tensor,
        padding_mask: Tensor,
        *args: Any,
        **kwargs: Any,
    ) -> Tensor:
        assert (next_t > t).all()
        assert (t >= 0).all()
        assert (next_t < self.config.n_discretization_steps).all()
        config = cast(EuclideanDDPMConfig, self.config.to(t))
        a_square = config.alphas_cumprod[next_t] / config.alphas_cumprod[t]
        a = a_square**0.5
        b = (1 - a_square) ** 0.5
        a = self.complete_micro_shape(a)
        b = self.complete_micro_shape(b)
        noise = torch.randn_like(x)
        x_next = a * x + b * noise
        return x_next

    def forward_process(
        self,
        x_0: Tensor,
        discrete_t: Tensor,
        mask: Tensor,
        **kwargs: Any,
    ) -> dict:
        device = x_0.device
        expectation, standard_deviation = self.q_xt_x_0(x_0, discrete_t, mask)
        noise = torch.randn_like(expectation, device=device)
        x_t = expectation + standard_deviation * noise
        return {
            "x_t": x_t,
            "noise": noise,
            "expectation": expectation,
            "standard_deviation": standard_deviation,
        }

    def step(
        self,
        x_t: Tensor,
        t: Tensor,
        padding_mask: Tensor | None = None,
        *args: Any,
        **kwargs: Any,
    ) -> dict:
        r"""
        Predict the sample from the previous timestep by reversing the SDE.
        This function propagates the diffusion process from the learned model outputs.

        Based on the standard DDPM sampling formula:

        .. math::

            \hat{\mathbf{x}}_0:=\frac{1}{\sqrt{\bar{\alpha}_t}}(\mathbf{x}_t - \sqrt{1-\bar{\alpha}_t}\mathbf{\epsilon}_{\theta}(\mathbf{x}_t,t))

            \mathcal{N}\left( \boldsymbol{x}_{t-1}; \underbrace{\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})\boldsymbol{x}_t + \sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)\hat{\boldsymbol{x}}_0}{1-\bar{\alpha}_t}}_{\mu_q(\boldsymbol{x}_t, \hat{\boldsymbol{x}}_0)}, \underbrace{\frac{(1-\alpha_t)(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}\mathbf{I}}_{\Sigma_q(t)} \right)

        Args:
            x_t (Tensor): the sample at timestep t
            t (Tensor): the timestep
            padding_mask (Tensor): the padding mask

        Returns:
            dict:
                "x": the sample at timestep t-1
                "E_x0_xt": the predicted original sample
        """
        mode: Literal["epsilon", "x_0", "score"] = kwargs.get("mode", "epsilon")
        assert torch.all(t == t.view(-1)[0]).item()
        config = cast(EuclideanDDPMConfig, self.config.to(t))

        # Convert to scalar timestep for indexing
        t_scalar = t.view(-1)[0].long()

        # Get model prediction
        model_output = self.model(
            **{"x_t": x_t, "t": t.long(), "padding_mask": padding_mask, **kwargs}
        )

        if mode == "epsilon":
            model_pred = model_output["x"]
            hook_input = {
                "x_t": x_t,
                "t": t,
                "p_noise": model_pred,
                "padding_mask": padding_mask,
                "config": self.config,
                "sampling_condition": kwargs.get("sampling_condition"),
                "b": self.complete_micro_shape(self.config.sqrt_1m_alphas_cumprod[t]),
            }
            hook_output = self.hook_manager.run_hooks(
                GMHookStageType.PRE_UPDATE_IN_STEP_FN,
                tgt_key_name="p_noise",
                **hook_input,
            )
            if hook_output is not None:
                model_pred = hook_output
        elif mode == "x_0":
            raise ValueError(f"Currently not supported mode: {mode}")
            model_pred = model_output["x"]
        elif mode == "score":
            raise ValueError(f"Currently not supported mode: {mode}")
        else:
            raise ValueError(f"Invalid mode: {mode}")

        # Calculate previous timestep (handle both standard and custom timestep schedules)
        prev_t = self._get_previous_timestep(int(t_scalar.item()))

        # Get alpha values
        alpha_prod_t = config.alphas_cumprod[t_scalar]
        alpha_prod_t_prev = config.alphas_cumprod[prev_t] if prev_t >= 0 else torch.tensor(1.0).to(t.device)
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev
        current_alpha_t = alpha_prod_t / alpha_prod_t_prev
        current_beta_t = 1 - current_alpha_t

        # Compute predicted original sample from predicted noise
        if mode == "epsilon":
            pred_original_sample = (x_t - beta_prod_t**0.5 * model_pred) / alpha_prod_t**0.5
        elif mode == "x_0":
            raise ValueError(f"Currently not supported mode: {mode}")
            pred_original_sample = model_pred

        # Clip predicted x_0 (following standard DDPM implementation)
        # 3. Clip or threshold "predicted x_0"
        if self.config.use_dyn_thresholding:
            pred_original_sample = self._threshold_sample(pred_original_sample)
        elif self.config.use_clip:
            pred_original_sample = pred_original_sample.clamp(
                -self.config.clip_sample_range, self.config.clip_sample_range
            )
        # Compute coefficients for pred_original_sample x_0 and current sample x_t
        # See formula (7) from https://huggingface.co/papers/2006.11239
        pred_original_sample_coeff = (alpha_prod_t_prev**0.5 * current_beta_t) / beta_prod_t
        current_sample_coeff = current_alpha_t**0.5 * beta_prod_t_prev / beta_prod_t

        # Compute predicted previous sample µ_t
        pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * x_t

        # Add noise (variance) - following standard DDPM variance calculation
        variance = 0
        t_scalar_int = int(t_scalar.item())
        if t_scalar_int > 0:
            # Standard DDPM variance: β_t * (1 - α̅_{t-1}) / (1 - α̅_t)
            variance_value = self._get_variance(t_scalar_int, alpha_prod_t, alpha_prod_t_prev, current_beta_t)
            variance_noise = torch.randn_like(x_t)
            variance = (variance_value**0.5) * variance_noise
            pred_prev_sample = pred_prev_sample + variance

        return {
            "x": pred_prev_sample,
            "E_x0_xt": pred_original_sample,
        }

    def _get_previous_timestep(self, timestep: int) -> int:
        r"""Get the previous timestep for sampling.

        Args:
            timestep (int): timestep

        Returns:
            int: the previous timestep for sampling
        """
        return timestep - 1

    def _get_variance(
        self,
        t: int,
        alpha_prod_t: Tensor,
        alpha_prod_t_prev: Tensor,
        current_beta_t: Tensor,
    ) -> Tensor:
        r"""Calculate variance for timestep t following standard DDPM formula. For t > 0, compute predicted variance βt (see formula (6) and (7) from https://huggingface.co/papers/2006.11239)

        .. math::

            \sigma^2 = (\frac{1 - \bar{\alpha}_{pre}}{1 - \bar{\alpha}_{t}}) \cdot ( 1- \frac{\bar{\alpha}_{t}}{\bar{\alpha}_{pre}})

        Args:
            t (int): timestep
            alpha_prod_t (Tensor): :math:`\bar{\alpha}_t`
            alpha_prod_t_prev (Tensor): :math:`\bar{\alpha}_{t-1}`
            current_beta_t (Tensor): :math:`\beta_t`

        Returns:
            Tensor: the variance for timestep t
        """
        variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
        # Clamp variance to ensure numerical stability
        variance = torch.clamp(variance, min=1e-20)
        return variance

    def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
        """
        "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
        prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
        s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
        pixels from saturation at each step. We find that dynamic thresholding results in significantly better
        photorealism as well as better image-text alignment, especially when using very large guidance weights."

        https://huggingface.co/papers/2205.11487
        """
        dtype = sample.dtype
        batch_size, channels, *remaining_dims = sample.shape

        if dtype not in (torch.float32, torch.float64):
            sample = sample.float()  # upcast for quantile calculation, and clamp not implemented for cpu half

        # Flatten sample for doing quantile calculation along each image
        sample = sample.reshape(batch_size, channels * int(np.prod(remaining_dims)))

        abs_sample = sample.abs()  # "a certain percentile absolute pixel value"

        s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)  # (batch_size, 1)
        s = torch.clamp(
            s, min=1, max=self.config.sample_max_value
        )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]
        s = s.unsqueeze(1)  # (batch_size, 1) because clamp will broadcast along dim=0
        sample = torch.clamp(sample, -s, s) / s  # "we threshold xt0 to the range [-s, s] and then divide by s"

        sample = sample.reshape(batch_size, channels, *remaining_dims)
        sample = sample.to(dtype)

        return sample

    def get_posterior_mean_fn(
        self,
        score: Tensor | None = None,
        score_fn: Callable[[Tensor, Tensor, Tensor | None], Tensor] | None = None,
    ):
        r"""Get the posterior mean function

        Args:
            score (Tensor, optional): the score of the sample
            score_fn (Callable, optional): the function to compute score

        Returns:
            Callable: the posterior mean function
        """

        def _ddpm_posterior_mean_fn(
            x_t: Tensor,
            t: Tensor,
            padding_mask: Tensor,
        ):
            r"""
            Args:
                x_t: shape=(..., n_nodes, 3)
                t: shape=(...), dtype=torch.long

            For the case of DDPM sampling, the posterior mean is given by

            .. math::

                E[x_0|x_t] = \frac{1}{\sqrt{\bar{\alpha}(t)}}(x_t + (1 - \bar{\alpha}(t))\nabla_{x_t}\log p_t(x_t))

            """
            nonlocal score, score_fn
            assert score is not None or score_fn is not None, "either score or score_fn must be provided"
            t = t.view(*t.shape, *([1] * (x_t.ndim - t.ndim)))
            if score is None:
                assert score_fn is not None
                score = score_fn(x_t, t, padding_mask)
            config = cast(EuclideanDDPMConfig, self.config.to(t))
            alpha_bar_t = config.alphas_cumprod[t]  # macro_shape
            alpha_bar_t.view(*alpha_bar_t.shape, *([1] * config.ndim_micro_shape))
            x_0 = (x_t + (1 - alpha_bar_t) * score) / torch.sqrt(alpha_bar_t)
            return x_0

        return _ddpm_posterior_mean_fn

    def get_condition_post_compute_loss_hook(self, conditioner_list: list[Conditioner]):
        def _hook_fn(**kwargs: Any):
            nonlocal conditioner_list
            x_0 = require(cast(Tensor | None, kwargs.get("gt_data")), "gt_data")
            x_t = require(cast(Tensor | None, kwargs.get("x_t")), "x_t")
            t = require(cast(Tensor | None, kwargs.get("t")), "t")
            noise = require(cast(Tensor | None, kwargs.get("noise")), "noise")
            p_noise = require(cast(Tensor | None, kwargs.get("p_noise")), "p_noise")
            padding_mask = require(cast(Tensor | None, kwargs.get("padding_mask")), "padding_mask")
            b = require(cast(Tensor | None, kwargs.get("b")), "b")
            loss_fn = require(cast(Callable[..., Any] | None, kwargs.get("loss_fn")), "loss_fn")

            p_uc_score = -p_noise / b
            gt_uc_score = -noise / b

            tgt_mask = padding_mask
            for conditioner in conditioner_list:
                if not conditioner.is_enabled():
                    continue
                conditioner.set_condition(
                    **{
                        **conditioner.prepare_condition_dict(
                            train=True,
                            **{
                                "tgt_mask": tgt_mask,
                                "gt_data": x_0,
                                "padding_mask": padding_mask,
                                "posterior_mean_fn": self.get_posterior_mean_fn(score=p_uc_score, score_fn=None),
                            },
                        ),
                    }
                )

            acc_c_score = get_accumulated_conditional_score(conditioner_list, x_t, t, padding_mask)
            gt_score = gt_uc_score + acc_c_score

            # Scale and compute conditioned loss
            p_uc_score = b * p_uc_score
            gt_score = b * gt_score
            new_loss = loss_fn(p_uc_score, gt_score, padding_mask)
            kwargs["loss"] = new_loss
            return kwargs

        return GMHook(
            name="DDPM_condition_post_compute_loss_hook",
            stage=GMHookStageType.POST_COMPUTE_LOSS,
            fn=_hook_fn,
            priority=0,
            enabled=True,
        )

    def get_condition_pre_update_in_step_fn_hook(self, conditioner_list: list[Conditioner]):
        def _hook_fn(**kwargs: Any):
            nonlocal conditioner_list
            x_t = require(cast(Tensor | None, kwargs.get("x_t")), "x_t")
            t = require(cast(Tensor | None, kwargs.get("t")), "t")
            p_noise = require(cast(Tensor | None, kwargs.get("p_noise")), "p_noise")
            padding_mask = require(cast(Tensor | None, kwargs.get("padding_mask")), "padding_mask")
            b = require(cast(Tensor | None, kwargs.get("b")), "b")
            sampling_condition = kwargs.get("sampling_condition")
            p_uc_score = -p_noise / b

            tgt_mask = padding_mask
            for conditioner in conditioner_list:
                if not conditioner.is_enabled():
                    continue
                conditioner.set_condition(
                    **{
                        **conditioner.prepare_condition_dict(
                            train=False,
                            **{
                                "tgt_mask": tgt_mask,
                                "sampling_condition": sampling_condition,
                                "padding_mask": padding_mask,
                                "posterior_mean_fn": self.get_posterior_mean_fn(score=p_uc_score, score_fn=None),
                            },
                        ),
                    }
                )

            acc_c_score = get_accumulated_conditional_score(conditioner_list, x_t, t, padding_mask)
            # Scale and compute conditioned loss
            p_epsilon = -b * (p_uc_score + acc_c_score)
            return p_epsilon

        return GMHook(
            name="DDPM_condition_pre_update_in_step_fn_hook",
            stage=GMHookStageType.PRE_UPDATE_IN_STEP_FN,
            fn=_hook_fn,
            priority=0,
            enabled=True,
        )

__init__(config, time_scheduler, masker, model, loss_fn)

Initialize the EuclideanDDPMDiffuser

Parameters:

Name Type Description Default
config EuclideanDDPMConfig

the config of the diffuser

required
time_scheduler DiffusionTimeScheduler

the time scheduler of the diffuser

required
masker MaskerInterface

the masker of the diffuser

required
model Module

the model of the diffuser

required
loss_fn Callable[[Tensor, Tensor, Tensor], Tensor]

the loss function of the diffuser

required

Returns:

Type Description

None

Source code in src/ls_mlkit/diffusion/euclidean_ddpm_diffuser.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def __init__(
    self,
    config: EuclideanDDPMConfig,
    time_scheduler: DiffusionTimeScheduler,
    masker: MaskerInterface,
    model: Module,
    loss_fn: Callable[[Tensor, Tensor, Tensor], Tensor],  # (predicted, ground_true, padding_mask)
):
    """Initialize the EuclideanDDPMDiffuser

    Args:
        config (EuclideanDDPMConfig): the config of the diffuser
        time_scheduler (DiffusionTimeScheduler): the time scheduler of the diffuser
        masker (MaskerInterface): the masker of the diffuser
        model (Module): the model of the diffuser
        loss_fn (Callable[[Tensor, Tensor, Tensor], Tensor]): the loss function of the diffuser

    Returns:
        None
    """
    super().__init__(config=config, time_scheduler=time_scheduler, masker=masker)
    self.config: EuclideanDDPMConfig = config
    self.model = model
    self.loss_fn = loss_fn

q_xt_x_0(x_0, t, mask)

Forward process

.. math::

q(x_t|x_0) = \mathcal{N}(\sqrt{\alpha_t} x_0, \sqrt{1-\alpha_t} I)

Parameters:

Name Type Description Default
x_0 Tensor

:math:x_0

required
t Tensor

:math:t

required
mask Tensor

the mask of the sample

required

Returns:

Type Description
Tuple[Tensor, Tensor]

Tuple[Tensor, Tensor]: the expectation and standard deviation of the sample

Source code in src/ls_mlkit/diffusion/euclidean_ddpm_diffuser.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def q_xt_x_0(self, x_0: Tensor, t: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
    r"""Forward process

    .. math::

        q(x_t|x_0) = \mathcal{N}(\sqrt{\alpha_t} x_0, \sqrt{1-\alpha_t} I)

    Args:
        x_0 (Tensor): :math:`x_0`
        t (Tensor): :math:`t`
        mask (Tensor): the mask of the sample

    Returns:
        Tuple[Tensor, Tensor]: the expectation and standard deviation of the sample
    """
    config = cast(EuclideanDDPMConfig, self.config.to(t))
    expectation = self.complete_micro_shape(config.sqrt_alphas_cumprod[t]) * x_0
    standard_deviation = self.complete_micro_shape(config.sqrt_1m_alphas_cumprod[t])
    return expectation, standard_deviation

step(x_t, t, padding_mask=None, *args, **kwargs)

Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs.

Based on the standard DDPM sampling formula:

.. math::

\hat{\mathbf{x}}_0:=\frac{1}{\sqrt{\bar{\alpha}_t}}(\mathbf{x}_t - \sqrt{1-\bar{\alpha}_t}\mathbf{\epsilon}_{\theta}(\mathbf{x}_t,t))

\mathcal{N}\left( \boldsymbol{x}_{t-1}; \underbrace{\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})\boldsymbol{x}_t + \sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)\hat{\boldsymbol{x}}_0}{1-\bar{\alpha}_t}}_{\mu_q(\boldsymbol{x}_t, \hat{\boldsymbol{x}}_0)}, \underbrace{\frac{(1-\alpha_t)(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}\mathbf{I}}_{\Sigma_q(t)} \right)

Parameters:

Name Type Description Default
x_t Tensor

the sample at timestep t

required
t Tensor

the timestep

required
padding_mask Tensor

the padding mask

None

Returns:

Name Type Description
dict dict

"x": the sample at timestep t-1 "E_x0_xt": the predicted original sample

Source code in src/ls_mlkit/diffusion/euclidean_ddpm_diffuser.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
def step(
    self,
    x_t: Tensor,
    t: Tensor,
    padding_mask: Tensor | None = None,
    *args: Any,
    **kwargs: Any,
) -> dict:
    r"""
    Predict the sample from the previous timestep by reversing the SDE.
    This function propagates the diffusion process from the learned model outputs.

    Based on the standard DDPM sampling formula:

    .. math::

        \hat{\mathbf{x}}_0:=\frac{1}{\sqrt{\bar{\alpha}_t}}(\mathbf{x}_t - \sqrt{1-\bar{\alpha}_t}\mathbf{\epsilon}_{\theta}(\mathbf{x}_t,t))

        \mathcal{N}\left( \boldsymbol{x}_{t-1}; \underbrace{\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})\boldsymbol{x}_t + \sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)\hat{\boldsymbol{x}}_0}{1-\bar{\alpha}_t}}_{\mu_q(\boldsymbol{x}_t, \hat{\boldsymbol{x}}_0)}, \underbrace{\frac{(1-\alpha_t)(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}\mathbf{I}}_{\Sigma_q(t)} \right)

    Args:
        x_t (Tensor): the sample at timestep t
        t (Tensor): the timestep
        padding_mask (Tensor): the padding mask

    Returns:
        dict:
            "x": the sample at timestep t-1
            "E_x0_xt": the predicted original sample
    """
    mode: Literal["epsilon", "x_0", "score"] = kwargs.get("mode", "epsilon")
    assert torch.all(t == t.view(-1)[0]).item()
    config = cast(EuclideanDDPMConfig, self.config.to(t))

    # Convert to scalar timestep for indexing
    t_scalar = t.view(-1)[0].long()

    # Get model prediction
    model_output = self.model(
        **{"x_t": x_t, "t": t.long(), "padding_mask": padding_mask, **kwargs}
    )

    if mode == "epsilon":
        model_pred = model_output["x"]
        hook_input = {
            "x_t": x_t,
            "t": t,
            "p_noise": model_pred,
            "padding_mask": padding_mask,
            "config": self.config,
            "sampling_condition": kwargs.get("sampling_condition"),
            "b": self.complete_micro_shape(self.config.sqrt_1m_alphas_cumprod[t]),
        }
        hook_output = self.hook_manager.run_hooks(
            GMHookStageType.PRE_UPDATE_IN_STEP_FN,
            tgt_key_name="p_noise",
            **hook_input,
        )
        if hook_output is not None:
            model_pred = hook_output
    elif mode == "x_0":
        raise ValueError(f"Currently not supported mode: {mode}")
        model_pred = model_output["x"]
    elif mode == "score":
        raise ValueError(f"Currently not supported mode: {mode}")
    else:
        raise ValueError(f"Invalid mode: {mode}")

    # Calculate previous timestep (handle both standard and custom timestep schedules)
    prev_t = self._get_previous_timestep(int(t_scalar.item()))

    # Get alpha values
    alpha_prod_t = config.alphas_cumprod[t_scalar]
    alpha_prod_t_prev = config.alphas_cumprod[prev_t] if prev_t >= 0 else torch.tensor(1.0).to(t.device)
    beta_prod_t = 1 - alpha_prod_t
    beta_prod_t_prev = 1 - alpha_prod_t_prev
    current_alpha_t = alpha_prod_t / alpha_prod_t_prev
    current_beta_t = 1 - current_alpha_t

    # Compute predicted original sample from predicted noise
    if mode == "epsilon":
        pred_original_sample = (x_t - beta_prod_t**0.5 * model_pred) / alpha_prod_t**0.5
    elif mode == "x_0":
        raise ValueError(f"Currently not supported mode: {mode}")
        pred_original_sample = model_pred

    # Clip predicted x_0 (following standard DDPM implementation)
    # 3. Clip or threshold "predicted x_0"
    if self.config.use_dyn_thresholding:
        pred_original_sample = self._threshold_sample(pred_original_sample)
    elif self.config.use_clip:
        pred_original_sample = pred_original_sample.clamp(
            -self.config.clip_sample_range, self.config.clip_sample_range
        )
    # Compute coefficients for pred_original_sample x_0 and current sample x_t
    # See formula (7) from https://huggingface.co/papers/2006.11239
    pred_original_sample_coeff = (alpha_prod_t_prev**0.5 * current_beta_t) / beta_prod_t
    current_sample_coeff = current_alpha_t**0.5 * beta_prod_t_prev / beta_prod_t

    # Compute predicted previous sample µ_t
    pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * x_t

    # Add noise (variance) - following standard DDPM variance calculation
    variance = 0
    t_scalar_int = int(t_scalar.item())
    if t_scalar_int > 0:
        # Standard DDPM variance: β_t * (1 - α̅_{t-1}) / (1 - α̅_t)
        variance_value = self._get_variance(t_scalar_int, alpha_prod_t, alpha_prod_t_prev, current_beta_t)
        variance_noise = torch.randn_like(x_t)
        variance = (variance_value**0.5) * variance_noise
        pred_prev_sample = pred_prev_sample + variance

    return {
        "x": pred_prev_sample,
        "E_x0_xt": pred_original_sample,
    }

get_posterior_mean_fn(score=None, score_fn=None)

Get the posterior mean function

Parameters:

Name Type Description Default
score Tensor

the score of the sample

None
score_fn Callable

the function to compute score

None

Returns:

Name Type Description
Callable

the posterior mean function

Source code in src/ls_mlkit/diffusion/euclidean_ddpm_diffuser.py
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
def get_posterior_mean_fn(
    self,
    score: Tensor | None = None,
    score_fn: Callable[[Tensor, Tensor, Tensor | None], Tensor] | None = None,
):
    r"""Get the posterior mean function

    Args:
        score (Tensor, optional): the score of the sample
        score_fn (Callable, optional): the function to compute score

    Returns:
        Callable: the posterior mean function
    """

    def _ddpm_posterior_mean_fn(
        x_t: Tensor,
        t: Tensor,
        padding_mask: Tensor,
    ):
        r"""
        Args:
            x_t: shape=(..., n_nodes, 3)
            t: shape=(...), dtype=torch.long

        For the case of DDPM sampling, the posterior mean is given by

        .. math::

            E[x_0|x_t] = \frac{1}{\sqrt{\bar{\alpha}(t)}}(x_t + (1 - \bar{\alpha}(t))\nabla_{x_t}\log p_t(x_t))

        """
        nonlocal score, score_fn
        assert score is not None or score_fn is not None, "either score or score_fn must be provided"
        t = t.view(*t.shape, *([1] * (x_t.ndim - t.ndim)))
        if score is None:
            assert score_fn is not None
            score = score_fn(x_t, t, padding_mask)
        config = cast(EuclideanDDPMConfig, self.config.to(t))
        alpha_bar_t = config.alphas_cumprod[t]  # macro_shape
        alpha_bar_t.view(*alpha_bar_t.shape, *([1] * config.ndim_micro_shape))
        x_0 = (x_t + (1 - alpha_bar_t) * score) / torch.sqrt(alpha_bar_t)
        return x_0

    return _ddpm_posterior_mean_fn

EuclideanDiffuser

Bases: BaseDiffuser

Source code in src/ls_mlkit/diffusion/euclidean_diffuser.py
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
@inherit_docstrings
class EuclideanDiffuser(BaseDiffuser):
    def __init__(
        self,
        config: EuclideanDiffuserConfig,
        time_scheduler: DiffusionTimeScheduler,
        masker: MaskerInterface,
    ):
        super().__init__(config=config, time_scheduler=time_scheduler)
        self.config: EuclideanDiffuserConfig = config
        self.time_scheduler: DiffusionTimeScheduler = time_scheduler
        self.masker = masker

    @torch.no_grad()
    def sampling(
        self,
        shape: Tuple[int, ...],
        device,
        x_init_posterior: Optional[Tensor] = None,
        return_all=False,
        *args: Any,
        **kwargs: Any,
    ) -> dict:
        if x_init_posterior is not None:
            shape = x_init_posterior.shape
        macro_shape = shape[: -self.config.ndim_micro_shape]
        macro_shape = self.hook_manager.run_hooks(
            stage=GMHookStageType.POST_GET_MACRO_SHAPE,
            tgt_key_name="macro_shape",
            macro_shape=macro_shape,
            batch=kwargs,
        )
        assert macro_shape is not None
        masker = self.masker
        if x_init_posterior is None:
            x_t = self.prior_sampling(shape).to(device)
        else:
            padding_mask = kwargs.get("padding_mask", None)
            if padding_mask is None:
                padding_mask = masker.get_full_bright_mask(x_init_posterior)
            t_a = torch.ones(tuple(macro_shape), device=device, dtype=torch.long) * (
                self.time_scheduler.get_timestep_index_start() - 1
            )
            t_a = self.hook_manager.run_hooks(
                stage=GMHookStageType.POST_SAMPLING_TIME_STEP,
                tgt_key_name="t",
                t=t_a,
                batch=kwargs,
            )
            t_a = cast(Tensor, t_a)
            t_a = self.complete_micro_shape(t_a)
            t_b = (
                torch.ones(tuple(macro_shape), device=device, dtype=torch.long)
                * self.time_scheduler.get_timestep_index_end()
            )
            t_b = self.hook_manager.run_hooks(
                stage=GMHookStageType.POST_SAMPLING_TIME_STEP,
                tgt_key_name="t",
                t=t_b,
                batch=kwargs,
            )
            t_b = cast(Tensor, t_b)
            t_b = self.complete_micro_shape(t_b)

            x_t = self.forward_process(
                x_init_posterior,
                t_a,
                t_b,
                padding_mask,
                is_continuous_time=False,
            )["x_t"]

        x_list = [x_t]
        E_x0_xt_list = [x_t]

        time_steps = self.time_scheduler.get_timestep_indices_schedule().to(device)
        for idx, t in enumerate(tqdm(time_steps)):
            t = torch.ones(tuple(macro_shape), device=device, dtype=torch.long) * t
            t = self.hook_manager.run_hooks(
                stage=GMHookStageType.POST_SAMPLING_TIME_STEP,
                tgt_key_name="t",
                t=t,
                batch=kwargs,
            )
            t = cast(Tensor, t)
            t = self.complete_micro_shape(t)
            no_padding_mask = masker.get_full_bright_mask(x_t)
            kwargs["idx"] = idx
            step_kwargs = {k: v for k, v in kwargs.items() if k not in ("x_t", "t", "padding_mask")}
            step_output = self.step(x_t=x_t, t=t, padding_mask=no_padding_mask, **step_kwargs)
            x_t = step_output["x"]
            if "E_x0_xt" in step_output:
                E_x0_xt_list.append(step_output["E_x0_xt"])
            if return_all:
                x_list.append(x_t)
        return {"x": x_t, "x_list": x_list, "E_x0_xt_list": E_x0_xt_list}

    @torch.no_grad()
    def inpainting(
        self,
        x: Tensor,
        padding_mask: Tensor,
        inpainting_mask: Tensor,
        device,
        x_init_posterior: Optional[Tensor] = None,
        inpainting_mask_key: str = "inpainting_mask",
        sapmling_condition_key: Optional[str] = "sapmling_condition",
        return_all: bool = False,
        sampling_condition: Optional[Any] = None,
        n_repaint_steps: int = 1,
        **kwargs: Any,
    ) -> dict:
        self.config = cast(EuclideanDiffuserConfig, self.config.to(device))
        x_0 = x
        shape = x_0.shape
        macro_shape = shape[: -self.config.ndim_micro_shape]
        # >>>>>>>>>>>>>>>>>>>
        macro_shape = self.hook_manager.run_hooks(
            stage=GMHookStageType.POST_GET_MACRO_SHAPE,
            tgt_key_name="macro_shape",
            macro_shape=macro_shape,
            batch=kwargs,
        )
        assert macro_shape is not None
        # <<<<<<<<<<<<<<<<<<
        masker = self.masker
        # Add inpainting_mask to kwargs so it gets passed to the model
        kwargs[inpainting_mask_key] = inpainting_mask

        x_t = None
        if x_init_posterior is None:
            x_t = self.prior_sampling(shape).to(device)
        else:
            t_a = torch.ones(tuple(macro_shape), device=device, dtype=torch.long) * (
                self.time_scheduler.get_timestep_index_start() - 1
            )
            t_a = self.hook_manager.run_hooks(
                stage=GMHookStageType.POST_SAMPLING_TIME_STEP,
                tgt_key_name="t",
                t=t_a,
                batch=kwargs,
            )
            t_a = cast(Tensor, t_a)
            t_a = self.complete_micro_shape(t_a)
            t_b = (
                torch.ones(tuple(macro_shape), device=device, dtype=torch.long)
                * self.time_scheduler.get_timestep_index_end()
            )
            t_b = self.hook_manager.run_hooks(
                stage=GMHookStageType.POST_SAMPLING_TIME_STEP,
                tgt_key_name="t",
                t=t_b,
                batch=kwargs,
            )
            t_b = cast(Tensor, t_b)
            t_b = self.complete_micro_shape(t_b)

            x_t = self.forward_process(
                x_init_posterior,
                t_a,
                t_b,
                padding_mask,
                is_continuous_time=False,
            )["x_t"]
        x_0 = masker.apply_mask(x_0, padding_mask)
        x_T = x_t.detach().clone()

        x_list = [x_t]
        E_x0_xt_list = [x_t]

        timesteps = self.time_scheduler.get_timestep_indices_schedule().to(device)
        for i, t in enumerate(tqdm(timesteps)):
            t = torch.ones(tuple(macro_shape), device=device, dtype=torch.long) * t
            t = self.hook_manager.run_hooks(
                stage=GMHookStageType.POST_SAMPLING_TIME_STEP,
                tgt_key_name="t",
                t=t,
                batch=kwargs,
            )
            t = cast(Tensor, t)
            t = self.complete_micro_shape(t)
            for u in range(1, n_repaint_steps + 1):
                x_t = self.recover_bright_region(
                    x_known=x_0,
                    x_t=x_t,
                    t=t,
                    padding_mask=padding_mask,
                    x_prior=x_T,
                    **kwargs,
                )
                step_output = self.step(x_t, t, padding_mask, **kwargs)  # get x_tm1
                x_t = step_output["x"]
                if "E_x0_xt" in step_output:
                    E_x0_xt_list.append(step_output["E_x0_xt"])
                x_t = masker.apply_mask(x_t, padding_mask)
                if u < n_repaint_steps and (t > 0).all():
                    prev_t = timesteps[i + 1].to(device)
                    prev_t = self.hook_manager.run_hooks(
                        stage=GMHookStageType.POST_SAMPLING_TIME_STEP,
                        tgt_key_name="t",
                        t=prev_t,
                        batch=kwargs,
                    )
                    prev_t = cast(Tensor, prev_t)
                    prev_t = self.complete_micro_shape(prev_t)
                    x_t = self.forward_process(
                        x_t,
                        prev_t,
                        t,
                        padding_mask,
                        is_continuous_time=False,
                        **kwargs,
                    )["x_t"]
            if return_all:
                x_list.append(x_t)
        x_t = masker.apply_inpainting_mask(x_0, x_t, inpainting_mask)

        return {"x": x_t, "x_list": x_list, "E_x0_xt_list": E_x0_xt_list}

    def recover_bright_region(
        self,
        x_known,
        x_t,
        t,
        padding_mask,
        inpainting_mask,
        x_prior,
        *args,
        **kwargs,
    ) -> Tensor:
        x_0 = x_known
        t_a = torch.ones_like(t, device=t.device) * (self.time_scheduler.get_timestep_index_start() - 1)
        t_a = self.hook_manager.run_hooks(
            stage=GMHookStageType.POST_SAMPLING_TIME_STEP,
            tgt_key_name="t",
            t=t_a,
            batch=kwargs,
        )
        t_a = cast(Tensor, t_a)
        x_0t = self.forward_process(
            x_0,
            t_a,
            t,
            padding_mask,
            is_continuous_time=False,
        )["x_t"]
        x_t = self.masker.apply_inpainting_mask(x_0t, x_t, inpainting_mask)
        return x_t

    def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
        """
        "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
        prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
        s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
        pixels from saturation at each step. We find that dynamic thresholding results in significantly better
        photorealism as well as better image-text alignment, especially when using very large guidance weights."

        https://huggingface.co/papers/2205.11487
        """
        dtype = sample.dtype
        batch_size, channels, *remaining_dims = sample.shape

        if dtype not in (torch.float32, torch.float64):
            sample = sample.float()  # upcast for quantile calculation, and clamp not implemented for cpu half

        # Flatten sample for doing quantile calculation along each image

        sample = sample.reshape(batch_size, channels * int(np.prod(remaining_dims)))

        abs_sample = sample.abs()  # "a certain percentile absolute pixel value"

        s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)  # (batch_size, 1)
        s = torch.clamp(
            s, min=1, max=self.config.sample_max_value
        )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]
        s = s.unsqueeze(1)  # (batch_size, 1) because clamp will broadcast along dim=0
        sample = torch.clamp(sample, -s, s) / s  # "we threshold xt0 to the range [-s, s] and then divide by s"

        sample = sample.reshape(batch_size, channels, *remaining_dims)
        sample = sample.to(dtype)

        return sample

EuclideanEDMConfig

Bases: EuclideanDiffuserConfig

Config Class for Euclidean EDM Diffuser

Source code in src/ls_mlkit/diffusion/euclidean_edm_diffuser.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
@inherit_docstrings
class EuclideanEDMConfig(EuclideanDiffuserConfig):
    """
    Config Class for Euclidean EDM Diffuser
    """

    def __init__(
        self,
        n_discretization_steps: int = 200,
        ndim_micro_shape: int = 2,
        P_mean: float = -1.2,
        P_std: float = 1.2,
        sigma_data: float = 0.5,
        sigma_min: float = 0.002,
        sigma_max: float = 80.0,
        rho: float = 7.0,
        use_2nd_order_correction: bool = True,
        use_ode_flow: bool = False,
        S_churn: float = 0.0,
        S_min: float = 0.0,
        S_max: float = float("inf"),
        S_noise: float = 1.0,
        use_clip: bool = False,
        clip_sample_range: float = 1.0,
        use_dyn_thresholding: bool = False,
        dynamic_thresholding_ratio=0.995,
        sample_max_value: float = 1.0,
        sigma_multiply_by_sigma_data: bool = False,
        *args,
        **kwargs,
    ):
        r"""
        Args:
            n_discretization_steps: the number of discretization steps
            ndim_micro_shape: the number of dimensions of the micro shape
            P_mean: mean of the log-normal distribution for sampling sigma during training
            P_std: standard deviation of the log-normal distribution for sampling sigma during training
            sigma_data: expected standard deviation of the training data
            sigma_min: minimum supported noise level
            sigma_max: maximum supported noise level
            rho: time step exponent for sampling schedule
        Returns:
            None
        """
        super().__init__(
            n_discretization_steps=n_discretization_steps,
            ndim_micro_shape=ndim_micro_shape,
        )
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
        self.rho = rho
        self.use_ode_flow = use_ode_flow
        self.use_2nd_order_correction = use_2nd_order_correction
        self.S_churn = S_churn
        self.S_min = S_min
        self.S_max = S_max
        self.S_noise = S_noise

        self.use_clip = use_clip
        self.clip_sample_range = clip_sample_range
        self.use_dyn_thresholding = use_dyn_thresholding
        self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
        self.sample_max_value = sample_max_value

        self.sigma_multiply_by_sigma_data = sigma_multiply_by_sigma_data

        step_indices = torch.arange(n_discretization_steps + 1, dtype=torch.float32)
        self.sigma_schedule: Tensor = (
            sigma_min ** (1 / rho)
            + (step_indices - 1) / (n_discretization_steps - 1) * (sigma_max ** (1 / rho) - sigma_min ** (1 / rho))
        ) ** rho
        self.sigma_schedule[0] = 0.0

    def c_in(self, sigma: Tensor) -> Tensor:
        return 1 / torch.sqrt(sigma**2 + self.sigma_data**2)

    def c_noise(self, sigma: Tensor) -> Tensor:
        return 1 / 4 * torch.log(sigma)

    def c_skip(self, sigma: Tensor) -> Tensor:
        return self.sigma_data**2 / (sigma**2 + self.sigma_data**2)

    def c_out(self, sigma: Tensor) -> Tensor:
        return sigma * self.sigma_data / torch.sqrt(sigma**2 + self.sigma_data**2)

    def sigma(self, t: Tensor, is_continuous_time: bool = True) -> Tensor:
        if is_continuous_time:
            return t
        else:
            return self.timestep_index_to_sigma(t)

    def timestep_index_to_sigma(self, timestep_index: Tensor) -> Tensor:
        """Convert discrete timesteps to sigma values.

        Args:
            discrete_t: discrete timesteps, shape=(...)

        Returns:
            sigma: noise levels, shape=(...)
        """
        timestep_index = timestep_index.clamp(1, self.n_discretization_steps).long()
        return self.sigma_schedule[timestep_index].to(timestep_index.device)

    def compute_loss_weight(self, sigma: Tensor) -> Tensor:
        """Compute EDM loss weight: (sigma² + sigma_data²) / (sigma * sigma_data)².

        Args:
            sigma: noise level, shape=(...)

        Returns:
            weight: the loss weight, shape=(...)
        """
        return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2

    def sampling_timestep_for_training(self, macro_shape: tuple):
        rnd_normal = torch.randn(macro_shape)
        t = (self.P_mean + self.P_std * rnd_normal).exp()
        if self.sigma_multiply_by_sigma_data:
            t = t * self.sigma_data
        return t

__init__(n_discretization_steps=200, ndim_micro_shape=2, P_mean=-1.2, P_std=1.2, sigma_data=0.5, sigma_min=0.002, sigma_max=80.0, rho=7.0, use_2nd_order_correction=True, use_ode_flow=False, S_churn=0.0, S_min=0.0, S_max=float('inf'), S_noise=1.0, use_clip=False, clip_sample_range=1.0, use_dyn_thresholding=False, dynamic_thresholding_ratio=0.995, sample_max_value=1.0, sigma_multiply_by_sigma_data=False, *args, **kwargs)

Parameters:

Name Type Description Default
n_discretization_steps int

the number of discretization steps

200
ndim_micro_shape int

the number of dimensions of the micro shape

2
P_mean float

mean of the log-normal distribution for sampling sigma during training

-1.2
P_std float

standard deviation of the log-normal distribution for sampling sigma during training

1.2
sigma_data float

expected standard deviation of the training data

0.5
sigma_min float

minimum supported noise level

0.002
sigma_max float

maximum supported noise level

80.0
rho float

time step exponent for sampling schedule

7.0

Returns: None

Source code in src/ls_mlkit/diffusion/euclidean_edm_diffuser.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def __init__(
    self,
    n_discretization_steps: int = 200,
    ndim_micro_shape: int = 2,
    P_mean: float = -1.2,
    P_std: float = 1.2,
    sigma_data: float = 0.5,
    sigma_min: float = 0.002,
    sigma_max: float = 80.0,
    rho: float = 7.0,
    use_2nd_order_correction: bool = True,
    use_ode_flow: bool = False,
    S_churn: float = 0.0,
    S_min: float = 0.0,
    S_max: float = float("inf"),
    S_noise: float = 1.0,
    use_clip: bool = False,
    clip_sample_range: float = 1.0,
    use_dyn_thresholding: bool = False,
    dynamic_thresholding_ratio=0.995,
    sample_max_value: float = 1.0,
    sigma_multiply_by_sigma_data: bool = False,
    *args,
    **kwargs,
):
    r"""
    Args:
        n_discretization_steps: the number of discretization steps
        ndim_micro_shape: the number of dimensions of the micro shape
        P_mean: mean of the log-normal distribution for sampling sigma during training
        P_std: standard deviation of the log-normal distribution for sampling sigma during training
        sigma_data: expected standard deviation of the training data
        sigma_min: minimum supported noise level
        sigma_max: maximum supported noise level
        rho: time step exponent for sampling schedule
    Returns:
        None
    """
    super().__init__(
        n_discretization_steps=n_discretization_steps,
        ndim_micro_shape=ndim_micro_shape,
    )
    self.P_mean = P_mean
    self.P_std = P_std
    self.sigma_data = sigma_data
    self.sigma_min = sigma_min
    self.sigma_max = sigma_max
    self.rho = rho
    self.use_ode_flow = use_ode_flow
    self.use_2nd_order_correction = use_2nd_order_correction
    self.S_churn = S_churn
    self.S_min = S_min
    self.S_max = S_max
    self.S_noise = S_noise

    self.use_clip = use_clip
    self.clip_sample_range = clip_sample_range
    self.use_dyn_thresholding = use_dyn_thresholding
    self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
    self.sample_max_value = sample_max_value

    self.sigma_multiply_by_sigma_data = sigma_multiply_by_sigma_data

    step_indices = torch.arange(n_discretization_steps + 1, dtype=torch.float32)
    self.sigma_schedule: Tensor = (
        sigma_min ** (1 / rho)
        + (step_indices - 1) / (n_discretization_steps - 1) * (sigma_max ** (1 / rho) - sigma_min ** (1 / rho))
    ) ** rho
    self.sigma_schedule[0] = 0.0

timestep_index_to_sigma(timestep_index)

Convert discrete timesteps to sigma values.

Parameters:

Name Type Description Default
discrete_t

discrete timesteps, shape=(...)

required

Returns:

Name Type Description
sigma Tensor

noise levels, shape=(...)

Source code in src/ls_mlkit/diffusion/euclidean_edm_diffuser.py
112
113
114
115
116
117
118
119
120
121
122
def timestep_index_to_sigma(self, timestep_index: Tensor) -> Tensor:
    """Convert discrete timesteps to sigma values.

    Args:
        discrete_t: discrete timesteps, shape=(...)

    Returns:
        sigma: noise levels, shape=(...)
    """
    timestep_index = timestep_index.clamp(1, self.n_discretization_steps).long()
    return self.sigma_schedule[timestep_index].to(timestep_index.device)

compute_loss_weight(sigma)

Compute EDM loss weight: (sigma² + sigma_data²) / (sigma * sigma_data)².

Parameters:

Name Type Description Default
sigma Tensor

noise level, shape=(...)

required

Returns:

Name Type Description
weight Tensor

the loss weight, shape=(...)

Source code in src/ls_mlkit/diffusion/euclidean_edm_diffuser.py
124
125
126
127
128
129
130
131
132
133
def compute_loss_weight(self, sigma: Tensor) -> Tensor:
    """Compute EDM loss weight: (sigma² + sigma_data²) / (sigma * sigma_data)².

    Args:
        sigma: noise level, shape=(...)

    Returns:
        weight: the loss weight, shape=(...)
    """
    return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2

EuclideanEDMDiffuser

Bases: EuclideanDiffuser

Source code in src/ls_mlkit/diffusion/euclidean_edm_diffuser.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
@inherit_docstrings
class EuclideanEDMDiffuser(EuclideanDiffuser):
    def __init__(
        self,
        config: EuclideanEDMConfig,
        time_scheduler: DiffusionTimeScheduler,
        masker: MaskerInterface,
        model: Module,
        loss_fn: Callable[[Tensor, Tensor, Tensor], Tensor],  # (predicted, ground_true, padding_mask)
    ):
        super().__init__(config=config, time_scheduler=time_scheduler, masker=masker)
        self.config: EuclideanEDMConfig = config
        self.model = model
        self.loss_fn = loss_fn

    def prior_sampling(self, shape: Tuple[int, ...]) -> Tensor:
        return torch.randn(shape) * self.config.sigma_max

    def compute_loss(self, **batch) -> dict:
        """Compute the EDM loss.

        Args:
            **batch: batch dictionary containing:
                - gt_data: ground truth data x_0
                - padding_mask: padding mask

        Returns:
            dict: A dictionary containing the loss and other information
        """
        x_0 = batch["gt_data"]
        padding_mask = batch["padding_mask"]
        device = x_0.device

        macro_shape = self.get_macro_shape(x_0)  # (b, )
        macro_shape = self.hook_manager.run_hooks(
            stage=GMHookStageType.POST_GET_MACRO_SHAPE,
            tgt_key_name="macro_shape",
            macro_shape=macro_shape,
            batch=batch,
        )
        macro_shape = cast(tuple[int, ...], macro_shape)
        t = self.config.sampling_timestep_for_training(macro_shape=macro_shape).to(device)
        t = self.hook_manager.run_hooks(
            stage=GMHookStageType.POST_SAMPLING_TIME_STEP,
            tgt_key_name="t",
            t=t,
            batch=batch,
        )
        t = cast(Tensor, t)
        t = self.complete_micro_shape(t)

        # Forward process: add noise
        forward_result = self.forward_process(x_0, torch.zeros_like(t), t, padding_mask, is_continuous_time=True)
        x_t, noise, sigma_diff = (
            forward_result["x_t"],
            forward_result["noise"],
            forward_result["sigma_diff"],
        )
        sigma = sigma_diff
        batch["t"] = t
        batch["x_t"] = self.config.c_in(sigma) * x_t
        batch["gm_kwargs"] = {"c_in": self.config.c_in(sigma)}

        with TemporaryKeyRemover(mapping=batch, keys=["gt_data"]):
            model_output = self.model(**batch)

        # Compute EDM loss
        p_raw = model_output["x"]
        D_yn = self._compute_denoised(x_t, p_raw, sigma)

        # EDM loss weight: lambda(sigma) = (sigma^2 + sigma_data^2) / (sigma * sigma_data)^2
        weight = self.config.compute_loss_weight(sigma)
        sqrt_weight = weight.sqrt()
        loss = self.loss_fn(sqrt_weight * D_yn, sqrt_weight * x_0, padding_mask)
        p_x_0 = D_yn

        return {
            "loss": loss,
            "gt_data": x_0,
            "t": t,
            "sigma": sigma,
            "x_t": x_t,
            "noise": noise,
            "p_raw": p_raw,
            "p_x_0": p_x_0,
            "padding_mask": padding_mask,
            "loss_fn": self.loss_fn,
            "config": self.config,
            "base_model_output": model_output,
        }

    def forward_process(
        self,
        x_0: Tensor,
        t_a: Tensor,
        t_b: Tensor,
        mask: Tensor,
        is_continuous_time: bool = True,
        *args: Any,
        **kwargs: Any,
    ) -> dict:
        assert (t_b >= t_a).all()
        sigma_a = self.config.sigma(t_a, is_continuous_time)
        sigma_b = self.config.sigma(t_b, is_continuous_time)
        sigma_diff = (sigma_b**2 - sigma_a**2).clamp(min=0).sqrt()
        noise = torch.randn_like(x_0)
        x_t = x_0 + sigma_diff * noise
        return {"x_t": x_t, "noise": noise, "sigma_diff": sigma_diff}

    def _compute_denoised(self, x: Tensor, F_x: Tensor, sigma_expanded: Tensor) -> Tensor:
        """Compute denoised prediction using EDM preconditioning.

        Args:
            x: noisy input
            F_x: raw network output
            sigma_expanded: sigma value expanded to micro shape

        Returns:
            Denoised prediction D_x = c_skip * x + c_out * F_x
        """

        return self.config.c_skip(sigma_expanded) * x + self.config.c_out(sigma_expanded) * F_x

    def step(
        self,
        x_t: Tensor,
        t: Tensor,
        padding_mask: Optional[Tensor] = None,
        *args: Any,
        **kwargs: Any,
    ) -> dict:
        r"""EDM sampling step (Euler or Heun's method).

        Args:
            x_t: the sample at timestep t
            t: the timestep (all elements must be the same)
            padding_mask: the padding mask

        Returns:
            dict:
                - x: the sample at timestep t-1
                - E_x0_xt: the predicted original sample
        """
        assert torch.all(t == t.view(-1)[0]).item(), "All timesteps in batch must be the same for EDM step"
        assert t.ndim == x_t.ndim, "Timestep and sample must have the same number of dimensions"
        config = cast(EuclideanEDMConfig, self.config.to(t))
        t = t.long()
        t_next = t - 1
        is_final_step = (t_next == 0).all()
        use_heun = not is_final_step and self.config.use_2nd_order_correction

        # Get sigma values and preconditioning coefficients with batch dimension
        sigma_cur = config.sigma(t, is_continuous_time=False)

        if not self.config.use_ode_flow:
            episilon = self.config.S_noise * torch.randn_like(x_t)
            gamma = (
                min(
                    self.config.S_churn / self.config.n_discretization_steps,
                    math.sqrt(2) - 1,
                )
                if ((self.config.S_min <= sigma_cur).all() and (sigma_cur <= self.config.S_max).all())
                else 0.0
            )
            sigma_cur_hat = sigma_cur + gamma * sigma_cur
            x_t = x_t + torch.sqrt(sigma_cur_hat**2 - sigma_cur**2) * episilon

        # p_x_0 prediction
        c_in_cur = config.c_in(sigma_cur)
        scaled_x_t = c_in_cur * x_t
        batch_dict = {
            "x_t": scaled_x_t,
            "t": sigma_cur,
            "padding_mask": padding_mask,
            **kwargs,
            "gm_kwargs": {"c_in": c_in_cur},
        }
        F_x = self.model(**batch_dict)["x"]
        p_x_0 = self._compute_denoised(x_t, F_x, sigma_cur)

        # Clip predicted x_0 (following standard DDPM implementation)
        # 3. Clip or threshold "predicted x_0"
        if self.config.use_dyn_thresholding:
            p_x_0 = self._threshold_sample(p_x_0)
        elif self.config.use_clip:
            p_x_0 = p_x_0.clamp(-self.config.clip_sample_range, self.config.clip_sample_range)

        # Run PRE_UPDATE_IN_STEP_FN hooks for conditional sampling
        hook_input = {
            "x_t": x_t,
            "t": sigma_cur,
            "p_x_0": p_x_0,
            "p_raw": F_x,
            "padding_mask": padding_mask,
            **kwargs,
        }
        hook_output = self.hook_manager.run_hooks(
            GMHookStageType.PRE_UPDATE_IN_STEP_FN,
            tgt_key_name="p_x_0",
            **hook_input,
        )
        if hook_output is not None:
            p_x_0 = hook_output

        # Final step: return denoised directly
        if is_final_step:
            return {"x": p_x_0, "E_x0_xt": p_x_0}

        # Euler step
        sigma_next = config.sigma(t_next, is_continuous_time=False)
        d_cur = (x_t - p_x_0) / sigma_cur.clamp(min=1e-8)
        delta_sigma = sigma_next - sigma_cur
        x_next = x_t + delta_sigma * d_cur

        # Apply Heun's 2nd order correction
        if use_heun:
            c_in_next = config.c_in(sigma_next)
            scaled_x_next = c_in_next * x_next
            batch_dict_next = {
                "x_t": scaled_x_next,  # Apply c_in scaling to match training
                "t": sigma_next,
                "padding_mask": padding_mask,
                **kwargs,
                "gm_kwargs": {"c_in": c_in_next},
            }
            F_x_next = self.model(**batch_dict_next)["x"]
            p_x_0_next = self._compute_denoised(x_next, F_x_next, sigma_next)

            hook_input = {
                "x_t": x_next,
                "t": sigma_next,
                "p_x_0": p_x_0_next,
                "p_raw": F_x_next,
                "padding_mask": padding_mask,
                **kwargs,
            }
            hook_output = self.hook_manager.run_hooks(
                GMHookStageType.PRE_UPDATE_IN_STEP_FN,
                tgt_key_name="p_x_0",
                **hook_input,
            )
            if hook_output is not None:
                p_x_0_next = hook_output
            d_prime = (x_next - p_x_0_next) / sigma_next.clamp(min=1e-8)
            x_next = x_t + 0.5 * (d_cur + d_prime) * delta_sigma

        return {"x": x_next, "E_x0_xt": p_x_0}

    def get_posterior_mean_fn(self, score: Optional[Tensor] = None, score_fn: Optional[Callable] = None):
        r"""Get the posterior mean function for EDM.

        For EDM, the posterior mean is:
        .. math::
            E[x_0|x_t] = D_\theta(x_t, \sigma_t)

        where D_\theta is the denoised prediction.

        Args:
            score (Tensor, optional): the score of the sample
            score_fn (Callable, optional): the function to compute score

        Returns:
            Callable: the posterior mean function
        """

        def _edm_posterior_mean_fn(
            x_t: Tensor,
            t: Tensor,
            padding_mask: Tensor,
            is_continuous_time: bool = True,
        ):
            r"""
            Args:
                x_t: shape=(..., n_nodes, 3)
                t: shape=(...), dtype=torch.long

            For EDM, the posterior mean is the denoised prediction D_\theta(x_t, \sigma_t).
            """
            # TODO: get x0 by score function
            nonlocal score, score_fn
            sigma = self.config.sigma(t, is_continuous_time=True)
            c_in = self.config.c_in(sigma)
            batch_dict = {
                "x_t": c_in * x_t,
                "t": t,
                "sigma": sigma,
                "padding_mask": padding_mask,
                "gm_kwargs": {"c_in": c_in},
            }
            F_x = self.model(**batch_dict)["x"]
            return self._compute_denoised(x_t, F_x, sigma)

        return _edm_posterior_mean_fn

    def _compute_edm_score(self, x_t: Tensor, x_0: Tensor, sigma: Tensor) -> Tensor:
        """Compute EDM score function: -(x_t - x_0) / sigma².

        Args:
            x_t: noisy sample at time t
            x_0: clean sample (predicted or ground truth)
            sigma: noise level

        Returns:
            score: the score function value
        """
        sigma_squared = (sigma**2).clamp(min=1e-8)
        return -(x_t - x_0) / sigma_squared

    def _setup_conditioners(
        self,
        conditioner_list: list[Conditioner],
        *,
        train: bool,
        tgt_mask: Tensor,
        padding_mask: Tensor,
        p_uc_score: Tensor,
        gt_data: Optional[Tensor] = None,
        sampling_condition: Optional[Tensor] = None,
    ) -> None:
        """Setup conditioners with common parameters.

        Args:
            conditioner_list: list of conditioners to setup
            train: whether in training mode
            tgt_mask: target mask
            padding_mask: padding mask
            p_uc_score: unconditional predicted score
            gt_data: ground truth data (for training)
            sampling_condition: sampling condition (for inference)
        """
        posterior_mean_fn = self.get_posterior_mean_fn(score=p_uc_score, score_fn=None)

        for conditioner in conditioner_list:
            if not conditioner.is_enabled():
                continue

            if train:
                condition_dict = conditioner.prepare_condition_dict(
                    train=True,
                    tgt_mask=tgt_mask,
                    gt_data=gt_data,
                    padding_mask=padding_mask,
                    posterior_mean_fn=posterior_mean_fn,
                )
            else:
                condition_dict = conditioner.prepare_condition_dict(
                    train=False,
                    tgt_mask=tgt_mask,
                    sampling_condition=sampling_condition,
                    padding_mask=padding_mask,
                    posterior_mean_fn=posterior_mean_fn,
                )
            conditioner.set_condition(**condition_dict)

    def get_condition_post_compute_loss_hook(self, conditioner_list: list[Conditioner]):
        """Get hook for conditioning after loss computation (training).

        This hook modifies the loss to include conditional guidance during training.
        It computes the conditional score and updates the loss accordingly.

        Args:
            conditioner_list: list of conditioners

        Returns:
            GMHook: the hook for POST_COMPUTE_LOSS stage
        """

        def _hook_fn(**kwargs):
            x_0 = kwargs["gt_data"]
            x_t = kwargs["x_t"]
            t = kwargs["t"]
            padding_mask = kwargs["padding_mask"]
            loss_fn = kwargs["loss_fn"]

            # Use p_x_0 if available, otherwise compute from raw output
            p_x_0 = kwargs.get("p_x_0")

            # Compute scores
            sigma = self.config.sigma(t, is_continuous_time=True)
            p_x_0 = cast(Tensor, p_x_0)
            p_uc_score = self._compute_edm_score(x_t, p_x_0, sigma)
            gt_uc_score = self._compute_edm_score(x_t, x_0, sigma)

            # Setup conditioners and get accumulated conditional score
            self._setup_conditioners(
                conditioner_list,
                train=True,
                tgt_mask=padding_mask,
                padding_mask=padding_mask,
                p_uc_score=p_uc_score,
                gt_data=x_0,
            )
            acc_c_score = get_accumulated_conditional_score(
                conditioner_list, x_t, t, padding_mask, is_continuous_time=True
            )

            # Compute conditioned loss with EDM weighting
            gt_score = gt_uc_score + acc_c_score
            gt_x_0 = x_t + sigma**2 * gt_score
            weight = self.config.compute_loss_weight(sigma)
            sqrt_weight = weight.sqrt()
            kwargs["loss"] = loss_fn(sqrt_weight * gt_x_0, sqrt_weight * p_x_0, padding_mask)
            return kwargs

        return GMHook(
            name="EDM_condition_post_compute_loss_hook",
            stage=GMHookStageType.POST_COMPUTE_LOSS,
            fn=_hook_fn,
            priority=0,
            enabled=True,
        )

    def get_condition_pre_update_in_step_fn_hook(self, conditioner_list: list[Conditioner]):
        """Get hook for conditioning before update in step function (sampling).

        This hook applies conditional guidance during sampling by modifying
        the predicted denoised sample based on the conditional score.

        Args:
            conditioner_list: list of conditioners

        Returns:
            GMHook: the hook for PRE_UPDATE_IN_STEP_FN stage
        """

        def _hook_fn(**kwargs):
            x_t = kwargs["x_t"]
            t = kwargs["t"]
            padding_mask = kwargs["padding_mask"]
            sampling_condition = kwargs.get("sampling_condition")

            # Use p_x_0 if available, otherwise compute from raw output
            p_x_0 = kwargs.get("p_x_0")
            # Compute unconditional score
            sigma = self.config.sigma(t, is_continuous_time=True)
            p_x_0 = cast(Tensor, p_x_0)
            p_uc_score = self._compute_edm_score(x_t, p_x_0, sigma)

            # Setup conditioners and get accumulated conditional score
            self._setup_conditioners(
                conditioner_list,
                train=False,
                tgt_mask=padding_mask,
                padding_mask=padding_mask,
                p_uc_score=p_uc_score,
                sampling_condition=sampling_condition,
            )
            acc_c_score = get_accumulated_conditional_score(
                conditioner_list, x_t, t, padding_mask, is_continuous_time=True
            )

            # Compute conditioned denoised prediction: x_0 = x_t + sigma² * score
            # From: score = -(x_t - x_0) / sigma² => x_0 = x_t + sigma² * score
            sigma_squared = sigma**2
            p_c_x_0 = x_t + sigma_squared * (p_uc_score + acc_c_score)

            # Return p_c_x_0 directly (hook manager expects target value when tgt_key_name is set)
            return p_c_x_0

        return GMHook(
            name="EDM_condition_pre_update_in_step_fn_hook",
            stage=GMHookStageType.PRE_UPDATE_IN_STEP_FN,
            fn=_hook_fn,
            priority=0,
            enabled=True,
        )

compute_loss(**batch)

Compute the EDM loss.

Parameters:

Name Type Description Default
**batch

batch dictionary containing: - gt_data: ground truth data x_0 - padding_mask: padding mask

{}

Returns:

Name Type Description
dict dict

A dictionary containing the loss and other information

Source code in src/ls_mlkit/diffusion/euclidean_edm_diffuser.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def compute_loss(self, **batch) -> dict:
    """Compute the EDM loss.

    Args:
        **batch: batch dictionary containing:
            - gt_data: ground truth data x_0
            - padding_mask: padding mask

    Returns:
        dict: A dictionary containing the loss and other information
    """
    x_0 = batch["gt_data"]
    padding_mask = batch["padding_mask"]
    device = x_0.device

    macro_shape = self.get_macro_shape(x_0)  # (b, )
    macro_shape = self.hook_manager.run_hooks(
        stage=GMHookStageType.POST_GET_MACRO_SHAPE,
        tgt_key_name="macro_shape",
        macro_shape=macro_shape,
        batch=batch,
    )
    macro_shape = cast(tuple[int, ...], macro_shape)
    t = self.config.sampling_timestep_for_training(macro_shape=macro_shape).to(device)
    t = self.hook_manager.run_hooks(
        stage=GMHookStageType.POST_SAMPLING_TIME_STEP,
        tgt_key_name="t",
        t=t,
        batch=batch,
    )
    t = cast(Tensor, t)
    t = self.complete_micro_shape(t)

    # Forward process: add noise
    forward_result = self.forward_process(x_0, torch.zeros_like(t), t, padding_mask, is_continuous_time=True)
    x_t, noise, sigma_diff = (
        forward_result["x_t"],
        forward_result["noise"],
        forward_result["sigma_diff"],
    )
    sigma = sigma_diff
    batch["t"] = t
    batch["x_t"] = self.config.c_in(sigma) * x_t
    batch["gm_kwargs"] = {"c_in": self.config.c_in(sigma)}

    with TemporaryKeyRemover(mapping=batch, keys=["gt_data"]):
        model_output = self.model(**batch)

    # Compute EDM loss
    p_raw = model_output["x"]
    D_yn = self._compute_denoised(x_t, p_raw, sigma)

    # EDM loss weight: lambda(sigma) = (sigma^2 + sigma_data^2) / (sigma * sigma_data)^2
    weight = self.config.compute_loss_weight(sigma)
    sqrt_weight = weight.sqrt()
    loss = self.loss_fn(sqrt_weight * D_yn, sqrt_weight * x_0, padding_mask)
    p_x_0 = D_yn

    return {
        "loss": loss,
        "gt_data": x_0,
        "t": t,
        "sigma": sigma,
        "x_t": x_t,
        "noise": noise,
        "p_raw": p_raw,
        "p_x_0": p_x_0,
        "padding_mask": padding_mask,
        "loss_fn": self.loss_fn,
        "config": self.config,
        "base_model_output": model_output,
    }

step(x_t, t, padding_mask=None, *args, **kwargs)

EDM sampling step (Euler or Heun's method).

Parameters:

Name Type Description Default
x_t Tensor

the sample at timestep t

required
t Tensor

the timestep (all elements must be the same)

required
padding_mask Optional[Tensor]

the padding mask

None

Returns:

Name Type Description
dict dict
  • x: the sample at timestep t-1
  • E_x0_xt: the predicted original sample
Source code in src/ls_mlkit/diffusion/euclidean_edm_diffuser.py
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
def step(
    self,
    x_t: Tensor,
    t: Tensor,
    padding_mask: Optional[Tensor] = None,
    *args: Any,
    **kwargs: Any,
) -> dict:
    r"""EDM sampling step (Euler or Heun's method).

    Args:
        x_t: the sample at timestep t
        t: the timestep (all elements must be the same)
        padding_mask: the padding mask

    Returns:
        dict:
            - x: the sample at timestep t-1
            - E_x0_xt: the predicted original sample
    """
    assert torch.all(t == t.view(-1)[0]).item(), "All timesteps in batch must be the same for EDM step"
    assert t.ndim == x_t.ndim, "Timestep and sample must have the same number of dimensions"
    config = cast(EuclideanEDMConfig, self.config.to(t))
    t = t.long()
    t_next = t - 1
    is_final_step = (t_next == 0).all()
    use_heun = not is_final_step and self.config.use_2nd_order_correction

    # Get sigma values and preconditioning coefficients with batch dimension
    sigma_cur = config.sigma(t, is_continuous_time=False)

    if not self.config.use_ode_flow:
        episilon = self.config.S_noise * torch.randn_like(x_t)
        gamma = (
            min(
                self.config.S_churn / self.config.n_discretization_steps,
                math.sqrt(2) - 1,
            )
            if ((self.config.S_min <= sigma_cur).all() and (sigma_cur <= self.config.S_max).all())
            else 0.0
        )
        sigma_cur_hat = sigma_cur + gamma * sigma_cur
        x_t = x_t + torch.sqrt(sigma_cur_hat**2 - sigma_cur**2) * episilon

    # p_x_0 prediction
    c_in_cur = config.c_in(sigma_cur)
    scaled_x_t = c_in_cur * x_t
    batch_dict = {
        "x_t": scaled_x_t,
        "t": sigma_cur,
        "padding_mask": padding_mask,
        **kwargs,
        "gm_kwargs": {"c_in": c_in_cur},
    }
    F_x = self.model(**batch_dict)["x"]
    p_x_0 = self._compute_denoised(x_t, F_x, sigma_cur)

    # Clip predicted x_0 (following standard DDPM implementation)
    # 3. Clip or threshold "predicted x_0"
    if self.config.use_dyn_thresholding:
        p_x_0 = self._threshold_sample(p_x_0)
    elif self.config.use_clip:
        p_x_0 = p_x_0.clamp(-self.config.clip_sample_range, self.config.clip_sample_range)

    # Run PRE_UPDATE_IN_STEP_FN hooks for conditional sampling
    hook_input = {
        "x_t": x_t,
        "t": sigma_cur,
        "p_x_0": p_x_0,
        "p_raw": F_x,
        "padding_mask": padding_mask,
        **kwargs,
    }
    hook_output = self.hook_manager.run_hooks(
        GMHookStageType.PRE_UPDATE_IN_STEP_FN,
        tgt_key_name="p_x_0",
        **hook_input,
    )
    if hook_output is not None:
        p_x_0 = hook_output

    # Final step: return denoised directly
    if is_final_step:
        return {"x": p_x_0, "E_x0_xt": p_x_0}

    # Euler step
    sigma_next = config.sigma(t_next, is_continuous_time=False)
    d_cur = (x_t - p_x_0) / sigma_cur.clamp(min=1e-8)
    delta_sigma = sigma_next - sigma_cur
    x_next = x_t + delta_sigma * d_cur

    # Apply Heun's 2nd order correction
    if use_heun:
        c_in_next = config.c_in(sigma_next)
        scaled_x_next = c_in_next * x_next
        batch_dict_next = {
            "x_t": scaled_x_next,  # Apply c_in scaling to match training
            "t": sigma_next,
            "padding_mask": padding_mask,
            **kwargs,
            "gm_kwargs": {"c_in": c_in_next},
        }
        F_x_next = self.model(**batch_dict_next)["x"]
        p_x_0_next = self._compute_denoised(x_next, F_x_next, sigma_next)

        hook_input = {
            "x_t": x_next,
            "t": sigma_next,
            "p_x_0": p_x_0_next,
            "p_raw": F_x_next,
            "padding_mask": padding_mask,
            **kwargs,
        }
        hook_output = self.hook_manager.run_hooks(
            GMHookStageType.PRE_UPDATE_IN_STEP_FN,
            tgt_key_name="p_x_0",
            **hook_input,
        )
        if hook_output is not None:
            p_x_0_next = hook_output
        d_prime = (x_next - p_x_0_next) / sigma_next.clamp(min=1e-8)
        x_next = x_t + 0.5 * (d_cur + d_prime) * delta_sigma

    return {"x": x_next, "E_x0_xt": p_x_0}

get_posterior_mean_fn(score=None, score_fn=None)

Get the posterior mean function for EDM.

For EDM, the posterior mean is: .. math:: E[x_0|x_t] = D_\theta(x_t, \sigma_t)

where D_\theta is the denoised prediction.

Parameters:

Name Type Description Default
score Tensor

the score of the sample

None
score_fn Callable

the function to compute score

None

Returns:

Name Type Description
Callable

the posterior mean function

Source code in src/ls_mlkit/diffusion/euclidean_edm_diffuser.py
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
def get_posterior_mean_fn(self, score: Optional[Tensor] = None, score_fn: Optional[Callable] = None):
    r"""Get the posterior mean function for EDM.

    For EDM, the posterior mean is:
    .. math::
        E[x_0|x_t] = D_\theta(x_t, \sigma_t)

    where D_\theta is the denoised prediction.

    Args:
        score (Tensor, optional): the score of the sample
        score_fn (Callable, optional): the function to compute score

    Returns:
        Callable: the posterior mean function
    """

    def _edm_posterior_mean_fn(
        x_t: Tensor,
        t: Tensor,
        padding_mask: Tensor,
        is_continuous_time: bool = True,
    ):
        r"""
        Args:
            x_t: shape=(..., n_nodes, 3)
            t: shape=(...), dtype=torch.long

        For EDM, the posterior mean is the denoised prediction D_\theta(x_t, \sigma_t).
        """
        # TODO: get x0 by score function
        nonlocal score, score_fn
        sigma = self.config.sigma(t, is_continuous_time=True)
        c_in = self.config.c_in(sigma)
        batch_dict = {
            "x_t": c_in * x_t,
            "t": t,
            "sigma": sigma,
            "padding_mask": padding_mask,
            "gm_kwargs": {"c_in": c_in},
        }
        F_x = self.model(**batch_dict)["x"]
        return self._compute_denoised(x_t, F_x, sigma)

    return _edm_posterior_mean_fn

get_condition_post_compute_loss_hook(conditioner_list)

Get hook for conditioning after loss computation (training).

This hook modifies the loss to include conditional guidance during training. It computes the conditional score and updates the loss accordingly.

Parameters:

Name Type Description Default
conditioner_list list[Conditioner]

list of conditioners

required

Returns:

Name Type Description
GMHook

the hook for POST_COMPUTE_LOSS stage

Source code in src/ls_mlkit/diffusion/euclidean_edm_diffuser.py
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
def get_condition_post_compute_loss_hook(self, conditioner_list: list[Conditioner]):
    """Get hook for conditioning after loss computation (training).

    This hook modifies the loss to include conditional guidance during training.
    It computes the conditional score and updates the loss accordingly.

    Args:
        conditioner_list: list of conditioners

    Returns:
        GMHook: the hook for POST_COMPUTE_LOSS stage
    """

    def _hook_fn(**kwargs):
        x_0 = kwargs["gt_data"]
        x_t = kwargs["x_t"]
        t = kwargs["t"]
        padding_mask = kwargs["padding_mask"]
        loss_fn = kwargs["loss_fn"]

        # Use p_x_0 if available, otherwise compute from raw output
        p_x_0 = kwargs.get("p_x_0")

        # Compute scores
        sigma = self.config.sigma(t, is_continuous_time=True)
        p_x_0 = cast(Tensor, p_x_0)
        p_uc_score = self._compute_edm_score(x_t, p_x_0, sigma)
        gt_uc_score = self._compute_edm_score(x_t, x_0, sigma)

        # Setup conditioners and get accumulated conditional score
        self._setup_conditioners(
            conditioner_list,
            train=True,
            tgt_mask=padding_mask,
            padding_mask=padding_mask,
            p_uc_score=p_uc_score,
            gt_data=x_0,
        )
        acc_c_score = get_accumulated_conditional_score(
            conditioner_list, x_t, t, padding_mask, is_continuous_time=True
        )

        # Compute conditioned loss with EDM weighting
        gt_score = gt_uc_score + acc_c_score
        gt_x_0 = x_t + sigma**2 * gt_score
        weight = self.config.compute_loss_weight(sigma)
        sqrt_weight = weight.sqrt()
        kwargs["loss"] = loss_fn(sqrt_weight * gt_x_0, sqrt_weight * p_x_0, padding_mask)
        return kwargs

    return GMHook(
        name="EDM_condition_post_compute_loss_hook",
        stage=GMHookStageType.POST_COMPUTE_LOSS,
        fn=_hook_fn,
        priority=0,
        enabled=True,
    )

get_condition_pre_update_in_step_fn_hook(conditioner_list)

Get hook for conditioning before update in step function (sampling).

This hook applies conditional guidance during sampling by modifying the predicted denoised sample based on the conditional score.

Parameters:

Name Type Description Default
conditioner_list list[Conditioner]

list of conditioners

required

Returns:

Name Type Description
GMHook

the hook for PRE_UPDATE_IN_STEP_FN stage

Source code in src/ls_mlkit/diffusion/euclidean_edm_diffuser.py
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
def get_condition_pre_update_in_step_fn_hook(self, conditioner_list: list[Conditioner]):
    """Get hook for conditioning before update in step function (sampling).

    This hook applies conditional guidance during sampling by modifying
    the predicted denoised sample based on the conditional score.

    Args:
        conditioner_list: list of conditioners

    Returns:
        GMHook: the hook for PRE_UPDATE_IN_STEP_FN stage
    """

    def _hook_fn(**kwargs):
        x_t = kwargs["x_t"]
        t = kwargs["t"]
        padding_mask = kwargs["padding_mask"]
        sampling_condition = kwargs.get("sampling_condition")

        # Use p_x_0 if available, otherwise compute from raw output
        p_x_0 = kwargs.get("p_x_0")
        # Compute unconditional score
        sigma = self.config.sigma(t, is_continuous_time=True)
        p_x_0 = cast(Tensor, p_x_0)
        p_uc_score = self._compute_edm_score(x_t, p_x_0, sigma)

        # Setup conditioners and get accumulated conditional score
        self._setup_conditioners(
            conditioner_list,
            train=False,
            tgt_mask=padding_mask,
            padding_mask=padding_mask,
            p_uc_score=p_uc_score,
            sampling_condition=sampling_condition,
        )
        acc_c_score = get_accumulated_conditional_score(
            conditioner_list, x_t, t, padding_mask, is_continuous_time=True
        )

        # Compute conditioned denoised prediction: x_0 = x_t + sigma² * score
        # From: score = -(x_t - x_0) / sigma² => x_0 = x_t + sigma² * score
        sigma_squared = sigma**2
        p_c_x_0 = x_t + sigma_squared * (p_uc_score + acc_c_score)

        # Return p_c_x_0 directly (hook manager expects target value when tgt_key_name is set)
        return p_c_x_0

    return GMHook(
        name="EDM_condition_pre_update_in_step_fn_hook",
        stage=GMHookStageType.PRE_UPDATE_IN_STEP_FN,
        fn=_hook_fn,
        priority=0,
        enabled=True,
    )

EuclideanVPSDEDiffuser

Bases: EuclideanDiffuser

Source code in src/ls_mlkit/diffusion/euclidean_vpsde_diffuser.py
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
@inherit_docstrings
class EuclideanVPSDEDiffuser(EuclideanDiffuser):
    def __init__(
        self,
        config: EuclideanVPSDEConfig,
        time_scheduler: DiffusionTimeScheduler,
        masker: MaskerInterface,
        model: Module,
        loss_fn: Callable[[Tensor, Tensor, Tensor], Tensor],  # (predicted, ground_true, padding_mask)
    ):
        """Initialize the EuclideanVPSDEDiffuser

        Args:
            config (EuclideanVPSDEConfig): the config of the diffuser
            time_scheduler (DiffusionTimeScheduler): the time scheduler of the diffuser
            masker (MaskerInterface): the masker of the diffuser
            model (Module): the model of the diffuser
            loss_fn (Callable[[Tensor, Tensor, Tensor], Tensor]): the loss function of the diffuser

        Returns:
            None
        """
        super().__init__(config=config, time_scheduler=time_scheduler, masker=masker)
        self.config: EuclideanVPSDEConfig = config
        self.sde = config.sde
        self.model = model
        self.loss_fn = loss_fn

        def score_fn(x: Tensor, t: Tensor, mask: Tensor) -> Tensor:
            return self.model(**{"x_t": x, "t": t.long(), "padding_mask": mask})["x"]

        self.corrector = LangevinCorrector(
            sde=self.sde,
            score_fn=score_fn,
            snr=self.config.snr,
            n_steps=self.config.n_correct_steps,
            ndim_micro_shape=self.config.ndim_micro_shape,
        )

    def prior_sampling(self, shape: Tuple[int, ...]) -> Tensor:
        return self.sde.prior_sampling(shape)

    def forward_process(
        self,
        x_0: Tensor,
        discrete_t: Tensor,
        mask: Tensor,
        *args: Any,
        **kwargs: Any,
    ) -> dict:
        t = self.time_scheduler.timestep_index_to_continuous_time(discrete_t)
        forward_result = self.sde.forward_process(x_0, t, mask)
        return {
            "x_t": forward_result["x_t"],
            "mean": forward_result["mean"],
            "std": forward_result["std"],
            "a": forward_result["a"],
            "b": forward_result["b"],
        }

    def compute_loss(
        self, batch: dict[str, Any], *args: Any, **kwargs: Any
    ) -> dict:  # ty: ignore[invalid-method-override]
        x_0 = batch["gt_data"]
        padding_mask = batch["padding_mask"]
        device = x_0.device
        macro_shape = self.get_macro_shape(x_0)

        t = batch.get("t", None)
        if t is None:
            t = self.time_scheduler.sample_timestep_index_uniformly(macro_shape).to(device)
        self.config = cast(EuclideanVPSDEConfig, self.config.to(t))

        forward_result = self.forward_process(x_0, t, padding_mask)
        x_t = forward_result["x_t"]
        mean = forward_result["mean"]
        std = forward_result["std"]
        a = forward_result["a"]
        b = forward_result["b"]
        gt_uc_score = self.sde.get_score(x_t=x_t, mean=mean, std=std)

        batch["x_t"] = x_t
        batch["t"] = t
        with TemporaryKeyRemover(mapping=batch, keys=["gt_data"]):
            model_output = self.model(**batch)
        p_uc_score = model_output["x"]

        gt_uc_score = b * gt_uc_score
        p_uc_score = b * p_uc_score

        loss = self.loss_fn(p_uc_score, gt_uc_score, padding_mask)

        return {
            "loss": loss,
            "gt_data": x_0,
            "t": t,
            "x_t": x_t,
            "padding_mask": padding_mask,
            "gt_uc_score": gt_uc_score,
            "p_uc_score": p_uc_score,
            "a": a,
            "b": b,
            "loss_fn": self.loss_fn,
            "config": self.config,
        }

    def forward_process_n_step(
        self,
        x: Tensor,
        t: Tensor,
        next_t: Tensor,
        padding_mask: Tensor,
        *args: Any,
        **kwargs: Any,
    ) -> Tensor:
        assert (next_t > t).all()
        assert (t >= 0).all()
        assert (next_t < self.config.n_discretization_steps).all()

        continuous_t1 = self.time_scheduler.timestep_index_to_continuous_time(t)
        continuous_t2 = self.time_scheduler.timestep_index_to_continuous_time(next_t)
        x_t2 = self.sde.forward_from_t1_to_t2(x, continuous_t1, continuous_t2)
        return x_t2

    def step(
        self,
        x_t: Tensor,
        t: Tensor,
        padding_mask: Tensor | None = None,
        *args: Any,
        **kwargs: Any,
    ) -> dict:
        r"""
        Args:
            x_t (Tensor): the sample at timestep t
            t (Tensor): the timestep
            padding_mask (Tensor): the padding mask

        Returns:
            Tensor: the sample at timestep t-1
        """
        assert torch.all(t == t.view(-1)[0]).item()
        device = x_t.device
        idx = require(kwargs.get("idx"), "idx")
        schedule = self.time_scheduler.get_continuous_timesteps_schedule().to(device)
        ones = torch.ones_like(t)
        t_start = schedule[int(idx)] * ones
        t_end = schedule[int(idx) + 1] * ones
        config = cast(EuclideanVPSDEConfig, self.config.to(device))
        model_output = self.model(
            **{"x_t": x_t, "t": t.long(), "padding_mask": padding_mask, **kwargs}
        )
        p_uc_score = model_output["x"]

        # score hook start=====================================================
        hook_input = {
            "p_uc_score": p_uc_score,
            "x_t": x_t,
            "t": t,
            "padding_mask": padding_mask,
            "config": config,
            "sampling_condition": kwargs.get("sampling_condition"),
        }
        hook_output = self.hook_manager.run_hooks(
            GMHookStageType.PRE_UPDATE_IN_STEP_FN,
            tgt_key_name="p_uc_score",
            **hook_input,
        )
        if hook_output is not None:
            p_uc_score = hook_output

        # score hook start end =================================================================

        rsde = self.sde.get_reverse_sde(
            score=p_uc_score,
            score_fn=None,
            use_probability_flow=self.config.use_probability_flow,
        )
        delta_t = t_end - t_start
        delta_t = self.complete_micro_shape(delta_t)
        f, g = rsde.get_drift_and_diffusion(x_t, t_start, mask=padding_mask)
        g = self.complete_micro_shape(g)
        z = torch.randn_like(x_t)
        x_mean = x_t + f * delta_t
        if (t > 0).all():
            x = x_mean + g * z * torch.sqrt(delta_t.abs())
        else:
            x = x_mean

        if (t > 0).all():
            x, _ = self.corrector.update_fn(x, t - 1, padding_mask)

        return {
            "x": x,
        }

    def get_posterior_mean_fn(
        self,
        score: Tensor | None = None,
        score_fn: Callable[[Tensor, Tensor, Tensor | None], Tensor] | None = None,
    ):
        r"""Get the posterior mean function

        Args:
            score (Tensor, optional): the score of the sample
            score_fn (Callable, optional): the function to compute score

        Returns:
            Callable: the posterior mean function
        """

        def _posterior_mean_fn(
            x_t: Tensor,
            t: Tensor,
            padding_mask: Tensor,
        ):
            r"""
            Args:
                x_t: shape=(..., n_nodes, 3)
                t: shape=(...), dtype=torch.long

            For the case of VPSDE sampling, the posterior mean is given by

            .. math::

                E[x_0|x_t] = \frac{b^2}{a} \nabla_{x_t}\log p_t(x_t) - \frac{x_t}{a}

            """
            nonlocal score, score_fn
            assert score is not None or score_fn is not None, "either score or score_fn must be provided"
            if score is None:
                assert score_fn is not None
                score = score_fn(x_t, t, padding_mask)
            sde = cast(EuclideanVPSDEConfig, self.config.to(t)).sde
            t = self.time_scheduler.timestep_index_to_continuous_time(t)
            a, b = sde.get_a_b(t)
            E_x0_xt = b**2 / a * score + x_t / a
            return E_x0_xt

        return _posterior_mean_fn

    def get_condition_post_compute_loss_hook(self, conditioner_list: list[Conditioner]):
        def _hook_fn(**kwargs: Any):
            nonlocal conditioner_list

            kwargs.get("loss")
            x_0 = require(cast(Tensor | None, kwargs.get("gt_data")), "gt_data")
            x_t = require(cast(Tensor | None, kwargs.get("x_t")), "x_t")
            t = require(cast(Tensor | None, kwargs.get("t")), "t")
            padding_mask = require(cast(Tensor | None, kwargs.get("padding_mask")), "padding_mask")
            loss_fn = require(cast(Callable[..., Any] | None, kwargs.get("loss_fn")), "loss_fn")
            kwargs.get("config")
            p_uc_score = require(cast(Tensor | None, kwargs.get("p_uc_score")), "p_uc_score")
            gt_uc_score = require(cast(Tensor | None, kwargs.get("gt_uc_score")), "gt_uc_score")
            kwargs.get("a")
            b = require(cast(Tensor | None, kwargs.get("b")), "b")

            tgt_mask = padding_mask
            for conditioner in conditioner_list:
                if not conditioner.is_enabled():
                    continue
                conditioner.set_condition(
                    **{
                        **conditioner.prepare_condition_dict(
                            train=True,
                            **{
                                "tgt_mask": tgt_mask,
                                "gt_data": x_0,
                                "padding_mask": padding_mask,
                                "posterior_mean_fn": self.get_posterior_mean_fn(score=p_uc_score, score_fn=None),
                            },
                        ),
                    }
                )

            acc_c_score = get_accumulated_conditional_score(conditioner_list, x_t, t, padding_mask)
            gt_score = gt_uc_score + acc_c_score

            # Scale and compute conditioned loss
            p_uc_score = b * p_uc_score
            gt_score = b * gt_score
            total_loss = loss_fn(p_uc_score, gt_score, padding_mask)
            kwargs["loss"] = total_loss
            return kwargs

        return GMHook(
            name="VPSDE_condition_post_compute_loss_hook",
            stage=GMHookStageType.POST_COMPUTE_LOSS,
            fn=_hook_fn,
            priority=0,
            enabled=True,
        )

    def get_condition_pre_update_in_step_fn_hook(self, conditioner_list: list[Conditioner]):
        def _hook_fn(**kwargs: Any):
            nonlocal conditioner_list
            p_uc_score = require(cast(Tensor | None, kwargs.get("p_uc_score")), "p_uc_score")
            x_t = require(cast(Tensor | None, kwargs.get("x_t")), "x_t")
            t = require(cast(Tensor | None, kwargs.get("t")), "t")
            padding_mask = require(cast(Tensor | None, kwargs.get("padding_mask")), "padding_mask")
            kwargs.get("config")
            sampling_condition = kwargs.get("sampling_condition")

            tgt_mask = padding_mask
            for conditioner in conditioner_list:
                if not conditioner.is_enabled():
                    continue
                conditioner.set_condition(
                    **{
                        **conditioner.prepare_condition_dict(
                            train=False,
                            **{
                                "tgt_mask": tgt_mask,
                                "sampling_condition": sampling_condition,
                                "padding_mask": padding_mask,
                                "posterior_mean_fn": self.get_posterior_mean_fn(score=p_uc_score, score_fn=None),
                            },
                        ),
                    }
                )

            acc_c_score = get_accumulated_conditional_score(conditioner_list, x_t, t, padding_mask)
            p_score = p_uc_score + acc_c_score
            return p_score

        return GMHook(
            name="VPSDE_condition_pre_update_in_step_fn_hook",
            stage=GMHookStageType.PRE_UPDATE_IN_STEP_FN,
            fn=_hook_fn,
            priority=0,
            enabled=True,
        )

__init__(config, time_scheduler, masker, model, loss_fn)

Initialize the EuclideanVPSDEDiffuser

Parameters:

Name Type Description Default
config EuclideanVPSDEConfig

the config of the diffuser

required
time_scheduler DiffusionTimeScheduler

the time scheduler of the diffuser

required
masker MaskerInterface

the masker of the diffuser

required
model Module

the model of the diffuser

required
loss_fn Callable[[Tensor, Tensor, Tensor], Tensor]

the loss function of the diffuser

required

Returns:

Type Description

None

Source code in src/ls_mlkit/diffusion/euclidean_vpsde_diffuser.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def __init__(
    self,
    config: EuclideanVPSDEConfig,
    time_scheduler: DiffusionTimeScheduler,
    masker: MaskerInterface,
    model: Module,
    loss_fn: Callable[[Tensor, Tensor, Tensor], Tensor],  # (predicted, ground_true, padding_mask)
):
    """Initialize the EuclideanVPSDEDiffuser

    Args:
        config (EuclideanVPSDEConfig): the config of the diffuser
        time_scheduler (DiffusionTimeScheduler): the time scheduler of the diffuser
        masker (MaskerInterface): the masker of the diffuser
        model (Module): the model of the diffuser
        loss_fn (Callable[[Tensor, Tensor, Tensor], Tensor]): the loss function of the diffuser

    Returns:
        None
    """
    super().__init__(config=config, time_scheduler=time_scheduler, masker=masker)
    self.config: EuclideanVPSDEConfig = config
    self.sde = config.sde
    self.model = model
    self.loss_fn = loss_fn

    def score_fn(x: Tensor, t: Tensor, mask: Tensor) -> Tensor:
        return self.model(**{"x_t": x, "t": t.long(), "padding_mask": mask})["x"]

    self.corrector = LangevinCorrector(
        sde=self.sde,
        score_fn=score_fn,
        snr=self.config.snr,
        n_steps=self.config.n_correct_steps,
        ndim_micro_shape=self.config.ndim_micro_shape,
    )

step(x_t, t, padding_mask=None, *args, **kwargs)

Parameters:

Name Type Description Default
x_t Tensor

the sample at timestep t

required
t Tensor

the timestep

required
padding_mask Tensor

the padding mask

None

Returns:

Name Type Description
Tensor dict

the sample at timestep t-1

Source code in src/ls_mlkit/diffusion/euclidean_vpsde_diffuser.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
def step(
    self,
    x_t: Tensor,
    t: Tensor,
    padding_mask: Tensor | None = None,
    *args: Any,
    **kwargs: Any,
) -> dict:
    r"""
    Args:
        x_t (Tensor): the sample at timestep t
        t (Tensor): the timestep
        padding_mask (Tensor): the padding mask

    Returns:
        Tensor: the sample at timestep t-1
    """
    assert torch.all(t == t.view(-1)[0]).item()
    device = x_t.device
    idx = require(kwargs.get("idx"), "idx")
    schedule = self.time_scheduler.get_continuous_timesteps_schedule().to(device)
    ones = torch.ones_like(t)
    t_start = schedule[int(idx)] * ones
    t_end = schedule[int(idx) + 1] * ones
    config = cast(EuclideanVPSDEConfig, self.config.to(device))
    model_output = self.model(
        **{"x_t": x_t, "t": t.long(), "padding_mask": padding_mask, **kwargs}
    )
    p_uc_score = model_output["x"]

    # score hook start=====================================================
    hook_input = {
        "p_uc_score": p_uc_score,
        "x_t": x_t,
        "t": t,
        "padding_mask": padding_mask,
        "config": config,
        "sampling_condition": kwargs.get("sampling_condition"),
    }
    hook_output = self.hook_manager.run_hooks(
        GMHookStageType.PRE_UPDATE_IN_STEP_FN,
        tgt_key_name="p_uc_score",
        **hook_input,
    )
    if hook_output is not None:
        p_uc_score = hook_output

    # score hook start end =================================================================

    rsde = self.sde.get_reverse_sde(
        score=p_uc_score,
        score_fn=None,
        use_probability_flow=self.config.use_probability_flow,
    )
    delta_t = t_end - t_start
    delta_t = self.complete_micro_shape(delta_t)
    f, g = rsde.get_drift_and_diffusion(x_t, t_start, mask=padding_mask)
    g = self.complete_micro_shape(g)
    z = torch.randn_like(x_t)
    x_mean = x_t + f * delta_t
    if (t > 0).all():
        x = x_mean + g * z * torch.sqrt(delta_t.abs())
    else:
        x = x_mean

    if (t > 0).all():
        x, _ = self.corrector.update_fn(x, t - 1, padding_mask)

    return {
        "x": x,
    }

get_posterior_mean_fn(score=None, score_fn=None)

Get the posterior mean function

Parameters:

Name Type Description Default
score Tensor

the score of the sample

None
score_fn Callable

the function to compute score

None

Returns:

Name Type Description
Callable

the posterior mean function

Source code in src/ls_mlkit/diffusion/euclidean_vpsde_diffuser.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
def get_posterior_mean_fn(
    self,
    score: Tensor | None = None,
    score_fn: Callable[[Tensor, Tensor, Tensor | None], Tensor] | None = None,
):
    r"""Get the posterior mean function

    Args:
        score (Tensor, optional): the score of the sample
        score_fn (Callable, optional): the function to compute score

    Returns:
        Callable: the posterior mean function
    """

    def _posterior_mean_fn(
        x_t: Tensor,
        t: Tensor,
        padding_mask: Tensor,
    ):
        r"""
        Args:
            x_t: shape=(..., n_nodes, 3)
            t: shape=(...), dtype=torch.long

        For the case of VPSDE sampling, the posterior mean is given by

        .. math::

            E[x_0|x_t] = \frac{b^2}{a} \nabla_{x_t}\log p_t(x_t) - \frac{x_t}{a}

        """
        nonlocal score, score_fn
        assert score is not None or score_fn is not None, "either score or score_fn must be provided"
        if score is None:
            assert score_fn is not None
            score = score_fn(x_t, t, padding_mask)
        sde = cast(EuclideanVPSDEConfig, self.config.to(t)).sde
        t = self.time_scheduler.timestep_index_to_continuous_time(t)
        a, b = sde.get_a_b(t)
        E_x0_xt = b**2 / a * score + x_t / a
        return E_x0_xt

    return _posterior_mean_fn

SO3Diffuser

Bases: LieGroupDiffuser

Source code in src/ls_mlkit/diffusion/so3_diffuser.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
@inherit_docstrings
class SO3Diffuser(LieGroupDiffuser):
    def __init__(
        self,
        config: SO3DiffuserConfig,
        time_scheduler: DiffusionTimeScheduler,
        masker: BioSO3Masker,
        sde: SDE,
        score_fn: Callable[[Tensor, Tensor, Tensor], Tensor],  # (x, t, mask) -> score
        loss_fn: Callable[[Tensor, Tensor, Tensor], Tensor],  # (predicted_score, ground_truth_score, mask) -> loss
    ):
        so3 = SO3()
        super().__init__(
            config=config,
            time_scheduler=time_scheduler,
            lie_group=so3,
        )
        self.config = config
        self.time_scheduler = time_scheduler
        self.masker = masker
        self.sde = sde
        self.loss_fn = loss_fn
        self.so3 = so3
        self.score_fn = score_fn
        assert isinstance(self.sde, VESDE), "only VESDE is supported"
        igso3_cache = calculate_igso3(
            num_sigma=config.igso3_num_sigma,
            num_omega=config.igso3_num_omega,
            min_sigma=config.igso3_min_sigma,
            max_sigma=config.igso3_max_sigma,
            discrete_omega=torch.linspace(0, torch.pi, config.igso3_num_omega + 1)[1:],
            discrete_sigma=self.sde.discrete_sigmas,
        )

        # Register buffers - these will automatically move with the model
        self.register_buffer("_igso3_cdf", igso3_cache["cdf"])  # [num_sigma, num_omega]
        self.register_buffer(
            "_igso3_score_norm", igso3_cache["score_norm"]
        )  # [num_sigma, num_omega] # $$\frac{d}{d\omega} f(\omega, c, L)$$

        self.register_buffer(
            "_igso3_exp_score_norms", igso3_cache["exp_score_norms"]
        )  # [num_sigma, ]                  $$\sqrt{\mathbb{E}_{\omega} || \frac{d}{d\omega} f(\omega, c, L)||_2^2}$$

        self.register_buffer("_igso3_discrete_omega", igso3_cache["discrete_omega"])  # [num_omega, ]
        self.register_buffer("_igso3_discrete_sigma", igso3_cache["discrete_sigma"])  # [num_sigma, ]

    @property
    def igso3_cdf(self) -> Tensor:
        return cast(Tensor, self._igso3_cdf)

    @property
    def igso3_score_norm(self) -> Tensor:
        return cast(Tensor, self._igso3_score_norm)

    @property
    def igso3_exp_score_norms(self) -> Tensor:
        return cast(Tensor, self._igso3_exp_score_norms)

    @property
    def igso3_discrete_omega(self) -> Tensor:
        return cast(Tensor, self._igso3_discrete_omega)

    @property
    def igso3_discrete_sigma(self) -> Tensor:
        return cast(Tensor, self._igso3_discrete_sigma)

    def prior_sampling(self, shape: Tuple[int, ...]) -> Tensor:
        r"""Sample initial noise used for reverse process

        .. math::

            \mathcal{U}_{SO(3)}

        Args:
            shape (Tuple[int, ...]): the shape of the sample

        Returns:
            Tensor: the initial noise
        """
        macro_shape = shape
        discrete_t = self.time_scheduler.num_train_timesteps - 1
        axis = torch.randn(macro_shape + (3,))
        axis_in_s2 = axis / torch.norm(axis, dim=-1, keepdim=True)
        angle = inverse_transform_sampling(
            shape=macro_shape,
            cdf=self.igso3_cdf[discrete_t],
            discrete_omega=self.igso3_discrete_omega,
        )
        rotation_vector = angle * axis_in_s2
        rotation_skew_symmetric = vector_to_skew_symmetric(rotation_vector)
        rotation_matrix = self.so3.exp(v=rotation_skew_symmetric)
        return rotation_matrix

    def forward_process(
        self,
        x_0: Tensor,
        discrete_t: Tensor,
        mask: Tensor,
        *args: Any,
        **kwargs: Any,
    ) -> dict:
        r"""Forward process

        .. math::

            \text{IG}_{\text{SO}(3)} (\mathbf{x}; \mathbf{\mu}, \sigma^2) = f_{\sigma} (\arccos((\text{tr}(\mathbf{\mu}^T \mathbf{x}) - 1)/2)) \quad \forall \mathbf{x} \in \text{SO}(3)

        Args:
            x_0 (Tensor): the initial sample
            discrete_t (Tensor): the discrete timestep
            mask (Tensor): the mask
            *args: additional arguments
            **kwargs: additional keyword arguments

        Returns:
            dict: a dictionary that must contain the key "x_t"
        """
        # x.shape = (b, n, 3, 3)
        macro_shape = self.get_macro_shape(x_0)  # (*macro_shape, ) = (b,)
        n = x_0.shape[-3]
        shape = macro_shape + (n,)
        device = x_0.device
        axis = torch.randn(shape + (3,), device=device)  # (*macro_shape, n, 3)
        axis_in_s2 = axis / torch.norm(axis, dim=-1, keepdim=True)  # (*macro_shape, n, 3)
        igso3_cdf = self.igso3_cdf[discrete_t]  # (*macro_shape, num_omega)
        igso3_cdf = igso3_cdf.unsqueeze(-2).expand(*macro_shape, n, -1)  # (*macro_shape, n, num_omega)
        angle = inverse_transform_sampling(
            shape=shape, cdf=igso3_cdf, discrete_omega=self.igso3_discrete_omega
        )  # (*macro_shape, n)
        rotation_vector = angle.unsqueeze(-1) * axis_in_s2  # (*macro_shape, n, 3)
        rotation_skew_symmetric = vector_to_skew_symmetric(rotation_vector)  # (*macro_shape,n 3, 3)
        rotation_matrix = self.so3.exp(v=rotation_skew_symmetric)  # (*macro_shape,n, 3, 3)
        x_t = self.so3.multiply(rotation_matrix, x_0)  # (*macro_shape, n, 3, 3)
        return {"x_t": x_t}

    def get_ground_truth_score(self, x_0: Tensor, x_t: Tensor, discrete_t: Tensor, padding_mask: Tensor) -> Tensor:
        r"""Denoise Score Matching

        .. math::
            \nabla_x \log p_{0t} (x_t | x_0)

        Args:
            x_0 (Tensor): _description_
            x_t (Tensor): _description_
            discrete_t (Tensor): _description_
            padding_mask (Tensor): _description_

        Returns:
            Tensor: _description_
        """
        macro_shape = self.get_macro_shape(x_0)
        n = x_0.shape[-3]
        x_0t = x_0.transpose(-1, -2) @ x_t  # (*macro_shape, n, 3, 3)
        omega = rotation_matrix_to_angle(x_0t)  # (*macro_shape, n)
        igso3_score_norm = self.igso3_score_norm[discrete_t]  # (*macro_shape, num_omega)
        igso3_score_norm = igso3_score_norm.unsqueeze(-2).expand(*macro_shape, n, -1)  # (*macro_shape, n, num_omega)
        ground_truth_score = (
            x_t  # (*macro_shape, n, 3, 3)
            @ (self.so3.log(q=x_0t) / (omega.unsqueeze(-1).unsqueeze(-1) + EPS))  # (*macro_shape, n, 3, 3)
            * interp(x=omega, xp=self.igso3_discrete_omega, fp=igso3_score_norm)
            .unsqueeze(-1)
            .unsqueeze(-1)  # (*macro_shape, n, 3, 3)
        )
        return ground_truth_score

    def compute_loss(
        self, batch: dict[str, Any], *args: Any, **kwargs: Any
    ) -> dict:  # ty: ignore[invalid-method-override]
        x_0 = batch["x_0"]
        padding_mask = batch["padding_mask"]
        macro_shape = self.get_macro_shape(x_0)
        discrete_t = batch.get("t", None)
        if discrete_t is None:
            discrete_t = self.time_scheduler.sample_timestep_index_uniformly(macro_shape=macro_shape)
        x_t = self.forward_process(x_0, discrete_t=discrete_t, mask=padding_mask)["x_t"]
        ground_truth_score = self.get_ground_truth_score(x_0, x_t, discrete_t, padding_mask)
        predicted_score = self.score_fn(x_t, discrete_t, padding_mask)
        loss = self.loss_fn(predicted_score, ground_truth_score, padding_mask)
        return {"loss": loss}

    def step(  # ty: ignore[invalid-method-override]
        self,
        x_t: Tensor,
        discrete_t: Tensor,
        padding_mask: Tensor | None = None,
        *args: Any,
        **kwargs: Any,
    ) -> dict:
        r"""
        .. math::

            dx &= \exp_{x_t}(f_{rev} dt + g_{rev} dw)\\
            x_{t+\Delta_t} &= \exp_{x_t}(- f_{rev} |\Delta_t| + g_{rev} \Delta w)\\
            f_{rev} &= (f - g^2 \nabla_x \ln p_t(x))\\
            g_{rev} &= g\\

        """
        continuous_t = self.time_scheduler.timestep_index_to_continuous_time(discrete_t)
        f, g = self.sde.get_drift_and_diffusion(x=x_t, t=continuous_t, mask=padding_mask)

        # p_x_0: Tensor = kwargs.get("p_x_0", None)
        # assert p_x_0 is not None, "p_x_0 is required"
        # riemannian_grad = self.get_ground_truth_score(
        #     x_0=p_x_0, x_t=x_t, discrete_t=discrete_t, padding_mask=padding_mask
        # )

        assert padding_mask is not None
        riemannian_grad = self.score_fn(x_t, discrete_t, padding_mask)

        assert f.sum() == 0, "f should be 0"
        rev_f = f - g**2 * riemannian_grad
        rev_g = g

        delta_t = self.time_scheduler.T / self.time_scheduler.num_inference_timesteps
        term1 = -rev_f * delta_t

        noise_lie_algebra = self.sample_noise_in_lie_algebra(macro_shape=self.get_macro_shape(x_t))
        delta_w = torch.sqrt(torch.as_tensor(delta_t, device=x_t.device, dtype=x_t.dtype)) * x_t @ noise_lie_algebra
        term2 = rev_g * delta_w

        move_in_tangent_space = term1 + term2
        x_tm1 = self.so3.exp(p=x_t, v=move_in_tangent_space)
        return {"x": x_tm1}

    def sample_noise_in_lie_algebra(
        self,
        macro_shape: Tuple[int, ...],
    ) -> Tensor:
        r"""Sample noise in Lie algebra, Skew-symmetric matrix

        Args:
            macro_shape (Tuple[int, ...]): the macro shape of the noise

        Returns:
            Tensor: the noise in Lie algebra of shape :math:`(*macro_shape, 3, 3)`
        """
        return self.so3.random_tangent(p=self.so3.identity(macro_shape=macro_shape))

    def sampling(self, shape, device, x_init_posterior=None, *args, **kwargs):
        raise NotImplementedError

    def inpainting(
        self,
        x,
        padding_mask,
        inpainting_mask,
        device,
        x_init_posterior=None,
        inpainting_mask_key="inpainting_mask",
        *args,
        **kwargs,
    ):
        raise NotImplementedError

prior_sampling(shape)

Sample initial noise used for reverse process

.. math::

\mathcal{U}_{SO(3)}

Parameters:

Name Type Description Default
shape Tuple[int, ...]

the shape of the sample

required

Returns:

Name Type Description
Tensor Tensor

the initial noise

Source code in src/ls_mlkit/diffusion/so3_diffuser.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def prior_sampling(self, shape: Tuple[int, ...]) -> Tensor:
    r"""Sample initial noise used for reverse process

    .. math::

        \mathcal{U}_{SO(3)}

    Args:
        shape (Tuple[int, ...]): the shape of the sample

    Returns:
        Tensor: the initial noise
    """
    macro_shape = shape
    discrete_t = self.time_scheduler.num_train_timesteps - 1
    axis = torch.randn(macro_shape + (3,))
    axis_in_s2 = axis / torch.norm(axis, dim=-1, keepdim=True)
    angle = inverse_transform_sampling(
        shape=macro_shape,
        cdf=self.igso3_cdf[discrete_t],
        discrete_omega=self.igso3_discrete_omega,
    )
    rotation_vector = angle * axis_in_s2
    rotation_skew_symmetric = vector_to_skew_symmetric(rotation_vector)
    rotation_matrix = self.so3.exp(v=rotation_skew_symmetric)
    return rotation_matrix

forward_process(x_0, discrete_t, mask, *args, **kwargs)

Forward process

.. math::

\text{IG}_{\text{SO}(3)} (\mathbf{x}; \mathbf{\mu}, \sigma^2) = f_{\sigma} (\arccos((\text{tr}(\mathbf{\mu}^T \mathbf{x}) - 1)/2)) \quad \forall \mathbf{x} \in \text{SO}(3)

Parameters:

Name Type Description Default
x_0 Tensor

the initial sample

required
discrete_t Tensor

the discrete timestep

required
mask Tensor

the mask

required
*args Any

additional arguments

()
**kwargs Any

additional keyword arguments

{}

Returns:

Name Type Description
dict dict

a dictionary that must contain the key "x_t"

Source code in src/ls_mlkit/diffusion/so3_diffuser.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def forward_process(
    self,
    x_0: Tensor,
    discrete_t: Tensor,
    mask: Tensor,
    *args: Any,
    **kwargs: Any,
) -> dict:
    r"""Forward process

    .. math::

        \text{IG}_{\text{SO}(3)} (\mathbf{x}; \mathbf{\mu}, \sigma^2) = f_{\sigma} (\arccos((\text{tr}(\mathbf{\mu}^T \mathbf{x}) - 1)/2)) \quad \forall \mathbf{x} \in \text{SO}(3)

    Args:
        x_0 (Tensor): the initial sample
        discrete_t (Tensor): the discrete timestep
        mask (Tensor): the mask
        *args: additional arguments
        **kwargs: additional keyword arguments

    Returns:
        dict: a dictionary that must contain the key "x_t"
    """
    # x.shape = (b, n, 3, 3)
    macro_shape = self.get_macro_shape(x_0)  # (*macro_shape, ) = (b,)
    n = x_0.shape[-3]
    shape = macro_shape + (n,)
    device = x_0.device
    axis = torch.randn(shape + (3,), device=device)  # (*macro_shape, n, 3)
    axis_in_s2 = axis / torch.norm(axis, dim=-1, keepdim=True)  # (*macro_shape, n, 3)
    igso3_cdf = self.igso3_cdf[discrete_t]  # (*macro_shape, num_omega)
    igso3_cdf = igso3_cdf.unsqueeze(-2).expand(*macro_shape, n, -1)  # (*macro_shape, n, num_omega)
    angle = inverse_transform_sampling(
        shape=shape, cdf=igso3_cdf, discrete_omega=self.igso3_discrete_omega
    )  # (*macro_shape, n)
    rotation_vector = angle.unsqueeze(-1) * axis_in_s2  # (*macro_shape, n, 3)
    rotation_skew_symmetric = vector_to_skew_symmetric(rotation_vector)  # (*macro_shape,n 3, 3)
    rotation_matrix = self.so3.exp(v=rotation_skew_symmetric)  # (*macro_shape,n, 3, 3)
    x_t = self.so3.multiply(rotation_matrix, x_0)  # (*macro_shape, n, 3, 3)
    return {"x_t": x_t}

get_ground_truth_score(x_0, x_t, discrete_t, padding_mask)

Denoise Score Matching

.. math:: \nabla_x \log p_{0t} (x_t | x_0)

Parameters:

Name Type Description Default
x_0 Tensor

description

required
x_t Tensor

description

required
discrete_t Tensor

description

required
padding_mask Tensor

description

required

Returns:

Name Type Description
Tensor Tensor

description

Source code in src/ls_mlkit/diffusion/so3_diffuser.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def get_ground_truth_score(self, x_0: Tensor, x_t: Tensor, discrete_t: Tensor, padding_mask: Tensor) -> Tensor:
    r"""Denoise Score Matching

    .. math::
        \nabla_x \log p_{0t} (x_t | x_0)

    Args:
        x_0 (Tensor): _description_
        x_t (Tensor): _description_
        discrete_t (Tensor): _description_
        padding_mask (Tensor): _description_

    Returns:
        Tensor: _description_
    """
    macro_shape = self.get_macro_shape(x_0)
    n = x_0.shape[-3]
    x_0t = x_0.transpose(-1, -2) @ x_t  # (*macro_shape, n, 3, 3)
    omega = rotation_matrix_to_angle(x_0t)  # (*macro_shape, n)
    igso3_score_norm = self.igso3_score_norm[discrete_t]  # (*macro_shape, num_omega)
    igso3_score_norm = igso3_score_norm.unsqueeze(-2).expand(*macro_shape, n, -1)  # (*macro_shape, n, num_omega)
    ground_truth_score = (
        x_t  # (*macro_shape, n, 3, 3)
        @ (self.so3.log(q=x_0t) / (omega.unsqueeze(-1).unsqueeze(-1) + EPS))  # (*macro_shape, n, 3, 3)
        * interp(x=omega, xp=self.igso3_discrete_omega, fp=igso3_score_norm)
        .unsqueeze(-1)
        .unsqueeze(-1)  # (*macro_shape, n, 3, 3)
    )
    return ground_truth_score

step(x_t, discrete_t, padding_mask=None, *args, **kwargs)

.. math::

dx &= \exp_{x_t}(f_{rev} dt + g_{rev} dw)\\
x_{t+\Delta_t} &= \exp_{x_t}(- f_{rev} |\Delta_t| + g_{rev} \Delta w)\\
f_{rev} &= (f - g^2 \nabla_x \ln p_t(x))\\
g_{rev} &= g\\
Source code in src/ls_mlkit/diffusion/so3_diffuser.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def step(  # ty: ignore[invalid-method-override]
    self,
    x_t: Tensor,
    discrete_t: Tensor,
    padding_mask: Tensor | None = None,
    *args: Any,
    **kwargs: Any,
) -> dict:
    r"""
    .. math::

        dx &= \exp_{x_t}(f_{rev} dt + g_{rev} dw)\\
        x_{t+\Delta_t} &= \exp_{x_t}(- f_{rev} |\Delta_t| + g_{rev} \Delta w)\\
        f_{rev} &= (f - g^2 \nabla_x \ln p_t(x))\\
        g_{rev} &= g\\

    """
    continuous_t = self.time_scheduler.timestep_index_to_continuous_time(discrete_t)
    f, g = self.sde.get_drift_and_diffusion(x=x_t, t=continuous_t, mask=padding_mask)

    # p_x_0: Tensor = kwargs.get("p_x_0", None)
    # assert p_x_0 is not None, "p_x_0 is required"
    # riemannian_grad = self.get_ground_truth_score(
    #     x_0=p_x_0, x_t=x_t, discrete_t=discrete_t, padding_mask=padding_mask
    # )

    assert padding_mask is not None
    riemannian_grad = self.score_fn(x_t, discrete_t, padding_mask)

    assert f.sum() == 0, "f should be 0"
    rev_f = f - g**2 * riemannian_grad
    rev_g = g

    delta_t = self.time_scheduler.T / self.time_scheduler.num_inference_timesteps
    term1 = -rev_f * delta_t

    noise_lie_algebra = self.sample_noise_in_lie_algebra(macro_shape=self.get_macro_shape(x_t))
    delta_w = torch.sqrt(torch.as_tensor(delta_t, device=x_t.device, dtype=x_t.dtype)) * x_t @ noise_lie_algebra
    term2 = rev_g * delta_w

    move_in_tangent_space = term1 + term2
    x_tm1 = self.so3.exp(p=x_t, v=move_in_tangent_space)
    return {"x": x_tm1}

sample_noise_in_lie_algebra(macro_shape)

Sample noise in Lie algebra, Skew-symmetric matrix

Parameters:

Name Type Description Default
macro_shape Tuple[int, ...]

the macro shape of the noise

required

Returns:

Name Type Description
Tensor Tensor

the noise in Lie algebra of shape :math:(*macro_shape, 3, 3)

Source code in src/ls_mlkit/diffusion/so3_diffuser.py
275
276
277
278
279
280
281
282
283
284
285
286
287
def sample_noise_in_lie_algebra(
    self,
    macro_shape: Tuple[int, ...],
) -> Tensor:
    r"""Sample noise in Lie algebra, Skew-symmetric matrix

    Args:
        macro_shape (Tuple[int, ...]): the macro shape of the noise

    Returns:
        Tensor: the noise in Lie algebra of shape :math:`(*macro_shape, 3, 3)`
    """
    return self.so3.random_tangent(p=self.so3.identity(macro_shape=macro_shape))