Skip to content

util

ls_mlkit.util

BaseGenerativeModel

Bases: BaseLoss

abstract method: compute_loss, step, sampling, inpainting

Source code in src/ls_mlkit/util/base_class/base_gm_class.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
@inherit_docstrings
class BaseGenerativeModel(BaseLoss):
    """
    abstract method: compute_loss, step, sampling, inpainting
    """

    def __init__(self, config: BaseGenerativeModelConfig):
        super().__init__(config=config)
        self.config: BaseGenerativeModelConfig = config
        self.hook_manager = GMHookManager()

    @abstractmethod
    def prior_sampling(self, shape: tuple[int, ...]) -> Tensor: ...

    @abstractmethod
    def step(
        self,
        x_t: Tensor,
        t: Tensor,
        padding_mask: Optional[Tensor] = None,
        *args: Any,
        **kwargs: Any,
    ) -> dict:
        """_summary_

        Args:
            x_t (``Tensor``): _description_
            t (``Tensor``): _description_
            padding_mask (``Tensor``, *optional*): _description_. Defaults to None.

        Returns:
            ``dict``: A dictionary that must contain the key "x"
        """

    @abstractmethod
    def sampling(
        self,
        shape,
        device,
        x_init_posterior=None,
        return_all=False,
        sampling_condition=None,
        sapmling_condition_key="sapmling_condition",
        **kwargs,
    ) -> dict: ...

    @abstractmethod
    def inpainting(
        self,
        x,
        padding_mask,
        inpainting_mask,
        device,
        x_init_posterior: Optional[Tensor] = None,
        inpainting_mask_key: str = "inpainting_mask",
        sapmling_condition_key: Optional[str] = "sapmling_condition",
        return_all: bool = False,
        sampling_condition: Optional[Any] = None,
        **kwargs,
    ) -> dict: ...

    def forward(self, **batch) -> dict:
        r"""Forward function, input batch of data and return the dictionary containing the loss

        Args:
            batch (``dict[str, Any]``): the batch of data

        Returns:
            ``dict``: a dictionary that must contain the key "loss"
        """
        result = self.compute_loss(**batch)
        hook_result = self.hook_manager.run_hooks(stage=GMHookStageType.POST_COMPUTE_LOSS, tgt_key_name=None, **result)
        if hook_result is not None:
            assert isinstance(hook_result, (dict, Tensor))
            result = hook_result
        return result

    def register_post_compute_loss_hook(
        self,
        name: str,
        fn: Callable[..., Any],
        priority: int = 0,
        enabled: bool = True,
    ) -> GMHookHandler:
        r"""Register a hook to be called after loss computation

        Args:
            name (``str``): the name of the hook
            fn (``Callable[..., Any]``): the function to be called
            priority (``int``, optional): the priority of the hook. Defaults to 0.
            enabled (``bool``, optional): whether the hook is enabled. Defaults to True.
        """
        hook = Hook(
            name=name,
            stage=GMHookStageType.POST_COMPUTE_LOSS,
            fn=fn,
            priority=priority,
            enabled=enabled,
        )
        handler = self.hook_manager.register_hook(hook)
        handler = cast(GMHookHandler, handler)
        return handler

    def register_hooks(self, hooks: list[GMHook]) -> list[GMHookHandler]:
        handler_list = []
        for hook in hooks:
            handler = self.hook_manager.register_hook(hook)
            handler = cast(GMHookHandler, handler)
            handler_list.append(handler)
        return handler_list

    def register_hook(self, hook: GMHook) -> GMHookHandler:
        handler = self.hook_manager.register_hook(hook)
        handler = cast(GMHookHandler, handler)
        return handler

step(x_t, t, padding_mask=None, *args, **kwargs) abstractmethod

summary

Parameters:

Name Type Description Default
x_t ``Tensor``

description

required
t ``Tensor``

description

required
padding_mask ``Tensor``, *optional*

description. Defaults to None.

None

Returns:

Type Description
dict

dict: A dictionary that must contain the key "x"

Source code in src/ls_mlkit/util/base_class/base_gm_class.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
@abstractmethod
def step(
    self,
    x_t: Tensor,
    t: Tensor,
    padding_mask: Optional[Tensor] = None,
    *args: Any,
    **kwargs: Any,
) -> dict:
    """_summary_

    Args:
        x_t (``Tensor``): _description_
        t (``Tensor``): _description_
        padding_mask (``Tensor``, *optional*): _description_. Defaults to None.

    Returns:
        ``dict``: A dictionary that must contain the key "x"
    """

forward(**batch)

Forward function, input batch of data and return the dictionary containing the loss

Parameters:

Name Type Description Default
batch ``dict[str, Any]``

the batch of data

{}

Returns:

Type Description
dict

dict: a dictionary that must contain the key "loss"

Source code in src/ls_mlkit/util/base_class/base_gm_class.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def forward(self, **batch) -> dict:
    r"""Forward function, input batch of data and return the dictionary containing the loss

    Args:
        batch (``dict[str, Any]``): the batch of data

    Returns:
        ``dict``: a dictionary that must contain the key "loss"
    """
    result = self.compute_loss(**batch)
    hook_result = self.hook_manager.run_hooks(stage=GMHookStageType.POST_COMPUTE_LOSS, tgt_key_name=None, **result)
    if hook_result is not None:
        assert isinstance(hook_result, (dict, Tensor))
        result = hook_result
    return result

register_post_compute_loss_hook(name, fn, priority=0, enabled=True)

Register a hook to be called after loss computation

Parameters:

Name Type Description Default
name ``str``

the name of the hook

required
fn ``Callable[..., Any]``

the function to be called

required
priority ``int``

the priority of the hook. Defaults to 0.

0
enabled ``bool``

whether the hook is enabled. Defaults to True.

True
Source code in src/ls_mlkit/util/base_class/base_gm_class.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def register_post_compute_loss_hook(
    self,
    name: str,
    fn: Callable[..., Any],
    priority: int = 0,
    enabled: bool = True,
) -> GMHookHandler:
    r"""Register a hook to be called after loss computation

    Args:
        name (``str``): the name of the hook
        fn (``Callable[..., Any]``): the function to be called
        priority (``int``, optional): the priority of the hook. Defaults to 0.
        enabled (``bool``, optional): whether the hook is enabled. Defaults to True.
    """
    hook = Hook(
        name=name,
        stage=GMHookStageType.POST_COMPUTE_LOSS,
        fn=fn,
        priority=priority,
        enabled=enabled,
    )
    handler = self.hook_manager.register_hook(hook)
    handler = cast(GMHookHandler, handler)
    return handler

BaseLoss

Bases: Module, ABC

abstract method: compute_loss

Source code in src/ls_mlkit/util/base_class/base_loss_class.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
@inherit_docstrings
class BaseLoss(Module, abc.ABC):
    r"""
    abstract method: compute_loss
    """

    def __init__(
        self,
        config: BaseLossConfig,
    ):
        Module.__init__(self)
        abc.ABC.__init__(self)
        self.config: BaseLossConfig = config
        self.shape_util = Shape(
            config=ShapeConfig(ndim_micro_shape=config.ndim_micro_shape),
        )

    @abc.abstractmethod
    def compute_loss(self, **batch) -> dict:
        r"""Compute loss

        Args:
            batch (``dict[str, Any]``): the batch of data

        Returns:
            ``dict``|``Tensor``: a dictionary that must contain the key "loss"
        """

    def get_macro_shape(self, x: Tensor) -> tuple[int, ...]:
        return self.shape_util.get_macro_shape(x)

    def complete_micro_shape(self, x: Tensor) -> Tensor:
        return self.shape_util.complete_micro_shape(x)

compute_loss(**batch) abstractmethod

Compute loss

Parameters:

Name Type Description Default
batch ``dict[str, Any]``

the batch of data

{}

Returns:

Type Description
dict

dict|Tensor: a dictionary that must contain the key "loss"

Source code in src/ls_mlkit/util/base_class/base_loss_class.py
40
41
42
43
44
45
46
47
48
49
@abc.abstractmethod
def compute_loss(self, **batch) -> dict:
    r"""Compute loss

    Args:
        batch (``dict[str, Any]``): the batch of data

    Returns:
        ``dict``|``Tensor``: a dictionary that must contain the key "loss"
    """

BaseTimeScheduler

Bases: ABC

Base class for time schedulers in diffusion models.

Notation Convention

Let the total diffusion time be :math:T, discretized into :math:N diffusion steps, corresponding to :math:N+1 continuous time points:

.. math:: 0 = t_0 < t_1 < \cdots < t_N = T

where :math:\{t_i\}_{i=0}^N represents continuous time. For uniform discretization:

.. math:: t_i = \frac{i}{N} \cdot T

The corresponding discrete time steps are defined as:

.. math:: i \in {0, 1, \ldots, N}

In diffusion models, :math:t_0 corresponds to the clean data distribution :math:q(x_0), so training and sampling typically only consider:

.. math:: i \in {1, \ldots, N}

For engineering convenience (0-based array indexing), we use:

.. math:: \text{idx} = i - 1

Therefore:

  • idx = 0 corresponds to discrete step :math:i=1, i.e., continuous time :math:t_1
  • idx = N-1 corresponds to discrete step :math:i=N, i.e., continuous time :math:t_N = T

In this implementation:

  • num_train_timesteps = :math:N (number of diffusion steps)
  • timestep_index (or idx) :math:\in \{\text{idx\_start}, \ldots, \text{idx\_start} + N - 1\}
  • continuous_time :math:\in [t_1, t_N] = [\frac{T}{N}, T] for training/sampling

The idx_start parameter controls the starting value of timestep indices:

  • When idx_start = 0 (default): :math:\text{idx} = i - 1, so :math:\text{idx} \in \{0, \ldots, N-1\}
  • When idx_start = 1: :math:\text{idx} = i, so :math:\text{idx} \in \{1, \ldots, N\}

Parameters:

Name Type Description Default
continuous_time_start ``float``, *optional*

The start of continuous time range (typically 0). Defaults to 0.0.

0.0
continuous_time_end ``float``, *optional*

The end of continuous time range (i.e., :math:T). Defaults to 1.0.

1.0
num_train_timesteps ``int``, *optional*

Number of diffusion steps :math:N. Defaults to 1000.

1000
num_inference_steps ``int``, *optional*

Number of inference steps. If None, uses num_train_timesteps. Defaults to None.

None
idx_start ``int``, *optional*

The starting value for timestep indices. Set to 1 if you prefer 1-based indexing where idx directly equals the discrete step i. Defaults to 0.

0
Source code in src/ls_mlkit/util/base_class/base_time_class.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
@inherit_docstrings
class BaseTimeScheduler(ABC):
    r"""Base class for time schedulers in diffusion models.

    Notation Convention
    -------------------
    Let the total diffusion time be :math:`T`, discretized into :math:`N` diffusion steps,
    corresponding to :math:`N+1` continuous time points:

    .. math::
        0 = t_0 < t_1 < \cdots < t_N = T

    where :math:`\{t_i\}_{i=0}^N` represents continuous time. For uniform discretization:

    .. math::
        t_i = \frac{i}{N} \cdot T

    The corresponding discrete time steps are defined as:

    .. math::
        i \in \{0, 1, \ldots, N\}

    In diffusion models, :math:`t_0` corresponds to the clean data distribution :math:`q(x_0)`,
    so training and sampling typically only consider:

    .. math::
        i \in \{1, \ldots, N\}

    For engineering convenience (0-based array indexing), we use:

    .. math::
        \text{idx} = i - 1

    Therefore:

    - ``idx = 0`` corresponds to discrete step :math:`i=1`, i.e., continuous time :math:`t_1`
    - ``idx = N-1`` corresponds to discrete step :math:`i=N`, i.e., continuous time :math:`t_N = T`

    In this implementation:

    - ``num_train_timesteps`` = :math:`N` (number of diffusion steps)
    - ``timestep_index`` (or ``idx``) :math:`\in \{\text{idx\_start}, \ldots, \text{idx\_start} + N - 1\}`
    - ``continuous_time`` :math:`\in [t_1, t_N] = [\frac{T}{N}, T]` for training/sampling

    The ``idx_start`` parameter controls the starting value of timestep indices:

    - When ``idx_start = 0`` (default): :math:`\text{idx} = i - 1`, so :math:`\text{idx} \in \{0, \ldots, N-1\}`
    - When ``idx_start = 1``: :math:`\text{idx} = i`, so :math:`\text{idx} \in \{1, \ldots, N\}`

    Args:
        continuous_time_start (``float``, *optional*): The start of continuous time range (typically 0). Defaults to 0.0.
        continuous_time_end (``float``, *optional*): The end of continuous time range (i.e., :math:`T`). Defaults to 1.0.
        num_train_timesteps (``int``, *optional*): Number of diffusion steps :math:`N`. Defaults to 1000.
        num_inference_steps (``int``, *optional*): Number of inference steps. If None, uses ``num_train_timesteps``. Defaults to None.
        idx_start (``int``, *optional*): The starting value for timestep indices.
            Set to 1 if you prefer 1-based indexing where idx directly equals the discrete step i. Defaults to 0.
    """

    def __init__(
        self,
        continuous_time_start: float = 0.0,
        continuous_time_end: float = 1.0,
        num_train_timesteps: int = 1000,
        num_inference_steps: Optional[int] = None,
        idx_start: int = 0,
    ) -> None:
        self.continuous_time_start: float = continuous_time_start
        self.continuous_time_end: float = continuous_time_end
        self.num_train_timesteps: int = num_train_timesteps  # This is N
        self.num_inference_timesteps: int = (
            num_inference_steps if num_inference_steps is not None else num_train_timesteps
        )
        self.idx_start: int = idx_start  # Starting value for timestep indices
        self.idx_end: int = idx_start + num_train_timesteps - 1  # Last value for timestep indices
        self._timesteps_idx: Optional[Tensor] = None  # Stores timestep indices
        self._continuous_timesteps: Optional[Tensor] = None  # Stores continuous times
        self.T: float = continuous_time_end - continuous_time_start
        self.initialize_timesteps_schedule()

    def get_timestep_index_start(self) -> int:
        return self.idx_start

    def get_timestep_index_end(self) -> int:
        return self.idx_end

    def get_continuous_time_start(self) -> float:
        return self.continuous_time_start

    def get_coutinuous_time_end(self) -> float:
        return self.continuous_time_end

    @abstractmethod
    def initialize_timesteps_schedule(self) -> None:
        """Initialize timesteps schedule for sampling/inference."""

    def continuous_time_to_timestep_index(self, continuous_time: Tensor) -> Tensor:
        r"""Convert continuous time to timestep index.

        Given continuous time :math:`t`, compute the timestep index:

        .. math::
            \text{idx} = \text{round}\left(\frac{t - t_0}{T} \cdot N\right) - 1 + \text{idx\_start}

        where :math:`t_0` is ``continuous_time_start``, :math:`T` is the total time span,
        :math:`N` is ``num_train_timesteps``, and :math:`\text{idx\_start}` is the starting index.

        The result is clamped to :math:`[\text{idx\_start}, \text{idx\_start} + N - 1]`.

        Args:
            continuous_time (``Tensor``): Continuous time values :math:`t \in [t_0, t_0 + T]`.

        Returns:
            ``Tensor``: Timestep indices :math:`\text{idx} \in \{\text{idx\_start}, \ldots, \text{idx\_start} + N - 1\}`.
        """
        # Normalize time to [0, 1] range, then scale to [0, N]
        normalized = (continuous_time - self.continuous_time_start) / self.T
        # idx = round(normalized * N) - 1 + idx_start
        return torch.clamp(
            torch.round(normalized * self.num_train_timesteps) - 1 + self.idx_start,
            min=self.idx_start,
            max=self.idx_start + self.num_train_timesteps - 1,
        ).to(torch.int64)

    def timestep_index_to_continuous_time(self, timestep_index: Tensor) -> Tensor:
        r"""Convert timestep index to continuous time.

        Given timestep index :math:`\text{idx}`, compute the continuous time:

        .. math::
            t = t_0 + \frac{\text{idx} + 1 - \text{idx\_start}}{N} \cdot T

        where :math:`t_0` is ``continuous_time_start``, :math:`T` is the total time span,
        :math:`N` is ``num_train_timesteps``, and :math:`\text{idx\_start}` is the starting index.

        Args:
            timestep_index (``Tensor``): Timestep indices :math:`\text{idx} \in \{\text{idx\_start}, \ldots, \text{idx\_start} + N - 1\}`.

        Returns:
            ``Tensor``: Continuous time values :math:`t \in [t_1, t_N]`.
        """
        # t = t_0 + (idx + 1 - idx_start) / N * T
        return (
            self.continuous_time_start
            + (timestep_index + 1 - self.idx_start).float() / self.num_train_timesteps * self.T
        )

    def get_timestep_indices_schedule(self) -> Tensor:
        r"""Get the timestep indices schedule for sampling/inference.

        Returns:
            ``Tensor``: 1D tensor of timestep indices :math:`\text{idx} \in \{\text{idx\_start}, \ldots, \text{idx\_start} + N - 1\}`.
        """
        assert self._timesteps_idx is not None, "timestep indices schedule is not set"
        assert isinstance(self._timesteps_idx, Tensor), "timestep indices must be a Tensor"
        assert self._timesteps_idx.ndim == 1, "timestep indices must be a 1D Tensor"
        return self._timesteps_idx

    def get_continuous_timesteps_schedule(self) -> Tensor:
        r"""Get the continuous timesteps schedule for sampling/inference.

        Returns:
            ``Tensor``: 1D tensor of continuous time values :math:`t \in [t_1, t_N]`.
        """
        assert self._continuous_timesteps is not None, "continuous timesteps schedule is not set"
        assert isinstance(self._continuous_timesteps, Tensor), "continuous timesteps must be a Tensor"
        assert self._continuous_timesteps.ndim == 1, "continuous timesteps must be a 1D Tensor"
        return self._continuous_timesteps

    def set_timestep_indices_schedule(self, timestep_indices: Tensor) -> None:
        r"""Set the timestep indices schedule for sampling/inference.

        Args:
            timestep_indices (``Tensor``): 1D tensor of timestep indices :math:`\text{idx} \in \{\text{idx\_start}, \ldots, \text{idx\_start} + N - 1\}`.
        """
        self._timesteps_idx = timestep_indices

    def set_continuous_timesteps_schedule(self, continuous_timesteps: Tensor) -> None:
        r"""Set the continuous timesteps schedule for sampling/inference.

        Args:
            continuous_timesteps (``Tensor``): 1D tensor of continuous time values.
        """
        self._continuous_timesteps = continuous_timesteps

    def sample_timestep_index_uniformly(
        self, macro_shape: Tuple[int, ...], same_for_all_samples: bool = False
    ) -> Tensor:
        r"""Sample timestep indices uniformly from :math:`\{\text{idx\_start}, \ldots, \text{idx\_start} + N - 1\}`.

        This corresponds to sampling discrete steps :math:`i` uniformly from :math:`\{1, \ldots, N\}`
        and converting to index via :math:`\text{idx} = i - 1 + \text{idx\_start}`.

        Args:
            macro_shape (``Tuple[int, ...]``): Shape of the output tensor.
            same_for_all_samples (``bool``, *optional*): If True, use the same timestep index for all samples. Defaults to False.

        Returns:
            ``Tensor``: Timestep indices with shape ``macro_shape``.
        """
        idx_min = self.idx_start
        idx_max = self.idx_start + self.num_train_timesteps  # exclusive upper bound
        if same_for_all_samples:
            return torch.ones(macro_shape, dtype=torch.int64) * torch.randint(idx_min, idx_max, (1,), dtype=torch.int64)
        else:
            return torch.randint(idx_min, idx_max, macro_shape, dtype=torch.int64)

    def sample_continuous_time_uniformly(
        self, macro_shape: Tuple[int, ...], same_for_all_samples: bool = False
    ) -> Tensor:
        r"""Sample continuous time uniformly from :math:`[t_1, t_N]`.

        Note: This samples from :math:`[t_0 + \frac{T}{N}, t_0 + T]` to exclude :math:`t_0`
        (the clean data point).

        Args:
            macro_shape (``Tuple[int, ...]``): Shape of the output tensor.
            same_for_all_samples (``bool``, *optional*): If True, use the same time for all samples. Defaults to False.

        Returns:
            ``Tensor``: Continuous time values with shape ``macro_shape``.
        """
        # Sample from [t_1, t_N] = [t_0 + T/N, t_0 + T]
        t_min = self.continuous_time_start + self.T / self.num_train_timesteps  # t_1
        t_max = self.continuous_time_end  # t_N = T
        t_range = t_max - t_min

        if same_for_all_samples:
            return torch.ones(macro_shape) * (torch.rand(1) * t_range + t_min)
        else:
            return torch.rand(macro_shape) * t_range + t_min

initialize_timesteps_schedule() abstractmethod

Initialize timesteps schedule for sampling/inference.

Source code in src/ls_mlkit/util/base_class/base_time_class.py
101
102
103
@abstractmethod
def initialize_timesteps_schedule(self) -> None:
    """Initialize timesteps schedule for sampling/inference."""

continuous_time_to_timestep_index(continuous_time)

Convert continuous time to timestep index.

Given continuous time :math:t, compute the timestep index:

.. math:: \text{idx} = \text{round}\left(\frac{t - t_0}{T} \cdot N\right) - 1 + \text{idx_start}

where :math:t_0 is continuous_time_start, :math:T is the total time span, :math:N is num_train_timesteps, and :math:\text{idx\_start} is the starting index.

The result is clamped to :math:[\text{idx\_start}, \text{idx\_start} + N - 1].

Parameters:

Name Type Description Default
continuous_time ``Tensor``

Continuous time values :math:t \in [t_0, t_0 + T].

required

Returns:

Type Description
Tensor

Tensor: Timestep indices :math:\text{idx} \in \{\text{idx\_start}, \ldots, \text{idx\_start} + N - 1\}.

Source code in src/ls_mlkit/util/base_class/base_time_class.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def continuous_time_to_timestep_index(self, continuous_time: Tensor) -> Tensor:
    r"""Convert continuous time to timestep index.

    Given continuous time :math:`t`, compute the timestep index:

    .. math::
        \text{idx} = \text{round}\left(\frac{t - t_0}{T} \cdot N\right) - 1 + \text{idx\_start}

    where :math:`t_0` is ``continuous_time_start``, :math:`T` is the total time span,
    :math:`N` is ``num_train_timesteps``, and :math:`\text{idx\_start}` is the starting index.

    The result is clamped to :math:`[\text{idx\_start}, \text{idx\_start} + N - 1]`.

    Args:
        continuous_time (``Tensor``): Continuous time values :math:`t \in [t_0, t_0 + T]`.

    Returns:
        ``Tensor``: Timestep indices :math:`\text{idx} \in \{\text{idx\_start}, \ldots, \text{idx\_start} + N - 1\}`.
    """
    # Normalize time to [0, 1] range, then scale to [0, N]
    normalized = (continuous_time - self.continuous_time_start) / self.T
    # idx = round(normalized * N) - 1 + idx_start
    return torch.clamp(
        torch.round(normalized * self.num_train_timesteps) - 1 + self.idx_start,
        min=self.idx_start,
        max=self.idx_start + self.num_train_timesteps - 1,
    ).to(torch.int64)

timestep_index_to_continuous_time(timestep_index)

Convert timestep index to continuous time.

Given timestep index :math:\text{idx}, compute the continuous time:

.. math:: t = t_0 + \frac{\text{idx} + 1 - \text{idx_start}}{N} \cdot T

where :math:t_0 is continuous_time_start, :math:T is the total time span, :math:N is num_train_timesteps, and :math:\text{idx\_start} is the starting index.

Parameters:

Name Type Description Default
timestep_index ``Tensor``

Timestep indices :math:\text{idx} \in \{\text{idx\_start}, \ldots, \text{idx\_start} + N - 1\}.

required

Returns:

Type Description
Tensor

Tensor: Continuous time values :math:t \in [t_1, t_N].

Source code in src/ls_mlkit/util/base_class/base_time_class.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def timestep_index_to_continuous_time(self, timestep_index: Tensor) -> Tensor:
    r"""Convert timestep index to continuous time.

    Given timestep index :math:`\text{idx}`, compute the continuous time:

    .. math::
        t = t_0 + \frac{\text{idx} + 1 - \text{idx\_start}}{N} \cdot T

    where :math:`t_0` is ``continuous_time_start``, :math:`T` is the total time span,
    :math:`N` is ``num_train_timesteps``, and :math:`\text{idx\_start}` is the starting index.

    Args:
        timestep_index (``Tensor``): Timestep indices :math:`\text{idx} \in \{\text{idx\_start}, \ldots, \text{idx\_start} + N - 1\}`.

    Returns:
        ``Tensor``: Continuous time values :math:`t \in [t_1, t_N]`.
    """
    # t = t_0 + (idx + 1 - idx_start) / N * T
    return (
        self.continuous_time_start
        + (timestep_index + 1 - self.idx_start).float() / self.num_train_timesteps * self.T
    )

get_timestep_indices_schedule()

Get the timestep indices schedule for sampling/inference.

Returns:

Type Description
Tensor

Tensor: 1D tensor of timestep indices :math:\text{idx} \in \{\text{idx\_start}, \ldots, \text{idx\_start} + N - 1\}.

Source code in src/ls_mlkit/util/base_class/base_time_class.py
156
157
158
159
160
161
162
163
164
165
def get_timestep_indices_schedule(self) -> Tensor:
    r"""Get the timestep indices schedule for sampling/inference.

    Returns:
        ``Tensor``: 1D tensor of timestep indices :math:`\text{idx} \in \{\text{idx\_start}, \ldots, \text{idx\_start} + N - 1\}`.
    """
    assert self._timesteps_idx is not None, "timestep indices schedule is not set"
    assert isinstance(self._timesteps_idx, Tensor), "timestep indices must be a Tensor"
    assert self._timesteps_idx.ndim == 1, "timestep indices must be a 1D Tensor"
    return self._timesteps_idx

get_continuous_timesteps_schedule()

Get the continuous timesteps schedule for sampling/inference.

Returns:

Type Description
Tensor

Tensor: 1D tensor of continuous time values :math:t \in [t_1, t_N].

Source code in src/ls_mlkit/util/base_class/base_time_class.py
167
168
169
170
171
172
173
174
175
176
def get_continuous_timesteps_schedule(self) -> Tensor:
    r"""Get the continuous timesteps schedule for sampling/inference.

    Returns:
        ``Tensor``: 1D tensor of continuous time values :math:`t \in [t_1, t_N]`.
    """
    assert self._continuous_timesteps is not None, "continuous timesteps schedule is not set"
    assert isinstance(self._continuous_timesteps, Tensor), "continuous timesteps must be a Tensor"
    assert self._continuous_timesteps.ndim == 1, "continuous timesteps must be a 1D Tensor"
    return self._continuous_timesteps

set_timestep_indices_schedule(timestep_indices)

Set the timestep indices schedule for sampling/inference.

Parameters:

Name Type Description Default
timestep_indices ``Tensor``

1D tensor of timestep indices :math:\text{idx} \in \{\text{idx\_start}, \ldots, \text{idx\_start} + N - 1\}.

required
Source code in src/ls_mlkit/util/base_class/base_time_class.py
178
179
180
181
182
183
184
def set_timestep_indices_schedule(self, timestep_indices: Tensor) -> None:
    r"""Set the timestep indices schedule for sampling/inference.

    Args:
        timestep_indices (``Tensor``): 1D tensor of timestep indices :math:`\text{idx} \in \{\text{idx\_start}, \ldots, \text{idx\_start} + N - 1\}`.
    """
    self._timesteps_idx = timestep_indices

set_continuous_timesteps_schedule(continuous_timesteps)

Set the continuous timesteps schedule for sampling/inference.

Parameters:

Name Type Description Default
continuous_timesteps ``Tensor``

1D tensor of continuous time values.

required
Source code in src/ls_mlkit/util/base_class/base_time_class.py
186
187
188
189
190
191
192
def set_continuous_timesteps_schedule(self, continuous_timesteps: Tensor) -> None:
    r"""Set the continuous timesteps schedule for sampling/inference.

    Args:
        continuous_timesteps (``Tensor``): 1D tensor of continuous time values.
    """
    self._continuous_timesteps = continuous_timesteps

sample_timestep_index_uniformly(macro_shape, same_for_all_samples=False)

Sample timestep indices uniformly from :math:\{\text{idx\_start}, \ldots, \text{idx\_start} + N - 1\}.

This corresponds to sampling discrete steps :math:i uniformly from :math:\{1, \ldots, N\} and converting to index via :math:\text{idx} = i - 1 + \text{idx\_start}.

Parameters:

Name Type Description Default
macro_shape ``Tuple[int, ...]``

Shape of the output tensor.

required
same_for_all_samples ``bool``, *optional*

If True, use the same timestep index for all samples. Defaults to False.

False

Returns:

Type Description
Tensor

Tensor: Timestep indices with shape macro_shape.

Source code in src/ls_mlkit/util/base_class/base_time_class.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def sample_timestep_index_uniformly(
    self, macro_shape: Tuple[int, ...], same_for_all_samples: bool = False
) -> Tensor:
    r"""Sample timestep indices uniformly from :math:`\{\text{idx\_start}, \ldots, \text{idx\_start} + N - 1\}`.

    This corresponds to sampling discrete steps :math:`i` uniformly from :math:`\{1, \ldots, N\}`
    and converting to index via :math:`\text{idx} = i - 1 + \text{idx\_start}`.

    Args:
        macro_shape (``Tuple[int, ...]``): Shape of the output tensor.
        same_for_all_samples (``bool``, *optional*): If True, use the same timestep index for all samples. Defaults to False.

    Returns:
        ``Tensor``: Timestep indices with shape ``macro_shape``.
    """
    idx_min = self.idx_start
    idx_max = self.idx_start + self.num_train_timesteps  # exclusive upper bound
    if same_for_all_samples:
        return torch.ones(macro_shape, dtype=torch.int64) * torch.randint(idx_min, idx_max, (1,), dtype=torch.int64)
    else:
        return torch.randint(idx_min, idx_max, macro_shape, dtype=torch.int64)

sample_continuous_time_uniformly(macro_shape, same_for_all_samples=False)

Sample continuous time uniformly from :math:[t_1, t_N].

Note: This samples from :math:[t_0 + \frac{T}{N}, t_0 + T] to exclude :math:t_0 (the clean data point).

Parameters:

Name Type Description Default
macro_shape ``Tuple[int, ...]``

Shape of the output tensor.

required
same_for_all_samples ``bool``, *optional*

If True, use the same time for all samples. Defaults to False.

False

Returns:

Type Description
Tensor

Tensor: Continuous time values with shape macro_shape.

Source code in src/ls_mlkit/util/base_class/base_time_class.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
def sample_continuous_time_uniformly(
    self, macro_shape: Tuple[int, ...], same_for_all_samples: bool = False
) -> Tensor:
    r"""Sample continuous time uniformly from :math:`[t_1, t_N]`.

    Note: This samples from :math:`[t_0 + \frac{T}{N}, t_0 + T]` to exclude :math:`t_0`
    (the clean data point).

    Args:
        macro_shape (``Tuple[int, ...]``): Shape of the output tensor.
        same_for_all_samples (``bool``, *optional*): If True, use the same time for all samples. Defaults to False.

    Returns:
        ``Tensor``: Continuous time values with shape ``macro_shape``.
    """
    # Sample from [t_1, t_N] = [t_0 + T/N, t_0 + T]
    t_min = self.continuous_time_start + self.T / self.num_train_timesteps  # t_1
    t_max = self.continuous_time_end  # t_N = T
    t_range = t_max - t_min

    if same_for_all_samples:
        return torch.ones(macro_shape) * (torch.rand(1) * t_range + t_min)
    else:
        return torch.rand(macro_shape) * t_range + t_min

DeviceConfig

Bases: object

Source code in src/ls_mlkit/util/base_class/base_config_class.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
@inherit_docstrings
class DeviceConfig(object):
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        pass

    def to(self, device: torch.device | str | Tensor, inplace: bool = True) -> "DeviceConfig":
        """Move the config to the given device

        Args:
            device (torch.device | str | Tensor): the device to move the config to
            inplace (bool, optional): whether to move the config in place. Defaults to True.

        Returns:
            BaseConfig: the config moved to the given device
        """
        obj = self if inplace else deepcopy(self)
        if isinstance(device, Tensor):
            device = device.device
        for k, v in obj.__dict__.items():
            if isinstance(v, Tensor):
                setattr(obj, k, v.to(device))
        return obj

to(device, inplace=True)

Move the config to the given device

Parameters:

Name Type Description Default
device device | str | Tensor

the device to move the config to

required
inplace bool

whether to move the config in place. Defaults to True.

True

Returns:

Name Type Description
BaseConfig DeviceConfig

the config moved to the given device

Source code in src/ls_mlkit/util/base_class/base_config_class.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def to(self, device: torch.device | str | Tensor, inplace: bool = True) -> "DeviceConfig":
    """Move the config to the given device

    Args:
        device (torch.device | str | Tensor): the device to move the config to
        inplace (bool, optional): whether to move the config in place. Defaults to True.

    Returns:
        BaseConfig: the config moved to the given device
    """
    obj = self if inplace else deepcopy(self)
    if isinstance(device, Tensor):
        device = device.device
    for k, v in obj.__dict__.items():
        if isinstance(v, Tensor):
            setattr(obj, k, v.to(device))
    return obj

HookManager

Bases: Generic[HookStageType]

Source code in src/ls_mlkit/util/hook/base_hook.py
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class HookManager(Generic[HookStageType]):
    def __init__(self) -> None:
        self._hooks: Dict[HookStageType, list[Hook[HookStageType]]] = {}

    def register_hook(self, hook: Hook[HookStageType]) -> HookHandler[HookStageType]:
        self._hooks.setdefault(hook.stage, []).append(hook)
        self._hooks[hook.stage].sort(key=lambda h: h.priority, reverse=False)
        return HookHandler(self, hook=hook)

    def register_hooks(self, hooks: list[Hook[HookStageType]]) -> list[HookHandler[HookStageType]]:
        return [self.register_hook(hook) for hook in hooks]

    def unregister_hook(self, name: str, stage: Optional[HookStageType] = None) -> None:
        if stage:
            if stage in self._hooks:
                self._hooks[stage] = [h for h in self._hooks[stage] if h.name != name]
        else:
            for s in self._hooks:
                self._hooks[s] = [h for h in self._hooks[s] if h.name != name]

    def enable_hook(
        self,
        name: Optional[str] = None,
        stage: Optional[HookStageType] = None,
        enabled: bool = True,
    ) -> None:
        hook_found = False
        for stage_key, hooks in self._hooks.items():
            if stage is not None and stage_key != stage:
                continue
            for h in hooks:
                if name is None or h.name == name:
                    h.enabled = enabled
                    hook_found = True
        if not hook_found:
            raise ValueError(f"Hook with name {name} not found.")

    def disable_hook(self, name: Optional[str] = None, stage: Optional[HookStageType] = None) -> None:
        self.enable_hook(name=name, stage=stage, enabled=False)

    def run_hooks(self, stage: HookStageType, tgt_key_name: Optional[str] = None, **kwargs) -> Optional[Any]:
        """Executes all enabled hooks for a given stage, optionally updating or collecting results in kwargs,
        and returns either the final modified kwargs or a specific key's value.

        Args:
            stage (``HookStageType``): _description_
            tgt_key_name (``str``, *optional*): target key name. Defaults to None.
        """
        hook_output: Optional[Any] = None

        if stage is not None and stage in self._hooks:
            for hook in self._hooks[stage]:
                if not hook.enabled:
                    continue
                hook_output = hook(**kwargs)
                if tgt_key_name is not None:
                    kwargs[tgt_key_name] = hook_output
                elif isinstance(hook_output, dict):
                    kwargs = hook_output

        if tgt_key_name is not None:
            return kwargs.get(tgt_key_name)
        return kwargs

    def list_hooks(self) -> None:
        for stage, hooks in self._hooks.items():
            print(f"[{stage}]")
            for h in hooks:
                print(f"  - {h} {'(enabled)' if h.enabled else '(disabled)'}")

run_hooks(stage, tgt_key_name=None, **kwargs)

Executes all enabled hooks for a given stage, optionally updating or collecting results in kwargs, and returns either the final modified kwargs or a specific key's value.

Parameters:

Name Type Description Default
stage ``HookStageType``

description

required
tgt_key_name ``str``, *optional*

target key name. Defaults to None.

None
Source code in src/ls_mlkit/util/hook/base_hook.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def run_hooks(self, stage: HookStageType, tgt_key_name: Optional[str] = None, **kwargs) -> Optional[Any]:
    """Executes all enabled hooks for a given stage, optionally updating or collecting results in kwargs,
    and returns either the final modified kwargs or a specific key's value.

    Args:
        stage (``HookStageType``): _description_
        tgt_key_name (``str``, *optional*): target key name. Defaults to None.
    """
    hook_output: Optional[Any] = None

    if stage is not None and stage in self._hooks:
        for hook in self._hooks[stage]:
            if not hook.enabled:
                continue
            hook_output = hook(**kwargs)
            if tgt_key_name is not None:
                kwargs[tgt_key_name] = hook_output
            elif isinstance(hook_output, dict):
                kwargs = hook_output

    if tgt_key_name is not None:
        return kwargs.get(tgt_key_name)
    return kwargs

SO3

Bases: LieGroup

SO(3): Special Orthogonal Group

Source code in src/ls_mlkit/util/manifold/so3.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
class SO3(LieGroup):
    """SO(3): Special Orthogonal Group"""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def exp(self, p: Tensor | None = None, v: Tensor | None = None) -> Tensor:
        r"""Exponential map
        $$\exp_p(v)$$ map a point in tangent space $$T_p M$$ to a point on the manifold $$M$$
        $$\exp_p(v) = p \cdot \exp(p^{-1} v)$$
        if p is None, it will be set to the identity matrix
        """
        assert v is not None, "v is required"
        if p is None:
            return exponential_map(v)
        p_inv = self.inverse(p)
        v_e = p_inv @ v
        result = p @ exponential_map(v_e)
        return result

    def log(self, p: Tensor | None = None, q: Tensor | None = None) -> Tensor:
        r"""Logarithm map
        $$\log_p(q)$$ map a point on the manifold $$M$$ to a point in tangent space $$T_p M$$
        $$\log_p(q)=p\log(p^{-1} q)$$
        if p is None, it will be set to the identity matrix
        """
        assert q is not None, "q is required"
        if p is None:
            return logarithmic_map(q)
        p_inv = self.inverse(p)
        q_e = p_inv @ q
        result = p @ logarithmic_map(q_e)
        return result

    def random_tangent(self, p: Tensor, random_type: str = "gaussian", std: float = 1.0) -> Tensor:
        r"""Sample noise from $$T_p M$$"""
        noise = None
        if random_type == "gaussian":
            noise = torch.randn(p.shape[:-2] + (3,), dtype=p.dtype, device=p.device) * std
        else:
            raise ValueError(f"Invalid random type: {random_type}")
        assert noise is not None
        result = vector_to_skew_symmetric(noise)
        return result

    def metric(self, p: Tensor, v: Tensor, w: Tensor) -> Tensor:
        r"""Inner product
        $$<v, w>_p: \mathfrak{so}(3) \times \mathfrak{so}(3) \to \mathbb{R}$$ is the inner product at point $$p$$
        $$
        <v, w>_p = \frac{1}{2} \text{Tr}(v^T w)
        $$
        """
        result = 1 / 2 * trace(v.transpose(-1, -2) @ w)
        return result

    def grad(self, f: Callable, p: Tensor) -> Tensor:
        r"""Riemannian gradient of f at point p on SO(3)

        $$
        p \cdot skew(p^{-1} \nabla_p f(p))
        $$

        Args:
            f: Callable[[Tensor], Tensor], scalar function of p
            p: (..., 3, 3) point on SO(3)

        Returns:
            (..., 3, 3) gradient in the tangent space T_p SO(3)
        """
        p = p.clone().detach().requires_grad_(True)
        y = f(p)
        if y.ndim > 0:
            y = y.sum()

        # Euclidean gradient
        grad_euclid = torch.autograd.grad(y, p, create_graph=True)[0]  # (..., 3, 3)

        # project to tangent space T_p SO(3)
        grad_riemann = p @ (
            0.5 * (p.transpose(-1, -2) @ grad_euclid - (p.transpose(-1, -2) @ grad_euclid).transpose(-1, -2))
        )

        return grad_riemann

    def multiply(self, p, q):
        r"""Multiply in Group"""
        assert p.shape == q.shape, "p and q must have the same shape"
        result = p @ q
        return result

    def inverse(self, p):
        r"""Inverse in Group"""
        result = p.transpose(-1, -2)
        return result

    def identity(self, macro_shape: Tuple[int, ...] = tuple()):
        r"""Identity in Group"""
        result = torch.eye(3).view(*macro_shape, 3, 3)
        return result

    def left_translation(self, g, h):
        r"""
        $$L_g(h) = g \cdot h$$
        """
        result = g @ h
        return result

exp(p=None, v=None)

Exponential map $$\exp_p(v)$$ map a point in tangent space $$T_p M$$ to a point on the manifold $$M$$ $$\exp_p(v) = p \cdot \exp(p^{-1} v)$$ if p is None, it will be set to the identity matrix

Source code in src/ls_mlkit/util/manifold/so3.py
27
28
29
30
31
32
33
34
35
36
37
38
39
def exp(self, p: Tensor | None = None, v: Tensor | None = None) -> Tensor:
    r"""Exponential map
    $$\exp_p(v)$$ map a point in tangent space $$T_p M$$ to a point on the manifold $$M$$
    $$\exp_p(v) = p \cdot \exp(p^{-1} v)$$
    if p is None, it will be set to the identity matrix
    """
    assert v is not None, "v is required"
    if p is None:
        return exponential_map(v)
    p_inv = self.inverse(p)
    v_e = p_inv @ v
    result = p @ exponential_map(v_e)
    return result

log(p=None, q=None)

Logarithm map $$\log_p(q)$$ map a point on the manifold $$M$$ to a point in tangent space $$T_p M$$ $$\log_p(q)=p\log(p^{-1} q)$$ if p is None, it will be set to the identity matrix

Source code in src/ls_mlkit/util/manifold/so3.py
41
42
43
44
45
46
47
48
49
50
51
52
53
def log(self, p: Tensor | None = None, q: Tensor | None = None) -> Tensor:
    r"""Logarithm map
    $$\log_p(q)$$ map a point on the manifold $$M$$ to a point in tangent space $$T_p M$$
    $$\log_p(q)=p\log(p^{-1} q)$$
    if p is None, it will be set to the identity matrix
    """
    assert q is not None, "q is required"
    if p is None:
        return logarithmic_map(q)
    p_inv = self.inverse(p)
    q_e = p_inv @ q
    result = p @ logarithmic_map(q_e)
    return result

random_tangent(p, random_type='gaussian', std=1.0)

Sample noise from $$T_p M$$

Source code in src/ls_mlkit/util/manifold/so3.py
55
56
57
58
59
60
61
62
63
64
def random_tangent(self, p: Tensor, random_type: str = "gaussian", std: float = 1.0) -> Tensor:
    r"""Sample noise from $$T_p M$$"""
    noise = None
    if random_type == "gaussian":
        noise = torch.randn(p.shape[:-2] + (3,), dtype=p.dtype, device=p.device) * std
    else:
        raise ValueError(f"Invalid random type: {random_type}")
    assert noise is not None
    result = vector_to_skew_symmetric(noise)
    return result

metric(p, v, w)

Inner product $$_p: \mathfrak{so}(3) \times \mathfrak{so}(3) \to \mathbb{R}$$ is the inner product at point $$p$$ $$ _p = \frac{1}{2} \text{Tr}(v^T w) $$

Source code in src/ls_mlkit/util/manifold/so3.py
66
67
68
69
70
71
72
73
74
def metric(self, p: Tensor, v: Tensor, w: Tensor) -> Tensor:
    r"""Inner product
    $$<v, w>_p: \mathfrak{so}(3) \times \mathfrak{so}(3) \to \mathbb{R}$$ is the inner product at point $$p$$
    $$
    <v, w>_p = \frac{1}{2} \text{Tr}(v^T w)
    $$
    """
    result = 1 / 2 * trace(v.transpose(-1, -2) @ w)
    return result

grad(f, p)

Riemannian gradient of f at point p on SO(3)

$$ p \cdot skew(p^{-1} \nabla_p f(p)) $$

Parameters:

Name Type Description Default
f Callable

Callable[[Tensor], Tensor], scalar function of p

required
p Tensor

(..., 3, 3) point on SO(3)

required

Returns:

Type Description
Tensor

(..., 3, 3) gradient in the tangent space T_p SO(3)

Source code in src/ls_mlkit/util/manifold/so3.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def grad(self, f: Callable, p: Tensor) -> Tensor:
    r"""Riemannian gradient of f at point p on SO(3)

    $$
    p \cdot skew(p^{-1} \nabla_p f(p))
    $$

    Args:
        f: Callable[[Tensor], Tensor], scalar function of p
        p: (..., 3, 3) point on SO(3)

    Returns:
        (..., 3, 3) gradient in the tangent space T_p SO(3)
    """
    p = p.clone().detach().requires_grad_(True)
    y = f(p)
    if y.ndim > 0:
        y = y.sum()

    # Euclidean gradient
    grad_euclid = torch.autograd.grad(y, p, create_graph=True)[0]  # (..., 3, 3)

    # project to tangent space T_p SO(3)
    grad_riemann = p @ (
        0.5 * (p.transpose(-1, -2) @ grad_euclid - (p.transpose(-1, -2) @ grad_euclid).transpose(-1, -2))
    )

    return grad_riemann

multiply(p, q)

Multiply in Group

Source code in src/ls_mlkit/util/manifold/so3.py
105
106
107
108
109
def multiply(self, p, q):
    r"""Multiply in Group"""
    assert p.shape == q.shape, "p and q must have the same shape"
    result = p @ q
    return result

inverse(p)

Inverse in Group

Source code in src/ls_mlkit/util/manifold/so3.py
111
112
113
114
def inverse(self, p):
    r"""Inverse in Group"""
    result = p.transpose(-1, -2)
    return result

identity(macro_shape=tuple())

Identity in Group

Source code in src/ls_mlkit/util/manifold/so3.py
116
117
118
119
def identity(self, macro_shape: Tuple[int, ...] = tuple()):
    r"""Identity in Group"""
    result = torch.eye(3).view(*macro_shape, 3, 3)
    return result

left_translation(g, h)

$$L_g(h) = g \cdot h$$

Source code in src/ls_mlkit/util/manifold/so3.py
121
122
123
124
125
126
def left_translation(self, g, h):
    r"""
    $$L_g(h) = g \cdot h$$
    """
    result = g @ h
    return result

LieGroup

Bases: RiemannianManifold

Lie Group

Source code in src/ls_mlkit/util/manifold/lie_group.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class LieGroup(RiemannianManifold):
    r"""Lie Group"""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def multiply(self, p, q):
        r"""Multiply in Group"""
        raise NotImplementedError

    def inverse(self, p):
        r"""Inverse in Group"""
        raise NotImplementedError

    def identity(self):
        r"""Identity in Group"""
        raise NotImplementedError

    def left_translation(self, g, h):
        r"""
        $$L_g(h) = g \cdot h$$
        """

multiply(p, q)

Multiply in Group

Source code in src/ls_mlkit/util/manifold/lie_group.py
14
15
16
def multiply(self, p, q):
    r"""Multiply in Group"""
    raise NotImplementedError

inverse(p)

Inverse in Group

Source code in src/ls_mlkit/util/manifold/lie_group.py
18
19
20
def inverse(self, p):
    r"""Inverse in Group"""
    raise NotImplementedError

identity()

Identity in Group

Source code in src/ls_mlkit/util/manifold/lie_group.py
22
23
24
def identity(self):
    r"""Identity in Group"""
    raise NotImplementedError

left_translation(g, h)

$$L_g(h) = g \cdot h$$

Source code in src/ls_mlkit/util/manifold/lie_group.py
26
27
28
29
def left_translation(self, g, h):
    r"""
    $$L_g(h) = g \cdot h$$
    """

RiemannianManifold

Bases: ABC

Riemannian Manifold

Source code in src/ls_mlkit/util/manifold/riemannian_manifold.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class RiemannianManifold(abc.ABC):
    """Riemannian Manifold"""

    @abc.abstractmethod
    def exp(self, p, v):
        r"""Exponential map
        $$exp_p(v)$$ map a point in tangent space to a point on the manifold
        $$
        T_p M \to M
        $$
        """
        raise NotImplementedError

    @abc.abstractmethod
    def log(self, p, q):
        r"""Logarithm map
        $$log_p(q)$$ map a point on the manifold to a point in tangent space
        $$
        M \to T_p M
        $$
        """
        raise NotImplementedError

    @abc.abstractmethod
    def random_tangent(self, p, random_type="gaussian", std=1.0):
        r"""Sample noise in the tangent space at point p
        $$T_p M$$
        """
        raise NotImplementedError

    @abc.abstractmethod
    def metric(self, p, v, w):
        r"""Inner product
        $$<v, w>_p$$ is the inner product of $$v$$ and $$w$$ at point $$p$$
        """
        raise NotImplementedError

    @abc.abstractmethod
    def grad(self, f, p):
        r"""Gradient
        $$\nabla_p f$$ is the gradient of $$f$$ at point $$p$$
        """
        raise NotImplementedError

exp(p, v) abstractmethod

Exponential map $$exp_p(v)$$ map a point in tangent space to a point on the manifold $$ T_p M \to M $$

Source code in src/ls_mlkit/util/manifold/riemannian_manifold.py
11
12
13
14
15
16
17
18
19
@abc.abstractmethod
def exp(self, p, v):
    r"""Exponential map
    $$exp_p(v)$$ map a point in tangent space to a point on the manifold
    $$
    T_p M \to M
    $$
    """
    raise NotImplementedError

log(p, q) abstractmethod

Logarithm map $$log_p(q)$$ map a point on the manifold to a point in tangent space $$ M \to T_p M $$

Source code in src/ls_mlkit/util/manifold/riemannian_manifold.py
21
22
23
24
25
26
27
28
29
@abc.abstractmethod
def log(self, p, q):
    r"""Logarithm map
    $$log_p(q)$$ map a point on the manifold to a point in tangent space
    $$
    M \to T_p M
    $$
    """
    raise NotImplementedError

random_tangent(p, random_type='gaussian', std=1.0) abstractmethod

Sample noise in the tangent space at point p $$T_p M$$

Source code in src/ls_mlkit/util/manifold/riemannian_manifold.py
31
32
33
34
35
36
@abc.abstractmethod
def random_tangent(self, p, random_type="gaussian", std=1.0):
    r"""Sample noise in the tangent space at point p
    $$T_p M$$
    """
    raise NotImplementedError

metric(p, v, w) abstractmethod

Inner product $$_p$$ is the inner product of $$v$$ and $$w$$ at point $$p$$

Source code in src/ls_mlkit/util/manifold/riemannian_manifold.py
38
39
40
41
42
43
@abc.abstractmethod
def metric(self, p, v, w):
    r"""Inner product
    $$<v, w>_p$$ is the inner product of $$v$$ and $$w$$ at point $$p$$
    """
    raise NotImplementedError

grad(f, p) abstractmethod

Gradient $$\nabla_p f$$ is the gradient of $$f$$ at point $$p$$

Source code in src/ls_mlkit/util/manifold/riemannian_manifold.py
45
46
47
48
49
50
@abc.abstractmethod
def grad(self, f, p):
    r"""Gradient
    $$\nabla_p f$$ is the gradient of $$f$$ at point $$p$$
    """
    raise NotImplementedError

Masker

Bases: MaskerInterface

Source code in src/ls_mlkit/util/mask/masker.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class Masker(MaskerInterface):
    def __init__(self, ndim_mini_micro_shape: int = 0, **kwargs: dict[Any, Any]):
        super().__init__(**kwargs)
        self.ndim_mini_micro_shape: int = ndim_mini_micro_shape

    def apply_mask(self, x: Tensor, mask: Tensor) -> Tensor:
        self.check_mask_shape(x, mask)
        if self.ndim_mini_micro_shape == 0:
            return x * mask
        else:
            return x * mask.view(*mask.shape, *[1 for _ in range(self.ndim_mini_micro_shape)])

    def check_mask_shape(self, x: Tensor, mask: Tensor):
        if self.ndim_mini_micro_shape == 0:
            assert x.shape == mask.shape
        else:
            assert x.shape[: -self.ndim_mini_micro_shape] == mask.shape

    def count_bright_area(self, mask: Tensor) -> Tensor:
        r"""
        Bright area can be seen
        Dark area cannot be seen
        """
        return torch.sum(mask)

    def get_full_bright_mask(self, x: Tensor) -> Tensor:
        if self.ndim_mini_micro_shape == 0:
            shape = x.shape
        else:
            shape = x.shape[: -self.ndim_mini_micro_shape]
        device = x.device

        return torch.ones(shape, device=device)

    def apply_inpainting_mask(self, x_0: Tensor, x_t: Tensor, inpainting_mask: Tensor) -> Tensor:
        r"""
        1 represents the region that can be seen
        """
        self.check_mask_shape(x_0, inpainting_mask)
        inpainting_mask = inpainting_mask.view(
            *inpainting_mask.shape,
            *[1 for _ in range(self.ndim_mini_micro_shape)],
        )
        return x_t * (1 - inpainting_mask) + x_0 * inpainting_mask

count_bright_area(mask)

Bright area can be seen Dark area cannot be seen

Source code in src/ls_mlkit/util/mask/masker.py
27
28
29
30
31
32
def count_bright_area(self, mask: Tensor) -> Tensor:
    r"""
    Bright area can be seen
    Dark area cannot be seen
    """
    return torch.sum(mask)

apply_inpainting_mask(x_0, x_t, inpainting_mask)

1 represents the region that can be seen

Source code in src/ls_mlkit/util/mask/masker.py
43
44
45
46
47
48
49
50
51
52
def apply_inpainting_mask(self, x_0: Tensor, x_t: Tensor, inpainting_mask: Tensor) -> Tensor:
    r"""
    1 represents the region that can be seen
    """
    self.check_mask_shape(x_0, inpainting_mask)
    inpainting_mask = inpainting_mask.view(
        *inpainting_mask.shape,
        *[1 for _ in range(self.ndim_mini_micro_shape)],
    )
    return x_t * (1 - inpainting_mask) + x_0 * inpainting_mask

MaskerInterface

Bases: ABC

Source code in src/ls_mlkit/util/mask/masker_interface.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class MaskerInterface(abc.ABC):
    def __init__(self, *args, **kwargs: dict[Any, Any]):
        self.args = args
        self.kwargs: dict[Any, Any] = kwargs

    @abc.abstractmethod
    def apply_mask(self, x: Tensor, mask: Tensor) -> Tensor:
        pass

    @abc.abstractmethod
    def check_mask_shape(self, x: Tensor, mask: Tensor):
        """
        check whether the shape of mask is as expected
        """

    @abc.abstractmethod
    def count_bright_area(self, mask: Tensor) -> Tensor:
        """
        Bright area can be seen
        Dark area cannot be seen
        """

    @abc.abstractmethod
    def get_full_bright_mask(self, x: Tensor) -> Tensor:
        """
        Return a mask that is all bright
        """

    @abc.abstractmethod
    def apply_inpainting_mask(self, x_0: Tensor, x_t: Tensor, inpainting_mask: Tensor) -> Tensor:
        """
        1 represents the region that can be seen
        """

check_mask_shape(x, mask) abstractmethod

check whether the shape of mask is as expected

Source code in src/ls_mlkit/util/mask/masker_interface.py
16
17
18
19
20
@abc.abstractmethod
def check_mask_shape(self, x: Tensor, mask: Tensor):
    """
    check whether the shape of mask is as expected
    """

count_bright_area(mask) abstractmethod

Bright area can be seen Dark area cannot be seen

Source code in src/ls_mlkit/util/mask/masker_interface.py
22
23
24
25
26
27
@abc.abstractmethod
def count_bright_area(self, mask: Tensor) -> Tensor:
    """
    Bright area can be seen
    Dark area cannot be seen
    """

get_full_bright_mask(x) abstractmethod

Return a mask that is all bright

Source code in src/ls_mlkit/util/mask/masker_interface.py
29
30
31
32
33
@abc.abstractmethod
def get_full_bright_mask(self, x: Tensor) -> Tensor:
    """
    Return a mask that is all bright
    """

apply_inpainting_mask(x_0, x_t, inpainting_mask) abstractmethod

1 represents the region that can be seen

Source code in src/ls_mlkit/util/mask/masker_interface.py
35
36
37
38
39
@abc.abstractmethod
def apply_inpainting_mask(self, x_0: Tensor, x_t: Tensor, inpainting_mask: Tensor) -> Tensor:
    """
    1 represents the region that can be seen
    """

Observer

Bases: object

Source code in src/ls_mlkit/util/observer.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
class Observer(object):
    function_mapping = {
        "weight_norm": weight_norm_fn,
        "gradient_norm": gradient_norm_fn,
        "weights": weights_fn,
        "gradients": gradients_fn,
    }

    def __init__(
        self,
        model: Optional[Module] = None,
        optimizer: Optional[Optimizer] = None,
        scheduler: Optional[LambdaLR] = None,
        dataset: Optional[Dataset | HFDataset] = None,
        target_modules: Optional[List[str]] = None,
        no_split_classes: Optional[List[str]] = None,
    ):
        """Initialize the Observer


        Args:
            model (Module, optional): the model to observe. Defaults to None.
            optimizer (Optimizer, optional): the optimizer to observe. Defaults to None.
            scheduler (LambdaLR, optional): the scheduler to observe. Defaults to None.
            dataset (Dataset | HFDataset, optional): the dataset to observe. Defaults to None.
            target_modules (List[str], optional): the modules to observe. Defaults to None. if target_modules is not None, then no_split_classes and strategy is ignored.
            no_split_classes (List[str], optional): the classes to not split. Defaults to None.
        """
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.dataset = dataset
        self.no_split_classes = no_split_classes
        self.target_modules = target_modules

    # get something=================================================================

    @torch.no_grad()
    @staticmethod
    def _get_something(
        model: Module,
        strategy: Literal["all", "block"] = "all",
        no_split_classes: Optional[List[str]] = None,
        function: Callable = weight_norm_fn,
    ):
        info = dict()

        def __get_something(module: Module, prefix=""):
            if (
                len(list(module.named_children())) == 0
                or (no_split_classes is not None and module.__class__.__name__ in no_split_classes)
            ) and any(param.requires_grad for param in module.parameters()):
                info[prefix] = function(module)
                return
            for name, sub_module in module.named_children():
                sub_module_name = f"{prefix}.{name}" if prefix != "" else name
                __get_something(sub_module, sub_module_name)

        match strategy:
            case "all":
                something = function(model)
                return {"total_model": something}
            case "block":
                __get_something(model, "")
                return info
            case _:
                raise ValueError(f"Unsupported strategy: {strategy}")

    @torch.no_grad()
    @staticmethod
    def _get_target_modules(model: Module, target_modules: List[str]):
        info = dict()

        def __get_target_modules(module: Module, prefix=""):
            if any(target_module in prefix for target_module in target_modules):
                info[prefix] = module
                return
            for name, sub_module in module.named_children():
                sub_module_name = f"{prefix}.{name}" if prefix != "" else name
                __get_target_modules(sub_module, sub_module_name)

        __get_target_modules(model, "")
        return info

    @torch.no_grad()
    @staticmethod
    def _get_something_from_targets(
        model: Optional[Module] = None,
        target_modules_dict: Optional[Dict[str, Module]] = None,
        target_modules: Optional[List[str]] = None,
        function: Callable = weight_norm_fn,
    ):
        info = dict()
        if target_modules_dict is None:
            assert model is not None and target_modules is not None
            target_modules_dict = Observer._get_target_modules(model, target_modules)
        for module_path, module in target_modules_dict.items():
            info[module_path] = function(module)
        return info

    @torch.no_grad()
    def get_something_from_targets(self, function: Callable):
        return Observer._get_something_from_targets(
            model=self.model,
            target_modules_dict=None,
            target_modules=self.target_modules,
            function=function,
        )

    @torch.no_grad()
    def get_something(
        self,
        name,
        strategy: Literal["all", "block"] = "all",
        no_split_classes: Optional[List[str]] = None,
    ):
        if self.target_modules is None:
            if no_split_classes is None:
                no_split_classes = self.no_split_classes
            assert self.model is not None
            return Observer._get_something(
                model=self.model,
                strategy=strategy,
                no_split_classes=no_split_classes,
                function=Observer.function_mapping[name],
            )
        return self.get_something_from_targets(function=Observer.function_mapping[name])

    @torch.no_grad()
    def get_weight_norm(
        self,
        strategy: Literal["all", "block"] = "all",
        no_split_classes: Optional[List[str]] = None,
    ):
        return self.get_something("weight_norm", strategy, no_split_classes)

    @torch.no_grad()
    def get_gradient_norm(
        self,
        strategy: Literal["all", "block"] = "all",
        no_split_classes: Optional[List[str]] = None,
    ):
        return self.get_something("gradient_norm", strategy, no_split_classes)

    @torch.no_grad()
    def get_weights(
        self,
        strategy: Literal["all", "block"] = "all",
        no_split_classes: Optional[List[str]] = None,
    ):
        return self.get_something("weights", strategy, no_split_classes)

    @torch.no_grad()
    def get_gradients(
        self,
        strategy: Literal["all", "block"] = "all",
        no_split_classes: Optional[List[str]] = None,
    ):
        return self.get_something("gradients", strategy, no_split_classes)

    @torch.no_grad()
    @staticmethod
    def _get_statistics(data: List[Tensor]):
        flattened_tensor = torch.cat([item.reshape(-1) for item in data], dim=0)
        mean = flattened_tensor.mean()
        std = flattened_tensor.std()
        median = flattened_tensor.median()
        var = flattened_tensor.var()
        return {"mean": mean, "std": std, "median": median, "variance": var}

    @torch.no_grad()
    def get_statistics(
        self,
        name,
        strategy: Literal["all", "block"] = "all",
        no_split_classes: Optional[List[str]] = None,
    ):
        something = self.get_something(name, strategy=strategy, no_split_classes=no_split_classes)
        return {key: Observer._get_statistics(value) for key, value in something.items()}

__init__(model=None, optimizer=None, scheduler=None, dataset=None, target_modules=None, no_split_classes=None)

Initialize the Observer

Parameters:

Name Type Description Default
model Module

the model to observe. Defaults to None.

None
optimizer Optimizer

the optimizer to observe. Defaults to None.

None
scheduler LambdaLR

the scheduler to observe. Defaults to None.

None
dataset Dataset | Dataset

the dataset to observe. Defaults to None.

None
target_modules List[str]

the modules to observe. Defaults to None. if target_modules is not None, then no_split_classes and strategy is ignored.

None
no_split_classes List[str]

the classes to not split. Defaults to None.

None
Source code in src/ls_mlkit/util/observer.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def __init__(
    self,
    model: Optional[Module] = None,
    optimizer: Optional[Optimizer] = None,
    scheduler: Optional[LambdaLR] = None,
    dataset: Optional[Dataset | HFDataset] = None,
    target_modules: Optional[List[str]] = None,
    no_split_classes: Optional[List[str]] = None,
):
    """Initialize the Observer


    Args:
        model (Module, optional): the model to observe. Defaults to None.
        optimizer (Optimizer, optional): the optimizer to observe. Defaults to None.
        scheduler (LambdaLR, optional): the scheduler to observe. Defaults to None.
        dataset (Dataset | HFDataset, optional): the dataset to observe. Defaults to None.
        target_modules (List[str], optional): the modules to observe. Defaults to None. if target_modules is not None, then no_split_classes and strategy is ignored.
        no_split_classes (List[str], optional): the classes to not split. Defaults to None.
    """
    self.model = model
    self.optimizer = optimizer
    self.scheduler = scheduler
    self.dataset = dataset
    self.no_split_classes = no_split_classes
    self.target_modules = target_modules

ForwardBackwardOffloadHookContext

Bases: ForwardHookForDevice

Source code in src/ls_mlkit/util/offload/forward_backward_offload.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
class ForwardBackwardOffloadHookContext(ForwardHookForDevice):
    def __init__(
        self,
        model,
        device="cuda",
        no_split_module_classes=None,
        enable=True,
        num_block: int = 2,
        strategy="block",
    ):
        """Offload model weights to CPU between forward/backward sub-blocks.

        Args:
            model: The model to which hooks will be applied.
            device: The compute device (e.g. "cuda").
            no_split_module_classes: Module class names that should not be split further.
            enable: If False, this context is a no-op.
            num_block: Number of blocks for the "block" strategy.
            strategy: Only ``'block'`` is implemented: partition leaf boundary modules into
                ``num_block`` groups.
        """
        super().__init__()
        self.enable = enable
        if not enable:
            return
        if strategy != "block":
            raise ValueError(f"Unsupported strategy {strategy!r}; only 'block' is implemented.")
        self.strategy = strategy
        self.num_block = num_block
        self.device = device
        self.model = model
        self.handle_list: list = []
        self.module_list = get_module_list(model, no_split_module_classes=no_split_module_classes)
        self.module_info = get_partition_block(self.module_list, self.num_block)

    def __enter__(self):
        if not self.enable:
            return self
        self.register_hook_by_block(self.model)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if not self.enable:
            return False
        for handle in self.handle_list:
            handle.remove()
        return False

    def register_hook_by_block(self, module: torch.nn.Module, parent_name=""):
        if parent_name and parent_name in self.module_list:
            handle = module.register_forward_pre_hook(
                hook=self.get_forward_hook_by_block(info=self.module_info[parent_name], pre=True, device=self.device),
                with_kwargs=True,
            )
            self.handle_list.append(handle)
            handle = module.register_forward_hook(
                hook=self.get_forward_hook_by_block(info=self.module_info[parent_name], pre=False, device=self.device),
                with_kwargs=True,
            )
            self.handle_list.append(handle)
            handle = module.register_full_backward_pre_hook(
                hook=self.get_backward_hook_by_block(info=self.module_info[parent_name], pre=True, device=self.device)
            )
            self.handle_list.append(handle)
            handle = module.register_full_backward_hook(
                hook=self.get_backward_hook_by_block(info=self.module_info[parent_name], pre=False, device=self.device)
            )
            self.handle_list.append(handle)
            return

        for name, sub_module in module.named_children():
            full_name = f"{parent_name}.{name}" if parent_name else name
            self.register_hook_by_block(sub_module, full_name)

    @staticmethod
    def get_forward_hook_by_block(info: dict, pre=True, device="cuda"):
        if device is None:
            device = "cuda"
        offload_device = "cpu"
        last_block_flag = info["last_block_flag"]
        first_module_flag = info["first_module_flag"]

        def pre_hook(module, args, kwargs):
            module.to(device)
            args = tuple(arg.to(device) if isinstance(arg, torch.Tensor) else arg for arg in args)
            kwargs = {n: v.to(device) if isinstance(v, torch.Tensor) else v for n, v in kwargs.items()}
            # Steer saved-tensor offloading: non-last blocks offload activations to CPU.
            if first_module_flag:
                OffloadSavedTensorHook._set_offload_device(offload_device if not last_block_flag else device)
            return args, kwargs

        def after_hook(module, args, kwargs, output):
            module.to(offload_device if not last_block_flag else device)
            return output

        if pre:
            return pre_hook
        else:
            return after_hook

    @staticmethod
    def get_backward_hook_by_block(info: dict, pre=True, device="cuda"):
        if device is None:
            device = "cuda"
        offload_device = "cpu"
        first_block_flag = info["first_block_flag"]

        def pre_hook(module, grad_output):
            module.to(device)
            return grad_output

        def after_hook(module, grad_input, grad_output):
            if not first_block_flag:
                module.to(offload_device)
            return grad_input

        if pre:
            return pre_hook
        else:
            return after_hook

__init__(model, device='cuda', no_split_module_classes=None, enable=True, num_block=2, strategy='block')

Offload model weights to CPU between forward/backward sub-blocks.

Parameters:

Name Type Description Default
model

The model to which hooks will be applied.

required
device

The compute device (e.g. "cuda").

'cuda'
no_split_module_classes

Module class names that should not be split further.

None
enable

If False, this context is a no-op.

True
num_block int

Number of blocks for the "block" strategy.

2
strategy

Only 'block' is implemented: partition leaf boundary modules into num_block groups.

'block'
Source code in src/ls_mlkit/util/offload/forward_backward_offload.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def __init__(
    self,
    model,
    device="cuda",
    no_split_module_classes=None,
    enable=True,
    num_block: int = 2,
    strategy="block",
):
    """Offload model weights to CPU between forward/backward sub-blocks.

    Args:
        model: The model to which hooks will be applied.
        device: The compute device (e.g. "cuda").
        no_split_module_classes: Module class names that should not be split further.
        enable: If False, this context is a no-op.
        num_block: Number of blocks for the "block" strategy.
        strategy: Only ``'block'`` is implemented: partition leaf boundary modules into
            ``num_block`` groups.
    """
    super().__init__()
    self.enable = enable
    if not enable:
        return
    if strategy != "block":
        raise ValueError(f"Unsupported strategy {strategy!r}; only 'block' is implemented.")
    self.strategy = strategy
    self.num_block = num_block
    self.device = device
    self.model = model
    self.handle_list: list = []
    self.module_list = get_module_list(model, no_split_module_classes=no_split_module_classes)
    self.module_info = get_partition_block(self.module_list, self.num_block)

GradientOffloadHookContext

Source code in src/ls_mlkit/util/offload/gradient_offload.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class GradientOffloadHookContext:
    def __init__(self, model: torch.nn.Module, enable: bool, record_dict: dict, *args, **kwargs):
        """Offload gradients to CPU after each accumulation step.

        Args:
            model: The model whose gradients will be offloaded.
            enable: If False, this context is a no-op.
            record_dict: Dictionary that accumulates offloaded named gradients.
        """
        self.enable = enable
        if not enable:
            return
        self.model = model
        self.record_dict = record_dict
        self.offload_device = "cpu"
        self.handle_list: list = []

    def __enter__(self):
        if not self.enable:
            return self
        for name, param in self.model.named_parameters():
            handle = param.register_post_accumulate_grad_hook(
                self._make_offload_grad_hook(name, self.record_dict, self.offload_device)
            )
            self.handle_list.append(handle)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if not self.enable:
            return False
        for handle in self.handle_list:
            handle.remove()
        return False

    @staticmethod
    def _make_offload_grad_hook(name: str, record_dict: dict, offload_device: str):
        def offload_grad_hook(param):
            if param.grad is None:
                return
            grad = param.grad.to(offload_device)
            param.grad = None
            if name not in record_dict:
                record_dict[name] = grad
            else:
                acc = record_dict[name]
                if acc.dtype == grad.dtype and acc.device == grad.device:
                    acc.add_(grad)
                else:
                    record_dict[name] = acc + grad

        return offload_grad_hook

__init__(model, enable, record_dict, *args, **kwargs)

Offload gradients to CPU after each accumulation step.

Parameters:

Name Type Description Default
model Module

The model whose gradients will be offloaded.

required
enable bool

If False, this context is a no-op.

required
record_dict dict

Dictionary that accumulates offloaded named gradients.

required
Source code in src/ls_mlkit/util/offload/gradient_offload.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
def __init__(self, model: torch.nn.Module, enable: bool, record_dict: dict, *args, **kwargs):
    """Offload gradients to CPU after each accumulation step.

    Args:
        model: The model whose gradients will be offloaded.
        enable: If False, this context is a no-op.
        record_dict: Dictionary that accumulates offloaded named gradients.
    """
    self.enable = enable
    if not enable:
        return
    self.model = model
    self.record_dict = record_dict
    self.offload_device = "cpu"
    self.handle_list: list = []

ModelOffloadHookContext

Source code in src/ls_mlkit/util/offload/model_offload.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class ModelOffloadHookContext:
    def __init__(
        self,
        model,
        no_split_module_classes=None,
        num_block: int = 2,
        enable=True,
        device="cuda",
        strategy="block",
    ):
        """Combine forward/backward offloading with saved-tensor offloading.

        Args:
            model: The model to which hooks will be applied.
            no_split_module_classes: Module class names that should not be split further.
            num_block: Number of blocks for the "block" strategy.
            enable: If False, this context is a no-op.
            device: The compute device (e.g. "cuda").
            strategy: Only ``'block'`` is supported (see ``ForwardBackwardOffloadHookContext``).
        """
        self.enable = enable
        if not enable:
            return
        self.forwardBackwardOffloadHookContext = ForwardBackwardOffloadHookContext(
            model=model,
            device=device,
            no_split_module_classes=no_split_module_classes,
            enable=True,
            num_block=num_block,
            strategy=strategy,
        )
        self.savedTensorOffloadContext = OffloadSavedTensorHookContext()

    def __enter__(self):
        self._stack = ExitStack()
        if self.enable:
            self._stack.enter_context(self.forwardBackwardOffloadHookContext)
            self._stack.enter_context(self.savedTensorOffloadContext)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        return self._stack.__exit__(exc_type, exc_val, exc_tb)

__init__(model, no_split_module_classes=None, num_block=2, enable=True, device='cuda', strategy='block')

Combine forward/backward offloading with saved-tensor offloading.

Parameters:

Name Type Description Default
model

The model to which hooks will be applied.

required
no_split_module_classes

Module class names that should not be split further.

None
num_block int

Number of blocks for the "block" strategy.

2
enable

If False, this context is a no-op.

True
device

The compute device (e.g. "cuda").

'cuda'
strategy

Only 'block' is supported (see ForwardBackwardOffloadHookContext).

'block'
Source code in src/ls_mlkit/util/offload/model_offload.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def __init__(
    self,
    model,
    no_split_module_classes=None,
    num_block: int = 2,
    enable=True,
    device="cuda",
    strategy="block",
):
    """Combine forward/backward offloading with saved-tensor offloading.

    Args:
        model: The model to which hooks will be applied.
        no_split_module_classes: Module class names that should not be split further.
        num_block: Number of blocks for the "block" strategy.
        enable: If False, this context is a no-op.
        device: The compute device (e.g. "cuda").
        strategy: Only ``'block'`` is supported (see ``ForwardBackwardOffloadHookContext``).
    """
    self.enable = enable
    if not enable:
        return
    self.forwardBackwardOffloadHookContext = ForwardBackwardOffloadHookContext(
        model=model,
        device=device,
        no_split_module_classes=no_split_module_classes,
        enable=True,
        num_block=num_block,
        strategy=strategy,
    )
    self.savedTensorOffloadContext = OffloadSavedTensorHookContext()

Scheduler

Source code in src/ls_mlkit/util/scheduler.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class Scheduler:
    def __init__(
        self,
        info: dict[str, dict[str, Any]],
        total: int,
    ):
        self.info = info
        self.total = total
        self.current = 0
        for key, value in self.info.items():
            if value.get("value") is None:
                raise ValueError(f"value of {key} is not defined")
            if value.get("schedule") is None:
                raise ValueError(f"schedule of {key} is not defined")
            if value.get("warmup_steps") is None:
                assert (
                    value.get("warmup_ratio") is not None
                ), f"warmup_ratio of {key} must be provided if warmup_steps is not provided"
                value["warmup_steps"] = int(self.total * value["warmup_ratio"])

    def step(self):
        """Step the scheduler"""
        self.current += 1
        for key, value in self.info.items():
            value["current_value"] = value["schedule"](value["value"], self.current, self.total, value["warmup_steps"])

    def get(self, key=None):
        """Get the current value of the scheduler

        Args:
            key (str, optional): The key of the scheduler to get. If None, return the entire scheduler info. Defaults to None.

        Returns:
            dict[str, Any] or Any: The entire scheduler info or the value of the scheduler for the given key
        """
        if key is None:
            return self.info
        else:
            return self.info[key]["current_value"]

step()

Step the scheduler

Source code in src/ls_mlkit/util/scheduler.py
70
71
72
73
74
def step(self):
    """Step the scheduler"""
    self.current += 1
    for key, value in self.info.items():
        value["current_value"] = value["schedule"](value["value"], self.current, self.total, value["warmup_steps"])

get(key=None)

Get the current value of the scheduler

Parameters:

Name Type Description Default
key str

The key of the scheduler to get. If None, return the entire scheduler info. Defaults to None.

None

Returns:

Type Description

dict[str, Any] or Any: The entire scheduler info or the value of the scheduler for the given key

Source code in src/ls_mlkit/util/scheduler.py
76
77
78
79
80
81
82
83
84
85
86
87
88
def get(self, key=None):
    """Get the current value of the scheduler

    Args:
        key (str, optional): The key of the scheduler to get. If None, return the entire scheduler info. Defaults to None.

    Returns:
        dict[str, Any] or Any: The entire scheduler info or the value of the scheduler for the given key
    """
    if key is None:
        return self.info
    else:
        return self.info[key]["current_value"]

SDE

Bases: ABC

SDE abstract class. Functions are designed for a mini-batch of inputs.

Source code in src/ls_mlkit/util/sde/base_sde.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
class SDE(abc.ABC):
    r"""
    SDE abstract class. Functions are designed for a mini-batch of inputs.
    """

    def __init__(self, ndim_micro_shape: int = 2, n_discretization_steps: int = 1000):
        r"""Initialize the SDE

        Args:
            ndim_micro_shape (``int``, *optional*): number of dimensions of a sample.
            e.g. for image with shape ``[b, c, h, w]``, ndim_micro_shape = 3
            e.g. for protein with shape ``[b, n_res, 3]``, ndim_micro_shape = 2
            n_discretization_steps (``int``, *optional*): number of discretization steps.
        """
        super().__init__()
        self.ndim_micro_shape = ndim_micro_shape
        self.n_discretization_steps = n_discretization_steps

    @property
    @abc.abstractmethod
    def T(self) -> float:
        r"""End time of the SDE."""

    @abc.abstractmethod
    def prior_sampling(self, shape: Tuple) -> Tensor:
        r"""Sample from the prior distribution.

        Args:
            shape (``Tuple``): the shape of the sample.

        Returns:
            ``Tensor``: a sample from the prior distribution.
        """

    @abc.abstractmethod
    def get_drift_and_diffusion(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]:
        r"""Get the drift and diffusion of the SDE.

        Args:
            x (``Tensor``): the sample.
            t (``Tensor``): the time step.
            mask (``Tensor``, *optional*): the mask of the sample. Defaults to None.

        Returns:
            ``Tuple[Tensor, Tensor]``: the drift and diffusion of the SDE.
        """

    def get_reverse_sde(
        self,
        score: Optional[Tensor] = None,
        score_fn: Optional[Callable[..., Tensor]] = None,
        use_probability_flow: bool = False,
    ):
        r"""Create the reverse-time SDE/ODE.

        Args:
            score_fn: A time-dependent score-based model that takes (x ,t, mask) and returns the score.
            use_probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
        """
        ndim_micro_shape = self.ndim_micro_shape
        get_forward_drift_and_diffusion = self.get_drift_and_diffusion
        # get_forward_discretized_drift_and_diffusion = self.get_discretized_drift_and_diffusion

        class RSDE(SDE):
            def __init__(self):
                self.use_probability_flow = use_probability_flow
                self.ndim_micro_shape = ndim_micro_shape

            def get_drift_and_diffusion(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]:
                r"""
                Create the drift and diffusion functions for the reverse SDE/ODE.
                $$
                \begin{align*}
                    dx = (f(x,t) - g(x,t)^2 \nabla_x \log p_t(x)) dt + g(x,t) dw
                \end{align*}
                $$
                if use ODE probability flow:
                $$
                \begin{align*}
                    dx = (f(x,t) - \frac{1}{2} g(x,t)^2 \nabla_x \log p_t(x)) dt
                \end{align*}
                $$
                """
                nonlocal score, score_fn
                if score is None and score_fn is None:
                    raise ValueError("either score or score_fn must be provided")
                if score is None:
                    assert score_fn is not None
                    score = score_fn(x, t, mask)
                drift, diffusion = get_forward_drift_and_diffusion(x, t, mask=mask)
                rev_diffusion: Tensor = torch.zeros_like(diffusion) if self.use_probability_flow else diffusion
                diffusion = diffusion.view(
                    *x.shape[: -self.ndim_micro_shape],
                    *[1 for _ in range(self.ndim_micro_shape)],
                )
                rev_drift = drift - diffusion**2 * score * (0.5 if self.use_probability_flow else 1.0)
                # Set the diffusion function to zero for ODEs.
                return rev_drift, rev_diffusion

        return RSDE()

T abstractmethod property

End time of the SDE.

__init__(ndim_micro_shape=2, n_discretization_steps=1000)

Initialize the SDE

Parameters:

Name Type Description Default
ndim_micro_shape ``int``, *optional*

number of dimensions of a sample.

2
n_discretization_steps ``int``, *optional*

number of discretization steps.

1000
Source code in src/ls_mlkit/util/sde/base_sde.py
19
20
21
22
23
24
25
26
27
28
29
30
def __init__(self, ndim_micro_shape: int = 2, n_discretization_steps: int = 1000):
    r"""Initialize the SDE

    Args:
        ndim_micro_shape (``int``, *optional*): number of dimensions of a sample.
        e.g. for image with shape ``[b, c, h, w]``, ndim_micro_shape = 3
        e.g. for protein with shape ``[b, n_res, 3]``, ndim_micro_shape = 2
        n_discretization_steps (``int``, *optional*): number of discretization steps.
    """
    super().__init__()
    self.ndim_micro_shape = ndim_micro_shape
    self.n_discretization_steps = n_discretization_steps

prior_sampling(shape) abstractmethod

Sample from the prior distribution.

Parameters:

Name Type Description Default
shape ``Tuple``

the shape of the sample.

required

Returns:

Type Description
Tensor

Tensor: a sample from the prior distribution.

Source code in src/ls_mlkit/util/sde/base_sde.py
37
38
39
40
41
42
43
44
45
46
@abc.abstractmethod
def prior_sampling(self, shape: Tuple) -> Tensor:
    r"""Sample from the prior distribution.

    Args:
        shape (``Tuple``): the shape of the sample.

    Returns:
        ``Tensor``: a sample from the prior distribution.
    """

get_drift_and_diffusion(x, t, mask=None) abstractmethod

Get the drift and diffusion of the SDE.

Parameters:

Name Type Description Default
x ``Tensor``

the sample.

required
t ``Tensor``

the time step.

required
mask ``Tensor``, *optional*

the mask of the sample. Defaults to None.

None

Returns:

Type Description
Tuple[Tensor, Tensor]

Tuple[Tensor, Tensor]: the drift and diffusion of the SDE.

Source code in src/ls_mlkit/util/sde/base_sde.py
48
49
50
51
52
53
54
55
56
57
58
59
@abc.abstractmethod
def get_drift_and_diffusion(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]:
    r"""Get the drift and diffusion of the SDE.

    Args:
        x (``Tensor``): the sample.
        t (``Tensor``): the time step.
        mask (``Tensor``, *optional*): the mask of the sample. Defaults to None.

    Returns:
        ``Tuple[Tensor, Tensor]``: the drift and diffusion of the SDE.
    """

get_reverse_sde(score=None, score_fn=None, use_probability_flow=False)

Create the reverse-time SDE/ODE.

Parameters:

Name Type Description Default
score_fn Optional[Callable[..., Tensor]]

A time-dependent score-based model that takes (x ,t, mask) and returns the score.

None
use_probability_flow bool

If True, create the reverse-time ODE used for probability flow sampling.

False
Source code in src/ls_mlkit/util/sde/base_sde.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def get_reverse_sde(
    self,
    score: Optional[Tensor] = None,
    score_fn: Optional[Callable[..., Tensor]] = None,
    use_probability_flow: bool = False,
):
    r"""Create the reverse-time SDE/ODE.

    Args:
        score_fn: A time-dependent score-based model that takes (x ,t, mask) and returns the score.
        use_probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
    """
    ndim_micro_shape = self.ndim_micro_shape
    get_forward_drift_and_diffusion = self.get_drift_and_diffusion
    # get_forward_discretized_drift_and_diffusion = self.get_discretized_drift_and_diffusion

    class RSDE(SDE):
        def __init__(self):
            self.use_probability_flow = use_probability_flow
            self.ndim_micro_shape = ndim_micro_shape

        def get_drift_and_diffusion(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]:
            r"""
            Create the drift and diffusion functions for the reverse SDE/ODE.
            $$
            \begin{align*}
                dx = (f(x,t) - g(x,t)^2 \nabla_x \log p_t(x)) dt + g(x,t) dw
            \end{align*}
            $$
            if use ODE probability flow:
            $$
            \begin{align*}
                dx = (f(x,t) - \frac{1}{2} g(x,t)^2 \nabla_x \log p_t(x)) dt
            \end{align*}
            $$
            """
            nonlocal score, score_fn
            if score is None and score_fn is None:
                raise ValueError("either score or score_fn must be provided")
            if score is None:
                assert score_fn is not None
                score = score_fn(x, t, mask)
            drift, diffusion = get_forward_drift_and_diffusion(x, t, mask=mask)
            rev_diffusion: Tensor = torch.zeros_like(diffusion) if self.use_probability_flow else diffusion
            diffusion = diffusion.view(
                *x.shape[: -self.ndim_micro_shape],
                *[1 for _ in range(self.ndim_micro_shape)],
            )
            rev_drift = drift - diffusion**2 * score * (0.5 if self.use_probability_flow else 1.0)
            # Set the diffusion function to zero for ODEs.
            return rev_drift, rev_diffusion

    return RSDE()

VESDE

Bases: SDE

Source code in src/ls_mlkit/util/sde/sde_lib.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
class VESDE(SDE):
    def __init__(
        self,
        sigma_min=0.01,
        sigma_max=50,
        n_discretization_steps=1000,
        ndim_micro_shape=2,
        drop_first_step=False,
    ):
        """Construct a Variance Exploding SDE.

        Args:

            sigma_min: smallest sigma.
            sigma_max: largest sigma.
            n_discretization_steps: number of discretization steps
            ndim_micro_shape: number of dimensions of a sample
        """
        super().__init__(
            n_discretization_steps=n_discretization_steps,
            ndim_micro_shape=ndim_micro_shape,
        )
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
        self.drop_first_step = drop_first_step
        sigma_min = torch.tensor(sigma_min)
        sigma_max = torch.tensor(sigma_max)
        if drop_first_step:
            self.discrete_sigmas = (
                10
                ** torch.linspace(
                    torch.log10(sigma_min),
                    torch.log10(sigma_max),
                    n_discretization_steps + 1,
                )[1:]
            )
        else:
            self.discrete_sigmas = torch.exp(
                torch.linspace(
                    torch.log(sigma_min),
                    torch.log(sigma_max),
                    n_discretization_steps,
                )
            )

    @property
    def T(self) -> float:
        return 1

    def get_drift_and_diffusion(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]:
        r"""
        .. math::

            dx = 0 dt + \sigma_{min} \left(\frac{\sigma_{max}}{\sigma_{min}}\right)^t \sqrt{2 \log(\frac{\sigma_{max}}{\sigma_{min}})} dw
            \sigma_t = \sigma_{min} \left(\frac{\sigma_{max}}{\sigma_{min}}\right)^t

            diffusion = \sigma_t * \sqrt{2 \log(\frac{\sigma_{max}}{\sigma_{min}})}

        """
        sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
        drift = torch.zeros_like(x)
        diffusion = sigma * torch.sqrt(
            torch.tensor(
                2 * (np.log(self.sigma_max) - np.log(self.sigma_min)),
                device=t.device,
            )
        )
        return drift, diffusion

    def get_discretized_drift_and_diffusion(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]:
        r"""SMLD(NCSN) discretization.
        .. math::

            x_t &= x_0 + g \epsilon

            x_t &\sim \mathcal{N}(x_0, \sigma_t^2)

            \sigma_t^2 &= \sigma_{t-1}^2 + g^2

            g &= \sqrt{\sigma_t^2 - \sigma_{t-1}^2}

        """
        timestep = (t * (self.n_discretization_steps - 1) / self.T).long()
        sigma = self.discrete_sigmas.to(t.device)[timestep]
        adjacent_sigma = torch.where(
            timestep == 0,
            torch.zeros_like(t),
            self.discrete_sigmas[timestep - 1].to(t.device),
        )
        f = torch.zeros_like(x)
        g = torch.sqrt(sigma**2 - adjacent_sigma**2)
        return f, g

    def marginal_prob(self, x, t, mask=None):
        std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
        mean = x
        return mean, std

    def prior_sampling(self, shape):
        return torch.randn(*shape) * self.sigma_max

    def prior_logp(self, z):
        shape = z.shape
        N = np.prod(shape[1:])
        return -N / 2.0 * np.log(2 * np.pi * self.sigma_max**2) - torch.sum(z**2, dim=(1, 2, 3)) / (
            2 * self.sigma_max**2
        )

__init__(sigma_min=0.01, sigma_max=50, n_discretization_steps=1000, ndim_micro_shape=2, drop_first_step=False)

Construct a Variance Exploding SDE.

Args:

sigma_min: smallest sigma.
sigma_max: largest sigma.
n_discretization_steps: number of discretization steps
ndim_micro_shape: number of dimensions of a sample
Source code in src/ls_mlkit/util/sde/sde_lib.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def __init__(
    self,
    sigma_min=0.01,
    sigma_max=50,
    n_discretization_steps=1000,
    ndim_micro_shape=2,
    drop_first_step=False,
):
    """Construct a Variance Exploding SDE.

    Args:

        sigma_min: smallest sigma.
        sigma_max: largest sigma.
        n_discretization_steps: number of discretization steps
        ndim_micro_shape: number of dimensions of a sample
    """
    super().__init__(
        n_discretization_steps=n_discretization_steps,
        ndim_micro_shape=ndim_micro_shape,
    )
    self.sigma_min = sigma_min
    self.sigma_max = sigma_max
    self.drop_first_step = drop_first_step
    sigma_min = torch.tensor(sigma_min)
    sigma_max = torch.tensor(sigma_max)
    if drop_first_step:
        self.discrete_sigmas = (
            10
            ** torch.linspace(
                torch.log10(sigma_min),
                torch.log10(sigma_max),
                n_discretization_steps + 1,
            )[1:]
        )
    else:
        self.discrete_sigmas = torch.exp(
            torch.linspace(
                torch.log(sigma_min),
                torch.log(sigma_max),
                n_discretization_steps,
            )
        )

get_drift_and_diffusion(x, t, mask=None)

.. math::

dx = 0 dt + \sigma_{min} \left(\frac{\sigma_{max}}{\sigma_{min}}\right)^t \sqrt{2 \log(\frac{\sigma_{max}}{\sigma_{min}})} dw
\sigma_t = \sigma_{min} \left(\frac{\sigma_{max}}{\sigma_{min}}\right)^t

diffusion = \sigma_t * \sqrt{2 \log(\frac{\sigma_{max}}{\sigma_{min}})}
Source code in src/ls_mlkit/util/sde/sde_lib.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
def get_drift_and_diffusion(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]:
    r"""
    .. math::

        dx = 0 dt + \sigma_{min} \left(\frac{\sigma_{max}}{\sigma_{min}}\right)^t \sqrt{2 \log(\frac{\sigma_{max}}{\sigma_{min}})} dw
        \sigma_t = \sigma_{min} \left(\frac{\sigma_{max}}{\sigma_{min}}\right)^t

        diffusion = \sigma_t * \sqrt{2 \log(\frac{\sigma_{max}}{\sigma_{min}})}

    """
    sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
    drift = torch.zeros_like(x)
    diffusion = sigma * torch.sqrt(
        torch.tensor(
            2 * (np.log(self.sigma_max) - np.log(self.sigma_min)),
            device=t.device,
        )
    )
    return drift, diffusion

get_discretized_drift_and_diffusion(x, t, mask=None)

SMLD(NCSN) discretization. .. math::

x_t &= x_0 + g \epsilon

x_t &\sim \mathcal{N}(x_0, \sigma_t^2)

\sigma_t^2 &= \sigma_{t-1}^2 + g^2

g &= \sqrt{\sigma_t^2 - \sigma_{t-1}^2}
Source code in src/ls_mlkit/util/sde/sde_lib.py
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
def get_discretized_drift_and_diffusion(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]:
    r"""SMLD(NCSN) discretization.
    .. math::

        x_t &= x_0 + g \epsilon

        x_t &\sim \mathcal{N}(x_0, \sigma_t^2)

        \sigma_t^2 &= \sigma_{t-1}^2 + g^2

        g &= \sqrt{\sigma_t^2 - \sigma_{t-1}^2}

    """
    timestep = (t * (self.n_discretization_steps - 1) / self.T).long()
    sigma = self.discrete_sigmas.to(t.device)[timestep]
    adjacent_sigma = torch.where(
        timestep == 0,
        torch.zeros_like(t),
        self.discrete_sigmas[timestep - 1].to(t.device),
    )
    f = torch.zeros_like(x)
    g = torch.sqrt(sigma**2 - adjacent_sigma**2)
    return f, g

VPSDE

Bases: SDE

Source code in src/ls_mlkit/util/sde/sde_lib.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
class VPSDE(SDE):
    def __init__(
        self,
        beta_min: float = 0.1,
        beta_max: float = 20,
        ndim_micro_shape: int = 2,
    ):
        r"""Construct a Variance Preserving SDE.

        Args:
            beta_min: value of beta(0)
            beta_max: value of beta(1)
            ndim_micro_shape: number of dimensions of a sample
        """
        super().__init__(ndim_micro_shape=ndim_micro_shape)
        self.beta_0 = beta_min
        self.beta_1 = beta_max

    @property
    def T(self) -> float:
        return 1

    def get_drift_and_diffusion(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]:
        r"""continuous DDPM SDE

        .. math::

            dx &= -\frac{1}{2}\beta_t x dt + \sqrt{\beta_t} dw

        Args:
            x:
            t: (macro_shape)
            mask:

        Returns:
            drift: shape = x.shape
            diffusion: shape=x.macro_shape
        """
        macro_shape = x.shape[: -self.ndim_micro_shape]
        beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
        drift = -0.5 * beta_t.view(*macro_shape, *[1 for _ in range(self.ndim_micro_shape)]) * x
        diffusion = torch.sqrt(beta_t)
        return drift, diffusion

    def get_score(self, x_t, mean, std) -> Tensor:
        r"""

        .. math::

            p_{0t} (x_t|x_0) = \nabla_{x_t} \ln p_{0t} (x_t|x_0)

        """
        score = -(x_t - mean) / std**2
        return score

    def get_a_b(self, t: Tensor) -> Tuple[Tensor, Tensor]:
        """x_t = a * x_0 + b * epsilon, epsilon ~ N(0, 1)

        Args:
            t (``Tensor``): continuous time

        Returns:
            ``Tuple[Tensor, Tensor]``: a, b
        """
        macro_shape = t.shape
        log_mean_coeff = -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0  # mcro_shape
        log_mean_coeff = log_mean_coeff.view(*macro_shape, *[1 for _ in range(self.ndim_micro_shape)])
        a = torch.exp(log_mean_coeff)
        b = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
        return a, b

    def forward_process(self, x_0: Tensor, t: Tensor, mask: Optional[Tensor] = None) -> Dict[str, Tensor]:
        r"""

        .. math::

            p_{0t} (x_t|x_0)

        .. math::

            \gamma = -\frac{1}{4}t^2 (\beta_1 - \beta_0) - \frac{1}{2} t  \beta_0

            mean = e^{\gamma} * x

            std = \sqrt{1 - e^{2 \gamma }}

        """
        a, b = self.get_a_b(t)
        mean = a * x_0
        x_t = mean + b * torch.randn_like(x_0)
        return {
            "x_t": x_t,
            "mean": mean,
            "std": b,
            "a": a,
            "b": b,
        }

    def forward_from_t1_to_t2(self, x_t1: Tensor, t1: Tensor, t2: Tensor) -> Tensor:
        assert (t1 <= t2).all(), "t1 must be less than or equal to t2"
        a1, b1 = self.get_a_b(t1)
        a2, b2 = self.get_a_b(t2)
        a12 = a2 / a1
        b12 = a2 * torch.sqrt((b2 / a2) ** 2 - (b1 / a1) ** 2)
        x_t2 = a12 * x_t1 + b12 * torch.randn_like(x_t1)
        return x_t2

    def prior_sampling(self, shape: Tuple) -> Tensor:
        r"""
        .. math::
            \epsilon \sim \mathbfcal{N}(0,1)

        """
        return torch.randn(*shape)

    def prior_logp(self, z: torch.Tensor) -> Tensor:
        r"""

        .. math::

            (2\pi)^{-k/2} \det(\Sigma)^{-1/2} \exp\left( -\frac{1}{2} (\mathbf{x} - \boldsymbol{\mu})^\mathrm{T} \Sigma^{-1} (\mathbf{x} - \boldsymbol{\mu}) \right)

        where :math:`\Sigma = I` and  :math:`\mathbf{\mu} = 0`
        """
        shape = z.shape
        N = np.prod(shape[1:])
        logps = -N / 2.0 * np.log(2 * np.pi) - torch.sum(z**2, dim=(1, 2, 3)) / 2.0
        return logps

__init__(beta_min=0.1, beta_max=20, ndim_micro_shape=2)

Construct a Variance Preserving SDE.

Parameters:

Name Type Description Default
beta_min float

value of beta(0)

0.1
beta_max float

value of beta(1)

20
ndim_micro_shape int

number of dimensions of a sample

2
Source code in src/ls_mlkit/util/sde/sde_lib.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def __init__(
    self,
    beta_min: float = 0.1,
    beta_max: float = 20,
    ndim_micro_shape: int = 2,
):
    r"""Construct a Variance Preserving SDE.

    Args:
        beta_min: value of beta(0)
        beta_max: value of beta(1)
        ndim_micro_shape: number of dimensions of a sample
    """
    super().__init__(ndim_micro_shape=ndim_micro_shape)
    self.beta_0 = beta_min
    self.beta_1 = beta_max

get_drift_and_diffusion(x, t, mask=None)

continuous DDPM SDE

.. math::

dx &= -\frac{1}{2}\beta_t x dt + \sqrt{\beta_t} dw

Parameters:

Name Type Description Default
x Tensor
required
t Tensor

(macro_shape)

required
mask
None

Returns:

Name Type Description
drift Tensor

shape = x.shape

diffusion Tensor

shape=x.macro_shape

Source code in src/ls_mlkit/util/sde/sde_lib.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def get_drift_and_diffusion(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]:
    r"""continuous DDPM SDE

    .. math::

        dx &= -\frac{1}{2}\beta_t x dt + \sqrt{\beta_t} dw

    Args:
        x:
        t: (macro_shape)
        mask:

    Returns:
        drift: shape = x.shape
        diffusion: shape=x.macro_shape
    """
    macro_shape = x.shape[: -self.ndim_micro_shape]
    beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
    drift = -0.5 * beta_t.view(*macro_shape, *[1 for _ in range(self.ndim_micro_shape)]) * x
    diffusion = torch.sqrt(beta_t)
    return drift, diffusion

get_score(x_t, mean, std)

.. math::

p_{0t} (x_t|x_0) = \nabla_{x_t} \ln p_{0t} (x_t|x_0)
Source code in src/ls_mlkit/util/sde/sde_lib.py
54
55
56
57
58
59
60
61
62
63
def get_score(self, x_t, mean, std) -> Tensor:
    r"""

    .. math::

        p_{0t} (x_t|x_0) = \nabla_{x_t} \ln p_{0t} (x_t|x_0)

    """
    score = -(x_t - mean) / std**2
    return score

get_a_b(t)

x_t = a * x_0 + b * epsilon, epsilon ~ N(0, 1)

Parameters:

Name Type Description Default
t ``Tensor``

continuous time

required

Returns:

Type Description
Tuple[Tensor, Tensor]

Tuple[Tensor, Tensor]: a, b

Source code in src/ls_mlkit/util/sde/sde_lib.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def get_a_b(self, t: Tensor) -> Tuple[Tensor, Tensor]:
    """x_t = a * x_0 + b * epsilon, epsilon ~ N(0, 1)

    Args:
        t (``Tensor``): continuous time

    Returns:
        ``Tuple[Tensor, Tensor]``: a, b
    """
    macro_shape = t.shape
    log_mean_coeff = -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0  # mcro_shape
    log_mean_coeff = log_mean_coeff.view(*macro_shape, *[1 for _ in range(self.ndim_micro_shape)])
    a = torch.exp(log_mean_coeff)
    b = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
    return a, b

forward_process(x_0, t, mask=None)

.. math::

p_{0t} (x_t|x_0)

.. math::

\gamma = -\frac{1}{4}t^2 (\beta_1 - \beta_0) - \frac{1}{2} t  \beta_0

mean = e^{\gamma} * x

std = \sqrt{1 - e^{2 \gamma }}
Source code in src/ls_mlkit/util/sde/sde_lib.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def forward_process(self, x_0: Tensor, t: Tensor, mask: Optional[Tensor] = None) -> Dict[str, Tensor]:
    r"""

    .. math::

        p_{0t} (x_t|x_0)

    .. math::

        \gamma = -\frac{1}{4}t^2 (\beta_1 - \beta_0) - \frac{1}{2} t  \beta_0

        mean = e^{\gamma} * x

        std = \sqrt{1 - e^{2 \gamma }}

    """
    a, b = self.get_a_b(t)
    mean = a * x_0
    x_t = mean + b * torch.randn_like(x_0)
    return {
        "x_t": x_t,
        "mean": mean,
        "std": b,
        "a": a,
        "b": b,
    }

prior_sampling(shape)

.. math:: \epsilon \sim \mathbfcal{N}(0,1)

Source code in src/ls_mlkit/util/sde/sde_lib.py
117
118
119
120
121
122
123
def prior_sampling(self, shape: Tuple) -> Tensor:
    r"""
    .. math::
        \epsilon \sim \mathbfcal{N}(0,1)

    """
    return torch.randn(*shape)

prior_logp(z)

.. math::

(2\pi)^{-k/2} \det(\Sigma)^{-1/2} \exp\left( -\frac{1}{2} (\mathbf{x} - \boldsymbol{\mu})^\mathrm{T} \Sigma^{-1} (\mathbf{x} - \boldsymbol{\mu}) \right)

where :math:\Sigma = I and :math:\mathbf{\mu} = 0

Source code in src/ls_mlkit/util/sde/sde_lib.py
125
126
127
128
129
130
131
132
133
134
135
136
137
def prior_logp(self, z: torch.Tensor) -> Tensor:
    r"""

    .. math::

        (2\pi)^{-k/2} \det(\Sigma)^{-1/2} \exp\left( -\frac{1}{2} (\mathbf{x} - \boldsymbol{\mu})^\mathrm{T} \Sigma^{-1} (\mathbf{x} - \boldsymbol{\mu}) \right)

    where :math:`\Sigma = I` and  :math:`\mathbf{\mu} = 0`
    """
    shape = z.shape
    N = np.prod(shape[1:])
    logps = -N / 2.0 * np.log(2 * np.pi) - torch.sum(z**2, dim=(1, 2, 3)) / 2.0
    return logps

Corrector

Bases: ABC

The abstract class for a corrector algorithm.

Source code in src/ls_mlkit/util/sde/corrector.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class Corrector(abc.ABC):
    """The abstract class for a corrector algorithm."""

    def __init__(self, sde: SDE, score_fn: Callable[..., Tensor], snr: float, n_steps: int):
        super().__init__()
        self.sde = sde
        self.score_fn = score_fn
        self.snr = snr
        self.n_steps = n_steps

    @abc.abstractmethod
    def update_fn(self, x: Tensor, t: Tensor, mask=None):
        """One update of the corrector.

        Args:
          x: A PyTorch tensor representing the current state
          t: A PyTorch tensor representing the current time step.

        Returns:
          x: A PyTorch tensor of the next state.
          x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
        """

update_fn(x, t, mask=None) abstractmethod

One update of the corrector.

Parameters:

Name Type Description Default
x Tensor

A PyTorch tensor representing the current state

required
t Tensor

A PyTorch tensor representing the current time step.

required

Returns:

Name Type Description
x

A PyTorch tensor of the next state.

x_mean

A PyTorch tensor. The next state without random noise. Useful for denoising.

Source code in src/ls_mlkit/util/sde/corrector.py
28
29
30
31
32
33
34
35
36
37
38
39
@abc.abstractmethod
def update_fn(self, x: Tensor, t: Tensor, mask=None):
    """One update of the corrector.

    Args:
      x: A PyTorch tensor representing the current state
      t: A PyTorch tensor representing the current time step.

    Returns:
      x: A PyTorch tensor of the next state.
      x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
    """

NoneCorrector

Bases: Corrector

An empty corrector that does nothing.

Source code in src/ls_mlkit/util/sde/corrector.py
42
43
44
45
46
47
48
49
@register_corrector(key_name="none")
class NoneCorrector(Corrector):
    """An empty corrector that does nothing."""

    def __init__(self, sde: SDE, score_fn: Callable[..., Tensor], snr: float, n_steps: int): ...

    def update_fn(self, x, t, mask=None):
        return x, x

Predictor

Bases: ABC

Source code in src/ls_mlkit/util/sde/predictor.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class Predictor(abc.ABC):
    def __init__(self, sde: SDE, score_fn: Callable[..., Tensor], use_probability_flow=False):
        super().__init__()
        self.sde = sde
        # Compute the reverse SDE/ODE
        self.rsde = sde.get_reverse_sde(score_fn=score_fn, use_probability_flow=use_probability_flow)
        self.score_fn = score_fn

    @abc.abstractmethod
    def update_fn(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]:
        r"""One update of the predictor.

        Args:
            x: A PyTorch tensor representing the current state
            t: A Pytorch tensor representing the current time step.

        Returns:
            x: A PyTorch tensor of the next state.
            x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
        """

update_fn(x, t, mask=None) abstractmethod

One update of the predictor.

Parameters:

Name Type Description Default
x Tensor

A PyTorch tensor representing the current state

required
t Tensor

A Pytorch tensor representing the current time step.

required

Returns:

Name Type Description
x Tensor

A PyTorch tensor of the next state.

x_mean Tensor

A PyTorch tensor. The next state without random noise. Useful for denoising.

Source code in src/ls_mlkit/util/sde/predictor.py
25
26
27
28
29
30
31
32
33
34
35
36
@abc.abstractmethod
def update_fn(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]:
    r"""One update of the predictor.

    Args:
        x: A PyTorch tensor representing the current state
        t: A Pytorch tensor representing the current time step.

    Returns:
        x: A PyTorch tensor of the next state.
        x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
    """

ReverseDiffusionPredictor

Bases: Predictor

Source code in src/ls_mlkit/util/sde/predictor.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@register_predictor(key_name="reverse_diffusion_predictor")
class ReverseDiffusionPredictor(Predictor):
    def __init__(self, sde: SDE, score_fn: Callable[..., Tensor], use_probability_flow=False, n_dim: int = 3):
        super().__init__(
            sde=sde,
            score_fn=score_fn,
            use_probability_flow=use_probability_flow,
        )
        self.n_dim = n_dim

    @override
    def update_fn(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]:
        r"""

        .. math::

            x_{t+\Delta t} &= x_t + f(x_t, t)(\Delta t) + g(x_t, t) \epsilon, \epsilon \sim \mathcal{N}(0,\sqrt{\Delta t}))

            f &= f(x_t, t)|\Delta t|

            g &= g(x_t, t)\sqrt{|\Delta t|}

        """
        f, g = self.rsde.get_discretized_drift_and_diffusion(x, t, mask=mask)
        z = torch.randn_like(x)
        x_mean = x - f
        x = x_mean + g * z
        return x, x_mean

update_fn(x, t, mask=None)

.. math::

x_{t+\Delta t} &= x_t + f(x_t, t)(\Delta t) + g(x_t, t) \epsilon, \epsilon \sim \mathcal{N}(0,\sqrt{\Delta t}))

f &= f(x_t, t)|\Delta t|

g &= g(x_t, t)\sqrt{|\Delta t|}
Source code in src/ls_mlkit/util/sde/predictor.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@override
def update_fn(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]:
    r"""

    .. math::

        x_{t+\Delta t} &= x_t + f(x_t, t)(\Delta t) + g(x_t, t) \epsilon, \epsilon \sim \mathcal{N}(0,\sqrt{\Delta t}))

        f &= f(x_t, t)|\Delta t|

        g &= g(x_t, t)\sqrt{|\Delta t|}

    """
    f, g = self.rsde.get_discretized_drift_and_diffusion(x, t, mask=mask)
    z = torch.randn_like(x)
    x_mean = x - f
    x = x_mean + g * z
    return x, x_mean

SubVPSDE

Bases: SDE

Source code in src/ls_mlkit/util/sde/sde_lib.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
class SubVPSDE(SDE):
    def __init__(
        self,
        beta_min: float = 0.1,
        beta_max: float = 20,
        ndim_micro_shape: int = 2,
    ):
        """Construct the sub-VP SDE that excels at likelihoods.

        Args:
            beta_min: value of beta(0)
            beta_max: value of beta(1)
            n_discretization_steps: number of discretization steps
            ndim_micro_shape: number of dimensions of a sample
        """
        super().__init__(ndim_micro_shape=ndim_micro_shape)
        self.beta_0 = beta_min
        self.beta_1 = beta_max

    @property
    def T(self) -> float:
        return 1

    def get_drift_and_diffusion(self, x: Tensor, t: Tensor, mask=None) -> Tuple[Tensor, Tensor]:
        beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
        macro_shape = x.shape[: -self.ndim_micro_shape]
        beta_t = beta_t.view(*macro_shape, *[1 for _ in range(self.ndim_micro_shape)])
        drift = -0.5 * beta_t * x
        discount = 1.0 - torch.exp(-2 * self.beta_0 * t - (self.beta_1 - self.beta_0) * t**2)
        diffusion = torch.sqrt(beta_t * discount)
        return drift, diffusion

    def marginal_prob(self, x, t, mask=None):
        log_mean_coeff = -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
        macro_shape = x.shape[: -self.ndim_micro_shape]
        log_mean_coeff = log_mean_coeff.view(*macro_shape, *[1 for _ in range(self.ndim_micro_shape)])
        mean = torch.exp(log_mean_coeff) * x
        std = 1 - torch.exp(2.0 * log_mean_coeff)
        return mean, std

    def prior_sampling(self, shape):
        return torch.randn(*shape)

    def prior_logp(self, z):
        shape = z.shape
        N = np.prod(shape[1:])
        return -N / 2.0 * np.log(2 * np.pi) - torch.sum(z**2, dim=(1, 2, 3)) / 2.0

__init__(beta_min=0.1, beta_max=20, ndim_micro_shape=2)

Construct the sub-VP SDE that excels at likelihoods.

Parameters:

Name Type Description Default
beta_min float

value of beta(0)

0.1
beta_max float

value of beta(1)

20
n_discretization_steps

number of discretization steps

required
ndim_micro_shape int

number of dimensions of a sample

2
Source code in src/ls_mlkit/util/sde/sde_lib.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def __init__(
    self,
    beta_min: float = 0.1,
    beta_max: float = 20,
    ndim_micro_shape: int = 2,
):
    """Construct the sub-VP SDE that excels at likelihoods.

    Args:
        beta_min: value of beta(0)
        beta_max: value of beta(1)
        n_discretization_steps: number of discretization steps
        ndim_micro_shape: number of dimensions of a sample
    """
    super().__init__(ndim_micro_shape=ndim_micro_shape)
    self.beta_0 = beta_min
    self.beta_1 = beta_max

Sniffer

Source code in src/ls_mlkit/util/sniffer.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class Sniffer:
    def __init__(self):
        pass

    def sniff_file(self, directory_path, pattern, max_deep=-1) -> list[str]:
        """Get all files in the directory_path that match the pattern

        Args:
            directory_path (``str``): the path of the directory to sniff
            pattern (``str``): the pattern to match
            max_deep (``int``, *optional*): the maximum depth to sniff. Defaults to -1.

        Returns:
            ``list[str]``: the full path of the files that match the pattern
        """
        ans = list()
        all_file_path_list = self.get_all_by_recursion(directory_path, deep=0, max_deep=max_deep)
        for path in all_file_path_list:
            path_ = re.split(pattern="/", string=path)
            filename = path_[len(path_) - 1]
            if re.findall(pattern=pattern, string=filename):
                ans.append(path)
        return ans

    def sniff_file_by_path_pattern(self, directory_path, pattern, max_deep=-1) -> list[str]:
        """Get all files in the directory_path that match the pattern

        Args:
            directory_path (``str``): the path of the directory to sniff
            pattern (``str``): the pattern to match
            max_deep (``int``, *optional*): the maximum depth to sniff. Defaults to -1.

        Returns:
            ``list[str]``: the full path of the files that match the pattern
        """
        ans = list()
        all_file_path_list = self.get_all_by_recursion(directory_path, deep=0, max_deep=max_deep)
        for path in all_file_path_list:
            if re.findall(pattern=pattern, string=path):
                ans.append(path)
        return ans

    def get_all_by_recursion(self, directory_path, deep, max_deep):
        all_file_path_list = []

        def _get_all_by_recursion(directory_path, deep, max_deep):
            if directory_path[len(directory_path) - 1] == "/":
                directory_path = directory_path[0:-1]
            if os.path.isdir(directory_path) and (max_deep < 0 or deep < max_deep):
                filename_list = os.listdir(path=directory_path)
                for filename in filename_list:
                    _get_all_by_recursion(directory_path + "/" + filename, deep + 1, max_deep)
            else:
                all_file_path_list.append(directory_path)

        _get_all_by_recursion(directory_path, deep, max_deep)
        return all_file_path_list

sniff_file(directory_path, pattern, max_deep=-1)

Get all files in the directory_path that match the pattern

Parameters:

Name Type Description Default
directory_path ``str``

the path of the directory to sniff

required
pattern ``str``

the pattern to match

required
max_deep ``int``, *optional*

the maximum depth to sniff. Defaults to -1.

-1

Returns:

Type Description
list[str]

list[str]: the full path of the files that match the pattern

Source code in src/ls_mlkit/util/sniffer.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def sniff_file(self, directory_path, pattern, max_deep=-1) -> list[str]:
    """Get all files in the directory_path that match the pattern

    Args:
        directory_path (``str``): the path of the directory to sniff
        pattern (``str``): the pattern to match
        max_deep (``int``, *optional*): the maximum depth to sniff. Defaults to -1.

    Returns:
        ``list[str]``: the full path of the files that match the pattern
    """
    ans = list()
    all_file_path_list = self.get_all_by_recursion(directory_path, deep=0, max_deep=max_deep)
    for path in all_file_path_list:
        path_ = re.split(pattern="/", string=path)
        filename = path_[len(path_) - 1]
        if re.findall(pattern=pattern, string=filename):
            ans.append(path)
    return ans

sniff_file_by_path_pattern(directory_path, pattern, max_deep=-1)

Get all files in the directory_path that match the pattern

Parameters:

Name Type Description Default
directory_path ``str``

the path of the directory to sniff

required
pattern ``str``

the pattern to match

required
max_deep ``int``, *optional*

the maximum depth to sniff. Defaults to -1.

-1

Returns:

Type Description
list[str]

list[str]: the full path of the files that match the pattern

Source code in src/ls_mlkit/util/sniffer.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def sniff_file_by_path_pattern(self, directory_path, pattern, max_deep=-1) -> list[str]:
    """Get all files in the directory_path that match the pattern

    Args:
        directory_path (``str``): the path of the directory to sniff
        pattern (``str``): the pattern to match
        max_deep (``int``, *optional*): the maximum depth to sniff. Defaults to -1.

    Returns:
        ``list[str]``: the full path of the files that match the pattern
    """
    ans = list()
    all_file_path_list = self.get_all_by_recursion(directory_path, deep=0, max_deep=max_deep)
    for path in all_file_path_list:
        if re.findall(pattern=pattern, string=path):
            ans.append(path)
    return ans

check_cuda()

Check the CUDA environment, test whether torch can use cuda

Returns:

Type Description

None

Source code in src/ls_mlkit/util/cuda.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def check_cuda():
    """Check the CUDA environment, test whether torch can use cuda

    Returns:
        None
    """
    print("============================================================")
    print("cuda.is_available:", torch.cuda.is_available())
    print("cuda.device_count:", torch.cuda.device_count())
    cuda_version = getattr(torch, "version", None)
    if cuda_version is not None:
        cuda_version = getattr(cuda_version, "cuda", None)
    print("version.cuda:", cuda_version)
    print("current_device:", torch.cuda.current_device())
    print(
        "cuda.get_device_name:",
        torch.cuda.get_device_name(torch.cuda.current_device()),
    )
    print("torch.backends.cudnn.is_available", torch.backends.cudnn.is_available())
    print("print(torch.backends.cudnn.version", torch.backends.cudnn.version())
    print(
        "(free, total)GB",
        [x / (1024 * 1024 * 1024) for x in torch.cuda.mem_get_info()],
    )
    print(
        "torch.cuda.max_memory_allocated",
        torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),
        "GB",
    )
    current_backend = torch.backends.cuda.preferred_linalg_library()
    print(f"Current preferred linalg library: {current_backend}")
    print("============================================================")

cache_to_disk(root_datadir='cached_dataset', exclude_first_arg=False)

Cache the result of a function to disk

Parameters:

Name Type Description Default
root_datadir str

the root directory to save the cached data. Defaults to "cached_dataset".

'cached_dataset'
exclude_first_arg bool

whether to exclude the first argument of the function when generating the cache filename. Defaults to False.

False
Source code in src/ls_mlkit/util/decorators.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def cache_to_disk(root_datadir="cached_dataset", exclude_first_arg=False):
    """Cache the result of a function to disk

    Args:
        root_datadir (str, optional): the root directory to save the cached data. Defaults to "cached_dataset".
        exclude_first_arg (bool, optional): whether to exclude the first argument of the function when generating the cache filename. Defaults to False.
    """

    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            if not os.path.exists(root_datadir):
                os.makedirs(root_datadir)

            func_name = func.__name__.replace("/", "")
            cache_filename = root_datadir + "/" + f"{func_name}.pkl"
            args_str = "_".join(map(str, args[1:] if exclude_first_arg else args))
            kwargs_str = "_".join(f"{k}={v}" for k, v in kwargs.items())
            params_str = f"{args_str}_{kwargs_str}"
            params_hash = hashlib.md5(params_str.encode()).hexdigest()
            cache_filename = os.path.join(root_datadir, f"{func_name}_{params_hash}.pkl")
            print("cache_filename =", cache_filename)

            if os.path.exists(cache_filename):
                with open(cache_filename, "rb") as f:
                    print(f"Loading cached data for {func.__name__} {params_str}")
                    return pickle.load(f)

            result = func(*args, **kwargs)

            print("caching " + cache_filename)
            with open(cache_filename, "wb") as f:
                pickle.dump(result, f)
                print(f"Cached data for {func.__name__}")

            hash_table_filename = os.path.join(root_datadir, "hash_table.txt")
            if not os.path.exists(hash_table_filename):
                with open(hash_table_filename, "w"):
                    pass
            with open(hash_table_filename, "a") as f:
                f.write(f"{cache_filename}: {params_str}\n")

            return result

        return wrapper

    return decorator

inherit_docstring_from_parent(method_name=None)

Method decorator that inherits docstring from a specific parent class method.

This decorator allows you to explicitly inherit a docstring from a parent class method, even if the method names are different.

Usage:

.. code-block:: python

    class ChildClass(ParentClass):
        @inherit_docstring_from_parent('parent_method_name')
        def child_method(self):
            pass

        @inherit_docstring_from_parent()  # Uses same method name
        def some_method(self):
            pass

Parameters:

Name Type Description Default
method_name str | None

Name of the parent method to inherit docstring from. If None, uses the decorated method's name.

None

Returns:

Type Description

The decorated method with inherited docstring

Source code in src/ls_mlkit/util/decorators.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def inherit_docstring_from_parent(method_name: str | None = None):
    r"""
    Method decorator that inherits docstring from a specific parent class method.

    This decorator allows you to explicitly inherit a docstring from a parent class method,
    even if the method names are different.

    Usage:

        .. code-block:: python

            class ChildClass(ParentClass):
                @inherit_docstring_from_parent('parent_method_name')
                def child_method(self):
                    pass

                @inherit_docstring_from_parent()  # Uses same method name
                def some_method(self):
                    pass

    Args:
        method_name: Name of the parent method to inherit docstring from.
                    If None, uses the decorated method's name.

    Returns:
        The decorated method with inherited docstring
    """

    def decorator(func):
        target_method_name = method_name or func.__name__

        # We need to defer the docstring inheritance until the class is fully defined
        # So we'll mark the function and handle it in the class decorator
        func._inherit_docstring_from = target_method_name
        return func

    return decorator

inherit_docstrings(cls)

Class decorator that automatically inherits docstrings from parent class methods.

This decorator will: 1. Find methods in the class that don't have docstrings 2. Look for the same method in parent classes 3. Copy the docstring from the first parent class that has one 4. Handle methods marked with @inherit_docstring_from_parent

Usage:

.. code-block:: python

    @inherit_docstrings
    class ChildClass(ParentClass):
        def some_method(self):
            # This method will inherit docstring from ParentClass.some_method
            pass

        @inherit_docstring_from_parent('parent_method')
        def child_method(self):
            # This method will inherit docstring from ParentClass.parent_method
            pass

Parameters:

Name Type Description Default
cls

The class to apply docstring inheritance to

required

Returns:

Type Description

The modified class with inherited docstrings

Source code in src/ls_mlkit/util/decorators.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def inherit_docstrings(cls):
    r"""
    Class decorator that automatically inherits docstrings from parent class methods.

    This decorator will:
    1. Find methods in the class that don't have docstrings
    2. Look for the same method in parent classes
    3. Copy the docstring from the first parent class that has one
    4. Handle methods marked with @inherit_docstring_from_parent

    Usage:

        .. code-block:: python

            @inherit_docstrings
            class ChildClass(ParentClass):
                def some_method(self):
                    # This method will inherit docstring from ParentClass.some_method
                    pass

                @inherit_docstring_from_parent('parent_method')
                def child_method(self):
                    # This method will inherit docstring from ParentClass.parent_method
                    pass

    Args:
        cls: The class to apply docstring inheritance to

    Returns:
        The modified class with inherited docstrings
    """
    for attr_name in dir(cls):
        # Skip private/magic methods and non-callable attributes
        if attr_name.startswith("_"):
            continue

        attr = getattr(cls, attr_name)
        if not callable(attr):
            continue

        # Check if this method exists in the class's own __dict__ (not inherited)
        if attr_name not in cls.__dict__:
            continue

        # Handle methods marked with @inherit_docstring_from_parent
        if hasattr(attr, "_inherit_docstring_from"):
            target_method_name = attr._inherit_docstring_from
            # Look for the target method in parent classes
            for base in cls.__mro__[1:]:  # Skip the class itself
                if hasattr(base, target_method_name):
                    parent_method = getattr(base, target_method_name)
                    if callable(parent_method) and parent_method.__doc__:
                        # Copy the docstring from the specified parent method
                        attr.__doc__ = parent_method.__doc__
                        break
            continue

        # Check if the method already has a docstring
        if attr.__doc__:
            continue

        # Automatic inheritance: Look for docstring in parent classes with same method name
        for base in cls.__mro__[1:]:  # Skip the class itself
            if hasattr(base, attr_name):
                parent_method = getattr(base, attr_name)
                if callable(parent_method) and parent_method.__doc__:
                    # Copy the docstring
                    attr.__doc__ = parent_method.__doc__
                    break

    return cls

register_class_to_dict(cls=None, *, key_name=None, global_dict=None)

Register a class to a global dictionary

Parameters:

Name Type Description Default
cls class

the class to register. Defaults to None.

None
key_name str

the name of the key to register the class. Defaults to None.

None
global_dict dict

the global dictionary to register the class. Defaults to None.

None
Source code in src/ls_mlkit/util/decorators.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def register_class_to_dict(cls=None, *, key_name=None, global_dict: dict | None = None):
    """Register a class to a global dictionary

    Args:
        cls (class, optional): the class to register. Defaults to None.
        key_name (str, optional): the name of the key to register the class. Defaults to None.
        global_dict (dict, optional): the global dictionary to register the class. Defaults to None.

    """
    if global_dict is None:
        raise ValueError("global_dict must be provided")

    def _register(cls):
        if key_name is None:
            local_key_name = cls.__name__
        else:
            local_key_name = key_name
        if local_key_name in global_dict:
            raise ValueError(f"Already registered model with name: {local_key_name}")
        global_dict[local_key_name] = cls
        return cls

    if cls is None:
        return _register
    else:
        return _register(cls)

require_keys(*required_keys)

Decorator to ensure returned dictionary contains required keys

Source code in src/ls_mlkit/util/decorators.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def require_keys(*required_keys):
    """Decorator to ensure returned dictionary contains required keys"""

    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            result = func(*args, **kwargs)
            if not isinstance(result, dict):
                raise TypeError(f"{func.__name__} must return a dictionary")

            missing_keys = set(required_keys) - result.keys()
            if missing_keys:
                raise KeyError(f"{func.__name__} missing required keys: {missing_keys}")

            return result

        return wrapper

    return decorator

timer(format='ms')

Timer the execution time of a function

Parameters:

Name Type Description Default
format str

the format of the execution time. Defaults to "ms".

'ms'
Source code in src/ls_mlkit/util/decorators.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def timer(format="ms"):
    """Timer the execution time of a function

    Args:
        format (str, optional): the format of the execution time. Defaults to "ms".
    """

    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            begin_time = datetime.now()
            result = func(*args, **kwargs)
            end_time = datetime.now()
            cost = (end_time - begin_time).seconds
            print(
                func.__name__ + " run" + f" {cost // 60} min {cost % 60}s",
            )
            return result

        return wrapper

    return decorator

inf_iterator(iterable)

An infinite iterator

Parameters:

Name Type Description Default
iterable iterable

the iterable to iterate over

required

Yields:

Name Type Description
any

the next element in the iterable

Source code in src/ls_mlkit/util/iterator.py
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
def inf_iterator(iterable):
    """An infinite iterator

    Args:
        iterable (iterable): the iterable to iterate over

    Yields:
        any: the next element in the iterable
    """
    iterator = iterable.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = iterable.__iter__()

get_and_create_new_log_dir(root='./logs', prefix='', suffix='')

Get and create a new log directory

Parameters:

Name Type Description Default
root str

the root directory to save the logs. Defaults to "./logs".

'./logs'
prefix str

the prefix of the log directory. Defaults to "".

''
suffix str

the suffix of the log directory. Defaults to "".

''

Returns:

Name Type Description
str str

the new log directory

Source code in src/ls_mlkit/util/log.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def get_and_create_new_log_dir(root="./logs", prefix="", suffix="") -> str:
    """Get and create a new log directory

    Args:
        root (str, optional): the root directory to save the logs. Defaults to "./logs".
        prefix (str, optional): the prefix of the log directory. Defaults to "".
        suffix (str, optional): the suffix of the log directory. Defaults to "".

    Returns:
        str: the new log directory
    """
    filename = time.strftime("%Y_%m_%d__%H_%M_%S", time.localtime())
    if prefix != "":
        filename = prefix + "_" + filename
    if suffix != "":
        filename = filename + "_" + suffix
    log_dir = os.path.join(root, filename)
    os.makedirs(log_dir)
    return log_dir

get_logger(name='unnamed', log_dir=None)

Get a logger

Parameters:

Name Type Description Default
name str

the name of the logger. Defaults to "unnamed".

'unnamed'
log_dir str

the directory to save the logs. Defaults to None.

None

Returns:

Type Description
Logger

logging.Logger: the logger

Source code in src/ls_mlkit/util/log.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def get_logger(name="unnamed", log_dir: str | None = None) -> logging.Logger:
    """Get a logger

    Args:
        name (str, optional): the name of the logger. Defaults to "unnamed".
        log_dir (str, optional): the directory to save the logs. Defaults to None.

    Returns:
        logging.Logger: the logger
    """
    logger = logging.getLogger(name)
    logger.propagate = False
    logger.setLevel(logging.DEBUG)
    formatter = logging.Formatter("[%(asctime)s::%(name)s::%(levelname)s] %(message)s")

    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(logging.DEBUG)
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)

    if log_dir is not None:
        file_handler = logging.FileHandler(os.path.join(log_dir, "log.txt"))
        file_handler.setLevel(logging.DEBUG)
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)

    return logger

find_linear_modules(model)

Find the linear modules in a model

Parameters:

Name Type Description Default
model Module

the model to find the linear modules

required

Returns:

Type Description
List[str]

List[str]: the names of the linear modules

Source code in src/ls_mlkit/util/lora.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def find_linear_modules(model) -> List[str]:
    """Find the linear modules in a model

    Args:
        model (torch.nn.Module): the model to find the linear modules

    Returns:
        List[str]: the names of the linear modules
    """
    linear_cls = torch.nn.Linear
    output_layer_names = ["lm_head", "embed_tokens"]

    module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, linear_cls) and not any([output_layer in name for output_layer in output_layer_names]):
            module_names.add(name.split(".")[-1])
    return list(module_names)

get_lora_model(model, lora_config)

Get a LoRA model

Parameters:

Name Type Description Default
model Module

the model to get the LoRA model

required
lora_config LoraConfig

the LoRA configuration

required

Returns:

Type Description

torch.nn.Module: the LoRA model

Source code in src/ls_mlkit/util/lora.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def get_lora_model(model, lora_config):
    """Get a LoRA model

    Args:
        model (torch.nn.Module): the model to get the LoRA model
        lora_config (LoraConfig): the LoRA configuration

    Returns:
        torch.nn.Module: the LoRA model
    """
    taget_modules = find_linear_modules(model)
    lora_config = LoraConfig(
        r=lora_config["lora_r"],
        target_modules=taget_modules,
        lora_alpha=lora_config["lora_alpha"],
        lora_dropout=lora_config["lora_dropout"],
    )
    model = get_peft_model(model, lora_config)
    return model

get_nma_displacement_from_node_coordinates(node_coordinates, cutoff_distance=10.0, indexes=[6], node_mask=None)

node_coordinates: shape = (..., n, 3) node_mask: shape = (..., n)

Source code in src/ls_mlkit/util/nma/nma.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def get_nma_displacement_from_node_coordinates(
    node_coordinates: Tensor,
    cutoff_distance: float = 10.0,
    indexes: list[int] = [6],
    node_mask: Tensor | None = None,
) -> Tensor:
    """
    node_coordinates: shape = (..., n, 3)
    node_mask: shape = (..., n)
    """
    force_field = HinsenForceField(cutoff_distance=cutoff_distance)
    anm = ANM(
        atoms=node_coordinates,
        force_field=force_field,
        masses=None,
        device=node_coordinates.device,
        node_mask=node_mask,
    )
    return anm.get_displacements_from_normal_modes(indexes=indexes)

gradient_norm_fn(module)

Compute the gradient norm of a module

Parameters:

Name Type Description Default
module Module

the module to compute the gradient norm

required

Returns:

Name Type Description
float

the gradient norm of the module

Source code in src/ls_mlkit/util/observer.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def gradient_norm_fn(module: Module):
    """Compute the gradient norm of a module

    Args:
        module (Module): the module to compute the gradient norm

    Returns:
        float: the gradient norm of the module
    """
    sq_sum: Tensor | None = None
    for p in module.parameters():
        if p.grad is not None:
            g = p.grad
            term = torch.sum(g * g)
            sq_sum = term if sq_sum is None else sq_sum + term
    if sq_sum is None:
        return torch.sqrt(torch.zeros(1))
    return torch.sqrt(sq_sum)

gradients_fn(module)

Get the gradients of a module

Parameters:

Name Type Description Default
module Module

the module to get the gradients

required

Returns:

Name Type Description
list list[Tensor]

the gradients of the module

Source code in src/ls_mlkit/util/observer.py
63
64
65
66
67
68
69
70
71
72
def gradients_fn(module: Module) -> list[Tensor]:
    """Get the gradients of a module

    Args:
        module (Module): the module to get the gradients

    Returns:
        list: the gradients of the module
    """
    return [p.grad.detach().cpu() for p in module.parameters() if p.grad is not None]

weight_norm_fn(module)

Compute the weight norm of a module

Parameters:

Name Type Description Default
module Module

the module to compute the weight norm

required

Returns:

Name Type Description
float

the weight norm of the module

Source code in src/ls_mlkit/util/observer.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def weight_norm_fn(module: Module):
    """Compute the weight norm of a module

    Args:
        module (Module): the module to compute the weight norm

    Returns:
        float: the weight norm of the module
    """
    sq_sum: Tensor | None = None
    for p in module.parameters():
        if p.requires_grad:
            term = torch.sum(p.data * p.data)
            sq_sum = term if sq_sum is None else sq_sum + term
    if sq_sum is None:
        return torch.sqrt(torch.zeros(1))
    return torch.sqrt(sq_sum)

weights_fn(module)

Get the weights of a module

Parameters:

Name Type Description Default
module Module

the module to get the weights

required

Returns:

Name Type Description
list list[Tensor]

the weights of the module

Source code in src/ls_mlkit/util/observer.py
51
52
53
54
55
56
57
58
59
60
def weights_fn(module: Module) -> list[Tensor]:
    """Get the weights of a module

    Args:
        module (Module): the module to get the weights

    Returns:
        list: the weights of the module
    """
    return [p.detach().cpu() for p in module.parameters() if p.requires_grad]

print_cpu_memory()

Print the CPU memory usage

Returns:

Type Description

None

Source code in src/ls_mlkit/util/resource_monitor.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def print_cpu_memory():
    """Print the CPU memory usage

    Returns:
        None
    """
    mem = psutil.virtual_memory()
    str(round(mem.total / 1024**3))
    str(round(mem.used / 1024**3))
    str(round(mem.percent))
    free = str(round(mem.free / 1024**3))
    process = psutil.Process(os.getpid())
    memory_info = process.memory_info()
    memory_usage_in_bytes = memory_info.rss
    # print("CPU memory size of all:" + total + "GB")
    # print("CPU memory used:" + used + "GB(" + use_per + "%)")
    print(f"CPU memory used:{memory_usage_in_bytes / 1024**3}GB")
    print("CPU memory available :" + free + "GB")

print_gpu_memory()

Print the GPU memory usage

Returns:

Type Description

None

Source code in src/ls_mlkit/util/resource_monitor.py
27
28
29
30
31
32
33
34
35
36
def print_gpu_memory():
    """Print the GPU memory usage

    Returns:
        None
    """
    allocated_memory = torch.cuda.memory_allocated() / (1024**3)
    max_allocated = torch.cuda.max_memory_allocated()
    print(f"GPU Allocated Memory: {allocated_memory:.2f} GB")
    print(f"GPU max_memory_allocated {max_allocated / (1024**3)} GB")

show_gpu_and_cpu_memory()

Show the GPU and CPU memory usage

Returns:

Type Description

None

Source code in src/ls_mlkit/util/resource_monitor.py
39
40
41
42
43
44
45
46
def show_gpu_and_cpu_memory():
    """Show the GPU and CPU memory usage

    Returns:
        None
    """
    print_gpu_memory()
    print_cpu_memory()

get_model_fn(model, train=False)

Create a function to give the output of the score-based model.

Parameters:

Name Type Description Default
model

The score model.

required
train

True for training and False for evaluation.

False

Returns:

Type Description

A model function.

Source code in src/ls_mlkit/util/sde/score_fn_utils.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def get_model_fn(model, train=False):
    """Create a function to give the output of the score-based model.

    Args:
      model: The score model.
      train: `True` for training and `False` for evaluation.

    Returns:
      A model function.
    """

    def model_fn(x, labels):
        """Compute the output of the score-based model.

        Args:
          x: A mini-batch of input data.
          labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
            for different models.

        Returns:
          A tuple of (model output, new mutable states)
        """
        if not train:
            model.eval()
            return model(x, labels)
        else:
            model.train()
            return model(x, labels)

    return model_fn

get_pc_sampler(sde, shape, predictor_class, corrector_class, inverse_scaler, snr, n_correct_steps=1, use_probability_flow=False, denoise_at_final=True, eps=0.001, device='cuda')

Create a Predictor-Corrector (PC) sampler.

Parameters:

Name Type Description Default
sde SDE

An SDE object representing the forward SDE.

required
shape Tuple[int, ...]

A sequence of integers. The expected shape of a single sample. First dimension is batch size.

required
predictor_class Predictor

A subclass of Predictor representing the predictor algorithm.

required
corrector_class Corrector

A subclass of Corrector representing the corrector algorithm.

required
inverse_scaler Callable

The inverse data normalizer.

required
snr float

A float number. The signal-to-noise ratio for configuring correctors.

required
n_correct_steps int

An integer. The number of corrector steps per predictor update.

1
use_probability_flow bool

If True, solve the reverse-time probability flow ODE when running the predictor.

False
denoise_at_final bool

If True, add one-step denoising to the final samples.

True
eps float

A float number. The reverse-time SDE and ODE are integrated to epsilon to avoid numerical issues.

0.001
device str

PyTorch device.

'cuda'

Returns:

Type Description

A sampling function that returns samples and the number of function evaluations during sampling.

Source code in src/ls_mlkit/util/sde/sampler.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def get_pc_sampler(
    sde: SDE,
    shape: Tuple[int, ...],
    predictor_class: Predictor,
    corrector_class: Corrector,
    inverse_scaler: Callable,
    snr: float,
    n_correct_steps: int = 1,
    use_probability_flow: bool = False,
    denoise_at_final: bool = True,
    eps: float = 1e-3,
    device: str = "cuda",
):
    """Create a Predictor-Corrector (PC) sampler.

    Args:
      sde: An `SDE` object representing the forward SDE.
      shape: A sequence of integers. The expected shape of a single sample. First dimension is batch size.
      predictor_class: A subclass of `Predictor` representing the predictor algorithm.
      corrector_class: A subclass of `Corrector` representing the corrector algorithm.
      inverse_scaler: The inverse data normalizer.
      snr: A `float` number. The signal-to-noise ratio for configuring correctors.
      n_correct_steps: An integer. The number of corrector steps per predictor update.
      use_probability_flow: If `True`, solve the reverse-time probability flow ODE when running the predictor.
      denoise_at_final: If `True`, add one-step denoising to the final samples.
      eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
      device: PyTorch device.

    Returns:
      A sampling function that returns samples and the number of function evaluations during sampling.
    """
    # Create predictor & corrector update functions
    predictor_update_fn = functools.partial(
        shared_predictor_update_fn,
        sde=sde,
        predictor_class=predictor_class,
        use_probability_flow=use_probability_flow,
    )
    corrector_update_fn = functools.partial(
        shared_corrector_update_fn,
        sde=sde,
        corrector_class=corrector_class,
        snr=snr,
        n_steps=n_correct_steps,
    )

    def pc_sampler(score_fn):
        with torch.no_grad():
            x = sde.prior_sampling(shape).to(device)
            timesteps = torch.linspace(sde.T, eps, sde.n_discretization_steps, device=device)

            for i in range(sde.n_discretization_steps):
                t = timesteps[i]
                vec_t = torch.ones(shape[0], device=t.device) * t
                x, x_mean = corrector_update_fn(x, vec_t, score_fn=score_fn)
                x, x_mean = predictor_update_fn(x, vec_t, score_fn=score_fn)

            return inverse_scaler(x_mean if denoise_at_final else x), sde.n_discretization_steps * (n_correct_steps + 1)

    return pc_sampler

get_score_fn(sde, model, train=False, continuous=False)

Wraps score_fn so that the model output corresponds to a real time-dependent score function.

Parameters:

Name Type Description Default
sde

An sde_lib.SDE object that represents the forward SDE.

required
model

A score model.

required
train

True for training and False for evaluation.

False
continuous

If True, the score-based model is expected to directly take continuous time steps.

False

Returns:

Type Description

A score function.

Source code in src/ls_mlkit/util/sde/score_fn_utils.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def get_score_fn(sde, model, train=False, continuous=False):
    """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.

    Args:
      sde: An `sde_lib.SDE` object that represents the forward SDE.
      model: A score model.
      train: `True` for training and `False` for evaluation.
      continuous: If `True`, the score-based model is expected to directly take continuous time steps.

    Returns:
      A score function.
    """
    model_fn = get_model_fn(model, train=train)

    if isinstance(sde, VPSDE) or isinstance(sde, SubVPSDE):

        def score_fn(x, t):
            # Scale neural network output by standard deviation and flip sign
            if continuous or isinstance(sde, SubVPSDE):
                # For VP-trained models, t=0 corresponds to the lowest noise level
                # The maximum value of time embedding is assumed to 999 for
                # continuously-trained models.
                labels = t * 999
                score = model_fn(x, labels)
                std = sde.marginal_prob(torch.zeros_like(x), t)[1]
            else:
                # For VP-trained models, t=0 corresponds to the lowest noise level
                labels = t * (sde.N - 1)
                score = model_fn(x, labels)
                std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()]

            score = -score / std[:, None, None, None]
            return score

    elif isinstance(sde, VESDE):

        def score_fn(x, t):
            if continuous:
                labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
            else:
                # For VE-trained models, t=0 corresponds to the highest noise level
                labels = sde.T - t
                labels *= sde.N - 1
                labels = torch.round(labels).long()

            score = model_fn(x, labels)
            return score

    else:
        raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")

    return score_fn

seed_everything(seed)

fix the seed for all the random number generators

Parameters:

Name Type Description Default
seed ``int``

the seed to use for the random number generators

required

Returns:

Type Description

None

Source code in src/ls_mlkit/util/seed.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def seed_everything(seed: int):
    """fix the seed for all the random number generators

    Args:
        seed (``int``): the seed to use for the random number generators

    Returns:
        None
    """
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

    accelerate.utils.set_seed(seed)