Source code for ls_mlkit.flow_matching.euclidean_ot_fm

from typing import Any, Callable

import torch
from torch import Tensor
from torch.nn import Module
from tqdm.auto import tqdm

from ..diffusion.conditioner.conditioner import Conditioner
from ..diffusion.conditioner.utils import get_accumulated_conditional_score
from ..util.base_class.base_gm_class import GMHook, GMHookStageType
from ..util.decorators import inherit_docstrings
from ..util.mask.masker_interface import MaskerInterface
from .base_fm import BaseFlow, BaseFlowConfig
from .time_scheduler import FlowMatchingTimeScheduler

EPS = 1e-5


[docs] @inherit_docstrings class EuclideanOTFlowConfig(BaseFlowConfig): def __init__( self, n_discretization_steps: int, ndim_micro_shape: int = 2, n_inference_steps: int = None, *args: list[Any], **kwargs: dict[Any, Any], ) -> None: super().__init__( ndim_micro_shape=ndim_micro_shape, n_discretization_steps=n_discretization_steps, n_inference_steps=n_inference_steps, *args, **kwargs, )
[docs] @inherit_docstrings class EuclideanOTFlow(BaseFlow): def __init__( self, config: EuclideanOTFlowConfig, time_scheduler: FlowMatchingTimeScheduler, masker: MaskerInterface, model: Module, loss_fn: Callable, ) -> None: super().__init__(config=config, time_scheduler=time_scheduler) self.config: EuclideanOTFlowConfig = config self.masker: MaskerInterface = masker self.model: Module = model self.loss_fn = loss_fn
[docs] def compute_loss(self, batch, *args, **kwargs): batch = self.model.prepare_batch_data_for_input(batch) x_1 = batch["gt_data"] device = x_1.device macro_shape: tuple[int, ...] = self.get_macro_shape(x_1) t = batch.get("t", torch.rand(macro_shape, device=device)) padding_mask: Any | None = batch.get("padding_mask", None) copied_t = t.clone().detach() t: torch.Tensor = self.complete_micro_shape(t) x_0: torch.Tensor = torch.randn_like(x_1, device=device) x_t = x_0 * (1 - t) + x_1 * t dx_t = x_1 - x_0 model_input_dict = batch model_input_dict.pop("gt_data") model_input_dict.pop("padding_mask") model_input_dict.pop("t", None) p_dx_t = self.model(x_t, copied_t, padding_mask, **model_input_dict)["x"] loss = self.loss_fn(p_dx_t, dx_t, padding_mask) result = { "loss": loss, "x_1": x_1, "x_t": x_t, "x_0": x_0, "t": t, "p_dx_t": p_dx_t, "padding_mask": padding_mask, "config": self.config, "loss_fn": self.loss_fn, } return result
[docs] def step(self, x_t, t, padding_mask=None, *args, **kwargs) -> dict: device = x_t.device idx = kwargs.get("idx") ones = torch.ones_like(t) t_start = self.time_scheduler.get_continuous_timesteps_schedule().to(device)[idx] * ones t_end = self.time_scheduler.get_continuous_timesteps_schedule().to(device)[idx + 1] * ones copied_t_start = t_start.clone().detach() copied_t_end = t_end.clone().detach() t_start: torch.Tensor = self.complete_micro_shape(copied_t_start) t_end: torch.Tensor = self.complete_micro_shape(copied_t_end) p_dx_t = self.model(x_t, copied_t_start, padding_mask, *args, **kwargs)["x"] hook_input = { "x_t": x_t, "t": copied_t_start, "padding_mask": padding_mask, "p_dx_t": p_dx_t, "sampling_condition": kwargs.get("sampling_condition"), } hook_output = self.hook_manager.run_hooks(GMHookStageType.PRE_UPDATE_IN_STEP_FN, **hook_input) if hook_output is not None: vf = hook_output else: vf = p_dx_t return {"x": x_t + (t_end - t_start) * vf}
[docs] @torch.no_grad() def sampling( self, shape, device, x_init_posterior=None, return_all=False, sampling_condition=None, sapmling_condition_key="sampling_condition", *args, **kwargs, ) -> dict: x_0 = self.prior_sampling(shape).to(device) if x_init_posterior is not None: x_0 = x_init_posterior * EPS + (1 - EPS) * x_0 x_t = x_0 masker = self.masker macro_shape = self.get_macro_shape(x_t) x_list = [x_0] kwargs[sapmling_condition_key] = sampling_condition time_steps = self.time_scheduler.get_discrete_timesteps_schedule().to(device) for idx, t in enumerate(tqdm(time_steps)): t = torch.ones(macro_shape, device=device, dtype=torch.long) * t no_padding_mask = masker.get_full_bright_mask(x_t) kwargs["idx"] = idx x_t = self.step(x_t=x_t, t=t, padding_mask=no_padding_mask, *args, **kwargs)["x"] if return_all: x_list.append(x_t) return {"x": x_t, "x_list": x_list}
[docs] @torch.no_grad() def inpainting( self, x, padding_mask, inpainting_mask, device, x_init_posterior=None, inpainting_mask_key="inpainting_mask", sapmling_condition_key="sapmling_condition", return_all=False, sampling_condition=None, *args, **kwargs, ) -> dict: x_1 = x shape = x_1.shape config = self.config masker = self.masker macro_shape = shape[: -config.ndim_micro_shape] x_0 = self.prior_sampling(shape).to(device) if x_init_posterior is not None: x_0 = x_init_posterior * EPS + (1 - EPS) * x_0 x_t = x_0 x_1 = masker.apply_mask(x_1, padding_mask) x_list = [x_0] kwargs[sapmling_condition_key] = sampling_condition kwargs[inpainting_mask_key] = inpainting_mask timesteps = self.time_scheduler.get_discrete_timesteps_schedule().to(device) for idx, t in enumerate(tqdm(timesteps)): kwargs["idx"] = idx t = torch.ones(macro_shape, device=device, dtype=torch.long) * t x_t = self.recover_bright_region( x_known=x_1, x_t=x_t, t=t, padding_mask=padding_mask, inpainting_mask=inpainting_mask, x_prior=x_0, **kwargs, ) x_t = self.step(x_t=x_t, t=t, padding_mask=padding_mask, *args, **kwargs)["x"] x_t = masker.apply_mask(x_t, padding_mask) if return_all: x_list.append(x_t) x_t = masker.apply_inpainting_mask(x_1, x_t, inpainting_mask) return {"x": x_t, "x_list": x_list}
[docs] def prior_sampling(self, shape) -> torch.Tensor: return torch.randn(shape)
[docs] def recover_bright_region(self, x_known, x_t, t, padding_mask, inpainting_mask, x_prior, *args, **kwargs) -> Tensor: device = x_t.device idx = kwargs.get("idx") t_start = self.time_scheduler.get_continuous_timesteps_schedule().to(device)[idx] t_start = self.complete_micro_shape(t_start) x_1_t = t_start * x_known + (1 - t_start) * x_prior x_t = self.masker.apply_inpainting_mask(x_1_t, x_t, inpainting_mask) return x_t
[docs] def get_posterior_mean_fn(self, vf, vf_fn=None): def _otfm_posterior_mean_fn(x_t, t, padding_mask): nonlocal vf, vf_fn assert vf is not None or vf_fn is not None, "Either vf or vf_fn must be provided" if vf is None: vf = vf_fn(x_t, t, padding_mask) t = t.view(*t.shape, *([1] * (vf.ndim - t.ndim))) x_1 = (1 - t) / (t + EPS) * (t * vf - x_t) + 1 / (t + EPS) * x_t return x_1 return _otfm_posterior_mean_fn
[docs] def get_condition_post_compute_loss_hook(self, conditioner_list: list[Conditioner]): def _hook_fn(**kwargs): nonlocal conditioner_list loss = kwargs.get("loss") x_0 = kwargs.get("x_0") x_t = kwargs.get("x_t") x_1 = kwargs.get("x_1") t = kwargs.get("t", None) padding_mask = kwargs.get("padding_mask") p_dx_t = kwargs.get("p_dx_t") loss_fn = kwargs.get("loss_fn") tgt_mask = padding_mask for conditioner in conditioner_list: if not conditioner.is_enabled(): continue conditioner.set_condition( **{ **conditioner.prepare_condition_dict( train=True, **{ "tgt_mask": tgt_mask, "gt_data": x_1, "padding_mask": padding_mask, "posterior_mean_fn": self.get_posterior_mean_fn(vf=p_dx_t, vf_fn=None), }, ), } ) acc_c_score = get_accumulated_conditional_score(conditioner_list, x_t, t, padding_mask) gt_vf = (x_1 - x_0) + acc_c_score # Scale and compute conditioned loss p_vf = p_dx_t total_loss = loss_fn(p_vf, gt_vf, padding_mask) kwargs["loss"] = total_loss return kwargs return GMHook( name="OTFM_condition_post_compute_loss_hook", stage=GMHookStageType.POST_COMPUTE_LOSS, fn=_hook_fn, priority=0, enabled=True, )
[docs] def get_condition_pre_update_in_step_fn_hook(self, conditioner_list: list[Conditioner]): def _hook_fn(**kwargs): nonlocal conditioner_list x_t = kwargs.get("x_t") t = kwargs.get("t", None) padding_mask = kwargs.get("padding_mask") p_dx_t = kwargs.get("p_dx_t") sampling_condition = kwargs.get("sampling_condition") tgt_mask = padding_mask for conditioner in conditioner_list: if not conditioner.is_enabled(): continue conditioner.set_condition( **{ **conditioner.prepare_condition_dict( train=False, **{ "tgt_mask": tgt_mask, "sampling_condition": sampling_condition, "padding_mask": padding_mask, "posterior_mean_fn": self.get_posterior_mean_fn(vf=p_dx_t, vf_fn=None), }, ), } ) acc_c_score = get_accumulated_conditional_score(conditioner_list, x_t, t, padding_mask) vf = p_dx_t + acc_c_score return vf return GMHook( name="OTFM_condition_pre_update_in_step_fn_hook", stage=GMHookStageType.PRE_UPDATE_IN_STEP_FN, fn=_hook_fn, priority=0, enabled=True, )