Source code for ls_mlkit.util.lora
from typing import List
import torch
from peft import LoraConfig, get_peft_model
[docs]
def find_linear_modules(model) -> List[str]:
"""Find the linear modules in a model
Args:
model (torch.nn.Module): the model to find the linear modules
Returns:
List[str]: the names of the linear modules
"""
linear_cls = torch.nn.Linear
output_layer_names = ["lm_head", "embed_tokens"]
module_names = set()
for name, module in model.named_modules():
if isinstance(module, linear_cls) and not any([output_layer in name for output_layer in output_layer_names]):
module_names.add(name.split(".")[-1])
return list(module_names)
[docs]
def get_lora_model(model, lora_config):
"""Get a LoRA model
Args:
model (torch.nn.Module): the model to get the LoRA model
lora_config (LoraConfig): the LoRA configuration
Returns:
torch.nn.Module: the LoRA model
"""
taget_modules = find_linear_modules(model)
lora_config = LoraConfig(
r=lora_config["lora_r"],
target_modules=taget_modules,
lora_alpha=lora_config["lora_alpha"],
lora_dropout=lora_config["lora_dropout"],
)
model = get_peft_model(model, lora_config)
return model