Source code for ls_mlkit.diffusion.time_scheduler
r"""
Time Scheduler for Diffusion
"""
import torch
from ..util.base_class.base_time_class import BaseTimeScheduler
from ..util.decorators import inherit_docstrings
[docs]
@inherit_docstrings
class DiffusionTimeScheduler(BaseTimeScheduler):
[docs]
def initialize_timesteps_schedule(self) -> None:
# Timestep indices: [idx_start + N - 1, ..., idx_start] (descending for reverse diffusion)
idx_min = self.idx_start
idx_max = self.idx_start + self.num_train_timesteps - 1
if self.num_inference_timesteps == self.num_train_timesteps:
self._timesteps_idx = torch.arange(idx_max, idx_min - 1, -1, dtype=torch.int64)
else:
self._timesteps_idx = (
torch.linspace(idx_min, idx_max, self.num_inference_timesteps).round().flip(0).to(torch.int64)
)
# Continuous times: [t_N, ..., t_1] (descending), excluding t_0
# t_1 = t_0 + T/N, t_N = t_0 + T
t_min = self.continuous_time_start + self.T / self.num_train_timesteps # t_1
t_max = self.continuous_time_end # t_N
self._continuous_timesteps = torch.linspace(t_max, t_min, self.num_inference_timesteps).to(torch.float32)
if __name__ == "__main__":
"""
uv run python -m ls_mlkit.diffusion.time_scheduler
"""
scheduler = DiffusionTimeScheduler(
continuous_time_start=0.0,
continuous_time_end=1.0,
num_train_timesteps=1000,
idx_start=1,
)
print(scheduler.get_timestep_indices_schedule())
print(scheduler.get_continuous_timesteps_schedule())
print(scheduler.timestep_index_to_continuous_time(scheduler.get_timestep_indices_schedule()))
print(scheduler.continuous_time_to_timestep_index(scheduler.get_continuous_timesteps_schedule()))