-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ModulesToSave
] add correct hook management for modules to save
#755
[ModulesToSave
] add correct hook management for modules to save
#755
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @BenjaminBossan for the fixes, LGTM! 🚀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM too, thanks for the PR. Let's wait if the accelerate devs have any concerns before merging.
@require_torch_gpu | ||
@pytest.mark.single_gpu_tests | ||
@require_bitsandbytes | ||
def test_modules_to_save_grad(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to throw in the device_map="auto"
option too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
load_in_4bit
should automatically set device_map
to the correct value: https://github.com/huggingface/transformers/blob/a0042379269bea9182c1f87e6b2eee4ba4c8cce8/src/transformers/modeling_utils.py#L2318
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay. I was just thinking about the users reporting that the issue also occurs with device_map="auto"
and without quantization, so we could cover that too. But it's not super important.
What does this PR do?
Fixes #602 and the solution has been found out by @BenjaminBossan
When loading the base model with accelerate, the old hooks were still attached to the new module, causing the
ModulesToSaveWrapper
module to call the previous forward method, leading to the gradients not being properly backpropagated to the right module.A reproducible script shared by Benjamin:
The fix is to create a fresh new copy of the previous hook to keep the same attributes, remove the old hook and attach that new hook to the module.
cc @BenjaminBossan @pacman100