Source code for ls_mlkit.flow_matching.time_scheduler
import torch
from ..util.base_class.base_time_class import BaseTimeScheduler
from ..util.decorators import inherit_docstrings
[docs]
@inherit_docstrings
class FlowMatchingTimeScheduler(BaseTimeScheduler):
[docs]
def initialize_timesteps_schedule(self) -> None:
# Timestep indices: [idx_start, ..., idx_start + N - 1] (ascending for flow matching)
idx_min = self.idx_start
idx_max = self.idx_start + self.num_train_timesteps - 1
if self.num_inference_timesteps == self.num_train_timesteps:
self._timesteps_idx = torch.arange(idx_min, idx_max + 1, dtype=torch.int64)
else:
self._timesteps_idx = torch.linspace(idx_min, idx_max, self.num_inference_timesteps).round().to(torch.int64)
# Continuous times: [t_1, ..., t_N] (ascending), excluding t_0
# t_1 = t_0 + T/N, t_N = t_0 + T
t_min = self.continuous_time_start + self.T / self.num_train_timesteps # t_1
t_max = self.continuous_time_end # t_N
self._continuous_timesteps = torch.linspace(t_min, t_max, self.num_inference_timesteps, dtype=torch.float32)