Source code for ls_mlkit.util.shape_class

from torch import Tensor

from .decorators import inherit_docstrings


[docs] @inherit_docstrings class ShapeConfig: def __init__( self, ndim_micro_shape: int, ): self.ndim_micro_shape: int = ndim_micro_shape
[docs] @inherit_docstrings class Shape(object): def __init__( self, config: ShapeConfig, ): super().__init__() self.config: ShapeConfig = config
[docs] def get_macro_shape(self, x: Tensor) -> tuple[int, ...]: r"""Get the macro shape of :math:`x` Args: x (``Tensor``): :math:`x` Returns: ``tuple[int, ...]``: the shape of the macro part of :math:`x` """ return x.shape[: -self.config.ndim_micro_shape]
[docs] def complete_micro_shape(self, x: Tensor) -> Tensor: """Complete the micro shape of :math:`x`, assuming the macro shape is already known Args: x (``Tensor``): :math:`x` Returns: ``Tensor``: :math:`x` with the micro shape completed """ return x.view(*x.shape, *([1] * self.config.ndim_micro_shape))