-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds Vera (Vector Based Random Matrix Adaption) #2 #1564
Merged
Merged
Changes from 21 commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
c324ab2
[WIP] Initial commit
BenjaminBossan d8fe06b
Make style
BenjaminBossan 463cd8e
Make style 2
BenjaminBossan 20bf5d9
Feat: Support for Conv2D DoRA (#1516)
sayakpaul a344708
TST Report slowest tests (#1556)
BenjaminBossan 2cc70ca
Changes to support fsdp+qlora and dsz3+qlora (#1550)
pacman100 1c4a765
Merge branch 'main' into add-vera-2
BenjaminBossan e7fabe3
More make style
BenjaminBossan 0abb38d
Some further work, still WIP
BenjaminBossan 2778657
More tests and fixes for VeRA
BenjaminBossan 5898220
More tests, remove support for Embedding
BenjaminBossan a829035
Add checks for require_grad
BenjaminBossan 9b80b1b
Some minor fixes, don't raise unnecessary errors
BenjaminBossan 2ccb286
Fix issue caused by order of init for VeRA
BenjaminBossan f4dd9a3
Merge branch 'main' into add-vera-2
BenjaminBossan 2d9687b
projection_prng_key now defaults to 0
BenjaminBossan 209abd2
Skip failing Deberta + Vera tests
BenjaminBossan f0a319d
Add a sanity check to data_ptr test
BenjaminBossan 30755a9
More sanity checks for data_ptr
BenjaminBossan b707ea8
Add VeRA example notebook
BenjaminBossan 1f86941
Add some docs
BenjaminBossan 4cc496f
Address reviewer feedback
BenjaminBossan ee86485
Merge branch 'main' into add-vera-2
BenjaminBossan 924c235
Merge branch 'main' into add-vera-2
BenjaminBossan 4739ef9
Make style
BenjaminBossan eefcc4f
Reviewer feedback: Adjust docstring
BenjaminBossan a0dab53
Update supported models for VeRA
BenjaminBossan 5979b7b
Merge branch 'main' into add-vera-2
BenjaminBossan fec23e7
Fix adapter name handling
BenjaminBossan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
<!--Copyright 2024 The HuggingFace Team. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
|
||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
rendered properly in your Markdown viewer. | ||
|
||
--> | ||
|
||
# VeRA: Vector-based Random Matrix Adaptation | ||
|
||
[VeRA](https://huggingface.co/papers/2310.11454) is a parameter-efficient fine-tuning technique that is similar to LoRA but requires even fewer extra parameters while promising similar or even better performance. As such, it is particularly useful when the parameter budget is very limited, e.g. when scaling to very large models. The reduction of the count of trainable parameters is achieved by sharing the same low-rank matrices across all layers, and only training two additional vectors per layer. | ||
|
||
When saving the adapter parameters, it's possible to eschew storing the low rank matrices by setting `save_projection=False` on the `VeraConfig`. In that case, these matrices will be restored based on the fixed random seed from the `projection_prng_key` argument. This cuts down on the size of the checkpoint, but we cannot guarantee reproducibility on all devices and for all future versions of PyTorch. If you want to ensure reproducibility, set `save_projection=True` (which is the default). | ||
|
||
VeRA currently has the following constraints: | ||
|
||
- All targeted parameters must have the same shape. | ||
- Only `nn.Linear` layers are supported. | ||
- Quantized layers are not supported. | ||
|
||
If these constraints don't work for your use case, use LoRA instead. | ||
|
||
The abstract from the paper is: | ||
|
||
> Low-rank adapation (LoRA) is a popular method that reduces the number of trainable parameters when finetuning large language models, but still faces acute storage challenges when scaling to even larger models or deploying numerous per-user or per-task adapted models. In this work, we present Vector-based Random Matrix Adaptation (VeRA), which significantly reduces the number of trainable parameters compared to LoRA, yet maintains the same performance. It achieves this by using a single pair of low-rank matrices shared across all layers and learning small scaling vectors instead. We demonstrate its effectiveness on the GLUE and E2E benchmarks, image classification tasks, and show its application in instruction-tuning of 7B and 13B language models. | ||
|
||
## VeRAConfig | ||
|
||
[[autodoc]] tuners.vera.config.VeraConfig | ||
|
||
## VeRAModel | ||
|
||
[[autodoc]] tuners.vera.model.VeraModel |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright 2023-present the HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .config import VeraConfig | ||
from .layer import Linear, VeraLayer | ||
from .model import VeraModel | ||
|
||
|
||
__all__ = ["VeraConfig", "VeraLayer", "Linear", "VeraModel"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# Adapted from https://botorch.org/api/_modules/botorch/utils/torch.html | ||
|
||
# TODO: To be removed once (if) https://github.com/pytorch/pytorch/pull/37385 lands | ||
|
||
from __future__ import annotations | ||
|
||
import collections | ||
from collections import OrderedDict | ||
|
||
import torch | ||
from torch.nn import Module | ||
|
||
|
||
class BufferDict(Module): | ||
r""" | ||
Holds buffers in a dictionary. | ||
|
||
BufferDict can be indexed like a regular Python dictionary, but buffers it contains are properly registered, and | ||
will be visible by all Module methods. `torch.nn.BufferDict` is an **ordered** dictionary that respects | ||
|
||
* the order of insertion, and | ||
* in `torch.nn.BufferDict.update`, the order of the merged `OrderedDict` or another `torch.nn.BufferDict` (the | ||
argument to `torch.nn.BufferDict.update`). | ||
|
||
Note that `torch.nn.BufferDict.update` with other unordered mapping types (e.g., Python's plain `dict`) does not | ||
preserve the order of the merged mapping. | ||
|
||
Args: | ||
buffers (iterable, optional): | ||
a mapping (dictionary) of (string : `torch.Tensor`) or an iterable of key-value pairs of type (string, | ||
`torch.Tensor`) | ||
|
||
```python | ||
class MyModule(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.buffers = nn.BufferDict({"left": torch.randn(5, 10), "right": torch.randn(5, 10)}) | ||
|
||
def forward(self, x, choice): | ||
x = self.buffers[choice].mm(x) | ||
return x | ||
``` | ||
""" | ||
|
||
def __init__(self, buffers=None, persistent: bool = False): | ||
r""" | ||
Args: | ||
buffers (`dict`): | ||
A mapping (dictionary) from string to `torch.Tensor`, or an iterable of key-value pairs of type | ||
(string, `torch.Tensor`). | ||
""" | ||
super().__init__() | ||
if buffers is not None: | ||
self.update(buffers) | ||
|
||
self.persistent = persistent | ||
|
||
def __getitem__(self, key): | ||
return self._buffers[key] | ||
|
||
def __setitem__(self, key, buffer): | ||
self.register_buffer(key, buffer, persistent=self.persistent) | ||
|
||
def __delitem__(self, key): | ||
del self._buffers[key] | ||
|
||
def __len__(self): | ||
return len(self._buffers) | ||
|
||
def __iter__(self): | ||
return iter(self._buffers.keys()) | ||
|
||
def __contains__(self, key): | ||
return key in self._buffers | ||
|
||
def clear(self): | ||
"""Remove all items from the BufferDict.""" | ||
self._buffers.clear() | ||
|
||
def pop(self, key): | ||
r"""Remove key from the BufferDict and return its buffer. | ||
|
||
Args: | ||
key (`str`): | ||
Key to pop from the BufferDict | ||
""" | ||
v = self[key] | ||
del self[key] | ||
return v | ||
|
||
def keys(self): | ||
r"""Return an iterable of the BufferDict keys.""" | ||
return self._buffers.keys() | ||
|
||
def items(self): | ||
r"""Return an iterable of the BufferDict key/value pairs.""" | ||
return self._buffers.items() | ||
|
||
def values(self): | ||
r"""Return an iterable of the BufferDict values.""" | ||
return self._buffers.values() | ||
|
||
def update(self, buffers): | ||
r""" | ||
Update the `torch.nn.BufferDict` with the key-value pairs from a mapping or an iterable, overwriting existing | ||
keys. | ||
|
||
Note: | ||
If `buffers` is an `OrderedDict`, a `torch.nn.BufferDict`, or an iterable of key-value pairs, the order of | ||
new elements in it is preserved. | ||
|
||
Args: | ||
buffers (iterable): | ||
a mapping (dictionary) from string to `torch.Tensor`, or an iterable of key-value pairs of type | ||
(string, `torch.Tensor`). | ||
""" | ||
if not isinstance(buffers, collections.abc.Iterable): | ||
raise TypeError( | ||
"BuffersDict.update should be called with an " | ||
"iterable of key/value pairs, but got " + type(buffers).__name__ | ||
) | ||
|
||
if isinstance(buffers, collections.abc.Mapping): | ||
if isinstance(buffers, (OrderedDict, BufferDict)): | ||
for key, buffer in buffers.items(): | ||
self[key] = buffer | ||
else: | ||
for key, buffer in sorted(buffers.items()): | ||
self[key] = buffer | ||
else: | ||
for j, p in enumerate(buffers): | ||
if not isinstance(p, collections.abc.Iterable): | ||
raise TypeError( | ||
"BufferDict update sequence element " | ||
"#" + str(j) + " should be Iterable; is" + type(p).__name__ | ||
) | ||
if not len(p) == 2: | ||
raise ValueError( | ||
"BufferDict update sequence element " | ||
"#" + str(j) + " has length " + str(len(p)) + "; 2 is required" | ||
) | ||
self[p[0]] = p[1] | ||
|
||
def extra_repr(self): | ||
child_lines = [] | ||
for k, p in self._buffers.items(): | ||
size_str = "x".join(str(size) for size in p.size()) | ||
device_str = "" if not p.is_cuda else f" (GPU {p.get_device()})" | ||
parastr = f"Buffer containing: [{torch.typename(p)} of size {size_str}{device_str}]" | ||
child_lines.append(" (" + k + "): " + parastr) | ||
tmpstr = "\n".join(child_lines) | ||
return tmpstr | ||
|
||
def __call__(self, input): | ||
raise RuntimeError("BufferDict should not be called.") |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] why not
raise NotImplementedError
? Avoid silent failures if something incorrectly calls the hook.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Passing is a valid outcome here, if we raised here, all non-VeRA adapters would suddenly error ;)