ls_mlkit.flow_matching.euclidean_ot_fm module¶
- class ls_mlkit.flow_matching.euclidean_ot_fm.EuclideanOTFlow(config: EuclideanOTFlowConfig, time_scheduler: FlowMatchingTimeScheduler, masker: MaskerInterface, model: Module, loss_fn: Callable)[source]¶
Bases:
BaseFlow- compute_loss(batch, *args, **kwargs)[source]¶
Compute loss
- Parameters:
batch (
dict[str, Any]) – the batch of data- Returns:
a dictionary that must contain the key “loss” or a tensor of loss
- Return type:
dict``|``Tensor
- get_condition_post_compute_loss_hook(conditioner_list: list[Conditioner])[source]¶
- get_condition_pre_update_in_step_fn_hook(conditioner_list: list[Conditioner])[source]¶
- inpainting(x, padding_mask, inpainting_mask, device, x_init_posterior=None, inpainting_mask_key='inpainting_mask', sapmling_condition_key='sapmling_condition', return_all=False, sampling_condition=None, *args, **kwargs) dict[source]¶
_summary_
- Parameters:
x (
_type_) – _description_padding_mask (
_type_) – _description_inpainting_mask (
_type_) – _description_device (
_type_) – _description_x_init_posterior (
_type_, optional) – _description_. Defaults to None.inpainting_mask_key (
str, optional) – _description_. Defaults to “inpainting_mask”.sapmling_condition_key (
str, optional) – _description_. Defaults to “sapmling_condition”.return_all (
bool, optional) – _description_. Defaults to False.sampling_condition (
_type_, optional) – _description_. Defaults to None.
- Returns:
_description_
- Return type:
dict
- prior_sampling(shape) Tensor[source]¶
prior sampling
- Parameters:
shape (
tuple[int, ...]) – the shape of the sample- Returns:
data from prior distribution
- Return type:
Tensor
- recover_bright_region(x_known, x_t, t, padding_mask, inpainting_mask, x_prior, *args, **kwargs) Tensor[source]¶
- sampling(shape, device, x_init_posterior=None, return_all=False, sampling_condition=None, sapmling_condition_key='sampling_condition', *args, **kwargs) dict[source]¶
_summary_
- Parameters:
shape (
_type_) – _description_device (
_type_) – _description_x_init_posterior (
_type_, optional) – _description_. Defaults to None.return_all (
bool, optional) – _description_. Defaults to False.sampling_condition (
_type_, optional) – _description_. Defaults to None.sapmling_condition_key (
str, optional) – _description_. Defaults to “sapmling_condition”.
- Returns:
_description_
- Return type:
dict
- class ls_mlkit.flow_matching.euclidean_ot_fm.EuclideanOTFlowConfig(n_discretization_steps: int, ndim_micro_shape: int = 2, n_inference_steps: int = None, *args: list[Any], **kwargs: dict[Any, Any])[source]¶
Bases:
BaseFlowConfig