From cfe610192d24d775cabb01bb48007f9e41e7cad5 Mon Sep 17 00:00:00 2001 From: jeejeeli Date: Mon, 1 Apr 2024 20:37:24 +0800 Subject: [PATCH 1/4] coding done --- vllm/lora/models.py | 17 +++++++++++++++-- vllm/lora/worker_manager.py | 11 +++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 945917a5aa86b..62f1502458008 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -191,6 +191,7 @@ def from_lora_tensors( def from_local_checkpoint( cls, lora_dir: str, + expected_lora_modules: List[str], lora_model_id: Optional[int] = None, device: str = "cuda", dtype: Optional[torch.dtype] = None, @@ -206,6 +207,20 @@ def from_local_checkpoint( lora_dir, "new_embeddings.safetensors") new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin") + with open(lora_config_path) as f: + config = json.load(f) + target_modules = config["target_modules"] + unexpected_modules = [] + for module in target_modules: + if module not in expected_lora_modules: + unexpected_modules.append(module) + # loaded lora's target modules must be a subset of expected_lora_modules + if unexpected_modules: + raise ValueError( + f"While loading {lora_dir}, expected" + f" target modules in {expected_lora_modules}" + f" but received {unexpected_modules}." + f" Please verify that the loaded LoRA module is correct") if os.path.isfile(lora_tensor_path): tensors = safetensors.torch.load_file(lora_tensor_path) elif os.path.isfile(lora_bin_file_path): @@ -220,8 +235,6 @@ def from_local_checkpoint( elif os.path.isfile(new_embeddings_bin_file_path): embeddings = torch.load(new_embeddings_bin_file_path) - with open(lora_config_path) as f: - config = json.load(f) rank = config["r"] lora_alpha = config["lora_alpha"] return cls.from_lora_tensors( diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 3224b3a9e3eb0..a0868defbd3ca 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -136,8 +136,19 @@ def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: try: + model = self._lora_manager.model + supported_lora_modules = model.supported_lora_modules + packed_modules_mapping = model.packed_modules_mapping + expected_lora_modules = [] + for module in supported_lora_modules: + if module in packed_modules_mapping: + expected_lora_modules.extend( + packed_modules_mapping[module]) + else: + expected_lora_modules.append(module) lora = self._lora_model_cls.from_local_checkpoint( lora_request.lora_local_path, + expected_lora_modules, lora_model_id=lora_request.lora_int_id, device="cpu", dtype=self.lora_config.lora_dtype, From d25d1a0f34591a2b074e0e5f00f7f1af6aa59285 Mon Sep 17 00:00:00 2001 From: jeejeeli Date: Tue, 9 Apr 2024 18:28:02 +0800 Subject: [PATCH 2/4] add unit test --- tests/lora/test_lora_checkpoints.py | 41 +++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 tests/lora/test_lora_checkpoints.py diff --git a/tests/lora/test_lora_checkpoints.py b/tests/lora/test_lora_checkpoints.py new file mode 100644 index 0000000000000..df68fb58961e2 --- /dev/null +++ b/tests/lora/test_lora_checkpoints.py @@ -0,0 +1,41 @@ +import pytest +from vllm.lora.request import LoRARequest +from vllm.lora.models import LoRAModel +from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM + + +@pytest.mark.parametrize("lora_name", ["baichuan7B", "chatglm3-6b"]) +def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files): + supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules + packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping + embedding_modules = BaiChuanBaseForCausalLM.embedding_modules + embedding_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules + expected_lora_modules = [] + for module in supported_lora_modules: + if module in packed_modules_mapping: + expected_lora_modules.extend(packed_modules_mapping[module]) + else: + expected_lora_modules.append(module) + if lora_name == "baichuan7B": + # For the baichuan7B model, load it's LoRA, + # and the test should pass. + LoRAModel.from_local_checkpoint( + baichuan_lora_files, + expected_lora_modules, + lora_model_id=1, + device="cpu", + embedding_modules=embedding_modules, + embedding_padding_modules=embedding_padding_modules) + else: + # For the baichuan7B model, load chatglm3-6b's LoRA, + # and the test should raise the following error + expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501 + with pytest.raises(ValueError, match=expected_error): + LoRAModel.from_local_checkpoint( + chatglm3_lora_files, + expected_lora_modules, + lora_model_id=1, + device="cpu", + embedding_modules=embedding_modules, + embedding_padding_modules=embedding_padding_modules) + From ccf5e7b31afef026b937e70f6fe5868506a916f0 Mon Sep 17 00:00:00 2001 From: jeejeeli Date: Tue, 9 Apr 2024 18:32:59 +0800 Subject: [PATCH 3/4] code format --- tests/lora/test_lora_checkpoints.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/lora/test_lora_checkpoints.py b/tests/lora/test_lora_checkpoints.py index df68fb58961e2..3d23944a3f8de 100644 --- a/tests/lora/test_lora_checkpoints.py +++ b/tests/lora/test_lora_checkpoints.py @@ -1,5 +1,5 @@ import pytest -from vllm.lora.request import LoRARequest + from vllm.lora.models import LoRAModel from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM @@ -9,7 +9,7 @@ def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files): supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping embedding_modules = BaiChuanBaseForCausalLM.embedding_modules - embedding_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules + embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules expected_lora_modules = [] for module in supported_lora_modules: if module in packed_modules_mapping: @@ -17,7 +17,7 @@ def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files): else: expected_lora_modules.append(module) if lora_name == "baichuan7B": - # For the baichuan7B model, load it's LoRA, + # For the baichuan7B model, load it's LoRA, # and the test should pass. LoRAModel.from_local_checkpoint( baichuan_lora_files, @@ -25,9 +25,9 @@ def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files): lora_model_id=1, device="cpu", embedding_modules=embedding_modules, - embedding_padding_modules=embedding_padding_modules) + embedding_padding_modules=embed_padding_modules) else: - # For the baichuan7B model, load chatglm3-6b's LoRA, + # For the baichuan7B model, load chatglm3-6b's LoRA, # and the test should raise the following error expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501 with pytest.raises(ValueError, match=expected_error): @@ -37,5 +37,4 @@ def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files): lora_model_id=1, device="cpu", embedding_modules=embedding_modules, - embedding_padding_modules=embedding_padding_modules) - + embedding_padding_modules=embed_padding_modules) From 484e20adc6d5fdd5f55b1c18e2caf384ff2b504b Mon Sep 17 00:00:00 2001 From: jeejeeli Date: Wed, 10 Apr 2024 08:02:56 +0800 Subject: [PATCH 4/4] retry CI --- tests/lora/test_lora_checkpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lora/test_lora_checkpoints.py b/tests/lora/test_lora_checkpoints.py index 3d23944a3f8de..35ad7342944cd 100644 --- a/tests/lora/test_lora_checkpoints.py +++ b/tests/lora/test_lora_checkpoints.py @@ -28,7 +28,7 @@ def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files): embedding_padding_modules=embed_padding_modules) else: # For the baichuan7B model, load chatglm3-6b's LoRA, - # and the test should raise the following error + # and the test should raise the following error. expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501 with pytest.raises(ValueError, match=expected_error): LoRAModel.from_local_checkpoint(