Source code for ls_mlkit.util.base_class.base_time_class
from abc import ABC, abstractmethod
from typing import Tuple
import torch
from torch import Tensor
from ..decorators import inherit_docstrings
[docs]
@inherit_docstrings
class BaseTimeScheduler(ABC):
def __init__(
self,
continuous_time_start: float = 0.0,
continuous_time_end: float = 1.0,
num_train_timesteps: int = 1000,
num_inference_steps: int = None,
) -> None:
self.continuous_time_start: float = continuous_time_start
self.continuous_time_end: float = continuous_time_end
self.num_train_timesteps: int = num_train_timesteps
self.num_inference_timesteps: int = (
num_inference_steps if num_inference_steps is not None else num_train_timesteps
)
self._discrete_timesteps: list[int] = None
self._continuous_timesteps: list[float] = None
self.T: float = continuous_time_end - continuous_time_start
self.initialize_timesteps_schedule()
[docs]
@abstractmethod
def initialize_timesteps_schedule(self) -> None:
"""initialize timesteps for sampling."""
[docs]
def continuous_time_to_discrete_time(self, continuous_time: Tensor) -> Tensor:
"""Convert a continuous time to a discrete time.
.. math::
t = \round(\frac{t'}{T} \cdot N - 1)
Args:
continuous_time (Tensor): the continuous time
Returns:
Tensor: the discrete time
"""
return torch.clamp(
torch.round(continuous_time / self.T * (self.num_train_timesteps) - 1),
min=0,
max=self.num_train_timesteps - 1,
).to(torch.int64)
[docs]
def discrete_time_to_continuous_time(self, discrete_time: Tensor) -> Tensor:
"""Convert a discrete time to a continuous time.
.. math::
t' = \frac{(t + 1)}{N} \cdot T
Args:
discrete_time (Tensor): the discrete time
Returns:
Tensor: the continuous time
"""
return 1.0 * (discrete_time + 1) / (self.num_train_timesteps) * self.T
[docs]
def get_discrete_timesteps_schedule(self) -> Tensor:
"""Get the discrete one-dimensional timesteps for sampling/inference.
Returns:
Tensor: the discrete one-dimensional timesteps for sampling/inference
"""
assert self._discrete_timesteps is not None, "discrete_timesteps is not set"
assert isinstance(self._discrete_timesteps, Tensor), "discrete_timesteps must be a Tensor"
assert self._discrete_timesteps.ndim == 1, "discrete_timesteps must be a one-dimensional Tensor"
return self._discrete_timesteps
[docs]
def get_continuous_timesteps_schedule(self) -> Tensor:
"""Get the continuous one-dimensional timesteps for sampling/inference.
Returns:
Tensor: the continuous one-dimensional timesteps for sampling/inference
"""
assert self._continuous_timesteps is not None, "_continuous_timesteps is not set"
assert isinstance(self._continuous_timesteps, Tensor), "continuous_timesteps must be a Tensor"
assert self._continuous_timesteps.ndim == 1, "continuous_timesteps must be a one-dimensional Tensor"
return self._continuous_timesteps
[docs]
def set_discrete_timesteps_schedule(self, discrete_timesteps: Tensor) -> None:
"""Set the discrete one-dimensional timesteps for sampling/inference.
Args:
discrete_timesteps (``Tensor``): the discrete one-dimensional timesteps for sampling/inference.
"""
self._discrete_timesteps = discrete_timesteps
[docs]
def set_continuous_timesteps_schedule(self, continuous_timesteps: Tensor) -> None:
"""Set the continuous one-dimensional timesteps for sampling/inference.
Args:
continuous_timesteps (``Tensor``): the continuous one-dimensional timesteps for sampling/inference.
"""
self._continuous_timesteps = continuous_timesteps