ls_mlkit.util.base_class.base_gm_class module¶
- class ls_mlkit.util.base_class.base_gm_class.BaseGenerativeModel(config: BaseGenerativeModelConfig)[source]¶
Bases:
BaseLossabstract method: compute_loss, step, sampling, inpainting
- forward(**batch) dict | Tensor[source]¶
Forward function, input batch of data and return the dictionary containing the loss
- Parameters:
batch (
dict[str, Any]) – the batch of data- Returns:
a dictionary that must contain the key “loss” or a tensor of loss
- Return type:
dict|Tensor
- abstractmethod inpainting(x, padding_mask, inpainting_mask, device, x_init_posterior=None, inpainting_mask_key='inpainting_mask', sapmling_condition_key='sapmling_condition', return_all=False, sampling_condition=None, *args, **kwargs) dict[source]¶
_summary_
- Parameters:
x (
_type_) – _description_padding_mask (
_type_) – _description_inpainting_mask (
_type_) – _description_device (
_type_) – _description_x_init_posterior (
_type_, optional) – _description_. Defaults to None.inpainting_mask_key (
str, optional) – _description_. Defaults to “inpainting_mask”.sapmling_condition_key (
str, optional) – _description_. Defaults to “sapmling_condition”.return_all (
bool, optional) – _description_. Defaults to False.sampling_condition (
_type_, optional) – _description_. Defaults to None.
- Returns:
_description_
- Return type:
dict
- abstractmethod prior_sampling(shape: tuple[int, ...]) Tensor[source]¶
prior sampling
- Parameters:
shape (
tuple[int, ...]) – the shape of the sample- Returns:
data from prior distribution
- Return type:
Tensor
- register_hook(hook: GMHook) GMHookHandler[source]¶
- register_hooks(hooks: list[GMHook]) list[GMHookHandler][source]¶
- register_post_compute_loss_hook(name: str, fn: Callable[[...], Any], priority: int = 0, enabled: bool = True) GMHookHandler[source]¶
Register a hook to be called after loss computation
- Parameters:
name (
str) – the name of the hookfn (
Callable[..., Any]) – the function to be calledpriority (
int, optional) – the priority of the hook. Defaults to 0.enabled (
bool, optional) – whether the hook is enabled. Defaults to True.
- abstractmethod sampling(shape, device, x_init_posterior=None, return_all=False, sampling_condition=None, sapmling_condition_key='sapmling_condition', *args, **kwargs) dict[source]¶
_summary_
- Parameters:
shape (
_type_) – _description_device (
_type_) – _description_x_init_posterior (
_type_, optional) – _description_. Defaults to None.return_all (
bool, optional) – _description_. Defaults to False.sampling_condition (
_type_, optional) – _description_. Defaults to None.sapmling_condition_key (
str, optional) – _description_. Defaults to “sapmling_condition”.
- Returns:
_description_
- Return type:
dict
- abstractmethod step(x_t: Tensor, t: Tensor, padding_mask: Tensor = None, *args: list[Any], **kwargs: dict[Any, Any]) dict[source]¶
_summary_
- Parameters:
x_t (
Tensor) – _description_t (
Tensor) – _description_padding_mask (
Tensor, optional) – _description_. Defaults to None.
- Returns:
A dictionary that must contain the key “x”
- Return type:
dict
- class ls_mlkit.util.base_class.base_gm_class.BaseGenerativeModelConfig(ndim_micro_shape: int, n_discretization_steps: int, use_batch_flattening: bool = False, n_inference_steps: int = None, *args: list[Any], **kwargs: dict[Any, Any])[source]¶
Bases:
BaseLossConfig
- class ls_mlkit.util.base_class.base_gm_class.GMHook(name: str, stage: HookStageType, fn: Callable[[...], Any | None], priority: int = 0, enabled: bool = True)[source]¶
Bases:
Hook[GMHookStageType]
- class ls_mlkit.util.base_class.base_gm_class.GMHookHandler(manager: HookManager[HookStageType], hook: Hook[HookStageType])[source]¶
Bases:
HookHandler[GMHookStageType]
- class ls_mlkit.util.base_class.base_gm_class.GMHookManager[source]¶
Bases:
HookManager[GMHookStageType]
- class ls_mlkit.util.base_class.base_gm_class.GMHookStageType(*values)[source]¶
Bases:
Enum- POST_COMPUTE_LOSS = 'post_compute_loss'¶
- POST_GET_MACRO_SHAPE = 'get_macro_shape'¶
- POST_SAMPLING_TIME_STEP = 'post_sampling_time_step'¶
- POST_UPDATE_IN_STEP_FN = 'post_update_in_step_fn'¶
- PRE_COMPUTE_LOSS = 'pre_compute_loss'¶
- PRE_UPDATE_IN_STEP_FN = 'pre_update_in_step_fn'¶