Source code for ls_mlkit.util.mask.image_masker

from torch import Tensor

from .masker import Masker


[docs] class ImageMasker(Masker):
[docs] def check_mask_shape(self, x: Tensor, mask: Tensor): if self.ndim_mini_micro_shape == 0: if mask.shape[-3] == 1: mask = mask.expand(-1, x.shape[-3], -1, -1) assert x.shape == mask.shape else: assert x.shape[: -self.ndim_mini_micro_shape] == mask.shape