Source code for ls_mlkit.util.offload.context
import torch.nn
from .gradient_offload import GradientOffloadHookContext
from .model_offload import ModelOffloadHookContext
[docs]
class OffloadContext:
def __init__(
self,
model: torch.nn.Module,
named_grads: dict,
num_block=2,
no_split_module_classes=None,
enable_gradient_offload=True,
enable_model_offload=True,
):
self.modelOffloadHookContext = ModelOffloadHookContext(
model=model,
no_split_module_classes=no_split_module_classes,
num_block=num_block,
enable=enable_model_offload,
# =========================
device="cuda",
strategy="block",
with_backward_hook=False,
)
self.gradientOffloadHookContext = GradientOffloadHookContext(
model=model,
enable=enable_gradient_offload,
record_dict=named_grads,
)
def __enter__(self):
self.modelOffloadHookContext.__enter__()
self.gradientOffloadHookContext.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
self.modelOffloadHookContext.__exit__(exc_type, exc_val, exc_tb)
self.gradientOffloadHookContext.__exit__(exc_type, exc_val, exc_tb)