Source code for ls_mlkit.util.base_class.base_loss_class
r"""
Base Diffuser Config and Base Diffuser.
"""
import abc
from typing import Any
from torch import Tensor
from torch.nn import Module
from ..decorators import inherit_docstrings
from ..shape_class import Shape, ShapeConfig
from .base_config_class import DeviceConfig
[docs]
@inherit_docstrings
class BaseLossConfig(DeviceConfig):
def __init__(
self, ndim_micro_shape: int, use_batch_flattening: bool = False, *args: list[Any], **kwargs: dict[Any, Any]
):
super().__init__(*args, **kwargs)
self.ndim_micro_shape: int = ndim_micro_shape
self.use_batch_flattening = use_batch_flattening
[docs]
@inherit_docstrings
class BaseLoss(Module, abc.ABC):
r"""
abstract method: compute_loss
"""
def __init__(
self,
config: BaseLossConfig,
):
Module.__init__(self)
abc.ABC.__init__(self)
self.config: BaseLossConfig = config
self.shape_util = Shape(
config=ShapeConfig(ndim_micro_shape=config.ndim_micro_shape),
)
[docs]
@abc.abstractmethod
def compute_loss(self, **batch) -> dict | Tensor:
r"""Compute loss
Args:
batch (``dict[str, Any]``): the batch of data
Returns:
``dict``|``Tensor``: a dictionary that must contain the key "loss" or a tensor of loss
"""
[docs]
def get_macro_shape(self, x: Tensor) -> tuple[int, ...]:
return self.shape_util.get_macro_shape(x)
[docs]
def complete_micro_shape(self, x: Tensor) -> Tensor:
return self.shape_util.complete_micro_shape(x)