ls_mlkit.util.base_class.base_loss_class module

Base Diffuser Config and Base Diffuser.

class ls_mlkit.util.base_class.base_loss_class.BaseLoss(config: BaseLossConfig)[source]

Bases: Module, ABC

abstract method: compute_loss

complete_micro_shape(x: Tensor) Tensor[source]
abstractmethod compute_loss(**batch) dict | Tensor[source]

Compute loss

Parameters:

batch (dict[str, Any]) – the batch of data

Returns:

a dictionary that must contain the key “loss” or a tensor of loss

Return type:

dict``|``Tensor

get_macro_shape(x: Tensor) tuple[int, ...][source]
class ls_mlkit.util.base_class.base_loss_class.BaseLossConfig(ndim_micro_shape: int, use_batch_flattening: bool = False, *args: list[Any], **kwargs: dict[Any, Any])[source]

Bases: DeviceConfig