ls_mlkit.flow_matching.euclidean_ot_fm module

class ls_mlkit.flow_matching.euclidean_ot_fm.EuclideanOTFlow(config: EuclideanOTFlowConfig, time_scheduler: FlowMatchingTimeScheduler, masker: MaskerInterface, model: Module, loss_fn: Callable)[source]

Bases: BaseFlow

compute_loss(batch, *args, **kwargs)[source]

Compute loss

Parameters:

batch (dict[str, Any]) – the batch of data

Returns:

a dictionary that must contain the key “loss” or a tensor of loss

Return type:

dict``|``Tensor

get_condition_post_compute_loss_hook(conditioner_list: list[Conditioner])[source]
get_condition_pre_update_in_step_fn_hook(conditioner_list: list[Conditioner])[source]
get_posterior_mean_fn(vf, vf_fn=None)[source]
inpainting(x, padding_mask, inpainting_mask, device, x_init_posterior=None, inpainting_mask_key='inpainting_mask', sapmling_condition_key='sapmling_condition', return_all=False, sampling_condition=None, *args, **kwargs) dict[source]

_summary_

Parameters:
  • x (_type_) – _description_

  • padding_mask (_type_) – _description_

  • inpainting_mask (_type_) – _description_

  • device (_type_) – _description_

  • x_init_posterior (_type_, optional) – _description_. Defaults to None.

  • inpainting_mask_key (str, optional) – _description_. Defaults to “inpainting_mask”.

  • sapmling_condition_key (str, optional) – _description_. Defaults to “sapmling_condition”.

  • return_all (bool, optional) – _description_. Defaults to False.

  • sampling_condition (_type_, optional) – _description_. Defaults to None.

Returns:

_description_

Return type:

dict

prior_sampling(shape) Tensor[source]

prior sampling

Parameters:

shape (tuple[int, ...]) – the shape of the sample

Returns:

data from prior distribution

Return type:

Tensor

recover_bright_region(x_known, x_t, t, padding_mask, inpainting_mask, x_prior, *args, **kwargs) Tensor[source]
sampling(shape, device, x_init_posterior=None, return_all=False, sampling_condition=None, sapmling_condition_key='sampling_condition', *args, **kwargs) dict[source]

_summary_

Parameters:
  • shape (_type_) – _description_

  • device (_type_) – _description_

  • x_init_posterior (_type_, optional) – _description_. Defaults to None.

  • return_all (bool, optional) – _description_. Defaults to False.

  • sampling_condition (_type_, optional) – _description_. Defaults to None.

  • sapmling_condition_key (str, optional) – _description_. Defaults to “sapmling_condition”.

Returns:

_description_

Return type:

dict

step(x_t, t, padding_mask=None, *args, **kwargs) dict[source]

_summary_

Parameters:
  • x_t (Tensor) – _description_

  • t (Tensor) – _description_

  • padding_mask (Tensor, optional) – _description_. Defaults to None.

Returns:

A dictionary that must contain the key “x”

Return type:

dict

class ls_mlkit.flow_matching.euclidean_ot_fm.EuclideanOTFlowConfig(n_discretization_steps: int, ndim_micro_shape: int = 2, n_inference_steps: int = None, *args: list[Any], **kwargs: dict[Any, Any])[source]

Bases: BaseFlowConfig