From f3c7c6e5c1e1b543e1f006ef38b3fdab3a4fefaf Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 22 Aug 2024 17:10:39 +0200 Subject: [PATCH] ENH Raise error when applying modules_to_save on tuner layer (#2028) Relates to #2027 Normally, when selecting the layers for fine-tuning, PEFT already ensures that the same layer is not targeted for both parameter-efficient fine-tuning (e.g. LoRA layer) and full fine-tuning (via modules_to_save), as that makes no sense. However, there is a loophole when the modules_to_save is applied ex post. This happens for instance when having a task type like sequence classification, where PEFT will automatically add the classfication head to modules_to_save for the user. This loophole is now closed by adding a check to ModulesToSaveWrapper that validates that the targeted layer is not a tuner layer. This does not fully resolve #2027 but will raise an early error in the future to avoid confusion. On top of this, the error message inside of ModulesToSaveWrapper.check_module has been slightly adjusted. Previously, the class name would be used, which can be confusing. E.g. for LoRA, the class name of the linear LoRA layer is just "Linear", which looks the same as nn.Linear. Therefore, the full name is now shown. --- src/peft/utils/other.py | 10 +++++++++- tests/test_custom_models.py | 25 +------------------------ tests/test_other.py | 19 ++++++++++++++++++- 3 files changed, 28 insertions(+), 26 deletions(-) diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index dcbaabc60c..b3b46e3b5f 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -202,7 +202,15 @@ def check_module(self): # ModuleList, even though their forward methods cannot be called forbidden_classes = (torch.nn.ModuleDict, torch.nn.ModuleList, torch.nn.ParameterDict, torch.nn.ParameterList) if isinstance(self.original_module, forbidden_classes): - cls_name = self.original_module.__class__.__name__ + cls_name = self.original_module.__class__ + raise TypeError(f"modules_to_save cannot be applied to modules of type {cls_name}") + + # local import to avoid circular import + from peft.tuners.tuners_utils import BaseTunerLayer + + if isinstance(self.original_module, BaseTunerLayer): + # e.g. applying modules_to_save to a lora layer makes no sense + cls_name = self.original_module.__class__ raise TypeError(f"modules_to_save cannot be applied to modules of type {cls_name}") @property diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index c3f4d1368f..27f367a536 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -49,7 +49,7 @@ get_peft_model, ) from peft.tuners.tuners_utils import BaseTunerLayer -from peft.utils import ModulesToSaveWrapper, infer_device +from peft.utils import infer_device from .testing_common import PeftCommonTester from .testing_utils import get_state_dict, require_non_cpu @@ -1530,29 +1530,6 @@ def test_adapter_name_makes_no_difference(self, config0): assert torch.allclose(output_custom1, output_custom2) assert torch.allclose(output_default, output_custom1) - @parameterized.expand(["merge_and_unload", "unload"]) - def test_double_wrapping_merge_and_unload(self, method): - # see issue #1485 - from transformers import AutoModelForTokenClassification - - model = AutoModelForTokenClassification.from_pretrained("hf-internal-testing/tiny-random-RobertaModel") - config = LoraConfig(task_type="TOKEN_CLS", target_modules="all-linear") - model = get_peft_model(model, config) - - # first check that double-wrapping happened - # Note: this may get fixed in a future PR, in which case this test can be removed - assert isinstance(model.base_model.model.classifier, ModulesToSaveWrapper) - assert hasattr(model.base_model.model.classifier.original_module, "lora_A") - assert hasattr(model.base_model.model.classifier.modules_to_save.default, "lora_A") - - # after unloading, despite double wrapping, the classifier module should be a normal nn.Linear layer - if method == "merge_and_unload": - unloaded = model.merge_and_unload() - else: - unloaded = model.unload() - - assert isinstance(unloaded.classifier, nn.Linear) - def test_gpt2_dora_merge_and_unload(self): # see https://github.com/huggingface/peft/pull/1588#discussion_r1537914207 model = AutoModelForCausalLM.from_pretrained("gpt2") diff --git a/tests/test_other.py b/tests/test_other.py index bc1315a41b..04c67d3bf2 100644 --- a/tests/test_other.py +++ b/tests/test_other.py @@ -15,7 +15,7 @@ import pytest import torch from torch import nn -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification from peft import LoraConfig, get_peft_model @@ -76,6 +76,23 @@ def test_modules_to_save_targets_module_dict_raises(cls): get_peft_model(model=model, peft_config=peft_config) +def test_modules_to_save_targets_tuner_layer_raises(): + # See e.g. issue 2027 + # Prevent users from (accidentally) targeting the same layer both with a tuner and modules_to_save. Normally, PEFT + # will not target the same layer with both a tuner and ModulesToSaveWrapper. However, if modules_to_save is + # automatically inferred, e.g. when using AutoModelForSequenceClassification, the ModulesToSaveWrapper is applied ex + # post, which can lead to the double wrapping. + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + model = AutoModelForSequenceClassification.from_pretrained(model_id) + + # Note: target_modules="all-linear" would also work and is closer to the original issue, but let's explicitly target + # "score" here in case that "all-linear" will be fixed to no longer target the score layer. + peft_config = LoraConfig(target_modules=["score"], task_type="SEQ_CLS") + msg = "modules_to_save cannot be applied to modules of type" + with pytest.raises(TypeError, match=msg): + get_peft_model(model, peft_config) + + def test_get_peft_model_revision_warning(tmp_path): base_model_id = "peft-internal-testing/tiny-random-BertModel" base_revision = "v2.0.0"