diff --git a/docs/source/developer_guides/lora.md b/docs/source/developer_guides/lora.md index a87a9e8b1a..680804f0d0 100644 --- a/docs/source/developer_guides/lora.md +++ b/docs/source/developer_guides/lora.md @@ -40,6 +40,20 @@ from peft import LoraConfig config = LoraConfig(init_lora_weights=False, ...) ``` +### PiSSA +[PiSSA](https://arxiv.org/abs/2404.02948) initializes the LoRA adapter using the principal singular values and singular vectors. This straightforward modification allows PiSSA to converge more rapidly than LoRA and ultimately attain superior performance. Moreover, PiSSA reduces the quantization error compared to QLoRA, leading to further enhancements. + +Configure the initialization method to "pissa", which may take several minutes to execute SVD on the pre-trained model: +```python +from peft import LoraConfig +config = LoraConfig(init_lora_weights="pissa", ...) +``` +Alternatively, execute fast SVD, which takes only a few seconds. The number of iterations determines the trade-off between the error and computation time: +```python +lora_config = LoraConfig(init_lora_weights="pissa_niter_[number of iters]", ...) +``` +For detailed instruction on using PiSSA, please follow [these instructions](https://github.com/fxmeng/peft/tree/main/examples/pissa_finetuning). + ### LoftQ #### Standard approach diff --git a/examples/pissa_finetuning/README.md b/examples/pissa_finetuning/README.md new file mode 100644 index 0000000000..a80aab8f24 --- /dev/null +++ b/examples/pissa_finetuning/README.md @@ -0,0 +1,131 @@ +# PiSSA: Principal Singular values and Singular vectors Adaptation +## Introduction ([Paper](https://arxiv.org/abs/2404.02948), [code](https://github.com/GraphPKU/PiSSA)) +PiSSA represents a matrix $W\in\mathbb{R}^{m\times n}$ within the model by the product of two trainable matrices $A \in \mathbb{R}^{m\times r}$ and $B \in \mathbb{R}^{r\times n}$, where $r \ll \min(m, n)$, plus a residual matrix $W^{res}\in\mathbb{R}^{m\times n}$ for error correction. Singular value decomposition (SVD) is employed to factorize $W$, and the principal singular values and vectors of $W$ are utilized to initialize $A$ and $B$. The residual singular values and vectors initialize the residual matrix $W^{res}$, which keeps frozen during fine-tuning. This straightforward modification allows PiSSA to converge more rapidly than LoRA and ultimately attain superior performance. Moreover, PiSSA reduces the quantization error compared to QLoRA, leading to further enhancements. + +## Quick Start +```python +import torch +from peft import LoraConfig, get_peft_model +from transformers import AutoTokenizer, AutoModelForCausalLM +from trl import SFTTrainer +from datasets import load_dataset + +model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto") +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") +tokenizer.pad_token_id = tokenizer.eos_token_id +lora_config = LoraConfig( + # init_lora_weights="pissa", # Configure the initialization method to "pissa", which may take several minutes to execute SVD on the pre-trained model. + init_lora_weights="pissa_niter_4", # Initialize the PiSSA with fast SVD, which completes in just a few seconds. +) +peft_model = get_peft_model(model, lora_config) + +peft_model.print_trainable_parameters() + +dataset = load_dataset("imdb", split="train[:1%]") + +trainer = SFTTrainer( + model=peft_model, + train_dataset=dataset, + dataset_text_field="text", + max_seq_length=128, + tokenizer=tokenizer, +) +trainer.train() +peft_model.save_pretrained("pissa-llama-2-7b") +``` +When utilizing fast SVD, reducing the rank and the number of iterations decreases the time required. However, this approach leads to higher errors in the computed matrices $A$ and $B$. To preserve the model's initial capabilities, we calculate the residual matrix by $W^{res} = W - BA$. Even with potential errors in $A$ and $B$, the sum of $W^{res}$ and $BA$ accurately equals $W$. + + +To utilize the fine-tuned PiSSA modules, simply run the following command: +```python +import torch +from peft import PeftModel +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto" +) +# Performs SVD again to initialize the residual model and loads the state_dict of the fine-tuned PiSSA modules. +peft_model = PeftModel.from_pretrained(model, "pissa-llama-2-7b") +``` + +## Advanced Usage + +### Access the preprocessed models +We recommend downloading decomposed models directly from the [Hugging Face Collections](https://huggingface.co/collections/fxmeng/pissa-661ce700721235e542a5d7a8) instead of performing SVD every time. +If the existing models do not meet your needs, apply PiSSA initialization to a pre-trained model and store the decomposed model locally: +```bash +python preprocess.py \ + --base_model_name_or_path meta-llama/Llama-2-7b-hf \ + --init_lora_weights pissa \ + --output_dir pissa-llama-2-7b-r32-alpha-32 \ + --lora_r 32 \ + --lora_alpha 32 \ + --lora_dropout 0 \ + --bits bf16 +``` + +### Convert PiSSA to LoRA +The main advantage of PiSSA is concentrated during the training phase. For a trained PiSSA adapter, we recommend converting it equivalently to the LoRA adapter for using and sharing. +```python +# The fine-tuned matrices $A$ and $B$ in PiSSA adapter is saved and should be combined with the residual model. +peft_model.save_pretrained(output_dir) +# Given the matrices $A_0$ and $B_0$, initialized by PiSSA and untrained, and the trained matrices $A$ and $B$, +# we can convert these to LoRA by setting $\Delta W = A \times B - A_0 \times B_0 = [A \mid A_0] \times [B \mid -B_0]^T = A'B'$. +peft_model.save_pretrained(output_dir, convert_pissa_to_lora="pissa_init") + +``` +This conversion enables the loading of LoRA on top of a standard base model: + +```python +import torch +from peft import PeftModel +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto" +) +# No SVD is performed during this step, and the base model remains unaltered. +peft_model = PeftModel.from_pretrained(model, "pissa-llama-2-7b-lora") +``` +Utilizing the converted LoRA does not require modifying the parameters of the base model. When multiple converted LoRAs are needed simultaneously, each adapter operates independently without interference, allowing for the adapters to be freely deleted or added. + + + +### Fine-tune in 4-bit or 8-bit +If quantization fine-tuning is desired, it is necessary to first decompose the original model at full precision and then reload the residual model in either 4-bit or 8-bit configurations. +```shell +python pissa_finetuning.py \ + --residual_model_name_or_path fxmeng/pissa-llama-2-7b-r16-alpha-16 \ + --output_dir output/pissa-llama-2-7b-r16-alpha-16-metamath-10k \ + --bits nf4 \ + --data_path meta-math/MetaMathQA \ + --dataset_split train[:100000] \ + --dataset_field query response \ + --bf16 True \ + --num_train_epochs 1 \ + --per_device_train_batch_size 32 \ + --gradient_accumulation_steps 4 \ + --save_strategy "steps" \ + --save_steps 1000 \ + --save_total_limit 1 \ + --logging_steps 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --tf32 True \ + --report_to none \ + --convert_pissa_to_lora +``` + +This approach ensures the preservation of high-frequency, out-of-distribution parameters in the low-rank PiSSA modules, resulting in reduced quantization errors during the quantization of the residual model. + +## Citation +``` +@article{meng2024pissa, + title={PiSSA: Principal Singular Values and Singular Vectors Adaptation of Large Language Models}, + author={Meng, Fanxu and Wang, Zhaohui and Zhang, Muhan}, + journal={arXiv preprint arXiv:2404.02948}, + year={2024} +} +``` \ No newline at end of file diff --git a/examples/pissa_finetuning/pissa_finetuning.py b/examples/pissa_finetuning/pissa_finetuning.py new file mode 100644 index 0000000000..f92cd74de6 --- /dev/null +++ b/examples/pissa_finetuning/pissa_finetuning.py @@ -0,0 +1,156 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from dataclasses import dataclass, field +from typing import List, Optional + +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments +from trl import SFTTrainer + +from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training + + +@dataclass +class TrainingArguments(TrainingArguments): + # model configs + base_model_name_or_path: Optional[str] = field( + default=None, metadata={"help": "The name or path of the fp32/16 base model."} + ) + residual_model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "The name or path of the fp32/16 residual model. (`['fxmeng/pissa-llama-2-7b-r16-alpha-16']`)" + }, + ) + bits: str = field(default="fp32", metadata={"help": "(`['fp4', 'nf4', 'int8', 'bf16', 'fp16', fp32]`)"}) + init_lora_weights: str = field(default="pissa", metadata={"help": "(`['gaussian', 'pissa', 'pissa_niter_4']`)"}) + lora_r: int = field(default=16) + lora_alpha: int = field(default=16) + lora_dropout: float = field(default=0) + convert_pissa_to_lora: bool = field(default=False) + merge_and_save: bool = field(default=False) + # dataset configs + data_path: str = field(default="imdb", metadata={"help": "Path to the training data."}) + dataset_split: str = field(default="train[:1%]", metadata={"help": "(`['train', 'test', 'eval']`):"}) + dataset_field: List[str] = field(default=None, metadata={"help": "Fields of dataset input and output."}) + max_seq_length: int = field( + default=512, + metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, + ) + + +parser = HfArgumentParser(TrainingArguments) +script_args = parser.parse_args_into_dataclasses()[0] +print(script_args) + +print(f"Load pre-processed residual model in {script_args.bits} bits.") +if script_args.bits in ["nf4", "fp4", "int8"]: + quantization_config = BitsAndBytesConfig( + load_in_4bit=(script_args.bits == "nf4" or script_args.bits == "fp4"), + load_in_8bit=script_args.bits == "int8", + bnb_4bit_quant_type=script_args.bits, + bnb_4bit_use_double_quant=True, + bnb_4bit_compute_dtype=torch.bfloat16, + ) + res_model = AutoModelForCausalLM.from_pretrained( + script_args.residual_model_name_or_path, quantization_config=quantization_config, low_cpu_mem_usage=True + ) + res_model = prepare_model_for_kbit_training(res_model) + print("Wrapping the residual model with PiSSA.") + peft_model = PeftModel.from_pretrained( + res_model, script_args.residual_model_name_or_path, subfolder="pissa_init", is_trainable=True + ) + tokenizer = AutoTokenizer.from_pretrained(script_args.residual_model_name_or_path) + +elif script_args.residual_model_name_or_path is not None: + res_model = AutoModelForCausalLM.from_pretrained( + script_args.residual_model_name_or_path, + torch_dtype=( + torch.float16 + if script_args.bits == "fp16" + else (torch.bfloat16 if script_args.bits == "bf16" else torch.float32) + ), + device_map="auto", + ) + print("Wrapping the residual model with PiSSA.") + peft_model = PeftModel.from_pretrained( + res_model, script_args.residual_model_name_or_path, subfolder="pissa_init", is_trainable=True + ) + tokenizer = AutoTokenizer.from_pretrained(script_args.residual_model_name_or_path) + +elif script_args.base_model_name_or_path is not None: + print( + f"No available pre-processed model, manually initialize a PiSSA using {script_args.base_model_name_or_path}." + ) + model = AutoModelForCausalLM.from_pretrained( + script_args.base_model_name_or_path, + torch_dtype=( + torch.float16 + if script_args.bits == "fp16" + else (torch.bfloat16 if script_args.bits == "bf16" else torch.float32) + ), + device_map="auto", + ) + tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name_or_path) + tokenizer.pad_token_id = tokenizer.eos_token_id + lora_config = LoraConfig( + r=script_args.lora_r, + lora_alpha=script_args.lora_alpha, + init_lora_weights=script_args.init_lora_weights, + lora_dropout=script_args.lora_dropout, + target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"], + bias="none", + task_type="CAUSAL_LM", + ) + peft_model = get_peft_model(model, lora_config) + +print(peft_model) +peft_model.print_trainable_parameters() + +print(f"Training PiSSA with trl on the {script_args.data_path}[{script_args.dataset_split}] dataset.") +dataset = load_dataset(script_args.data_path, split=script_args.dataset_split) +dataset = dataset.map( + lambda example: { + "text": f"### USER: {example[script_args.dataset_field[0]]}\n### ASSISTANT: {example[script_args.dataset_field[1]]}" + } +) + +trainer = SFTTrainer( + model=peft_model, + args=script_args, + train_dataset=dataset, + dataset_text_field="text", + max_seq_length=script_args.max_seq_length, + tokenizer=tokenizer, +) +trainer.train() +trainer.save_state() +############################## Upon training completion, convert and save PiSSA in LoRA format ############################## +if script_args.convert_pissa_to_lora: + peft_model.save_pretrained( + os.path.join(script_args.output_dir, "pissa_lora"), + convert_pissa_to_lora=os.path.join(script_args.residual_model_name_or_path, "pissa_init"), + ) +else: + peft_model.save_pretrained( + os.path.join(script_args.output_dir, "pissa_ft"), + ) + +if script_args.merge_and_save: + model = peft_model.merge_and_unload() + model.save_pretrained(os.path.join(script_args.output_dir, "pissa_merged")) + tokenizer.save_pretrained(os.path.join(script_args.output_dir, "pissa_merged")) diff --git a/examples/pissa_finetuning/preprocess.py b/examples/pissa_finetuning/preprocess.py new file mode 100644 index 0000000000..c17f75e6e5 --- /dev/null +++ b/examples/pissa_finetuning/preprocess.py @@ -0,0 +1,67 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from peft import LoraConfig, get_peft_model + + +parser = argparse.ArgumentParser( + description="Merge Adapter to Base Model", help="The name or path of the fp32/16 base model." +) +parser.add_argument("--base_model_name_or_path", type=str, default="bf16") +parser.add_argument("--bits", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]) +parser.add_argument( + "--init_lora_weights", type=str, default="pissa", help="(`['pissa', 'pissa_niter_[number of iters]']`)" +) +parser.add_argument("--lora_r", type=int, default=128) +parser.add_argument("--lora_alpha", type=int, default=128) +parser.add_argument("--lora_dropout", type=int, default=0) +script_args = parser.parse_args() +print(script_args) + +model = AutoModelForCausalLM.from_pretrained( + script_args.base_model_name_or_path, + torch_dtype=( + torch.float16 + if script_args.bits == "fp16" + else (torch.bfloat16 if script_args.bits == "bf16" else torch.float32) + ), + device_map="auto", +) +tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name_or_path) +tokenizer.pad_token_id = tokenizer.eos_token_id +lora_config = LoraConfig( + r=script_args.lora_r, + lora_alpha=script_args.lora_alpha, + init_lora_weights=script_args.init_lora_weights, + lora_dropout=script_args.lora_dropout, + target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"], + bias="none", + task_type="CAUSAL_LM", +) +peft_model = get_peft_model(model, lora_config) + +# Save PiSSA modules: +peft_model.peft_config["default"].init_lora_weights = True +peft_model.save_pretrained(os.path.join(script_args.output_dir, "pissa_init")) +# Save residual model: +peft_model = peft_model.unload() +peft_model.save_pretrained(script_args.output_dir) +# Save the tokenizer: +tokenizer.save_pretrained(script_args.output_dir) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 308413460f..adaa23cd6e 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -177,6 +177,7 @@ def save_pretrained( selected_adapters: Optional[list[str]] = None, save_embedding_layers: Union[str, bool] = "auto", is_main_process: bool = True, + convert_pissa_to_lora: Optional[str] = None, **kwargs: Any, ) -> None: r""" @@ -199,6 +200,13 @@ def save_pretrained( is_main_process (`bool`, *optional*): Whether the process calling this is the main process or not. Will default to `True`. Will not save the checkpoint if not on the main process, which is important for multi device setups (e.g. DDP). + convert_pissa_to_lora (`str`): + The path to the initialized PiSSA adapter, which is obtained after initializing the model with PiSSA + and before performing any training. When `convert_pissa_to_lora` is not None, the difference in PISSA + before and after fine-tuning is calculated. This difference can be represented as the parameters of a + of a standard LoRA adapter. Using this converted adapter does not require changes to the base model, + thus conveniently allowing the use of multiple PISSA and LoRA adapters, and the activation or + deactivation of any adapters. kwargs (additional keyword arguments, *optional*): Additional keyword arguments passed along to the `push_to_hub` method. """ @@ -217,6 +225,22 @@ def save_pretrained( f" {list(self.peft_config.keys())} - got {selected_adapters}." ) + def save_pissa_as_lora(peft_config, convert_pissa_to_lora, output_state_dict, kwargs): + if not str(peft_config.init_lora_weights).startswith("pissa"): + warnings.warn("`convert_pissa_to_lora` only works for converting a PiSSA adapter to a LoRA adapter") + initial_adapter = os.path.basename(convert_pissa_to_lora) + self.load_adapter( + os.path.dirname(convert_pissa_to_lora), subfolder=initial_adapter, adapter_name=initial_adapter + ) + if str(self.peft_config[initial_adapter].init_lora_weights).startswith("pissa"): + raise ValueError( + "The `init_lora_weights` parameter of the initial PiSSA adapter should be set to `True`. " + "Otherwise, `self.load_adapter` will subtract the principal singular value and vector again based on the residual model." + ) + output_state_dict = self.base_model.subtract_pissa_init(output_state_dict, initial_adapter, kwargs) + self.delete_adapter(adapter_name) + return output_state_dict + if is_main_process: os.makedirs(save_directory, exist_ok=True) self.create_or_update_model_card(save_directory) @@ -255,13 +279,20 @@ def save_pretrained( # not supported in safetensors. for shared_tensor_name in names[1:]: output_state_dict[shared_tensor_name] = output_state_dict[shared_tensor_name].clone() - + if convert_pissa_to_lora is not None: + output_state_dict = save_pissa_as_lora( + peft_config, convert_pissa_to_lora, output_state_dict, kwargs + ) safe_save_file( output_state_dict, os.path.join(output_dir, SAFETENSORS_WEIGHTS_NAME), metadata={"format": "pt"}, ) elif is_main_process: + if convert_pissa_to_lora is not None: + output_state_dict = save_pissa_as_lora( + peft_config, convert_pissa_to_lora, output_state_dict, kwargs + ) torch.save(output_state_dict, os.path.join(output_dir, WEIGHTS_NAME)) # save the config and change the inference mode to `True` @@ -289,6 +320,10 @@ def save_pretrained( auto_mapping_dict = None if is_main_process: + if convert_pissa_to_lora is not None: + peft_config.init_lora_weights = True + peft_config.r *= 2 + peft_config.lora_alpha *= 2 peft_config.save_pretrained(output_dir, auto_mapping_dict=auto_mapping_dict) peft_config.inference_mode = inference_mode diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index cc5c60a753..c463c5319d 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -73,11 +73,19 @@ class LoraConfig(PeftConfig): Otherwise, it will use the original default value of `lora_alpha/r`. modules_to_save (`List[str]`): List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint. - init_lora_weights (`bool` | `Literal["gaussian", "loftq"]`): + init_lora_weights (`bool` | `Literal["gaussian", "pissa", "pissa_niter_[number of iters]", "loftq"]`): How to initialize the weights of the adapter layers. Passing True (default) results in the default initialization from the reference implementation from Microsoft. Passing 'gaussian' results in Gaussian initialization scaled by the LoRA rank for linear and layers. Setting the initialization to False leads to - completely random initialization and is discouraged. Pass `'loftq'` to use LoftQ initialization. + completely random initialization and is discouraged. Pass `'loftq'` to use LoftQ initialization. Passing + 'pissa' results in the initialization of PiSSA, which converge more rapidly than LoRA and ultimately + achieve superior performance. Moreover, PiSSA reduces the quantization error compared to QLoRA, leading to + further enhancements. Passing 'pissa_niter_[number of iters]' initiates Fast-SVD-based PiSSA + initialization, where [number of iters] indicates the number of subspace iterations to perform FSVD, and + must be a nonnegative integer. When the [number of iters] is set to 16, it can complete the initialization + of a 7b model within seconds, and the training effect is approximately equivalent to using SVD. For more + information, see Principal Singular values and Singular vectors + Adaptation. layers_to_transform (`Union[List[int], int]`): The layer indices to transform. If a list of ints is passed, it will apply the adapter to the layer indices that are specified in this list. If a single integer is passed, it will apply the transformations on the @@ -155,7 +163,7 @@ class LoraConfig(PeftConfig): "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." }, ) - init_lora_weights: bool | Literal["gaussian", "loftq"] = field( + init_lora_weights: bool | Literal["gaussian", "pissa", "pissa_niter_[number of iters]", "loftq"] = field( default=True, metadata={ "help": ( @@ -163,6 +171,9 @@ class LoraConfig(PeftConfig): "initialization from the reference implementation from Microsoft. Passing 'gaussian' results " "in Gaussian initialization scaled by the LoRA rank for linear and layers. Setting the initialization " "to False leads to completely random initialization and is discouraged." + "Passing 'pissa' results in PiSSA initialization." + "Passing 'pissa_niter_[number of iters]' initiates Fast-SVD-based PiSSA initialization, " + "where [number of iters] indicates the number of subspace iterations to perform fsvd, and must be a nonnegative integer." "Pass `'loftq'` to use LoftQ initialization" ), }, diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 56ab7c4a1a..e5404d8bb4 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -20,6 +20,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch import svd_lowrank from transformers.pytorch_utils import Conv1D from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge @@ -109,7 +110,9 @@ def update_layer( else: self.scaling[adapter_name] = lora_alpha / r - if init_lora_weights == "loftq": + if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"): + self.pissa_init(adapter_name, init_lora_weights) + elif init_lora_weights == "loftq": self.loftq_init(adapter_name) elif init_lora_weights: self.reset_lora_parameters(adapter_name, init_lora_weights) @@ -152,6 +155,42 @@ def reset_lora_parameters(self, adapter_name, init_lora_weights): nn.init.zeros_(self.lora_embedding_A[adapter_name]) nn.init.normal_(self.lora_embedding_B[adapter_name]) + def pissa_init(self, adapter_name, init_lora_weights): + weight = self.get_base_layer().weight + dtype = weight.dtype + if dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise TypeError( + "Please initialize PiSSA under float32, float16, or bfloat16. " + "Subsequently, re-quantize the residual model to help minimize quantization errors." + ) + weight = weight.to(torch.float32) + + if init_lora_weights == "pissa": + # USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel}, + V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False) + Vr = V[:, : self.r[adapter_name]] + Sr = S[: self.r[adapter_name]] + Sr /= self.scaling[adapter_name] + Uhr = Uh[: self.r[adapter_name]] + elif len(init_lora_weights.split("_niter_")) == 2: + Vr, Sr, Ur = svd_lowrank( + weight.data, self.r[adapter_name], niter=int(init_lora_weights.split("_niter_")[-1]) + ) + Sr /= self.scaling[adapter_name] + Uhr = Ur.t() + else: + raise ValueError( + f"init_lora_weights should be 'pissa' or 'pissa_niter_[number of iters]', got {init_lora_weights} instead." + ) + + lora_A = torch.diag(torch.sqrt(Sr)) @ Uhr + lora_B = Vr @ torch.diag(torch.sqrt(Sr)) + self.lora_A[adapter_name].weight.data = lora_A + self.lora_B[adapter_name].weight.data = lora_B + weight = weight.data - self.scaling[adapter_name] * lora_B @ lora_A + weight = weight.to(dtype) + self.get_base_layer().weight.data = weight + def loftq_init(self, adapter_name): from peft.utils.loftq_utils import loftq_init diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 6b3fbc6a69..d126720a28 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -41,6 +41,7 @@ ModulesToSaveWrapper, _freeze_adapter, _get_submodules, + get_peft_model_state_dict, get_quantization_config, ) from peft.utils.merge_utils import dare_linear, dare_ties, magnitude_prune, task_arithmetic, ties @@ -829,3 +830,42 @@ def unload(self) -> torch.nn.Module: model. """ return self._unload_and_optionally_merge(merge=False) + + def subtract_pissa_init( + self, output_state_dict: dict[str, torch.Tensor], adapter_name: str = "pissa_init", kwargs=None + ): + """ + This function can calculate the updates of the PiSSA by comparing the parameters of the PiSSA adapter in + `output_state_dict` with the initial values of PiSSA in `adapter_name`, thus converting PiSSA to LoRA. + """ + for name, param in self.model.named_parameters(): + if ( + param.data.dtype != torch.float32 + and param.data.dtype != torch.float16 + and param.data.dtype != torch.bfloat16 + ): + warnings.warn( + r"Note that Quant(W_res) + AB != Quant(W) + \Delta(AB); " + "the converted LoRA, when combined with W or Quant(W), may introduce a certain gap in the fine-tuned model. " + "Therefore, we recommend directly using the Quant(W_res) in conjunction with the PiSSA adapter. " + ) + pissa_init_state_dict = get_peft_model_state_dict( + self, + state_dict=kwargs.get("state_dict", None), + adapter_name=adapter_name, + ) + tensors_lora = {} + for name in output_state_dict.keys(): + ## W = W^{res} + A_0 \times B_0, + ## W + \Delta W = W^{res} + A \times B, + ## \Delta W = A \times B - A_0 \times B_0 = [A | A_0] \times [B | -B_0]^T = A'B'. + if "lora_A" in name: + tensors_lora[name] = torch.cat( + [output_state_dict[name], pissa_init_state_dict[".".join(name.split(".")[1:])]], dim=0 + ) + elif "lora_B" in name: + tensors_lora[name] = torch.cat( + [output_state_dict[name], -pissa_init_state_dict[".".join(name.split(".")[1:])]], dim=1 + ) + + return tensors_lora diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index f2c2ae63d7..9e7bafc981 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -16,6 +16,7 @@ import os import tempfile import unittest +from copy import deepcopy from dataclasses import dataclass from typing import Any, Dict, List, Union @@ -54,6 +55,7 @@ replace_lora_weights_loftq, ) from peft.utils import SAFETENSORS_WEIGHTS_NAME +from peft.utils.loftq_utils import NFQuantizer from .testing_utils import ( require_aqlm, @@ -1435,6 +1437,208 @@ def test_offload_merge(self): assert torch.allclose(post_unload_merge_olayer, pre_merge_olayer) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU") +class TestPiSSA: + r""" + Tests for PiSSA to ensure that it reduces the quantization error compared to normal LoRA quantization. + """ + + # The error factor indicates by how much the quantization error should be decreased when using PiSSA compared to + # quantization without PiSSA. Thus 1.03 means that the error should be decreased by 3% at least. This is a very + # conservative value to prevent flakiness, in practice most gains are > 1.5 + error_factor = 1.03 + + def quantize_model(self, model, num_bits=4, device="cuda"): + # Quantize the `weight.data` of the linear layer in the model to `num_bits` and store it with full precision. + quantizer = NFQuantizer(num_bits=num_bits, device=device, method="normal", block_size=64) + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear) and "lm_head" not in name: + quantized_weight, max_abs, shape = quantizer.quantize_block(module.weight.data.to(device)) + module.weight.data = quantizer.dequantize_block(quantized_weight, max_abs, shape) + return model + + def nuclear_norm(self, base_model, quantized_model): + # Calculate the nuclear norm (sum of singular values) of the error matrices between the `quantized_model` and the `base_model`. + error_list = [] + for name, module in base_model.named_modules(): + if isinstance(module, torch.nn.Linear) and "lm_head" not in name: + quant_module = quantized_model.get_submodule(name) + error_list.append(torch.linalg.svdvals(module.weight.data - quant_module.weight.data).sum()) + return torch.Tensor(error_list).sum() + + def get_errors( + self, + tmp_path, + bits=4, + device="cuda", + model_id="hf-internal-testing/tiny-random-BloomForCausalLM", + ): + # Comparing the quantized LoRA model to the base model, vs the PiSSA quantized model to the base model. + # We expect the PiSSA quantized model to have less error than the normal LoRA quantized model. + + cls = AutoModelForSeq2SeqLM if "t5" in str(model_id) else AutoModelForCausalLM + base_model = cls.from_pretrained(model_id).eval().to(device) + task_type = TaskType.SEQ_2_SEQ_LM if base_model.config.is_encoder_decoder else TaskType.CAUSAL_LM + + # logits from the normal quantized LoRA model + target_modules = "all-linear" if task_type != TaskType.SEQ_2_SEQ_LM else ["o", "k", "wi", "q", "v"] + lora_config = LoraConfig(task_type=task_type, target_modules=target_modules) + + qlora_model = self.quantize_model(cls.from_pretrained(model_id).eval().to(device), bits, device) + qlora_model = get_peft_model( + qlora_model, + lora_config, + ) + qlora_model = qlora_model.merge_and_unload() + qlora_error = self.nuclear_norm(base_model, qlora_model) + del qlora_model + gc.collect() + torch.cuda.empty_cache() + + # logits from quantized LoRA model using PiSSA + lora_config = LoraConfig( + task_type=task_type, + init_lora_weights="pissa", + target_modules=target_modules, + ) + pissa_model = cls.from_pretrained(model_id).eval().to(device) + pissa_model = get_peft_model(pissa_model, lora_config) + + # save LoRA weights, they should be initialized such that they minimize the quantization error + pissa_model.base_model.peft_config["default"].init_lora_weights = True + pissa_model.save_pretrained(tmp_path / "pissa_model") + + pissa_model = pissa_model.unload() + pissa_model.save_pretrained(tmp_path / "residual_model") + + del pissa_model + gc.collect() + torch.cuda.empty_cache() + + # now load quantized model and apply PiSSA-initialized weights on top + qpissa_model = self.quantize_model( + cls.from_pretrained(tmp_path / "residual_model").eval().to(device), bits, device + ) + qpissa_model = PeftModel.from_pretrained(qpissa_model, tmp_path / "pissa_model") + qpissa_model = qpissa_model.merge_and_unload() + qpissa_error = self.nuclear_norm(base_model, qpissa_model) + del qpissa_model + gc.collect() + torch.cuda.empty_cache() + + assert qlora_error > 0.0 + assert qpissa_error > 0.0 + + # next, check that PiSSA quantization errors are smaller than LoRA errors by a certain margin + assert qpissa_error < (qlora_error / self.error_factor) + + @pytest.mark.parametrize("device", ["cuda", "cpu"]) + def test_bloomz_pissa_4bit(self, device, tmp_path): + # In this test, we compare the logits of the base model, the quantized LoRA model, and the quantized model + # using PiSSA. When quantizing, we expect a certain level of error. However, we expect the PiSSA quantized + # model to have less error than the normal LoRA quantized model. Note that when using normal LoRA, the + # quantization error is simply the error from quantization without LoRA, as LoRA is a no-op before training. + # We still apply LoRA for the test for consistency. + + self.get_errors(bits=4, device=device, tmp_path=tmp_path) + + @pytest.mark.parametrize("device", ["cuda", "cpu"]) + def test_bloomz_pissa_8bit(self, device, tmp_path): + # Same test as test_bloomz_pissa_4bit but with 8 bits. + self.get_errors(bits=8, device=device, tmp_path=tmp_path) + + @pytest.mark.parametrize("device", ["cuda", "cpu"]) + def test_t5_pissa_4bit(self, device, tmp_path): + self.get_errors(bits=4, device=device, model_id="t5-small", tmp_path=tmp_path) + + @pytest.mark.parametrize("device", ["cuda", "cpu"]) + def test_t5_pissa_8bit(self, device, tmp_path): + self.get_errors(bits=8, device=device, model_id="t5-small", tmp_path=tmp_path) + + @require_bitsandbytes + def test_lora_pissa_conversion_same_output_after_loading_with_quantization(self, tmp_path): + # A copy of the test `test_lora_pissa_conversion_same_output_after_loading` in peft/tests/test_initialization.py, + # that would fail if bitsandbytes quantization is used because Quant(W_res) + AB !=Quant(W) + \Delta(AB). + import bitsandbytes as bnb + + torch.manual_seed(0) + data = torch.rand(10, 1000).to("cuda") + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + # choose a large weight so that averages are close to expected values + self.linear = torch.nn.Linear(1000, 1000) + self.embed = torch.nn.Embedding(1000, 1000) + self.conv2d = torch.nn.Conv2d(100, 100, 3) + + def forward(self, x): + x_int = (100 * x).int() + x_4d = x.flatten().reshape(1, 100, 10, 10) + return self.linear(x), self.embed(x_int), self.conv2d(x_4d) + + model = MyModule().to("cuda") + output_base = model(data)[0] + + config = LoraConfig(init_lora_weights="pissa", target_modules=["linear"], r=8) + peft_model = get_peft_model(deepcopy(model), config) + # save the initial model + peft_model.peft_config["default"].init_lora_weights = True + peft_model.save_pretrained(tmp_path / "init-model") + peft_model = peft_model.unload() + torch.save(peft_model.state_dict(), tmp_path / "residual-model") + del peft_model + + # create 4bit base model + base_model = deepcopy(model) + base_model.load_state_dict(torch.load(tmp_path / "residual-model")) + # sanity check: the base model weights were indeed changed + tol = 1e-06 + assert not torch.allclose(model.linear.weight, base_model.linear.weight, atol=tol, rtol=tol) + # quantize the linear layer + linear4bit = bnb.nn.Linear4bit(base_model.linear.in_features, base_model.linear.out_features) + linear4bit.load_state_dict(base_model.linear.state_dict()) + linear4bit.to(0) + base_model.linear = linear4bit + peft_model = PeftModel.from_pretrained(deepcopy(base_model), tmp_path / "init-model") + output_quantized_pissa = peft_model(data)[0] + # sanity check + tol = 1e-06 + assert not torch.allclose(output_base, output_quantized_pissa, atol=tol, rtol=tol) + + # modify the weights, or else the adapter performs an identity transformation + peft_model.base_model.linear.lora_B["default"].weight.data *= 2.0 + output_finetuned_pissa = peft_model(data)[0] + # sanity check + tol = 1e-06 + assert not torch.allclose(output_quantized_pissa, output_finetuned_pissa, atol=tol, rtol=tol) + + # save the model normally + peft_model.save_pretrained(tmp_path / "pissa-model") + model_loaded = PeftModel.from_pretrained(deepcopy(base_model), tmp_path / "pissa-model") + output_loaded = model_loaded(data)[0] + + assert torch.allclose(output_finetuned_pissa, output_loaded, atol=tol, rtol=tol) + # sanity check: ranks should still be 8 as initially + assert model_loaded.peft_config["default"].r == 8 + assert model_loaded.base_model.model.linear.lora_A["default"].weight.shape[0] == 8 + + # save the model with conversion + peft_model.save_pretrained(tmp_path / "pissa-model-converted", convert_pissa_to_lora=tmp_path / "init-model") + model_converted = PeftModel.from_pretrained(deepcopy(model), tmp_path / "pissa-model-converted") + output_converted = model_converted(data)[0] + + # rank should be double of what it was initially + assert model_converted.peft_config["default"].r == 16 + assert model_converted.base_model.model.linear.lora_A["default"].weight.shape[0] == 16 + # base model weights should be the same as the initial model + assert torch.allclose( + model.linear.weight, model_converted.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol + ) + # This check is expected to fail when using bnb + assert not torch.allclose(output_finetuned_pissa, output_converted, atol=tol, rtol=tol) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU") class TestLoftQ: r""" diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 8e0e091b56..589e046ad6 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -13,13 +13,14 @@ # limitations under the License. import re +from copy import deepcopy import pytest import torch from scipy import stats from torch import nn -from peft import AdaLoraConfig, LoraConfig, PromptTuningConfig, VeraConfig, get_peft_model +from peft import AdaLoraConfig, LoraConfig, PeftModel, PromptTuningConfig, VeraConfig, get_peft_model from peft.utils import infer_device @@ -253,6 +254,65 @@ def test_lora_scaling_default(self): assert model.embed.scaling["default"] == expected_scaling assert model.conv2d.scaling["default"] == expected_scaling + def test_lora_pissa_linear_init_default(self, data): + model = self.get_model() + output = model(data)[0] + + config = LoraConfig(init_lora_weights="pissa", target_modules=["linear"]) + peft_model = get_peft_model(deepcopy(model), config) + assert torch.allclose(output, peft_model(data)[0], atol=1e-06) + + config = LoraConfig(init_lora_weights="pissa_niter_16", target_modules=["linear"]) + peft_model = get_peft_model(deepcopy(model), config) + assert torch.allclose(output, peft_model(data)[0], atol=1e-06) + + def test_lora_pissa_conversion_same_output_after_loading(self, data, tmp_path): + model = self.get_model() + output_base = model(data)[0] + + config = LoraConfig(init_lora_weights="pissa", target_modules=["linear"], r=8) + peft_model = get_peft_model(deepcopy(model), config) + # save the initial model + peft_model.peft_config["default"].init_lora_weights = True + peft_model.save_pretrained(tmp_path / "init-model") + peft_model.peft_config["default"].init_lora_weights = "pissa" + + # modify the weights, or else the adapter performs an identity transformation + peft_model.base_model.linear.lora_B["default"].weight.data *= 2.0 + output_pissa = peft_model(data)[0] + + # sanity check + tol = 1e-06 + assert not torch.allclose(output_base, output_pissa, atol=tol, rtol=tol) + + # save the model normally + peft_model.save_pretrained(tmp_path / "pissa-model") + model_loaded = PeftModel.from_pretrained(deepcopy(model), tmp_path / "pissa-model") + output_loaded = model_loaded(data)[0] + + assert torch.allclose(output_pissa, output_loaded, atol=tol, rtol=tol) + # sanity check: ranks should still be 8 as initially + assert model_loaded.peft_config["default"].r == 8 + assert model_loaded.base_model.model.linear.lora_A["default"].weight.shape[0] == 8 + # sanity check: the base model weights were indeed changed + assert not torch.allclose( + model.linear.weight, model_loaded.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol + ) + + # save the model with conversion + peft_model.save_pretrained(tmp_path / "pissa-model-converted", convert_pissa_to_lora=tmp_path / "init-model") + model_converted = PeftModel.from_pretrained(deepcopy(model), tmp_path / "pissa-model-converted") + output_converted = model_converted(data)[0] + + assert torch.allclose(output_pissa, output_converted, atol=tol, rtol=tol) + # rank should be double of what it was initially + assert model_converted.peft_config["default"].r == 16 + assert model_converted.base_model.model.linear.lora_A["default"].weight.shape[0] == 16 + # base model weights should be the same as the initial model + assert torch.allclose( + model.linear.weight, model_converted.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol + ) + def test_lora_rslora_scaling(self): # default is True torch.manual_seed(0)