Skip to content

Commit

Permalink
[Misc][LoRA] Support loading LoRA weights for target_modules in reg f…
Browse files Browse the repository at this point in the history
  • Loading branch information
jeejeelee authored Oct 11, 2024
1 parent e808156 commit 36ea790
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 5 deletions.
5 changes: 5 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,11 @@ def baichuan_zero_lora_files():
return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")


@pytest.fixture(scope="session")
def baichuan_regex_lora_files():
return snapshot_download(repo_id="jeeejeee/baichuan-7b-lora-zero-regex")


@pytest.fixture(scope="session")
def minicpmv_lora_files():
return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")
Expand Down
17 changes: 15 additions & 2 deletions tests/lora/test_lora_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
from vllm.lora.models import LoRAModel
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM

lora_lst = ["baichuan7B", "baichuan7B-zero", "chatglm3-6b"]
lora_lst = [
"baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"
]


@pytest.mark.parametrize("lora_name", lora_lst)
def test_load_checkpoints(
lora_name,
baichuan_lora_files,
baichuan_zero_lora_files,
baichuan_regex_lora_files,
chatglm3_lora_files,
):
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
Expand All @@ -36,7 +39,7 @@ def test_load_checkpoints(
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)
elif lora_name == "baichuan7B-zero":
#Test that the target_modules contain prefix
# Test that the target_modules contain prefix
# such as "model.layers.0.self_atten.W_pack", and
# the test should pass.
LoRAModel.from_local_checkpoint(
Expand All @@ -46,6 +49,16 @@ def test_load_checkpoints(
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)
elif lora_name == "baichuan7B-zero-regex":
# Test that the `target_modules` in the form of regular expressions,
# such as `model\\..*(W_pack|o_proj)`, and the test should pass.
LoRAModel.from_local_checkpoint(
baichuan_regex_lora_files,
expected_lora_modules,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)
else:
# For the baichuan7B model, load chatglm3-6b's LoRA,
# and the test should raise the following error.
Expand Down
7 changes: 5 additions & 2 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.punica import PunicaWrapper
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
is_regex_target_modules,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.module_mapping import MultiModelKeys
Expand Down Expand Up @@ -233,6 +234,8 @@ def from_local_checkpoint(
# modules.
unexpected_modules = []
target_modules = config["target_modules"]
if not isinstance(target_modules, list):
target_modules = [target_modules]
for module in target_modules:
# Compatible with more modules,
# such as:layers.11.self_attn.k_proj
Expand All @@ -243,8 +246,8 @@ def from_local_checkpoint(
# expected_lora_modules. It is not reliable. See
# https://github.com/vllm-project/vllm/pull/5909. But there's no
# other better mechanism.
if unexpected_modules:
print(unexpected_modules, "modules")
if unexpected_modules and not is_regex_target_modules(
config["target_modules"], expected_lora_modules):
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
Expand Down
35 changes: 34 additions & 1 deletion vllm/lora/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from typing import List, Optional, Set, Tuple, Type
import re
from typing import List, Optional, Set, Tuple, Type, Union

import huggingface_hub
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
Expand Down Expand Up @@ -113,6 +114,38 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
raise ValueError(f"{name} is unsupported LoRA weight")


def is_regex_target_modules(load_modules: Union[str, List[str]],
expected_lora_modules: List[str]) -> bool:
"""
PEFT supports passing `target_modules` in the form of regular expressions,
such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to
determine whether the suffix in the regular expression is present in the
`expected_lora_modules`.
"""

def is_valid_regex(pattern):
try:
re.compile(pattern)
return True
except re.error:
return False

def is_subset(sub_list, full_list):
return set(sub_list).issubset(set(full_list))

# Similar to PEFT's processing logic, regex-related operations are only
# executed when the load_modules is a `str`.
if not isinstance(load_modules, str):
return False

if is_valid_regex(load_modules):
match = re.search(r"\((.*?)\)\$?$", load_modules)
if match:
suffix = match.group(1).split("|")
return is_subset(suffix, expected_lora_modules)
return False


def get_adapter_absolute_path(lora_path: str) -> str:
"""
Resolves the given lora_path to an absolute local path.
Expand Down

0 comments on commit 36ea790

Please sign in to comment.