Skip to content

Commit

Permalink
[Misc] Avoid loading incorrect LoRA config (#3777)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeejeelee authored Apr 10, 2024
1 parent 6c0b045 commit 11dd6eb
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 2 deletions.
40 changes: 40 additions & 0 deletions tests/lora/test_lora_checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest

from vllm.lora.models import LoRAModel
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM


@pytest.mark.parametrize("lora_name", ["baichuan7B", "chatglm3-6b"])
def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files):
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
expected_lora_modules = []
for module in supported_lora_modules:
if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module])
else:
expected_lora_modules.append(module)
if lora_name == "baichuan7B":
# For the baichuan7B model, load it's LoRA,
# and the test should pass.
LoRAModel.from_local_checkpoint(
baichuan_lora_files,
expected_lora_modules,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)
else:
# For the baichuan7B model, load chatglm3-6b's LoRA,
# and the test should raise the following error.
expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501
with pytest.raises(ValueError, match=expected_error):
LoRAModel.from_local_checkpoint(
chatglm3_lora_files,
expected_lora_modules,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)
17 changes: 15 additions & 2 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def from_lora_tensors(
def from_local_checkpoint(
cls,
lora_dir: str,
expected_lora_modules: List[str],
lora_model_id: Optional[int] = None,
device: str = "cuda",
dtype: Optional[torch.dtype] = None,
Expand All @@ -206,6 +207,20 @@ def from_local_checkpoint(
lora_dir, "new_embeddings.safetensors")
new_embeddings_bin_file_path = os.path.join(lora_dir,
"new_embeddings.bin")
with open(lora_config_path) as f:
config = json.load(f)
target_modules = config["target_modules"]
unexpected_modules = []
for module in target_modules:
if module not in expected_lora_modules:
unexpected_modules.append(module)
# loaded lora's target modules must be a subset of expected_lora_modules
if unexpected_modules:
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct")
if os.path.isfile(lora_tensor_path):
tensors = safetensors.torch.load_file(lora_tensor_path)
elif os.path.isfile(lora_bin_file_path):
Expand All @@ -220,8 +235,6 @@ def from_local_checkpoint(
elif os.path.isfile(new_embeddings_bin_file_path):
embeddings = torch.load(new_embeddings_bin_file_path)

with open(lora_config_path) as f:
config = json.load(f)
rank = config["r"]
lora_alpha = config["lora_alpha"]
return cls.from_lora_tensors(
Expand Down
11 changes: 11 additions & 0 deletions vllm/lora/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,19 @@ def _apply_loras(self, lora_requests: List[LoRARequest]) -> None:

def _load_lora(self, lora_request: LoRARequest) -> LoRAModel:
try:
model = self._lora_manager.model
supported_lora_modules = model.supported_lora_modules
packed_modules_mapping = model.packed_modules_mapping
expected_lora_modules = []
for module in supported_lora_modules:
if module in packed_modules_mapping:
expected_lora_modules.extend(
packed_modules_mapping[module])
else:
expected_lora_modules.append(module)
lora = self._lora_model_cls.from_local_checkpoint(
lora_request.lora_local_path,
expected_lora_modules,
lora_model_id=lora_request.lora_int_id,
device="cpu",
dtype=self.lora_config.lora_dtype,
Expand Down

0 comments on commit 11dd6eb

Please sign in to comment.