Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Vera (Vector Based Random Matrix Adaption) #2 #1564

Merged
merged 29 commits into from
Apr 19, 2024

Conversation

BenjaminBossan
Copy link
Member

@BenjaminBossan BenjaminBossan commented Mar 14, 2024

Continuation of #1039.

Should now be 95% on par with that PR, with some minor changes on my part + resolving merge conflicts.

Examples and docs have not been included yet.

TODOS:

  • Add documentation
  • Add examples

https://arxiv.org/abs/2310.11454

Notable changes vis-à-vis 1039:

  • Some refactors around how the initialization of VeraModel proceeds, should be more straightforward now.
  • Add tests around saving and loading, which needs some special considerations for VeRA.
  • Fixed some issues with multiple adapters, requires more strictness (e.g. not allowing multiple different prng keys on the same model).
  • projection_prng_key now has a valid default value (0) in the config.
  • Removed support for Embedding to reduce complexity: Supporting Embedding layers with VeRA makes very little sense because its shape is always different from the linear layers' shapes. Therefore, they cannot share the vera_A and vera_B matrices, resulting in an error. The only conceivable way to support Embedding layers would be to only target that layer (and possibly the output layer if it shares the weight), but that more or less defeats the purpose of using VeRA. We may revisit support for Embeddings in the future, maybe if we can enable vera_A and vera_B to be of different shapes. Until then, let's support the most common use cases and simplify our lives.

Should now be 95% on par with huggingface#1039, with some minor changes on my part
+ resolving merge conflicts.

Examples have not been included yet.
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

BenjaminBossan and others added 6 commits March 14, 2024 18:17
* changes to support fsdp+qlora and dsz3+qlora

* address comments

* add example and start docs

* quality

* deepspeed fixes

* dsz3+qlora docs

* section link fix

* add fsdp+qlora docs

* Apply suggestions from code review

Co-authored-by: Benjamin Bossan <[email protected]>
Co-authored-by: Younes Belkada <[email protected]>

* address comments

---------

Co-authored-by: Benjamin Bossan <[email protected]>
Co-authored-by: Younes Belkada <[email protected]>
@BenjaminBossan BenjaminBossan changed the title [WIP] Initial commit [WIP] Adds Vera (Vector Based Random Matrix Adaption) #2 Mar 15, 2024
Needed to update hf-doc-builder
Supporting Embedding layers with VeRA makes very little sense because
its shape is always different from the linear layers' shapes. Therefore,
they cannot share the vera_A and vera_B matrices, resulting in an error.
The only conceivable way to support Embedding layers would be to only
target that layer (and possibly the output layer if it shares the
weight), but that more or less defeats the purpose of using VeRA.

We may revisit support for Embeddings in the future, maybe if we can
enable vera_A and vera_B to be of different shapes. Until then, let's
support the most common use cases and simplify our lives.
It was annoying that the default value was invalid and would raise an
error.
Same as for LoRA and IA3, these Deberta tests fail for some reason.
@BenjaminBossan
Copy link
Member Author

To ensure that the vera_A and vera_B weights are shared (but not other tensors), I added some tests that check their corresponding data_ptr()s.

Moreover, I wrote a small script to check the amount of memory taken by the model. For this, I used a very high rank of 10000, so that vera_A and vera_B should be quite large. Then I compared the GPU memory taken for a model with a single layer having a VeRA adapter vs a model with many layers having a VeRA adapter. We should expect that both should take roughly the same memory, since most parameters are shared.

Here is the script:

from transformers import AutoModelForCausalLM
from peft import get_peft_model, VeraConfig, LoraConfig
from peft.tuners.vera import VeraLayer
from peft.tuners.lora import LoraLayer
import gc
import torch

RANK = 10000
model_id = "facebook/opt-125m"

config_cls = VeraConfig
layer_cls = VeraLayer

def get_gpu_memory():
    torch.cuda.synchronize()  # Wait for all kernels to finish
    gpu_info = {
        'allocated': f"{torch.cuda.memory_allocated(0) / 2**30:.4f}GB",
        'reserved': f"{torch.cuda.memory_reserved(0) / 2**30:.4f}GB",
    }
    print(gpu_info)

print("before loading the base model")
get_gpu_memory()

model = AutoModelForCausalLM.from_pretrained(model_id).cuda()
print("after loading the model")
get_gpu_memory()

config = config_cls(task_type="CAUSAL_LM", target_modules=["model.decoder.layers.0.self_attn.k_proj"], r=RANK)
model = get_peft_model(model, config)
num_vera_layers = len([m for m in model.modules() if isinstance(m, layer_cls)])
print(f"after adding {num_vera_layers} adapted layers with rank {RANK}")
get_gpu_memory()

del model
torch.cuda.empty_cache()
gc.collect()

print("after resetting")
get_gpu_memory()

model = AutoModelForCausalLM.from_pretrained(model_id).cuda()
print("after loading the base model")
get_gpu_memory()

config = config_cls(task_type="CAUSAL_LM", target_modules=["v_proj", "q_proj"], r=10000)
model = get_peft_model(model, config)
num_vera_layers = len([m for m in model.modules() if isinstance(m, layer_cls)])
print(f"after adding {num_vera_layers} adapted layers with rank {RANK}")
get_gpu_memory()

For VeRA, the results are:

before loading the base model
{'allocated': '0.0000GB', 'reserved': '0.0000GB'}
after loading the model
{'allocated': '0.4677GB', 'reserved': '0.5176GB'}
after adding 1 adapted layers with rank 10000
{'allocated': '0.5264GB', 'reserved': '0.5762GB'}
after resetting
{'allocated': '0.0000GB', 'reserved': '0.0000GB'}
after loading the base model
{'allocated': '0.4677GB', 'reserved': '0.5176GB'}
after adding 24 adapted layers with rank 10000
{'allocated': '0.5273GB', 'reserved': '0.5762GB'}

As we can see, when adapting 24 layers vs 1 layer, the memory used is almost identical. We expect a small increase because vera_lambda_b and vera_lambda_d are not shared, so this is in line with our expectations.

As a sanity check, if we do the same with LoRA instead of VeRA, we see a big increase in memory used:

before loading the base model
{'allocated': '0.0000GB', 'reserved': '0.0000GB'}
after loading the model
{'allocated': '0.4677GB', 'reserved': '0.5176GB'}
after adding 1 adapted layers with rank 10000
{'allocated': '0.5263GB', 'reserved': '0.5762GB'}
after resetting
{'allocated': '0.0000GB', 'reserved': '0.0000GB'}
after loading the base model
{'allocated': '0.4677GB', 'reserved': '0.5176GB'}
after adding 24 adapted layers with rank 10000
{'allocated': '1.8740GB', 'reserved': '1.9238GB'}

All if this is a strong indicator to me that the memory sharing actually works. If anyone has ideas for more tests, let me know.

@BenjaminBossan
Copy link
Member Author

@dkopi @vvvm23 I think I'm pretty much finished with the implementation itself, docs and examples are yet to come. Still, if you have time, I'd be happy with a review or if you can run some tests to see if the implementation performs as expected. The changes compared to the original PR are documented above, the core VeRA computation hasn't been changed, though.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@vvvm23
Copy link
Contributor

vvvm23 commented Mar 24, 2024

Hi @BenjaminBossan, I can do a review some time this week.

Copy link
Contributor

@dkopi dkopi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good 👌

@BenjaminBossan
Copy link
Member Author

@vvvm23 Did you have time to take a look?

Copy link
Contributor

@vvvm23 vvvm23 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me, few small nitpicks. sorry for the delay on this!

adapter_name (`str`):
The adapter name.
"""
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] why not raise NotImplementedError? Avoid silent failures if something incorrectly calls the hook.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing is a valid outcome here, if we raised here, all non-VeRA adapters would suddenly error ;)

pattern is not in the common layers pattern.
"""

r: int = field(default=8, metadata={"help": "Vera attention dimension"})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps we should increase the default value? 8 is rather small for VeRA (paper used 256-1024 for their experiments)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, makes sense, I'll go with 256.

},
)
vera_dropout: float = field(default=0.0, metadata={"help": "Vera dropout"})
d_initial: float = field(default=1.0, metadata={"help": "Initial init value for d vector."})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.1 may be a better default value, see Table 6 in the paper

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, makes sense.

Comment on lines 154 to 155
if isinstance(module, Conv1D): # TODO: feels fragile, thoughts?
module_shape = module_shape[::-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this TODO? I feel this behaviour is actually fine, the semantics of Conv1D are unlikely to change

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

- better default for r
- better default for d_initial
- remove unnecessary comment
Copy link
Member Author

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback, Alex, your comments should be addressed now.

adapter_name (`str`):
The adapter name.
"""
pass
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing is a valid outcome here, if we raised here, all non-VeRA adapters would suddenly error ;)

pattern is not in the common layers pattern.
"""

r: int = field(default=8, metadata={"help": "Vera attention dimension"})
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, makes sense, I'll go with 256.

},
)
vera_dropout: float = field(default=0.0, metadata={"help": "Vera dropout"})
d_initial: float = field(default=1.0, metadata={"help": "Initial init value for d vector."})
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, makes sense.

Comment on lines 154 to 155
if isinstance(module, Conv1D): # TODO: feels fragile, thoughts?
module_shape = module_shape[::-1]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@BenjaminBossan BenjaminBossan marked this pull request as ready for review April 15, 2024 11:51
@BenjaminBossan BenjaminBossan changed the title [WIP] Adds Vera (Vector Based Random Matrix Adaption) #2 Adds Vera (Vector Based Random Matrix Adaption) #2 Apr 15, 2024
Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @BenjaminBossan for all the work on Vera continuing the efforts of @vvvm23, all looks great with examples, documentation and tests! 🔥🚀✨

It would be great to add @vvvm23 and @dkopi as co-authors for all their guidance and work!

Left a minor nit.

>>> import transformers
>>> from peft import VeraConfig, PeftModel, get_peft_model

>>> target_modules = ["q_proj", "k_proj", "v_proj", "out_proj", "fc_in", "fc_out", "wte"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think all the target modules have same shape and this also includes embedding layer.

A few models that work with LoRA don't work with VeRA (yet) because the
weight shapes of the target layers are not identical.
@BenjaminBossan
Copy link
Member Author

@pacman100 Thanks for the feedback, indeed I hadn't checked the docstring example. It is now changed to a working model.

Your comment also prompted me to take a look at the models that are pre-configured in TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING. This was just a copy of the the LoRA settings. Unfortunately, not all models work, I had to exclude some popular ones like Mistral, Mixtral, Phi, and gemma. The issue is again that the shapes of the target layer weights can differ. Hopefully, we can add the feature of supporting multiple different weight shapes in the future.

It would be great to add @vvvm23 and @dkopi as co-authors for all their guidance and work!

Yes, that was indeed my plan. @vvvm23 @dkopi could you please let me know how you want to be added as co-authors?

@dkopi
Copy link
Contributor

dkopi commented Apr 18, 2024

@BenjaminBossan You can add:
Co-authored-by: Dawid <[email protected]>
Thanks!

@vvvm23
Copy link
Contributor

vvvm23 commented Apr 18, 2024

Likewise, you can add
Co-authored-by: Alex McKinney <[email protected]>

Thanks @BenjaminBossan for bringing this PR to completion!

@BenjaminBossan BenjaminBossan merged commit 5a4b9ca into huggingface:main Apr 19, 2024
14 checks passed
@BenjaminBossan BenjaminBossan deleted the add-vera-2 branch April 19, 2024 08:56
@BenjaminBossan
Copy link
Member Author

Done 🎉

Thanks again so much @vvvm23 for doing the majority of the work and @dkopi for your constant feedback.

Let's hope that VeRA gains traction in the community. For the future, I'll add this list of improvements for VeRA that have yet to be implemented (contributions are welcome):

  1. Make VeRA work with different weight shapes. This is IMO the biggest limitation right now. Most straightforward way would be to have one pair of fixed vera_A/vera_B weights per target weight shape.
    There are cases where the shapes are the same when transposed, so (4096, 1024) vs (1024, 4096) (down and up projections for instance) - I wonder if we can use the same weights here and just transpose them.
  2. If this is done, implement VeRA for more layer types than just Linear. Right now, supporting, for instance, Embedding, makes little sense, as it almost always has a different shape than Linear.
  3. Support quantized weights, most notably bnb.
  4. Support DoRA ("DVoRA")

@vvvm23
Copy link
Contributor

vvvm23 commented Apr 19, 2024

Thanks again @BenjaminBossan ! please tag me in issues and PRs related to improvements :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants