Source code for ls_mlkit.util.show

import torch
from tabulate import tabulate
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import Dataset


[docs] def table_print_dict(sample, prefix_priority_list=[], show_value=False): table_data = [] # Collect and sort keys first keys = list(sample.keys()) def key_priority(k): for i, prefix in enumerate(prefix_priority_list): if k.startswith(prefix): return (i,) return (len(prefix_priority_list),) keys.sort(key=key_priority) # Build table with sorted keys for k in keys: v = sample[k] shape_str = v.shape if isinstance(v, torch.Tensor) else "" length = len(v) if hasattr(v, "__len__") else "" value = "" if show_value: value = v table_data.append([k, type(v).__name__, shape_str, length, value]) print(tabulate(table_data, headers=["Key", "Type", "Shape", "Length", "Value"], tablefmt="github"))
[docs] def show_info( model: Module | None = None, dataset: Dataset | None = None, optimizer: Optimizer | None = None, batch_size: int = 7, ) -> None: def print_kv_table(info: dict, title: str) -> None: table_data = [] for k, v in info.items(): if isinstance(v, float): v = f"{v:.6g}" table_data.append([k, v]) print("\n" + title) print(tabulate(table_data, headers=["Key", "Value"], tablefmt="github")) def get_model_info(model: Module) -> dict: param_size = 4 # float32 total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) total_bytes = total_params * param_size return { "class": model.__class__.__name__, "total_params": f"{total_params:,}", "trainable_params": f"{trainable_params:,}", "size_mb": total_bytes / (1024**2), "size_gb": total_bytes / (1024**3), } if model is not None: model_info = get_model_info(model) print_kv_table(model_info, title="Model Information") if dataset is not None: dataset_info = { "class": dataset.__class__.__name__, "batch_size": batch_size, } try: dataset_info["dataset_size"] = len(dataset) except TypeError: dataset_info["dataset_size"] = "unknown" print_kv_table(dataset_info, title="Dataset Information") if optimizer is not None: opt_info = { "class": optimizer.__class__.__name__, "num_param_groups": len(optimizer.param_groups), } print_kv_table(opt_info, title="Optimizer Information") for i, group in enumerate(optimizer.param_groups): group_info = { "lr": group.get("lr", "unknown"), "weight_decay": group.get("weight_decay", 0.0), "num_params": sum(p.numel() for p in group["params"]), } print_kv_table(group_info, title=f"Optimizer Param Group {i}")
[docs] def find_tensor_devices(obj, visited=None, path=""): if visited is None: visited = set() if id(obj) in visited: return {} visited.add(id(obj)) devices = {} if isinstance(obj, torch.Tensor): devices[path] = obj.device elif isinstance(obj, (list, tuple)): for idx, item in enumerate(obj): devices.update(find_tensor_devices(item, visited, f"{path}[{idx}]")) elif isinstance(obj, dict): for key, value in obj.items(): devices.update(find_tensor_devices(value, visited, f"{path}['{key}']")) elif hasattr(obj, "__dict__"): for attr, value in vars(obj).items(): devices.update(find_tensor_devices(value, visited, f"{path}.{attr}" if path else attr)) return devices