diff --git a/tests/lora/test_lora_checkpoints.py b/tests/lora/test_lora_checkpoints.py new file mode 100644 index 0000000000000..35ad7342944cd --- /dev/null +++ b/tests/lora/test_lora_checkpoints.py @@ -0,0 +1,40 @@ +import pytest + +from vllm.lora.models import LoRAModel +from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM + + +@pytest.mark.parametrize("lora_name", ["baichuan7B", "chatglm3-6b"]) +def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files): + supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules + packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping + embedding_modules = BaiChuanBaseForCausalLM.embedding_modules + embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules + expected_lora_modules = [] + for module in supported_lora_modules: + if module in packed_modules_mapping: + expected_lora_modules.extend(packed_modules_mapping[module]) + else: + expected_lora_modules.append(module) + if lora_name == "baichuan7B": + # For the baichuan7B model, load it's LoRA, + # and the test should pass. + LoRAModel.from_local_checkpoint( + baichuan_lora_files, + expected_lora_modules, + lora_model_id=1, + device="cpu", + embedding_modules=embedding_modules, + embedding_padding_modules=embed_padding_modules) + else: + # For the baichuan7B model, load chatglm3-6b's LoRA, + # and the test should raise the following error. + expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501 + with pytest.raises(ValueError, match=expected_error): + LoRAModel.from_local_checkpoint( + chatglm3_lora_files, + expected_lora_modules, + lora_model_id=1, + device="cpu", + embedding_modules=embedding_modules, + embedding_padding_modules=embed_padding_modules) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 945917a5aa86b..62f1502458008 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -191,6 +191,7 @@ def from_lora_tensors( def from_local_checkpoint( cls, lora_dir: str, + expected_lora_modules: List[str], lora_model_id: Optional[int] = None, device: str = "cuda", dtype: Optional[torch.dtype] = None, @@ -206,6 +207,20 @@ def from_local_checkpoint( lora_dir, "new_embeddings.safetensors") new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin") + with open(lora_config_path) as f: + config = json.load(f) + target_modules = config["target_modules"] + unexpected_modules = [] + for module in target_modules: + if module not in expected_lora_modules: + unexpected_modules.append(module) + # loaded lora's target modules must be a subset of expected_lora_modules + if unexpected_modules: + raise ValueError( + f"While loading {lora_dir}, expected" + f" target modules in {expected_lora_modules}" + f" but received {unexpected_modules}." + f" Please verify that the loaded LoRA module is correct") if os.path.isfile(lora_tensor_path): tensors = safetensors.torch.load_file(lora_tensor_path) elif os.path.isfile(lora_bin_file_path): @@ -220,8 +235,6 @@ def from_local_checkpoint( elif os.path.isfile(new_embeddings_bin_file_path): embeddings = torch.load(new_embeddings_bin_file_path) - with open(lora_config_path) as f: - config = json.load(f) rank = config["r"] lora_alpha = config["lora_alpha"] return cls.from_lora_tensors( diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 3224b3a9e3eb0..a0868defbd3ca 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -136,8 +136,19 @@ def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: try: + model = self._lora_manager.model + supported_lora_modules = model.supported_lora_modules + packed_modules_mapping = model.packed_modules_mapping + expected_lora_modules = [] + for module in supported_lora_modules: + if module in packed_modules_mapping: + expected_lora_modules.extend( + packed_modules_mapping[module]) + else: + expected_lora_modules.append(module) lora = self._lora_model_cls.from_local_checkpoint( lora_request.lora_local_path, + expected_lora_modules, lora_model_id=lora_request.lora_int_id, device="cpu", dtype=self.lora_config.lora_dtype,