ls_mlkit.util.base_class.base_gm_class module

class ls_mlkit.util.base_class.base_gm_class.BaseGenerativeModel(config: BaseGenerativeModelConfig)[source]

Bases: BaseLoss

abstract method: compute_loss, step, sampling, inpainting

forward(**batch) dict | Tensor[source]

Forward function, input batch of data and return the dictionary containing the loss

Parameters:

batch (dict[str, Any]) – the batch of data

Returns:

a dictionary that must contain the key “loss” or a tensor of loss

Return type:

dict | Tensor

abstractmethod inpainting(x, padding_mask, inpainting_mask, device, x_init_posterior=None, inpainting_mask_key='inpainting_mask', sapmling_condition_key='sapmling_condition', return_all=False, sampling_condition=None, *args, **kwargs) dict[source]

_summary_

Parameters:
  • x (_type_) – _description_

  • padding_mask (_type_) – _description_

  • inpainting_mask (_type_) – _description_

  • device (_type_) – _description_

  • x_init_posterior (_type_, optional) – _description_. Defaults to None.

  • inpainting_mask_key (str, optional) – _description_. Defaults to “inpainting_mask”.

  • sapmling_condition_key (str, optional) – _description_. Defaults to “sapmling_condition”.

  • return_all (bool, optional) – _description_. Defaults to False.

  • sampling_condition (_type_, optional) – _description_. Defaults to None.

Returns:

_description_

Return type:

dict

abstractmethod prior_sampling(shape: tuple[int, ...]) Tensor[source]

prior sampling

Parameters:

shape (tuple[int, ...]) – the shape of the sample

Returns:

data from prior distribution

Return type:

Tensor

register_hook(hook: GMHook) GMHookHandler[source]
register_hooks(hooks: list[GMHook]) list[GMHookHandler][source]
register_post_compute_loss_hook(name: str, fn: Callable[[...], Any], priority: int = 0, enabled: bool = True) GMHookHandler[source]

Register a hook to be called after loss computation

Parameters:
  • name (str) – the name of the hook

  • fn (Callable[..., Any]) – the function to be called

  • priority (int, optional) – the priority of the hook. Defaults to 0.

  • enabled (bool, optional) – whether the hook is enabled. Defaults to True.

abstractmethod sampling(shape, device, x_init_posterior=None, return_all=False, sampling_condition=None, sapmling_condition_key='sapmling_condition', *args, **kwargs) dict[source]

_summary_

Parameters:
  • shape (_type_) – _description_

  • device (_type_) – _description_

  • x_init_posterior (_type_, optional) – _description_. Defaults to None.

  • return_all (bool, optional) – _description_. Defaults to False.

  • sampling_condition (_type_, optional) – _description_. Defaults to None.

  • sapmling_condition_key (str, optional) – _description_. Defaults to “sapmling_condition”.

Returns:

_description_

Return type:

dict

abstractmethod step(x_t: Tensor, t: Tensor, padding_mask: Tensor = None, *args: list[Any], **kwargs: dict[Any, Any]) dict[source]

_summary_

Parameters:
  • x_t (Tensor) – _description_

  • t (Tensor) – _description_

  • padding_mask (Tensor, optional) – _description_. Defaults to None.

Returns:

A dictionary that must contain the key “x”

Return type:

dict

class ls_mlkit.util.base_class.base_gm_class.BaseGenerativeModelConfig(ndim_micro_shape: int, n_discretization_steps: int, use_batch_flattening: bool = False, n_inference_steps: int = None, *args: list[Any], **kwargs: dict[Any, Any])[source]

Bases: BaseLossConfig

class ls_mlkit.util.base_class.base_gm_class.GMHook(name: str, stage: HookStageType, fn: Callable[[...], Any | None], priority: int = 0, enabled: bool = True)[source]

Bases: Hook[GMHookStageType]

class ls_mlkit.util.base_class.base_gm_class.GMHookHandler(manager: HookManager[HookStageType], hook: Hook[HookStageType])[source]

Bases: HookHandler[GMHookStageType]

class ls_mlkit.util.base_class.base_gm_class.GMHookManager[source]

Bases: HookManager[GMHookStageType]

class ls_mlkit.util.base_class.base_gm_class.GMHookStageType(*values)[source]

Bases: Enum

POST_COMPUTE_LOSS = 'post_compute_loss'
POST_GET_MACRO_SHAPE = 'get_macro_shape'
POST_SAMPLING_TIME_STEP = 'post_sampling_time_step'
POST_UPDATE_IN_STEP_FN = 'post_update_in_step_fn'
PRE_COMPUTE_LOSS = 'pre_compute_loss'
PRE_UPDATE_IN_STEP_FN = 'pre_update_in_step_fn'