Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Vera (Vector Based Random Matrix Adaption) #2 #1564

Merged
merged 29 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c324ab2
[WIP] Initial commit
BenjaminBossan Mar 14, 2024
d8fe06b
Make style
BenjaminBossan Mar 14, 2024
463cd8e
Make style 2
BenjaminBossan Mar 14, 2024
20bf5d9
Feat: Support for Conv2D DoRA (#1516)
sayakpaul Mar 12, 2024
a344708
TST Report slowest tests (#1556)
BenjaminBossan Mar 12, 2024
2cc70ca
Changes to support fsdp+qlora and dsz3+qlora (#1550)
pacman100 Mar 13, 2024
1c4a765
Merge branch 'main' into add-vera-2
BenjaminBossan Mar 15, 2024
e7fabe3
More make style
BenjaminBossan Mar 15, 2024
0abb38d
Some further work, still WIP
BenjaminBossan Mar 18, 2024
2778657
More tests and fixes for VeRA
BenjaminBossan Mar 20, 2024
5898220
More tests, remove support for Embedding
BenjaminBossan Mar 20, 2024
a829035
Add checks for require_grad
BenjaminBossan Mar 20, 2024
9b80b1b
Some minor fixes, don't raise unnecessary errors
BenjaminBossan Mar 20, 2024
2ccb286
Fix issue caused by order of init for VeRA
BenjaminBossan Mar 20, 2024
f4dd9a3
Merge branch 'main' into add-vera-2
BenjaminBossan Mar 20, 2024
2d9687b
projection_prng_key now defaults to 0
BenjaminBossan Mar 21, 2024
209abd2
Skip failing Deberta + Vera tests
BenjaminBossan Mar 21, 2024
f0a319d
Add a sanity check to data_ptr test
BenjaminBossan Mar 21, 2024
30755a9
More sanity checks for data_ptr
BenjaminBossan Mar 21, 2024
b707ea8
Add VeRA example notebook
BenjaminBossan Mar 21, 2024
1f86941
Add some docs
BenjaminBossan Mar 21, 2024
4cc496f
Address reviewer feedback
BenjaminBossan Apr 8, 2024
ee86485
Merge branch 'main' into add-vera-2
BenjaminBossan Apr 11, 2024
924c235
Merge branch 'main' into add-vera-2
BenjaminBossan Apr 15, 2024
4739ef9
Make style
BenjaminBossan Apr 15, 2024
eefcc4f
Reviewer feedback: Adjust docstring
BenjaminBossan Apr 18, 2024
a0dab53
Update supported models for VeRA
BenjaminBossan Apr 18, 2024
5979b7b
Merge branch 'main' into add-vera-2
BenjaminBossan Apr 18, 2024
fec23e7
Fix adapter name handling
BenjaminBossan Apr 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@
title: Prefix tuning
- local: package_reference/prompt_tuning
title: Prompt tuning
- local: package_reference/vera
title: VeRA
title: Adapters
- sections:
- local: package_reference/merge_utils
Expand Down
41 changes: 41 additions & 0 deletions docs/source/package_reference/vera.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# VeRA: Vector-based Random Matrix Adaptation

[VeRA](https://huggingface.co/papers/2310.11454) is a parameter-efficient fine-tuning technique that is similar to LoRA but requires even fewer extra parameters while promising similar or even better performance. As such, it is particularly useful when the parameter budget is very limited, e.g. when scaling to very large models. The reduction of the count of trainable parameters is achieved by sharing the same low-rank matrices across all layers, and only training two additional vectors per layer.

When saving the adapter parameters, it's possible to eschew storing the low rank matrices by setting `save_projection=False` on the `VeraConfig`. In that case, these matrices will be restored based on the fixed random seed from the `projection_prng_key` argument. This cuts down on the size of the checkpoint, but we cannot guarantee reproducibility on all devices and for all future versions of PyTorch. If you want to ensure reproducibility, set `save_projection=True` (which is the default).

VeRA currently has the following constraints:

- All targeted parameters must have the same shape.
- Only `nn.Linear` layers are supported.
- Quantized layers are not supported.

If these constraints don't work for your use case, use LoRA instead.

The abstract from the paper is:

> Low-rank adapation (LoRA) is a popular method that reduces the number of trainable parameters when finetuning large language models, but still faces acute storage challenges when scaling to even larger models or deploying numerous per-user or per-task adapted models. In this work, we present Vector-based Random Matrix Adaptation (VeRA), which significantly reduces the number of trainable parameters compared to LoRA, yet maintains the same performance. It achieves this by using a single pair of low-rank matrices shared across all layers and learning small scaling vectors instead. We demonstrate its effectiveness on the GLUE and E2E benchmarks, image classification tasks, and show its application in instruction-tuning of 7B and 13B language models.

## VeRAConfig

[[autodoc]] tuners.vera.config.VeraConfig

## VeRAModel

[[autodoc]] tuners.vera.model.VeraModel
543 changes: 543 additions & 0 deletions examples/sequence_classification/VeRA.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions src/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
OFTModel,
PolyConfig,
PolyModel,
VeraConfig,
VeraModel,
)
from .utils import (
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
Expand Down
11 changes: 8 additions & 3 deletions src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,18 @@
PrefixTuningConfig,
PromptEncoderConfig,
PromptTuningConfig,
VeraConfig,
VeraModel,
)
from .tuners.tuners_utils import BaseTuner as _BaseTuner
from .utils import _prepare_prompt_learning_config


if TYPE_CHECKING:
from transformers import PreTrainedModel


MODEL_TYPE_TO_PEFT_MODEL_MAPPING: dict[str, PeftModel] = {
MODEL_TYPE_TO_PEFT_MODEL_MAPPING: dict[str, type[PeftModel]] = {
"SEQ_CLS": PeftModelForSequenceClassification,
"SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM,
"CAUSAL_LM": PeftModelForCausalLM,
Expand All @@ -66,7 +69,7 @@
"FEATURE_EXTRACTION": PeftModelForFeatureExtraction,
}

PEFT_TYPE_TO_CONFIG_MAPPING: dict[str, PeftConfig] = {
PEFT_TYPE_TO_CONFIG_MAPPING: dict[str, type[PeftConfig]] = {
"ADAPTION_PROMPT": AdaptionPromptConfig,
"PROMPT_TUNING": PromptTuningConfig,
"PREFIX_TUNING": PrefixTuningConfig,
Expand All @@ -79,16 +82,18 @@
"MULTITASK_PROMPT_TUNING": MultitaskPromptTuningConfig,
"OFT": OFTConfig,
"POLY": PolyConfig,
"VERA": VeraConfig,
}

PEFT_TYPE_TO_TUNER_MAPPING = {
PEFT_TYPE_TO_TUNER_MAPPING: dict[str, type[_BaseTuner]] = {
"LORA": LoraModel,
"LOHA": LoHaModel,
"LOKR": LoKrModel,
"ADALORA": AdaLoraModel,
"IA3": IA3Model,
"OFT": OFTModel,
"POLY": PolyModel,
"VERA": VeraModel,
}


Expand Down
2 changes: 2 additions & 0 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
PrefixEncoder,
PromptEmbedding,
PromptEncoder,
VeraModel,
)
from .utils import (
SAFETENSORS_WEIGHTS_NAME,
Expand Down Expand Up @@ -82,6 +83,7 @@
PeftType.IA3: IA3Model,
PeftType.OFT: OFTModel,
PeftType.POLY: PolyModel,
PeftType.VERA: VeraModel,
}


Expand Down
1 change: 1 addition & 0 deletions src/peft/tuners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@
from .oft import OFTConfig, OFTModel
from .mixed import MixedModel
from .poly import PolyConfig, PolyModel
from .vera import VeraConfig, VeraModel
20 changes: 18 additions & 2 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def __init__(self, model, peft_config: Union[PeftConfig, dict[str, PeftConfig]],
self.peft_config.update(peft_config)

self.active_adapter = adapter_name
self._pre_injection_hook(self.model, self.peft_config[adapter_name], adapter_name)
self.inject_adapter(self.model, adapter_name)

# Copy the peft_config in the injected model.
Expand All @@ -160,6 +161,21 @@ def active_adapters(self) -> list[str]:
def forward(self, *args: Any, **kwargs: Any):
return self.model.forward(*args, **kwargs)

def _pre_injection_hook(self, model: nn.Module, config: PeftConfig, adapter_name: str) -> None:
r"""
A hook to be called before the adapter is injected into the model. This method can be overridden by child
classes to perform any pre-injection operations.

Args:
model (`nn.Module`):
The model to be adapted.
config (`PeftConfig`):
The adapter config.
adapter_name (`str`):
The adapter name.
"""
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] why not raise NotImplementedError? Avoid silent failures if something incorrectly calls the hook.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing is a valid outcome here, if we raised here, all non-VeRA adapters would suddenly error ;)


@abstractmethod
def _prepare_adapter_config(self, peft_config: PeftConfig, model_config: dict) -> PeftConfig:
r"""
Expand Down Expand Up @@ -398,9 +414,9 @@ class BaseTunerLayer(ABC):
active_adapter = None

# All names of layers that may contain adapter (trainable) weights
adapter_layer_names: tuple[str] = ()
adapter_layer_names: tuple[str, ...] = ()
# All names of other parameters that may contain adapter-related parameters
other_param_names: tuple[str] = ()
other_param_names: tuple[str, ...] = ()

# indicates whether all adapters should be disabled
_disable_adapters: bool = False
Expand Down
20 changes: 20 additions & 0 deletions src/peft/tuners/vera/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .config import VeraConfig
from .layer import Linear, VeraLayer
from .model import VeraModel


__all__ = ["VeraConfig", "VeraLayer", "Linear", "VeraModel"]
160 changes: 160 additions & 0 deletions src/peft/tuners/vera/buffer_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# Adapted from https://botorch.org/api/_modules/botorch/utils/torch.html

# TODO: To be removed once (if) https://github.com/pytorch/pytorch/pull/37385 lands

from __future__ import annotations

import collections
from collections import OrderedDict

import torch
from torch.nn import Module


class BufferDict(Module):
r"""
Holds buffers in a dictionary.

BufferDict can be indexed like a regular Python dictionary, but buffers it contains are properly registered, and
will be visible by all Module methods. `torch.nn.BufferDict` is an **ordered** dictionary that respects

* the order of insertion, and
* in `torch.nn.BufferDict.update`, the order of the merged `OrderedDict` or another `torch.nn.BufferDict` (the
argument to `torch.nn.BufferDict.update`).

Note that `torch.nn.BufferDict.update` with other unordered mapping types (e.g., Python's plain `dict`) does not
preserve the order of the merged mapping.

Args:
buffers (iterable, optional):
a mapping (dictionary) of (string : `torch.Tensor`) or an iterable of key-value pairs of type (string,
`torch.Tensor`)

```python
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.buffers = nn.BufferDict({"left": torch.randn(5, 10), "right": torch.randn(5, 10)})

def forward(self, x, choice):
x = self.buffers[choice].mm(x)
return x
```
"""

def __init__(self, buffers=None, persistent: bool = False):
r"""
Args:
buffers (`dict`):
A mapping (dictionary) from string to `torch.Tensor`, or an iterable of key-value pairs of type
(string, `torch.Tensor`).
"""
super().__init__()
if buffers is not None:
self.update(buffers)

self.persistent = persistent

def __getitem__(self, key):
return self._buffers[key]

def __setitem__(self, key, buffer):
self.register_buffer(key, buffer, persistent=self.persistent)

def __delitem__(self, key):
del self._buffers[key]

def __len__(self):
return len(self._buffers)

def __iter__(self):
return iter(self._buffers.keys())

def __contains__(self, key):
return key in self._buffers

def clear(self):
"""Remove all items from the BufferDict."""
self._buffers.clear()

def pop(self, key):
r"""Remove key from the BufferDict and return its buffer.

Args:
key (`str`):
Key to pop from the BufferDict
"""
v = self[key]
del self[key]
return v

def keys(self):
r"""Return an iterable of the BufferDict keys."""
return self._buffers.keys()

def items(self):
r"""Return an iterable of the BufferDict key/value pairs."""
return self._buffers.items()

def values(self):
r"""Return an iterable of the BufferDict values."""
return self._buffers.values()

def update(self, buffers):
r"""
Update the `torch.nn.BufferDict` with the key-value pairs from a mapping or an iterable, overwriting existing
keys.

Note:
If `buffers` is an `OrderedDict`, a `torch.nn.BufferDict`, or an iterable of key-value pairs, the order of
new elements in it is preserved.

Args:
buffers (iterable):
a mapping (dictionary) from string to `torch.Tensor`, or an iterable of key-value pairs of type
(string, `torch.Tensor`).
"""
if not isinstance(buffers, collections.abc.Iterable):
raise TypeError(
"BuffersDict.update should be called with an "
"iterable of key/value pairs, but got " + type(buffers).__name__
)

if isinstance(buffers, collections.abc.Mapping):
if isinstance(buffers, (OrderedDict, BufferDict)):
for key, buffer in buffers.items():
self[key] = buffer
else:
for key, buffer in sorted(buffers.items()):
self[key] = buffer
else:
for j, p in enumerate(buffers):
if not isinstance(p, collections.abc.Iterable):
raise TypeError(
"BufferDict update sequence element "
"#" + str(j) + " should be Iterable; is" + type(p).__name__
)
if not len(p) == 2:
raise ValueError(
"BufferDict update sequence element "
"#" + str(j) + " has length " + str(len(p)) + "; 2 is required"
)
self[p[0]] = p[1]

def extra_repr(self):
child_lines = []
for k, p in self._buffers.items():
size_str = "x".join(str(size) for size in p.size())
device_str = "" if not p.is_cuda else f" (GPU {p.get_device()})"
parastr = f"Buffer containing: [{torch.typename(p)} of size {size_str}{device_str}]"
child_lines.append(" (" + k + "): " + parastr)
tmpstr = "\n".join(child_lines)
return tmpstr

def __call__(self, input):
raise RuntimeError("BufferDict should not be called.")
Loading
Loading