Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ModulesToSave] add correct hook management for modules to save #755

Merged
merged 4 commits into from
Jul 27, 2023

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Jul 26, 2023

What does this PR do?

Fixes #602 and the solution has been found out by @BenjaminBossan
When loading the base model with accelerate, the old hooks were still attached to the new module, causing the ModulesToSaveWrapper module to call the previous forward method, leading to the gradients not being properly backpropagated to the right module.

A reproducible script shared by Benjamin:

import torch
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
from transformers import AutoTokenizer, AutoModelForSequenceClassification, set_seed
set_seed(123)

PRETRAIN = 'bigscience/bloomz-560m'
tokenizer = AutoTokenizer.from_pretrained(PRETRAIN)

load_in_4bit=True
device_map=None
should_resize_token_embeddings=False
should_prepare_model_for_kbit_training=True

model = AutoModelForSequenceClassification.from_pretrained(
    PRETRAIN,
    load_in_4bit=load_in_4bit,
    torch_dtype=torch.float32,
    device_map=device_map,
)
if should_prepare_model_for_kbit_training:
    model = prepare_model_for_kbit_training(model)

config = LoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_CLS,
)

peft_model = get_peft_model(model, config)
if should_resize_token_embeddings:
    model.resize_token_embeddings(len(tokenizer))

lm_head = peft_model.base_model.model.score
original_module = lm_head.original_module
modules_to_save = lm_head.modules_to_save.default

inputs = torch.randn((1024))
o1 = lm_head(inputs)
o1.mean().backward()

assert modules_to_save.weight.requires_grad is True
assert original_module.weight.grad is None
assert modules_to_save.weight.grad is not None

The fix is to create a fresh new copy of the previous hook to keep the same attributes, remove the old hook and attach that new hook to the module.

cc @BenjaminBossan @pacman100

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 26, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @BenjaminBossan for the fixes, LGTM! 🚀

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM too, thanks for the PR. Let's wait if the accelerate devs have any concerns before merging.

@require_torch_gpu
@pytest.mark.single_gpu_tests
@require_bitsandbytes
def test_modules_to_save_grad(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to throw in the device_map="auto" option too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. I was just thinking about the users reporting that the issue also occurs with device_map="auto" and without quantization, so we could cover that too. But it's not super important.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

modules_to_save incompatible with load_in_4bit / load_in_8bit ?
4 participants