Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training PEFT models with new tokens being added to the embedding layers and tokenizer #1147

Merged
merged 20 commits into from
Nov 29, 2023

Conversation

pacman100
Copy link
Contributor

@pacman100 pacman100 commented Nov 17, 2023

What does this PR do?

  1. Training PEFT models with new tokens being added to the embedding layers and tokenizer.
  2. Now, users can train loras targeting the embedding layers with the added new tokens. During saving, the embedding layers also need to be saved as the LoRAs were trained on the specific random initialization of the added tokens. To achieve this, they simply need to pass save_embedding_layers=True to the save_pretrained method. If not explicitly passed, we try our best to guess if the common embedding module names are present in target_modules of the config.
  3. Currently, save_pretrained doesn't save on main process and as such when working in multi GPU environment, it leads to broken checkpoints as many processes are simultaneously writing to it. This is fixed now with is_main_process argument which is inline with that of transformers.
  4. Many lines were being duplicated by create_or_update_model_card when checkpoints were saved every n steps/epochs. Fixed this so that the created README is clean.
  5. End to end example and tests added

Todo:

  • Add tests
  • Add an end to end example

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 17, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks!

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this feature, it should help many users with this type of issue.

I have a couple of comments, please take a look. Also, some more general concerns:

  1. base_model_layers_to_save only makes sense when the corresponding layer is a target_module, right? IIUC, when a user adds a module to base_model_layers_to_save that is not a target_module, right now this would just be silently ignored. I wonder if we shouldn't perform a check to ensure that each entry in base_model_layers_to_save has found a match and raise an error otherwise. I would imagine that it can be very frustrating for users when they save the model after a long time of training and only later they find out that one part was silently ignored. This requirement should also be documented.
  2. This feature doesn't work for prompt learning because we rely on base_layer, right? I think we can make it work for prompt learning, though, by just checking the name without base_layer. WDYT?
  3. Let's add a unit test for this feature.

src/peft/peft_model.py Outdated Show resolved Hide resolved
src/peft/utils/save_and_load.py Outdated Show resolved Hide resolved
src/peft/utils/save_and_load.py Outdated Show resolved Hide resolved
src/peft/utils/save_and_load.py Outdated Show resolved Hide resolved
1. Add `is_embedding_layer_resized` parameter to `save_pretrained`
2. Fix the deduplication in README when adding PEFT details.
3. `save_pretrained` should only save the model when `is_main_process=True` which is one of the parameters of `save_pretrained`.
@pacman100
Copy link
Contributor Author

pacman100 commented Nov 28, 2023

Hello @BenjaminBossan, I have changed the code quite a bit based on the feedback [here] (https://huggingface.slack.com/archives/C04L3MWLE6B/p1700238526932559?thread_ts=1699626507.649479&cid=C04L3MWLE6B)

Would be great if the base_model_layers_to_save could be automatically handled based on what has change or was trainable

As there is only need to save the embedding layers when they are resized, I have changed the code to achieve it.

Also fixed the below issues:

  1. Currently, save_pretrained doesn't save on main process and as such when working in multi GPU environment, it leads to broken checkpoints as many processes are simultaneously writing to it. This is fixed now with is_main_process argument which is inline with that of transformers.
  2. Many lines were being duplicated by create_or_update_model_card when checkpoints were saved e very n steps/epochs. Fixed this so that the created README is clean.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for reworking the PR to make it more user-friendly, as well as for adding the is_main_process argument and for fixing the model card. Also well done adding a notebook to illustrate the new feature.

The logic for determining the embedding layer, even if only a few lines of code, was actually more intricate than I thought and took me some time to understand, but it looks correct to me.

I left a couple of comments. I'd say none of them are hard blockers, but please do take a look. I think my main concern would be the comment about when to automatically toggle the saving of embedding layers.

src/peft/peft_model.py Outdated Show resolved Hide resolved
src/peft/peft_model.py Outdated Show resolved Hide resolved
src/peft/peft_model.py Show resolved Hide resolved
src/peft/utils/save_and_load.py Show resolved Hide resolved
tests/test_custom_models.py Outdated Show resolved Hide resolved
tests/test_custom_models.py Show resolved Hide resolved
tests/test_custom_models.py Outdated Show resolved Hide resolved
src/peft/utils/save_and_load.py Show resolved Hide resolved
Co-Authored-By: Benjamin Bossan <[email protected]>
@pacman100
Copy link
Contributor Author

Hello @BenjaminBossan, I've addressed all the comments, Thank you!

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fantastic, great implementation, examples, docs, and tests.

I found a typo in the test name, apart from that this can be merged.

"2. Finetuning on a specific language wherein language spoecific tokens are added, e.g., korean tokens being added to vocabulary for finetuning LLM on Korean datasets.\n",
"3. Instruction finetuning to return outputs in certain format to enable agent behaviour new tokens such as `<|FUNCTIONS|>`, `<|BROWSE|>`, `<|TEXT2IMAGE|>`, `<|ASR|>`, `<|TTS|>`, `<|GENERATECODE|>`, `<|RAG|>`.\n",
"\n",
"In such cases, you add the Embedding modules to the LORA `target_modules`. PEFT will take care of saving the embedding layers with the new added tokens along with the adapter weights that were trained on the specific initialization of the embeddings weights of the added tokens."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great explanation.

tests/test_custom_models.py Outdated Show resolved Hide resolved
tests/test_custom_models.py Outdated Show resolved Hide resolved
@pacman100 pacman100 merged commit 8298f1a into main Nov 29, 2023
14 checks passed
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Nov 30, 2023
…ers and tokenizer (huggingface#1147)

* add support for saving base layers weights along with adapter weights

* Update save_and_load.py

* Add an example showing the usage of the added feature

* refactor the functionality

* fix

* refactoring code

1. Add `is_embedding_layer_resized` parameter to `save_pretrained`
2. Fix the deduplication in README when adding PEFT details.
3. `save_pretrained` should only save the model when `is_main_process=True` which is one of the parameters of `save_pretrained`.

* update example

* fix the model card

* fix model card

* 😅

* fix model card

* automate setting `is_embedding_layer_resized`

* nits

* Update peft_lora_clm_with_additional_tokens.ipynb

* add test

* fix tests

* maybe fixes the issue?

* address comments

Co-Authored-By: Benjamin Bossan <[email protected]>

* Apply suggestions from code review

Co-authored-by: Benjamin Bossan <[email protected]>

---------

Co-authored-by: Benjamin Bossan <[email protected]>
@pacman100 pacman100 deleted the smangrul/add-support-for-additional-tokens-training branch December 14, 2023 12:07
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Jan 9, 2024
Resolves huggingface#1300

Sourab added the feature to store the embedding layers alongside the
adapter in huggingface#1147. This PR adds an entry to the documentation to explain
the new feature.
BenjaminBossan added a commit that referenced this pull request Jan 12, 2024
Resolves #1300

Sourab added the feature to store the embedding layers alongside the
adapter in #1147. This PR adds an entry to the documentation to explain
the new feature.

---------

Co-authored-by: Steven Liu <[email protected]>
@debraj135
Copy link

I have a question about this pr. Is it that before this pr was merged it wasn't possible to both do both lora fine-tuning as well as save the lora adapters for "embed_tokens", "lm_head" ?

Why I ask is because i see that when I run the snippet included in this comment here, I see a massive difference in the size of the saved adapter files between 0.6.2 and 0.7.0

@amitagh
Copy link

amitagh commented Apr 11, 2024

I still see this issue of size increase with peft 0.10.0. Is the size increase due to an issue or is it fine? Is it going to be fixed or will the behavior remain with increase in size?
Have seen this issue with both Axolotl (axolotl-ai-cloud/axolotl#1511) and LLamafactory(hiyouga/LLaMA-Factory#3137).

@BenjaminBossan
Copy link
Member

I still see this issue of size increase with peft 0.10.0. Is the size increase due to an issue or is it fine? Is it going to be fixed or will the behavior remain with increase in size?

Please check out Sourab's comment, does it make clear why you see the size increase and what you can do about it? If not, could you please clarify what's still unclear to you?

@amitagh
Copy link

amitagh commented Apr 11, 2024

Thanks Benjamin. I understand that emb layers (embed_tokens & lm_head) will be saved by default, but does it mean the size inflation is obvious? Earlier too the modules used to be saved explicitly but it was done with lesser size. For cases where token vocab is not changing they too will now have to leave with bigger size? Or is it recommended not to save emb layers if there is no change in token vocab?

@BenjaminBossan
Copy link
Member

I understand that emb layers (embed_tokens & lm_head) will be saved by default

For cases where token vocab is not changing they too will now have to leave with bigger size?

No, the embedding is not saved by default. The default setting is "auto", which means PEFT tries to detect if saving the embedding layer is needed or not. If you find that this automatic detection leads to false positives, let us know and we can try to fix that.

Note that you can always pass False to save_embedding_layers to prevent this. If you use a library that uses PEFT under the hood but does not expose this parameter, open an issue there and ask them to do so.

Or is it recommended not to save emb layers if there is no change in token vocab?

Indeed, when the vocab is not changed (no fine-tuning, no additional tokens being added), it does not need to be included in the saved file.

but does it mean the size inflation is obvious?

Sorry, I don't understand this question, could you please clarify?

@amitagh
Copy link

amitagh commented Apr 11, 2024

I am Continuous pretraining with Lora a pretrainined LLM for non-English language without extending the tokenizer as it already has the needed tokens. For non-English lang it is recommended to add emb layers. So in such case since i need emb layers which peft will save as save_emb_layers flag is true it will lead to bigger size adaptor?

@BenjaminBossan
Copy link
Member

I am Continuous pretraining with Lora a pretrainined LLM for non-English language without extending the tokenizer as it already has the needed tokens. For non-English lang it is recommended to add emb layers.

You have to be more precise here: Are you fully fine-tuning the embedding layer -- in that case, it has to be saved, leading to a much larger file size. Or are you training a LoRA weight on top of the embedding layer -- in that case, only those LoRA weights must be saved, not the whole embedding layer, leading to only a slight increase of the file size (compared to not training the embedding at all).

@amitagh
Copy link

amitagh commented Apr 11, 2024

Thank Benjamin. I am doing pretraining with Lora on existing pretrained model Gemma-7B for augmenting non-English corpus with a big corpus. I know Full FT would have been good but cost would have been higher.
Hence doing pretraining with lora.
So yes Lora adaptor will be generated and will be saved separately. But since adaptor size is increasing significantly i guess it will slow down training and finally when i merged the adaptor i hope it doesnt increase the merged model size significantly compared to the input base model.

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Apr 12, 2024

@amitagh Since this discussion is only tangentially related to the PR, could you pleas open a thread in the PEFT discussions and describe your problem there? Please include as many details as you can share, like the LoraConfig, training code, size of the adapter file etc.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants