diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 4aba89fa23..3a09200217 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -38,7 +38,7 @@ from transformers.modeling_outputs import QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput from transformers.utils import PushToHubMixin -from peft.utils.constants import DUMMY_MODEL_CONFIG +from peft.utils.constants import DUMMY_MODEL_CONFIG, PEFT_TYPE_TO_PREFIX_MAPPING from . import __version__ from .config import PeftConfig @@ -1185,6 +1185,19 @@ def load_adapter( ignore_mismatched_sizes=ignore_mismatched_sizes, low_cpu_mem_usage=low_cpu_mem_usage, ) + + tuner = self.peft_config[adapter_name].peft_type + tuner_prefix = PEFT_TYPE_TO_PREFIX_MAPPING.get(tuner, "") + adapter_missing_keys = [] + + # Filter missing keys specific to the current adapter and tuner prefix. + for key in load_result.missing_keys: + if tuner_prefix in key and adapter_name in key: + adapter_missing_keys.append(key) + + load_result.missing_keys.clear() + load_result.missing_keys.extend(adapter_missing_keys) + if ( (getattr(self, "hf_device_map", None) is not None) and (len(set(self.hf_device_map.values()).intersection({"cpu", "disk"})) > 0) diff --git a/src/peft/utils/constants.py b/src/peft/utils/constants.py index e7c305926a..4365c878fe 100644 --- a/src/peft/utils/constants.py +++ b/src/peft/utils/constants.py @@ -15,6 +15,8 @@ import torch from transformers import BloomPreTrainedModel +from .peft_types import PeftType + # needed for prefix-tuning of bloom model def bloom_model_postprocess_past_key_value(past_key_values): @@ -284,6 +286,22 @@ def starcoder_model_postprocess_past_key_value(past_key_values): "qwen2": ["q_proj", "v_proj"], } +PEFT_TYPE_TO_PREFIX_MAPPING = { + PeftType.IA3: "ia3_", + PeftType.LORA: "lora_", + PeftType.ADALORA: "lora_", + PeftType.LOHA: "hada_", + PeftType.LOKR: "lokr_", + PeftType.OFT: "oft_", + PeftType.POLY: "poly_", + PeftType.BOFT: "boft_", + PeftType.LN_TUNING: "ln_tuning_", + PeftType.VERA: "vera_lambda_", + PeftType.FOURIERFT: "fourierft_", + PeftType.HRA: "hra_", + PeftType.VBLORA: "vblora_", +} + WEIGHTS_NAME = "adapter_model.bin" SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors" CONFIG_NAME = "adapter_config.json" diff --git a/src/peft/utils/save_and_load.py b/src/peft/utils/save_and_load.py index d407d79fd1..5b40b4314c 100644 --- a/src/peft/utils/save_and_load.py +++ b/src/peft/utils/save_and_load.py @@ -24,6 +24,7 @@ from packaging import version from safetensors.torch import load_file as safe_load_file +from .constants import PEFT_TYPE_TO_PREFIX_MAPPING from .other import ( EMBEDDING_LAYER_NAMES, SAFETENSORS_WEIGHTS_NAME, @@ -357,21 +358,7 @@ def set_peft_model_state_dict( PeftType.VBLORA, ): peft_model_state_dict = {} - parameter_prefix = { - PeftType.IA3: "ia3_", - PeftType.LORA: "lora_", - PeftType.ADALORA: "lora_", - PeftType.LOHA: "hada_", - PeftType.LOKR: "lokr_", - PeftType.OFT: "oft_", - PeftType.POLY: "poly_", - PeftType.BOFT: "boft_", - PeftType.LN_TUNING: "ln_tuning_", - PeftType.VERA: "vera_lambda_", - PeftType.FOURIERFT: "fourierft_", - PeftType.HRA: "hra_", - PeftType.VBLORA: "vblora_", - }[config.peft_type] + parameter_prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type] if config.peft_type == PeftType.VBLORA and config.save_only_topk_weights: num_vectors, _ = model.vblora_vector_bank[adapter_name].shape state_dict_keys = list(state_dict.keys()) diff --git a/tests/testing_common.py b/tests/testing_common.py index b4cf2ffb56..fe354edde2 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -532,8 +532,16 @@ def _test_load_multiple_adapters(self, model_id, config_cls, config_kwargs): model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) model = PeftModel.from_pretrained(model, tmp_dirname, torch_device=self.torch_device) - model.load_adapter(tmp_dirname, adapter_name="other") - model.load_adapter(tmp_dirname, adapter_name="yet-another") + + load_result1 = model.load_adapter(tmp_dirname, adapter_name="other") + load_result2 = model.load_adapter(tmp_dirname, adapter_name="yet-another") + + # VBLoRA uses a shared "vblora_vector_bank" across all layers, causing it to appear + # in the missing keys list, which leads to failed test cases. So + # skipping the missing keys check for VBLoRA. + if config.peft_type != "VBLORA": + assert load_result1.missing_keys == [] + assert load_result2.missing_keys == [] def _test_merge_layers_fp16(self, model_id, config_cls, config_kwargs): if config_cls not in (LoraConfig, IA3Config, AdaLoraConfig, LoHaConfig, LoKrConfig, VBLoRAConfig):