Source code for ls_mlkit.diffusion.euclidean_diffuser

from typing import Any, Tuple

import torch
from torch import Tensor
from tqdm.auto import tqdm

from ..util.decorators import inherit_docstrings
from ..util.mask.masker_interface import MaskerInterface
from .base_diffuser import BaseDiffuser, BaseDiffuserConfig
from .time_scheduler import DiffusionTimeScheduler


[docs] @inherit_docstrings class EuclideanDiffuserConfig(BaseDiffuserConfig): def __init__( self, n_discretization_steps: int = 1000, ndim_micro_shape: int = 2, n_inference_steps: int = None, use_batch_flattening: bool = False, *args, **kwargs, ): super().__init__( n_discretization_steps=n_discretization_steps, ndim_micro_shape=ndim_micro_shape, n_inference_steps=n_inference_steps, use_batch_flattening=use_batch_flattening, *args, **kwargs, )
[docs] @inherit_docstrings class EuclideanDiffuser(BaseDiffuser): def __init__( self, config: EuclideanDiffuserConfig, time_scheduler: DiffusionTimeScheduler, masker: MaskerInterface, ): super().__init__(config=config, time_scheduler=time_scheduler) self.config: EuclideanDiffuserConfig = config self.time_scheduler: DiffusionTimeScheduler = time_scheduler self.masker = masker
[docs] def forward_process_n_step( self, x: Tensor, t: Tensor, next_t: Tensor, padding_mask: Tensor, *args: Any, **kwargs: Any ) -> Tensor: r"""Forward process n step, from t to next_t Args: x (``Tensor``): the sample t (``Tensor``): the timestep next_t (``Tensor``): the next timestep padding_mask (``Tensor``): the padding mask Returns: ``Tensor``: the sample at the next timestep """
[docs] @torch.no_grad() def sampling( self, shape: Tuple[int, ...], device, x_init_posterior: Tensor = None, return_all=False, *args: Any, **kwargs: Any, ) -> dict: config = self.config if x_init_posterior is not None: shape = x_init_posterior.shape macro_shape = shape[: -self.config.ndim_micro_shape] masker = self.masker if x_init_posterior is None: x_t = self.prior_sampling(shape).to(device) else: x_t = x_init_posterior padding_mask = kwargs.get("padding_mask", None) if padding_mask is None: padding_mask = masker.get_full_bright_mask(x_t) x_t = self.forward_process( x_t, torch.ones(macro_shape, device=device, dtype=torch.long) * (config.n_discretization_steps - 1), padding_mask, )["x_t"] x_list = [x_t] E_x0_xt_list = [x_t] time_steps = self.time_scheduler.get_discrete_timesteps_schedule().to(device) for idx, t in enumerate(tqdm(time_steps)): t = torch.ones(macro_shape, device=device, dtype=torch.long) * t no_padding_mask = masker.get_full_bright_mask(x_t) kwargs["idx"] = idx step_output = self.step(x_t=x_t, t=t, padding_mask=no_padding_mask, *args, **kwargs) x_t = step_output["x"] if "E_x0_xt" in step_output: E_x0_xt_list.append(step_output["E_x0_xt"]) if return_all: x_list.append(x_t) return {"x": x_t, "x_list": x_list, "E_x0_xt_list": E_x0_xt_list}
[docs] @torch.no_grad() def inpainting( self, x: Tensor, padding_mask: Tensor, inpainting_mask: Tensor, device, x_init_posterior: Tensor = None, inpainting_mask_key="inpainting_mask", n_repaint_steps: int = 1, return_all=False, *args: Any, **kwargs: Any, ) -> dict: x_0 = x shape = x_0.shape config = self.config macro_shape = shape[: -config.ndim_micro_shape] masker = self.masker # Add inpainting_mask to kwargs so it gets passed to the model kwargs[inpainting_mask_key] = inpainting_mask x_t = None if x_init_posterior is None: x_t = self.prior_sampling(shape).to(device) else: x_t = x_init_posterior x_t = self.forward_process( x_t, torch.ones(macro_shape, device=device, dtype=torch.long) * (config.n_discretization_steps - 1), padding_mask, )["x_t"] x_0 = masker.apply_mask(x_0, padding_mask) x_T = x_t.detach().clone() x_list = [x_t] E_x0_xt_list = [x_t] timesteps = self.time_scheduler.get_discrete_timesteps_schedule().to(device) for i, t in enumerate(tqdm(timesteps)): for u in range(1, n_repaint_steps + 1): t = torch.ones(macro_shape, device=device, dtype=torch.long) * t x_t = self.recover_bright_region( x_known=x_0, x_t=x_t, t=t, inpainting_mask=inpainting_mask, padding_mask=padding_mask, x_prior=x_T ) step_output = self.step(x_t, t, padding_mask, *args, **kwargs) # get x_tm1 x_t = step_output["x"] if "E_x0_xt" in step_output: E_x0_xt_list.append(step_output["E_x0_xt"]) x_t = masker.apply_mask(x_t, padding_mask) if u < n_repaint_steps and (t > 0).all(): assert i < len(timesteps) - 1 prev_t = timesteps[i + 1].to(device) x_t = self.forward_process_n_step(x_t, prev_t, t, padding_mask, *args, **kwargs) if return_all: x_list.append(x_t) x_t = masker.apply_inpainting_mask(x_0, x_t, inpainting_mask) return {"x": x_t, "x_list": x_list, "E_x0_xt_list": E_x0_xt_list}
[docs] def recover_bright_region(self, x_known, x_t, t, padding_mask, inpainting_mask, x_prior) -> Tensor: x_0 = x_known x_0t = self.forward_process(x_0, t, padding_mask)["x_t"] x_t = self.masker.apply_inpainting_mask(x_0t, x_t, inpainting_mask) return x_t