Source code for ls_mlkit.util.shape

from typing import Tuple, Union

from torch import Tensor


[docs] def get_macroscopic_shape(obj: Union[Tensor, tuple], ndim_microscopic: int) -> Tuple[int]: """ Get the macroscopic shape of an object. """ if isinstance(obj, tuple): if len(obj) == ndim_microscopic: result = (1,) else: result = obj[:-ndim_microscopic] elif isinstance(obj, Tensor): if obj.ndim == ndim_microscopic: result = (1,) else: result = obj.shape[:-ndim_microscopic] else: raise ValueError(f"Invalid type: {type(obj)}") return result
[docs] def show_shape(x, prefix=""): import numpy as np import torch if torch.is_tensor(x): print(prefix, "Tensor", tuple(x.shape), x.dtype, x.device) elif isinstance(x, dict): for k, v in x.items(): show_shape(v, prefix + f"{k}: ") elif isinstance(x, (list, tuple)): print(prefix, type(x), len(x)) for i, v in enumerate(x[:3]): show_shape(v, prefix + f"[{i}] ") elif isinstance(x, np.ndarray): print(prefix, "ndarray", x.shape, x.dtype) else: print(prefix, type(x))