Source code for ls_mlkit.util.observer

from typing import Callable, Dict, List, Literal

import torch
from datasets import Dataset as HFDataset
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import Dataset


[docs] def weight_norm_fn(module: Module): """Compute the weight norm of a module Args: module (Module): the module to compute the weight norm Returns: float: the weight norm of the module """ return torch.sqrt(sum(torch.sum(p.data * p.data) for p in module.parameters() if p.requires_grad))
[docs] def gradient_norm_fn(module: Module): """Compute the gradient norm of a module Args: module (Module): the module to compute the gradient norm Returns: float: the gradient norm of the module """ return torch.sqrt(sum(torch.sum(p.grad.data * p.grad.data) for p in module.parameters() if p.grad is not None))
[docs] def weights_fn(module: Module) -> list[Tensor]: """Get the weights of a module Args: module (Module): the module to get the weights Returns: list: the weights of the module """ return [p.detach().cpu() for p in module.parameters() if p.requires_grad]
[docs] def gradients_fn(module: Module) -> list[Tensor]: """Get the gradients of a module Args: module (Module): the module to get the gradients Returns: list: the gradients of the module """ return [p.grad.detach().cpu() for p in module.parameters() if p.grad is not None]
[docs] class Observer(object): function_mapping = { "weight_norm": weight_norm_fn, "gradient_norm": gradient_norm_fn, "weights": weights_fn, "gradients": gradients_fn, } def __init__( self, model: Module = None, optimizer: Optimizer = None, scheduler: LambdaLR = None, dataset: Dataset | HFDataset = None, target_modules: List[str] = None, no_split_classes: List[str] = None, ): """Initialize the Observer Args: model (Module, optional): the model to observe. Defaults to None. optimizer (Optimizer, optional): the optimizer to observe. Defaults to None. scheduler (LambdaLR, optional): the scheduler to observe. Defaults to None. dataset (Dataset | HFDataset, optional): the dataset to observe. Defaults to None. target_modules (List[str], optional): the modules to observe. Defaults to None. if target_modules is not None, then no_split_classes and strategy is ignored. no_split_classes (List[str], optional): the classes to not split. Defaults to None. """ self.model = model self.optimizer = optimizer self.scheduler = scheduler self.dataset = dataset self.no_split_classes = no_split_classes self.target_modules = target_modules # get something================================================================= @torch.no_grad() @staticmethod def _get_something( model: Module, strategy: Literal["all", "block"] = "all", no_split_classes: List[str] = None, function: Callable = None, ): info = dict() def __get_something(module: Module, prefix=""): if ( len(list(module.named_children())) == 0 or (no_split_classes is not None and module.__class__.__name__ in no_split_classes) ) and any(param.requires_grad for param in module.parameters()): info[prefix] = function(module) return for name, sub_module in module.named_children(): sub_module_name = f"{prefix}.{name}" if prefix != "" else name __get_something(sub_module, sub_module_name) match strategy: case "all": something = function(model) return {"total_model": something} case "block": __get_something(model, "") return info case _: raise ValueError(f"Unsupported strategy: {strategy}") @torch.no_grad() @staticmethod def _get_target_modules(model: Module, target_modules: List[str]): info = dict() def __get_target_modules(module: Module, prefix=""): if any(target_module in prefix for target_module in target_modules): info[prefix] = module return for name, sub_module in module.named_children(): sub_module_name = f"{prefix}.{name}" if prefix != "" else name __get_target_modules(sub_module, sub_module_name) __get_target_modules(model, "") return info @torch.no_grad() @staticmethod def _get_something_from_targets( model: Module = None, target_modules_dict: Dict[str, Module] = None, target_modules: List[str] = None, function: Callable = None, ): info = dict() if target_modules_dict is None: target_modules_dict = Observer._get_target_modules(model, target_modules) for module_path, module in target_modules_dict.items(): info[module_path] = function(module) return info
[docs] @torch.no_grad() def get_something_from_targets(self, function: Callable): return Observer._get_something_from_targets( model=self.model, target_modules_dict=None, target_modules=self.target_modules, function=function, )
[docs] @torch.no_grad() def get_something( self, name, strategy: Literal["all", "block"] = "all", no_split_classes: List[str] = None, ): if self.target_modules is None: if no_split_classes is None: no_split_classes = self.no_split_classes return Observer._get_something( model=self.model, strategy=strategy, no_split_classes=no_split_classes, function=Observer.function_mapping[name], ) return self.get_something_from_targets(function=Observer.function_mapping[name])
[docs] @torch.no_grad() def get_weight_norm( self, strategy: Literal["all", "block"] = "all", no_split_classes: List[str] = None, ): return self.get_something("weight_norm", strategy, no_split_classes)
[docs] @torch.no_grad() def get_gradient_norm( self, strategy: Literal["all", "block"] = "all", no_split_classes: List[str] = None, ): return self.get_something("gradient_norm", strategy, no_split_classes)
[docs] @torch.no_grad() def get_weights( self, strategy: Literal["all", "block"] = "all", no_split_classes: List[str] = None, ): return self.get_something("weights", strategy, no_split_classes)
[docs] @torch.no_grad() def get_gradients( self, strategy: Literal["all", "block"] = "all", no_split_classes: List[str] = None, ): return self.get_something("gradients", strategy, no_split_classes)
@torch.no_grad() @staticmethod def _get_statistics(data: List[Tensor]): flattened_tensor = torch.cat([item.reshape(-1) for item in data], dim=0) mean = flattened_tensor.mean() std = flattened_tensor.std() median = flattened_tensor.median() var = flattened_tensor.var() return {"mean": mean, "std": std, "median": median, "variance": var}
[docs] @torch.no_grad() def get_statistics( self, name, strategy: Literal["all", "block"] = "all", no_split_classes: List[str] = None, ): something = self.get_something(name, strategy=strategy, no_split_classes=no_split_classes) return {key: Observer._get_statistics(value) for key, value in something.items()}