Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_is_peft_model update to recognise peft submodules, allowing training quantised models with peft submodules #30884

Closed
wants to merge 9 commits into from

Conversation

ambroser53
Copy link

What does this PR do?

Don't necessarily have peft models as the top-level wrapper for models, especially when working with custom built multi-modal models. For example:

model = AutoModelForVision2Seq.from_pretrained(
  args.pretrained_ckpt,
  torch_dtype=compute_dtype,
  quantization_config=BitsAndBytesConfig(
      load_in_4bit=bits == 4,
      load_in_8bit=bits == 8,
      llm_int8_threshold=6.0,
      int8_quant_skip_modules=int8_quant_skip_modules,
      llm_int8_has_fp16_weight=False,
      bnb_4bit_compute_dtype=compute_dtype,
      bnb_4bit_use_double_quant=True,
      bnb_4bit_quant_type='nf4'  # {'fp4', 'nf4'}
  ) if bits < 16 else None,
  attn_implementation=args.attn_implementation,
)

if (args.use_lora and not resume_from_checkpoint and not ft_checkpoint_dir):
  target_modules = get_target_modules(model.model.text_model, args, bits)
  peft_config = LoraConfig(
      target_modules=target_modules,
      inference_mode=args.inference_mode,
      r=args.lora_r,
      lora_alpha=args.lora_alpha,
      lora_dropout=args.lora_dropout,
      use_dora=args.use_dora
  )
  model.model.text_model = get_peft_model(model.model.text_model, peft_config)

  if args.vit_train:
      target_modules = get_target_modules(model.model.vision_model, args, args.vit_bits, vit=True)
      peft_config = LoraConfig(
          target_modules=target_modules,
          inference_mode=args.inference_mode,
          r=args.vit_lora_r,
          lora_alpha=args.vit_lora_alpha,
          lora_dropout=args.lora_dropout,
          use_dora=args.use_dora_vit
      )
      model.model.vision_model = get_peft_model(model.model.vision_model, peft_config)

  if args.lora_abstractor:
      target_modules = get_target_modules(model.model.connector, args, args.bits)
      peft_config = LoraConfig(
          target_modules=target_modules,
          inference_mode=args.inference_mode,
          r=args.lora_r,
          lora_alpha=args.lora_alpha,
          lora_dropout=args.lora_dropout,
          use_dora=args.use_dora
      )
      model.model.connector = get_peft_model(model.model.connector, peft_config)

This allows the hf trainer to recognise such models as still being peft models and thereby allow quantised training (QLoRA).

Fixes #30878

Before submitting

Who can review?

@younesbelkada

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the submodules support for PEFT + Trainer ! Left one suggestion - what do you think?

src/transformers/trainer.py Outdated Show resolved Hide resolved
@younesbelkada younesbelkada requested a review from amyeroberts May 20, 2024 09:15
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding! Could you add a test with a dummy model that has peft submodules that behaves correctly with this change but fails on main?

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ambroser53 !
For the styling checks, can you try to run pip install -U ".[quality]" and re-run make fixup ?

@ambroser53
Copy link
Author

@younesbelkada I've just done as you asked and it ran successfully but I have no working tree changes to commit.

@younesbelkada
Copy link
Contributor

Thanks ! Hmm I think something is off with our CI currently, let's wait for #30932 being merged first

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating on this.

I'm a bit confused about the intended behaviour here from the tests

tests/trainer/test_trainer.py Outdated Show resolved Hide resolved
Comment on lines 1023 to 1024
with self.assertRaises(ValueError):
_ = Trainer(tiny_model, args, train_dataset=train_dataset) # noqa
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait - I'm confused 😅

My understanding of this PR was that it's meant to allow training on quantized models which have a PEFT submodule.

This tests doesn't obviously quantize (I might be missing something here). Is it meant to fail if the model isn't quantized?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does quantise the model. See load_in_4bit=True on line 1002.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry, I missed that. In this case, why is this error being thrown? Shouldn't this model now be trainable (it's quantized and has a peft submodule).

@huggingface huggingface deleted a comment from github-actions bot Jun 17, 2024
@ambroser53
Copy link
Author

Is there any update on this? I'm still running the local version for training my model that uses this. Still don't understand why it couldn't be merged in the first place.

@amyeroberts
Copy link
Collaborator

@ambroser53 It's still open as there's two outstanding comments/questions on the tests which were marked as resolved, but were not. In the first case, the comment needs to be updated, in the second, could you explain the intended behaviour?

ambroser53 and others added 2 commits June 25, 2024 14:20
…l architectures that vary their multi-modal set-up. This should work with a wider range of models out the box.

- Updated test to have sample subclass with save_pretrained support.
@ambroser53
Copy link
Author

@amyeroberts I see where the confusion was here. I was under the assumption the code was fine and we were waiting on a fix for workflow checks (i.e. styling). I have changed the comments in the test to hopefully be more clear and expanded the it to run trainer.train() instead of just initialising the Trainer object and included a simple custom subclass with save_pretrained to allow it to work in that case to show the full use-case with training.

Intended Behaviour: model is trainable when quantised with a peft sub-module.

Workflow tests are still failing so if there is actually anything I need to do there let me know.

@amyeroberts
Copy link
Collaborator

@ambroser53 Great. For some of the failing tests re passing trust_remote_code for datasets, there's been upstream fixes. Could you rebase on main to include these? I think this should also fix a lot of the quality checks

@amyeroberts
Copy link
Collaborator

@ambroser53 Apologies for the continued delay. Could you try rebasing again? This should resolve the timeout errors we're having on the CI

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating!

As mentioned in my PR - needing to override save_pretrained in the test is an indication we should update the standard save_pretrained to account for models with peft submodules

Comment on lines +1005 to +1007
# Due to the way the Trainer is implemented we must be able to save the model with 'save_pretrained'
# Therefore to use peft submodules you must specify your own 'save_pretrained' method
# This example subclass allows for the saving of any and all submodules that are of type PeftModel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't right as it means we're no longer testing the transformers functionality. Instead, this PR should update save_pretrained to enable saving of models with PEFT submodules

Not sure who's best to tag here @BenjaminBossan @SunMarc

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest splitting that issue off into a separate PR. IIUC, this PR is originally about avoiding a false positive ValueError that the model is presumably not trainable. I think we should focus on that here.

Now this test touches on another issue, namely that save_pretrained of a PEFT model only saves the trainable adapter part, but of the PEFT model itself is just a submodule, save_pretrained acts like it would on a normal transformers model and saves the whole checkpoint (unless I misunderstand the intent).

IMO, first of all, this should not be handled here in the test. Instead, I'd just add a comment that says something along the lines of: "Note that save_pretrained saves the whole model here, not just the PEFT adater". The test should still pass, right?

Regarding the question of what the right thing to do is if the model has PEFT submodule(s): IMO this is not quite clear and changing save_pretrained for all transformers users to only save the PEFT submodules could be breaking existing code. Probably 90% of users in this situation would only want to save the PEFT adapters, but they would already require special handling to load these submodules correctly (right?) so maybe it's fine that they also need special handling to save them? IMHO, If we want to make it possible to only save the PEFT adapaters of the submodules, it should be configurable, with the default being the status quo.

if _is_peft_model(model):
return True
elif is_peft_available():
classes_to_check = (PeftModel,) if is_peft_available() else ()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
classes_to_check = (PeftModel,) if is_peft_available() else ()
classes_to_check = (PeftModel,)

is_peft_available() is already checked in the line above, so no need to check again.

Also, how about changign this slightly to:

    if not is_peft_available():
        return False
    ...

That way, the early returns are all handled in one place and we also save one level of indentation for the rest of the function body.

Comment on lines +271 to +274
for submodule in model.modules():
if isinstance(submodule, classes_to_check):
return True
return False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be condensed to:

return any(isinstance(submodule, classes_to_check) for submodule in model.modules())

Comment on lines +1005 to +1007
# Due to the way the Trainer is implemented we must be able to save the model with 'save_pretrained'
# Therefore to use peft submodules you must specify your own 'save_pretrained' method
# This example subclass allows for the saving of any and all submodules that are of type PeftModel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest splitting that issue off into a separate PR. IIUC, this PR is originally about avoiding a false positive ValueError that the model is presumably not trainable. I think we should focus on that here.

Now this test touches on another issue, namely that save_pretrained of a PEFT model only saves the trainable adapter part, but of the PEFT model itself is just a submodule, save_pretrained acts like it would on a normal transformers model and saves the whole checkpoint (unless I misunderstand the intent).

IMO, first of all, this should not be handled here in the test. Instead, I'd just add a comment that says something along the lines of: "Note that save_pretrained saves the whole model here, not just the PEFT adater". The test should still pass, right?

Regarding the question of what the right thing to do is if the model has PEFT submodule(s): IMO this is not quite clear and changing save_pretrained for all transformers users to only save the PEFT submodules could be breaking existing code. Probably 90% of users in this situation would only want to save the PEFT adapters, but they would already require special handling to load these submodules correctly (right?) so maybe it's fine that they also need special handling to save them? IMHO, If we want to make it possible to only save the PEFT adapaters of the submodules, it should be configurable, with the default being the status quo.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@ambroser53 ambroser53 closed this Aug 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Have _is_peft_model check if there's any peft submodule/Allow quantised training
5 participants