-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH Raise error when applying modules_to_save on tuner layer #2028
ENH Raise error when applying modules_to_save on tuner layer #2028
Conversation
Relates to huggingface#2027 Normally, when selecting the layers for fine-tuning, PEFT already ensures that the same layer is not targeted for both parameter-efficient fine-tuning (e.g. LoRA layer) and full fine-tuning (via modules_to_save), as that makes no sense. However, there is a loophole when the modules_to_save is applied ex post. This happens for instance when having a task type like sequence classification, where PEFT will automatically add the classfication head to modules_to_save for the user. This loophole is now closed by adding a check to ModulesToSaveWrapper that validates that the targeted layer is not a tuner layer. This does not fully resolve huggingface#2027 but will raise an early error in the future to avoid confusion. On top of this, the error message inside of ModulesToSaveWrapper.check_module has been slightly adjusted. Previously, the class name would be used, which can be confusing. E.g. for LoRA, the class name of the linear LoRA layer is just "Linear", which looks the same as nn.Linear. Therefore, the full name is now shown.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
As the test notes: > this may get fixed in a future PR, in which case this test can be removed
Fixes huggingface#2027 When using a transformers sequence classification model, target_modules="all-linear" should not wrap the classification head with LoRA. This is because it is already wrapped with ModulesToSave, i.e. it will be fully fine-tuned, which is the generally desired behavior. Before this bug fix, the classification head would be double-wrapped. With huggingface#2028, this now raises an error. With this PR, it is avoided completely. Still, keeping huggingface#2028 is good because it helps prevent other situations where double-wrapping might occur due to misconfiguration. Note that there is no fool-proof method to detect the classification head, we have to rely on transformers convention.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense. Thank you!
@@ -202,7 +202,15 @@ def check_module(self): | |||
# ModuleList, even though their forward methods cannot be called | |||
forbidden_classes = (torch.nn.ModuleDict, torch.nn.ModuleList, torch.nn.ParameterDict, torch.nn.ParameterList) | |||
if isinstance(self.original_module, forbidden_classes): | |||
cls_name = self.original_module.__class__.__name__ | |||
cls_name = self.original_module.__class__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why discard __name__
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is what I referred to above:
the error message inside of
ModulesToSaveWrapper.check_module
has been slightly adjusted. Previously, the class name would be used, which can be confusing. E.g. for LoRA, the class name of the linear LoRA layer is just"Linear"
, which looks the same asnn.Linear
. Therefore, the full name is now shown.
So the error message used to be:
'modules_to_save cannot be applied to modules of type Linear'
which is very confusing, as the user has no idea that "Linear" refers to PEFT's LoRA Linear layer. Now the message would state:
modules_to_save cannot be applied to modules of type <class 'peft.tuners.lora.layer.Linear'>"
…all-linear" (#2033) Fixes #2027 When using a transformers sequence classification model, target_modules="all-linear" should not wrap the classification head with LoRA. This is because it is already wrapped with ModulesToSave, i.e. it will be fully fine-tuned, which is the generally desired behavior. Before this bug fix, the classification head would be double-wrapped. With #2028, this now raises an error. With this PR, it is avoided completely. Still, keeping #2028 is good because it helps prevent other situations where double-wrapping might occur due to misconfiguration. Note that there is no fool-proof method to detect the classification head, we have to rely on transformers convention.
Relates to #2027
Normally, when selecting the layers for fine-tuning, PEFT already ensures that the same layer is not targeted for both parameter-efficient fine-tuning (e.g. LoRA layer) and full fine-tuning (via
modules_to_save
), as that makes no sense.However, there is a loophole when the
modules_to_save
is applied ex post. This happens for instance when having a task type like sequence classification, where PEFT will automatically add the classification head tomodules_to_save
for the user. This loophole is now closed by adding a check toModulesToSaveWrapper
that validates that the targeted layer is not a tuner layer.This does not fully resolve #2027 but will raise an early error in the future to avoid confusion. Edit: #2028 should fully resolve it, but having this PR would still be useful.
On top of this, the error message inside of
ModulesToSaveWrapper.check_module
has been slightly adjusted. Previously, the class name would be used, which can be confusing. E.g. for LoRA, the class name of the linear LoRA layer is just"Linear"
, which looks the same asnn.Linear
. Therefore, the full name is now shown.Moreover, a now obsolete test was removed.