Source code for ls_mlkit.util.offload.forward_hook
import torch
[docs]
class ForwardHookForDevice:
def __init__(self):
pass
[docs]
@staticmethod
def get_align_device_pre_forward_hook(device="cuda", with_kwargs=False):
"""
ensure same device for input and module
"""
def hook(module: torch.nn.Module, args):
if device is not None:
align_device = device
elif len(list(module.parameters())) > 0:
align_device = next(module.parameters()).device
else:
align_device = "cuda"
module.to(align_device)
args = tuple(arg.to(align_device) if isinstance(arg, torch.Tensor) else arg for arg in args)
return args
def hook_with_kwargs(module: torch.nn.Module, args, kwargs):
if device is not None:
align_device = device
elif len(list(module.parameters())) > 0:
align_device = next(module.parameters()).device
else:
align_device = "cuda"
module.to(align_device)
args = tuple(arg.to(align_device) if isinstance(arg, torch.Tensor) else arg for arg in args)
_kwargs = dict()
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
_kwargs[k] = v.to(align_device)
else:
_kwargs[k] = v
kwargs = _kwargs
return args, kwargs
if with_kwargs:
return hook_with_kwargs
else:
return hook
[docs]
@staticmethod
def get_forward_hook(pre: bool, device=None, with_kwargs=False):
"""
device is executing device
origin_device is the device where tensor is saved after forward
"""
origin_device = "cpu"
if device is not None:
device = device
else:
device = "cuda"
def pre_hook(module: torch.nn.Module, args):
module.to(device)
args = tuple(arg.to(device) if isinstance(arg, torch.Tensor) else arg for arg in args)
return args
def after_hook(module: torch.nn.Module, args, output):
module.to(origin_device)
output = output.to(origin_device) if isinstance(output, torch.Tensor) else output
if isinstance(output, tuple):
output = tuple(o.to(origin_device) if isinstance(o, torch.Tensor) else o for o in output)
return output
def pre_hook_with_kwargs(module, args, kwargs):
module.to(device)
args = tuple(arg.to(device) if isinstance(arg, torch.Tensor) else arg for arg in args)
kwargs = {n: v.to(device) if isinstance(v, torch.Tensor) else v for n, v in kwargs.items()}
return args, kwargs
def after_hook_with_kwargs(module, args, kwargs, output):
module.to(origin_device)
output = output.to(origin_device) if isinstance(output, torch.Tensor) else output
if isinstance(output, tuple):
output = tuple(o.to(origin_device) if isinstance(o, torch.Tensor) else o for o in output)
return output
if pre and with_kwargs:
return pre_hook_with_kwargs
elif pre and not with_kwargs:
return pre_hook
elif not pre and with_kwargs:
return after_hook_with_kwargs
elif not pre and not with_kwargs:
return after_hook
[docs]
@staticmethod
def get_full_name_list(model):
"""
Get the module name list of the leaf nodes of the module tree
"""
full_name_list = list()
def _get_full_name_list(module, parent_name=""):
"""
get full name list of all submodule. result is self.
"""
if len(list(module.named_children())) == 0:
full_name_list.append(parent_name)
for name, sub_module in module.named_children():
full_name = f"{parent_name}.{name}" if parent_name else name
_get_full_name_list(sub_module, full_name)
_get_full_name_list(model)
return full_name_list
[docs]
@staticmethod
def get_module_list(model, no_split_module_classes=None):
"""
Get the module name list of the leaf nodes of the module tree,
and stop recursing when the specified node(no_split_module_class) is reached.
"""
module_list = list()
def _get_module_list(module: torch.nn.Module, parent_name=""):
flag = False
if module.__class__.__name__ in no_split_module_classes:
flag = True
if flag:
module_list.append(parent_name)
return
if len(list(module.named_children())) == 0:
module_list.append(parent_name)
return
for name, sub_module in module.named_children():
extend_name = f"{parent_name}.{name}" if parent_name else name
_get_module_list(sub_module, extend_name)
_get_module_list(model)
return module_list