Source code for ls_mlkit.diffusion.conditioner.conditioner
import abc
from typing import Any
import torch
from torch import Tensor
from ...util.decorators import inherit_docstrings
[docs]
@inherit_docstrings
class Conditioner(abc.ABC):
def __init__(self, guidance_scale: float = 1.0):
self._enabled: bool = True
self.ready: bool = False
self._guidance_scale: float = guidance_scale
[docs]
@abc.abstractmethod
def prepare_condition_dict(self, train: bool = True, *args: list[Any], **kwargs: dict[Any, Any]) -> dict[str, Any]:
r"""Prepare the condition dictionary
Args:
train (``bool``, *optional*): whether the conditioner is used in training. Defaults to True.
Returns:
``dict[str, Any]``: the condition dictionary
"""
[docs]
@abc.abstractmethod
def set_condition(self, *args: list[Any], **kwargs: dict[Any, Any]) -> None:
r"""Set the condition
Args:
*args: additional arguments
**kwargs: additional keyword arguments
"""
[docs]
@abc.abstractmethod
def get_conditional_score(self, x_t: Tensor, t: Tensor, padding_mask: Tensor, *args: Any, **kwargs: Any) -> Tensor:
r"""Get conditional score
Args:
x_t (``Tensor``): the input tensor
t (``Tensor``): the time tensor
padding_mask (``Tensor``): the padding mask
Returns:
``Tensor``: the conditional score
"""
@property
def guidance_scale(self):
return self._guidance_scale
[docs]
def get_guidance_scale(self):
return self._guidance_scale
[docs]
def set_guidance_scale(self, guidance_scale: float):
self._guidance_scale = guidance_scale
[docs]
def enable(self):
self._enabled = True
[docs]
def disable(self):
self._enabled = False
[docs]
def is_enabled(self) -> bool:
return self._enabled
[docs]
@inherit_docstrings
class LGDConditioner(Conditioner):
r"""Loss Guidance Diffusion Conditioner"""
def __init__(
self,
guidance_scale: float = 1.0,
):
super().__init__(guidance_scale)
self.posterior_mean_fn = None
[docs]
@abc.abstractmethod
def compute_conditional_loss(self, p_gt_data: Tensor, padding_mask: Tensor) -> Tensor:
r"""Compute the conditional loss
Args:
p_gt_data (``Tensor``): predicted clean data.
padding_mask (``Tensor``): the padding mask
Returns:
``Tensor``: the conditional loss
"""
[docs]
def get_conditional_score(self, x_t: Tensor, t: Tensor, padding_mask: Tensor, *args: Any, **kwargs: Any) -> Tensor:
r"""Get conditional score
Args:
x_t (``Tensor``): the input tensor
t (``Tensor``): the time tensor
padding_mask (``Tensor``): the padding mask
Returns:
``Tensor``: the conditional score
"""
if not self._enabled:
return torch.zeros_like(x_t, device=x_t.device)
assert self.ready == True, "Conditioner is not ready, please call set_condition first"
with torch.autograd.set_detect_anomaly(True, check_nan=True):
with torch.enable_grad():
x_t = x_t.detach().clone().requires_grad_(True)
p_gt_data = self.posterior_mean_fn(x_t, t, padding_mask)
conditional_loss = self.compute_conditional_loss(p_gt_data, padding_mask)
grad = torch.autograd.grad(conditional_loss, x_t)[0]
score = -grad
return score * self.guidance_scale