ls_mlkit.util.base_class.base_time_class module

class ls_mlkit.util.base_class.base_time_class.BaseTimeScheduler(continuous_time_start: float = 0.0, continuous_time_end: float = 1.0, num_train_timesteps: int = 1000, num_inference_steps: int = None, idx_start: int = 0)[source]

Bases: ABC

Base class for time schedulers in diffusion models.

Notation Convention

Let the total diffusion time be \(T\), discretized into \(N\) diffusion steps, corresponding to \(N+1\) continuous time points:

\[0 = t_0 < t_1 < \cdots < t_N = T\]

where \(\{t_i\}_{i=0}^N\) represents continuous time. For uniform discretization:

\[t_i = \frac{i}{N} \cdot T\]

The corresponding discrete time steps are defined as:

\[i \in \{0, 1, \ldots, N\}\]

In diffusion models, \(t_0\) corresponds to the clean data distribution \(q(x_0)\), so training and sampling typically only consider:

\[i \in \{1, \ldots, N\}\]

For engineering convenience (0-based array indexing), we use:

\[\text{idx} = i - 1\]

Therefore:

  • idx = 0 corresponds to discrete step \(i=1\), i.e., continuous time \(t_1\)

  • idx = N-1 corresponds to discrete step \(i=N\), i.e., continuous time \(t_N = T\)

In this implementation:

  • num_train_timesteps = \(N\) (number of diffusion steps)

  • timestep_index (or idx) \(\in \{\text{idx\_start}, \ldots, \text{idx\_start} + N - 1\}\)

  • continuous_time \(\in [t_1, t_N] = [\frac{T}{N}, T]\) for training/sampling

The idx_start parameter controls the starting value of timestep indices:

  • When idx_start = 0 (default): \(\text{idx} = i - 1\), so \(\text{idx} \in \{0, \ldots, N-1\}\)

  • When idx_start = 1: \(\text{idx} = i\), so \(\text{idx} \in \{1, \ldots, N\}\)

param continuous_time_start:

The start of continuous time range (typically 0). Defaults to 0.0.

type continuous_time_start:

float, optional

param continuous_time_end:

The end of continuous time range (i.e., \(T\)). Defaults to 1.0.

type continuous_time_end:

float, optional

param num_train_timesteps:

Number of diffusion steps \(N\). Defaults to 1000.

type num_train_timesteps:

int, optional

param num_inference_steps:

Number of inference steps. If None, uses num_train_timesteps. Defaults to None.

type num_inference_steps:

int, optional

param idx_start:

The starting value for timestep indices. Set to 1 if you prefer 1-based indexing where idx directly equals the discrete step i. Defaults to 0.

type idx_start:

int, optional

continuous_time_to_timestep_index(continuous_time: Tensor) Tensor[source]

Convert continuous time to timestep index.

Given continuous time \(t\), compute the timestep index:

\[\text{idx} = \text{round}\left(\frac{t - t_0}{T} \cdot N\right) - 1 + \text{idx\_start}\]

where \(t_0\) is continuous_time_start, \(T\) is the total time span, \(N\) is num_train_timesteps, and \(\text{idx\_start}\) is the starting index.

The result is clamped to \([\text{idx\_start}, \text{idx\_start} + N - 1]\).

Parameters:

continuous_time (Tensor) – Continuous time values \(t \in [t_0, t_0 + T]\).

Returns:

Timestep indices \(\text{idx} \in \{\text{idx\_start}, \ldots, \text{idx\_start} + N - 1\}\).

Return type:

Tensor

get_continuous_time_start() float[source]
get_continuous_timesteps_schedule() Tensor[source]

Get the continuous timesteps schedule for sampling/inference.

Returns:

1D tensor of continuous time values \(t \in [t_1, t_N]\).

Return type:

Tensor

get_coutinuous_time_end() float[source]
get_timestep_index_end() int[source]
get_timestep_index_start() int[source]
get_timestep_indices_schedule() Tensor[source]

Get the timestep indices schedule for sampling/inference.

Returns:

1D tensor of timestep indices \(\text{idx} \in \{\text{idx\_start}, \ldots, \text{idx\_start} + N - 1\}\).

Return type:

Tensor

abstractmethod initialize_timesteps_schedule() None[source]

Initialize timesteps schedule for sampling/inference.

sample_continuous_time_uniformly(macro_shape: Tuple[int, ...], same_for_all_samples: bool = False) Tensor[source]

Sample continuous time uniformly from \([t_1, t_N]\).

Note: This samples from \([t_0 + \frac{T}{N}, t_0 + T]\) to exclude \(t_0\) (the clean data point).

Parameters:
  • macro_shape (Tuple[int, ...]) – Shape of the output tensor.

  • same_for_all_samples (bool, optional) – If True, use the same time for all samples. Defaults to False.

Returns:

Continuous time values with shape macro_shape.

Return type:

Tensor

sample_timestep_index_uniformly(macro_shape: Tuple[int, ...], same_for_all_samples: bool = False) Tensor[source]

Sample timestep indices uniformly from \(\{\text{idx\_start}, \ldots, \text{idx\_start} + N - 1\}\).

This corresponds to sampling discrete steps \(i\) uniformly from \(\{1, \ldots, N\}\) and converting to index via \(\text{idx} = i - 1 + \text{idx\_start}\).

Parameters:
  • macro_shape (Tuple[int, ...]) – Shape of the output tensor.

  • same_for_all_samples (bool, optional) – If True, use the same timestep index for all samples. Defaults to False.

Returns:

Timestep indices with shape macro_shape.

Return type:

Tensor

set_continuous_timesteps_schedule(continuous_timesteps: Tensor) None[source]

Set the continuous timesteps schedule for sampling/inference.

Parameters:

continuous_timesteps (Tensor) – 1D tensor of continuous time values.

set_timestep_indices_schedule(timestep_indices: Tensor) None[source]

Set the timestep indices schedule for sampling/inference.

Parameters:

timestep_indices (Tensor) – 1D tensor of timestep indices \(\text{idx} \in \{\text{idx\_start}, \ldots, \text{idx\_start} + N - 1\}\).

timestep_index_to_continuous_time(timestep_index: Tensor) Tensor[source]

Convert timestep index to continuous time.

Given timestep index \(\text{idx}\), compute the continuous time:

\[t = t_0 + \frac{\text{idx} + 1 - \text{idx\_start}}{N} \cdot T\]

where \(t_0\) is continuous_time_start, \(T\) is the total time span, \(N\) is num_train_timesteps, and \(\text{idx\_start}\) is the starting index.

Parameters:

timestep_index (Tensor) – Timestep indices \(\text{idx} \in \{\text{idx\_start}, \ldots, \text{idx\_start} + N - 1\}\).

Returns:

Continuous time values \(t \in [t_1, t_N]\).

Return type:

Tensor