Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
chiayi-hsu committed Nov 14, 2024
2 parents 5ee3c83 + 221965b commit e8ab799
Show file tree
Hide file tree
Showing 18 changed files with 2,087 additions and 20 deletions.
31 changes: 31 additions & 0 deletions docs/source/developer_guides/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,37 @@ from peft import LoraConfig
config = LoraConfig(init_lora_weights="olora", ...)
```
For more advanced usage, please refer to our [documentation](https://github.com/huggingface/peft/tree/main/examples/olora_finetuning).

### EVA
[EVA](https://arxiv.org/pdf/2410.07170) performs SVD on the input activations of each layer and uses the right-singular vectors to initialize LoRA weights. It therefore is a data-driven initialization scheme. Furthermore EVA adaptively allocates ranks across layers based on their "explained variance ratio" - a metric derived from the SVD analysis.

You can use EVA by setting `init_lora_weights="eva"` and defining [`EvaConfig`] in [`LoraConfig`]:
```python
from peft import LoraConfig, EvaConfig
peft_config = LoraConfig(
init_lora_weights = "eva",
eva_config = EvaConfig(rho = 2.0),
...
)
```
The parameter `rho` (≥ 1.0) determines how much redistribution is allowed. When `rho=1.0` and `r=16`, the system is limited to exactly 16 ranks, preventing any redistribution from occurring. A recommended value for eva with redistribution is 2.0, meaning the maximum rank allowed for a layer is 2r.

It is recommended to perform EVA initialization on a GPU as it is much faster. To optimize the amount of available memory for EVA, you can use the `low_cpu_mem_usage` flag in [`get_peft_model`]:
```python
peft_model = get_peft_model(model, peft_config, low_cpu_mem_usage=True)
```
Then, call [`initialize_lora_eva_weights`] to initialize the EVA weights (in most cases the dataloader used for eva initialization can be the same as the one used for finetuning):
```python
initialize_lora_eva_weights(peft_model, dataloader)
```
EVA works out of the box with bitsandbytes. Simply initialize the model with `quantization_config` and call [`initialize_lora_eva_weights`] as usual.

<Tip>

For further instructions on using EVA, please refer to our [documentation](https://github.com/huggingface/peft/tree/main/examples/eva_finetuning).

</Tip>

### LoftQ

#### Standard approach
Expand Down
16 changes: 16 additions & 0 deletions docs/source/package_reference/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,20 @@ The abstract from the paper is:

## Utility

### LoftQ

[[autodoc]] utils.loftq_utils.replace_lora_weights_loftq

### Eva

#### EvaConfig

[[autodoc]] tuners.lora.config.EvaConfig

#### initialize_lora_eva_weights

[[autodoc]] tuners.lora.eva.initialize_lora_eva_weights

#### get_eva_state_dict

[[autodoc]] tuners.lora.eva.get_eva_state_dict
153 changes: 153 additions & 0 deletions examples/eva_finetuning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# EVA: Explained Variance Adaptation
## Introduction ([Paper](https://arxiv.org/abs/2410.07170), [code](https://github.com/ml-jku/EVA))
Explained Variance Adaptation (EVA) is a novel initialization method for LoRA style adapters which initializes adapter weights in a data driven manner and adaptively allocates ranks according to the variance they explain. EVA improves average performance on a multitude of tasks across various domains, such as Language generation and understanding, Image classification, and Decision Making.

The abstract from the paper is:

*Foundation models (FMs) are pre-trained on large-scale datasets and then fine-tuned on a downstream task for a specific application. The most successful and most commonly used fine-tuning method is to update the pre-trained weights via a low-rank adaptation (LoRA). LoRA introduces new weight matrices that are usually initialized at random with a uniform rank distribution across model weights. Recent works focus on weight-driven initialization or learning of adaptive ranks during training. Both approaches have only been investigated in isolation, resulting in slow convergence or a uniform rank distribution, in turn leading to sub-optimal performance. We propose to enhance LoRA by initializing the new weights in a data-driven manner by computing singular value decomposition on minibatches of activation vectors. Then, we initialize the LoRA matrices with the obtained right-singular vectors and re-distribute ranks among all weight matrices to explain the maximal amount of variance and continue the standard LoRA fine-tuning procedure. This results in our new method **E**xplained **V**ariance **A**daptation (EVA). We apply EVA to a variety of fine-tuning tasks ranging from language generation and understanding to image classification and reinforcement learning. EVA exhibits faster convergence than competitors and attains the highest average score across a multitude of tasks per domain.*

## Quick Start
Below is an example of how to use EVA with a causal language model. For a more detailed example see [eva_finetuning.py](https://github.com/huggingface/peft/blob/main/examples/eva_finetuning/eva_finetuning.py).
```python
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer

from peft import EvaConfig, LoraConfig, get_peft_model, initialize_lora_eva_weights


# config
model_name = "meta-llama/Llama-3.1-8B"
max_seq_len = 512
rank = 16
alpha = 1
rho = 2.0
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
svd_batch_size = 4 # can be different from the batch size used in finetuning

# load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

# load dataset
dataset = load_dataset("Rowan/hellaswag")
dataset = dataset.map(
lambda x: tokenizer(x["ctx"], padding="max_length", truncation=True, max_length=max_seq_len),
batched=True,
remove_columns=dataset["train"].column_names,
)
dataset.set_format(type="torch")

# create dataloader for SVD
# typically this is the same as the dataloader used for finetuning
dataloader = DataLoader(
dataset["train"],
batch_size=svd_batch_size,
collate_fn=lambda examples: {k: torch.stack([v[k] for v in examples], dim=0) for k in examples[0].keys()},
)

# setup peft config
eva_config = EvaConfig(
rho=rho
)
peft_config = LoraConfig(
r=rank,
lora_alpha=alpha,
target_modules=target_modules,
init_lora_weights="eva",
eva_config=eva_config
)

# move model to GPU
model = model.cuda()

# to optimize memory usage during EVA initialization, set low_cpu_mem_usage=True
peft_model = get_peft_model(model, peft_config, low_cpu_mem_usage=True)

initialize_lora_eva_weights(peft_model, dataloader)
```
`initialize_lora_eva_weights` will compute the SVD and load the components into the model. After this continue with standard LoRA finetuning.

## Using EVA with Bitsandbytes
EVA is fully compatible with bitsandbytes. Simply initialize the pretrained model with a BitsAndBytesConfig and then use the peft model with EVA.
```python
from transformers import BitsAndBytesConfig
from peft import prepare_model_for_kbit_training

model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B",
quantization_config=BitsAndBytesConfig(load_in_4bit=True)
)
model = prepare_model_for_kbit_training(model)
peft_model = get_peft_model(model, peft_config)
initialize_lora_eva_weights(peft_model, dataloader)
```

## Getting the EVA state_dict without loading the adapter weights
In some cases you might just want to get the state_dict after EVA initialization without loading the adapter weights. This can be useful for example if:
- you want to precompute and store the state_dict for different downstream tasks.
- you need to quantize the model for finetuning but want to perform EVA initialization with model weights in full/half precision.
- you do not intend to use a peft model for LoRA finetuning.

You can do this by calling `get_eva_state_dict` directly (you only need to pass `peft_config` if `model` is not a PeftModel):
```python
from peft import get_eva_state_dict

eva_state_dict = get_eva_state_dict(model, dataloader, peft_config)
```
Later you can load the state_dict into a model without adapter weights by using the `eva_state_dict` argument in `initialize_lora_eva_weights`:
```python
initialize_lora_eva_weights(peft_model, eva_state_dict=eva_state_dict)
```

## Customizing EVA

By default, EVA is designed to work with standard transformer language models. However we integrated three different paramters which can be used to customize EVA for other types of models.
1. `forward_fn`: Defines how the forward pass during EVA initialization should be computed.
2. `prepare_model_inputs_fn`: Can be used if it is necessary to use information contained in the original model_input to prepare the input for SVD in individual layers.
3. `prepare_layer_inputs_fn`: Defines how layer inputs should be prepared for SVD.

All three parameters can be passed to `initialize_lora_eva_weights` and `get_eva_state_dict`.

### forward_fn

`forward_fn` defines how the forward pass during EVA initialization should be computed. `forward_fn` receives two arguments: `model` and `inputs`. By default this is set to `forward_fn_dict` which simply returns `model(**inputs)`.

### prepare_model_inputs_fn

`prepare_model_inputs_fn` can be used if it is necessary to use information contained in the original model_input to prepare the input for SVD in individual layers. `prepare_model_inputs_fn` receives two arguments: `model_input` and `peft_config`. This component is separate from `prepare_layer_inputs_fn` as the output only needs to be computed once per batch. By default this parameter is set to `prepare_model_inputs_fn_language_modeling` which is used get a subset of indices based on attention and label mask to avoid including padding tokens in the SVD computation. If you would like to not use this component set `prepare_model_inputs_fn` to None. The default logic is:
```python
def prepare_model_inputs_fn_language_modeling(model_input, peft_config: LoraConfig):
mask = model_input.get("attention_mask", torch.ones_like(model_input["input_ids"])).bool()
if peft_config.eva_config.use_label_mask and hasattr(model_input, "labels"):
mask = torch.logical_and(mask, model_input["labels"] != peft_config.eva_config.label_mask_value)
return mask.nonzero()
```

### prepare_layer_inputs_fn

`prepare_layer_inputs_fn` can be used to preprocess the layer inputs before passing them to the SVD algorithm. `prepare_layer_inputs_fn` receives three arguments: `layer_input`, `model_input` and `layer_name`. It can either be a callable or a dictionary where the keys are the layer names and the values are callables. If it is a dictionary, functions are assigned to adapter layers based on the layer names. By default a language modeling setting is assumed where model_inputs are the outputs of `prepare_model_inputs_fn_language_modeling` which is a mask of indices. If this parameter is set to None, only two modifications are made to the layer inputs
- take the first element incase of a tuple or list.
- if the input has more than 2 dimensions, we flatten all but the last dimension.

Must always return a tensor. The default logic is:
```python
def prepare_layer_inputs_fn_default(layer_input, model_input, layer_name) -> torch.Tensor:
if isinstance(layer_input, (tuple, list)):
layer_input = layer_input[0]
return layer_input[model_input.T.unbind()]
```

## Citation
In case you find our work useful, please consider citing it.

```
@article{paischer2024eva,
title={One Initialization to Rule them All: Fine-tuning via Explained Variance Adaptation},
author={Fabian Paischer, Lukas Hauzenberger, Thomas Schmied, Benedikt Alkin, Marc Peter Deisenroth, Sepp Hochreiter},
journal={arXiv preprint arXiv:2410.07170},
year={2024}
}
```
87 changes: 87 additions & 0 deletions examples/eva_finetuning/eva_finetuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from utils import DataCollator, TokenizerMetaMath

from peft import EvaConfig, LoraConfig, get_peft_model, initialize_lora_eva_weights


# config
model_name = "meta-llama/Llama-3.1-8B"
max_seq_len = 512
rank = 16
alpha = 1
rho = 2.0
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
svd_batch_size = 4 # can be different from the batch size used in finetuning
batch_size = 4
num_epochs = 1
output_dir = "outputs"
device = "cuda:0"

# load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# load dataset
dataset = load_dataset("meta-math/MetaMathQA")
dataset = dataset.map(
TokenizerMetaMath(model_name),
batched=True,
remove_columns=dataset["train"].column_names,
)
dataset.set_format(type="torch")

# data collator
data_collator = DataCollator(tokenizer.eos_token_id, max_length=max_seq_len)

# dataloader
dataloader = DataLoader(
dataset["train"],
batch_size=svd_batch_size,
collate_fn=data_collator,
)

# setup peft config
eva_config = EvaConfig(rho=rho)
peft_config = LoraConfig(
r=rank, lora_alpha=alpha, target_modules=target_modules, init_lora_weights="eva", eva_config=eva_config
)

# move model to GPU
model = model.to(device)

# to optimize memory usage during eva initialization, set low_cpu_mem_usage=True
peft_model = get_peft_model(model, peft_config, low_cpu_mem_usage=True)
initialize_lora_eva_weights(peft_model, dataloader)

# setup training arguments
training_args = TrainingArguments(
per_device_train_batch_size=batch_size,
num_train_epochs=num_epochs,
output_dir=output_dir,
remove_unused_columns=False,
)

# continue with standard finetuning
trainer = Trainer(
model=peft_model,
args=training_args,
train_dataset=dataset["train"],
data_collator=data_collator,
)
trainer.train()
76 changes: 76 additions & 0 deletions examples/eva_finetuning/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from transformers import AutoTokenizer


class TokenizerMetaMath:
PROMPT_NO_INPUT = (
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{query}\n\n### Response: "
)
PROMPT = (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{query}\n\n### Input:\n{input}\n\n### Response: "
)

def format_prompt(self, query):
query = query.split("\n", 1)
if len(query) == 1 or query[1].strip("\n") == "":
return self.PROMPT_NO_INPUT.format(query=query[0])
else:
return self.PROMPT.format(query=query[0], input=query[1])

def __init__(self, tokenizer_path):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

def __call__(self, examples):
prompts = [self.format_prompt(text) for text in examples["query"]]
completions = examples["response"]
return self._tokenize_fn(prompts, completions)

def _tokenize_fn(self, prompts, completions):
prompt_tokens = self.tokenizer(prompts, add_special_tokens=False)["input_ids"]
input_tokens = self.tokenizer([x + y for x, y in zip(prompts, completions)], add_special_tokens=False)[
"input_ids"
]
input_tokens = [[self.tokenizer.bos_token_id] + x + [self.tokenizer.eos_token_id] for x in input_tokens]
prompt_length = [len(x) + 1 for x in prompt_tokens] # +1 for the bos token
input_length = [len(x) for x in input_tokens]
return {"input_ids": input_tokens, "prompt_length": prompt_length, "input_length": input_length}


class DataCollator:
def __init__(self, eos_token_id, max_length=None):
self.eos_token_id = eos_token_id
self.max_length = max_length

def __call__(self, batch):
batch = {k: [item[k] for item in batch] for k in batch[0]}
input_lengths = torch.stack(batch["input_length"])
prompt_lengths = torch.stack(batch["prompt_length"])
input_ids = torch.nn.utils.rnn.pad_sequence(
batch["input_ids"], batch_first=True, padding_value=self.eos_token_id
)
col_indices = torch.arange(input_ids.size(1)).unsqueeze(0)
attention_mask = col_indices < input_lengths.unsqueeze(1)
label_mask = torch.logical_or(col_indices < prompt_lengths.unsqueeze(1), ~attention_mask)
labels = input_ids.masked_fill(label_mask, -100)
if self.max_length is not None:
input_ids = input_ids[:, : self.max_length]
attention_mask = attention_mask[:, : self.max_length]
labels = labels[:, : self.max_length]
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def main():

dataset = DatasetDict({"train": train_dataset, "validation": val_dataset})
else:
dataset = load_dataset(args.dataset_name)
dataset = load_dataset(args.dataset_name, revision="main")

def preprocess_function(examples):
queries = examples["query"]
Expand Down
Loading

0 comments on commit e8ab799

Please sign in to comment.