Source code for ls_mlkit.diffusion.time_scheduler
r"""
Time Scheduler for Diffusion
"""
import torch
from ..util.base_class.base_time_class import BaseTimeScheduler
from ..util.decorators import inherit_docstrings
[docs]
@inherit_docstrings
class DiffusionTimeScheduler(BaseTimeScheduler):
[docs]
def initialize_timesteps_schedule(self) -> None:
if self.num_inference_timesteps == self.num_train_timesteps:
self._discrete_timesteps = torch.arange(self.num_train_timesteps - 1, -1, -1, dtype=torch.int64)
else:
self._discrete_timesteps = (
torch.linspace(0, self.num_train_timesteps - 1, self.num_inference_timesteps)
.round()
.flip(0)
.to(torch.int64)
)
self._continuous_timesteps = (
torch.linspace(self.continuous_time_start, self.continuous_time_end, self.num_inference_timesteps + 1)
.flip(0)
.to(torch.float32)
)