Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding PiSSA as an optional initialization method of LoRA #1626

Merged
merged 112 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
112 commits
Select commit Hold shift + click to select a range
a320ff7
add pisa_init
MFX12138 Mar 14, 2024
13155fe
Update layer.py
fxmeng Mar 15, 2024
89bc355
Update layer.py
fxmeng Mar 15, 2024
9f00cb9
Update layer.py
fxmeng Mar 15, 2024
05edb1d
svd_lowrank and bitsandbytes
fxmeng Mar 25, 2024
58b2fc2
Update layer.py
fxmeng Mar 27, 2024
f105336
Update layer.py
fxmeng Apr 2, 2024
ec15caf
Merge branch 'huggingface:main' into main
fxmeng Apr 2, 2024
3a9e9bd
make style
fxmeng Apr 9, 2024
d1bb5dd
support scaling != 1
fxmeng Apr 11, 2024
823c166
usv^t=w->vsu^t=w^t
fxmeng Apr 11, 2024
65286be
from torch import svd_lowrank
fxmeng Apr 11, 2024
bae0da9
raise init_lora_weights error
fxmeng Apr 11, 2024
e912970
dequantize_bnb_weight
fxmeng Apr 11, 2024
c679a50
use pissa without changing base model
fxmeng Apr 11, 2024
7fabf84
comment on \Delta AB
fxmeng Apr 12, 2024
bc2b566
save_initial_pissa_and_residual_model
fxmeng Apr 12, 2024
0b8aea6
finetune in 4 bit
fxmeng Apr 12, 2024
cf5eac3
finetune in 4 bit
fxmeng Apr 13, 2024
decaea6
finetune in 4 bit
fxmeng Apr 13, 2024
5dd82e3
finetune in 4 bit
fxmeng Apr 13, 2024
d2c8aab
add pissa_utils
fxmeng Apr 15, 2024
ce2948a
add pissa_utils
fxmeng Apr 15, 2024
0e15216
add pissa_utils
fxmeng Apr 15, 2024
3488f69
add pissa_utils
fxmeng Apr 15, 2024
814f1c7
add pissa_utils
fxmeng Apr 15, 2024
db99cb3
add pissa_utils
fxmeng Apr 15, 2024
374e309
add pissa_utils
fxmeng Apr 15, 2024
fb75ab3
make style
fxmeng Apr 15, 2024
47f8534
make style
fxmeng Apr 15, 2024
9a5067a
svd first quantization
fxmeng Apr 15, 2024
079d316
optinal path
fxmeng Apr 15, 2024
2edfb07
optinal path
fxmeng Apr 15, 2024
8408959
optinal path
fxmeng Apr 15, 2024
528e49f
optinal path
fxmeng Apr 15, 2024
e3458c8
optinal path
fxmeng Apr 15, 2024
f543fc1
optinal path
fxmeng Apr 15, 2024
11d9843
optinal path
fxmeng Apr 15, 2024
1d1e162
optinal path
fxmeng Apr 15, 2024
5565b9b
optinal path
fxmeng Apr 15, 2024
bdfd306
quantization_config
fxmeng Apr 15, 2024
bd7c814
quantization_config
fxmeng Apr 15, 2024
3486fad
quantization_config
fxmeng Apr 15, 2024
f7d0453
pissa_finetuning
fxmeng Apr 15, 2024
c2c8ac3
pissa_finetuning
fxmeng Apr 15, 2024
4d9b74b
lora config
fxmeng Apr 15, 2024
ae0c3e6
lora config
fxmeng Apr 15, 2024
606a692
lora config
fxmeng Apr 15, 2024
0e961ab
readme
fxmeng Apr 15, 2024
3893276
readme
fxmeng Apr 15, 2024
57f65ff
lora.md
fxmeng Apr 15, 2024
33e1a8c
make style
fxmeng Apr 15, 2024
51161a5
make style
fxmeng Apr 15, 2024
52110a5
readme
fxmeng Apr 15, 2024
30e7673
readme
fxmeng Apr 15, 2024
b76f82b
readme
fxmeng Apr 15, 2024
3917777
test_init
fxmeng Apr 15, 2024
2dc3ddd
test_init
fxmeng Apr 15, 2024
f0642c6
test_init
fxmeng Apr 15, 2024
7b8af8e
test
fxmeng Apr 15, 2024
f8889ee
TrainingArguments
fxmeng Apr 16, 2024
83deb5a
Merge branch 'main' of https://github.com/fxmeng/peft
fxmeng Apr 16, 2024
41ed2b3
TrainingArguments
fxmeng Apr 16, 2024
c05cdae
TrainingArguments
fxmeng Apr 16, 2024
c53bd85
save_init_and_ft
fxmeng Apr 20, 2024
3cbacb0
unload
fxmeng Apr 20, 2024
c4222ac
explain W^{res}
fxmeng Apr 20, 2024
ae69146
save_and_load
fxmeng Apr 22, 2024
082c120
save_and_load
fxmeng Apr 22, 2024
efbd5aa
make style
fxmeng Apr 22, 2024
26feaa2
save_as_lora
fxmeng Apr 22, 2024
a3613c5
save_and_load
fxmeng Apr 22, 2024
07eda02
save_and_load
fxmeng Apr 22, 2024
c9b683b
readme
fxmeng Apr 22, 2024
694e6de
readme
fxmeng Apr 22, 2024
629941d
readme
fxmeng Apr 22, 2024
6fbc20f
readme
fxmeng Apr 22, 2024
412effb
test quantization error
fxmeng Apr 22, 2024
bfb8e84
test
fxmeng Apr 22, 2024
964aee2
test
fxmeng Apr 22, 2024
0b47f9a
test
fxmeng Apr 23, 2024
1f44f8b
lora_model.subtract_pissa_init
fxmeng Apr 28, 2024
8608303
convert_pissa_to_lora
fxmeng Apr 28, 2024
49d1257
initial_adapter
fxmeng Apr 28, 2024
83ad3bd
initial_adapter
fxmeng Apr 28, 2024
c575450
test
fxmeng Apr 28, 2024
29f987c
make style
fxmeng Apr 28, 2024
e7c8fe6
Update tests/test_initialization.py
fxmeng May 1, 2024
3ded9bd
Update src/peft/tuners/lora/model.py
fxmeng May 1, 2024
1ff7d94
Update src/peft/tuners/lora/model.py
fxmeng May 1, 2024
dfa2179
Update docs/source/developer_guides/lora.md
fxmeng May 1, 2024
0234aee
Update src/peft/peft_model.py
fxmeng May 1, 2024
06505cc
save_pissa_as_lora as a sub-method
fxmeng May 1, 2024
fd01b79
convert_pissa_to_lora
fxmeng May 1, 2024
da49fb2
test quant
fxmeng May 1, 2024
9f25bc1
warning
fxmeng May 1, 2024
95bfbb1
preprocess
fxmeng May 3, 2024
45636e5
test_lora_pissa_linear_init_default
fxmeng May 3, 2024
bd76657
quantized_convert_errors
fxmeng May 3, 2024
61efdf7
solve conflicts
fxmeng May 3, 2024
e2a943d
deepcopy
fxmeng May 4, 2024
005087b
test_lora_pissa_conversion_same_output_after_loading_with_quantization
fxmeng May 4, 2024
52c4620
nuclear_norm_error
fxmeng May 6, 2024
274001c
nuclear_norm_error
fxmeng May 6, 2024
e0fd841
nuclear_norm_error
fxmeng May 6, 2024
c803300
fix a bug
fxmeng May 7, 2024
b1f0038
fix a bug
fxmeng May 7, 2024
232e3d2
ruff
fxmeng May 7, 2024
9482167
Merge branch 'main' of https://github.com/fxmeng/peft
fxmeng May 7, 2024
74f470b
doc-builder
MFX12138 May 8, 2024
8a6e429
quantization_test
MFX12138 May 8, 2024
60596ec
更新 test_gpu_examples.py
fxmeng May 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/source/developer_guides/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,20 @@ from peft import LoraConfig
config = LoraConfig(init_lora_weights=False, ...)
```

### PiSSA
[PiSSA](https://arxiv.org/abs/2404.02948) initializes the LoRA adapter using the principal singular values and singular vectors. This straightforward modification allows PiSSA to converge more rapidly than LoRA and ultimately attain superior performance. Moreover, PiSSA reduces the quantization error compared to QLoRA, leading to further enhancements.

Configure the initialization method to "pissa", which may take several minutes to execute SVD on the pre-trained model:
```python
from peft import LoraConfig
config = LoraConfig(init_lora_weights="pissa", ...)
```
Alternatively, execute fast SVD, which takes only a few seconds. The number of iterations determines the trade-off between the error and computation time:
```python
lora_config = LoraConfig(init_lora_weights="pissa_niter_[number of iters]", ...)
```
For detailed instruction on using PiSSA, please follow [these instructions](https://github.com/fxmeng/peft/tree/main/examples/pissa_finetuning).

### LoftQ

#### Standard approach
Expand Down
131 changes: 131 additions & 0 deletions examples/pissa_finetuning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# PiSSA: Principal Singular values and Singular vectors Adaptation
## Introduction ([Paper](https://arxiv.org/abs/2404.02948), [code](https://github.com/GraphPKU/PiSSA))
PiSSA represents a matrix $W\in\mathbb{R}^{m\times n}$ within the model by the product of two trainable matrices $A \in \mathbb{R}^{m\times r}$ and $B \in \mathbb{R}^{r\times n}$, where $r \ll \min(m, n)$, plus a residual matrix $W^{res}\in\mathbb{R}^{m\times n}$ for error correction. Singular value decomposition (SVD) is employed to factorize $W$, and the principal singular values and vectors of $W$ are utilized to initialize $A$ and $B$. The residual singular values and vectors initialize the residual matrix $W^{res}$, which keeps frozen during fine-tuning. This straightforward modification allows PiSSA to converge more rapidly than LoRA and ultimately attain superior performance. Moreover, PiSSA reduces the quantization error compared to QLoRA, leading to further enhancements.

## Quick Start
```python
import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer
from datasets import load_dataset

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token_id = tokenizer.eos_token_id
lora_config = LoraConfig(
# init_lora_weights="pissa", # Configure the initialization method to "pissa", which may take several minutes to execute SVD on the pre-trained model.
init_lora_weights="pissa_niter_4", # Initialize the PiSSA with fast SVD, which completes in just a few seconds.
)
peft_model = get_peft_model(model, lora_config)

peft_model.print_trainable_parameters()

dataset = load_dataset("imdb", split="train[:1%]")

trainer = SFTTrainer(
model=peft_model,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=128,
tokenizer=tokenizer,
)
trainer.train()
peft_model.save_pretrained("pissa-llama-2-7b")
```
When utilizing fast SVD, reducing the rank and the number of iterations decreases the time required. However, this approach leads to higher errors in the computed matrices $A$ and $B$. To preserve the model's initial capabilities, we calculate the residual matrix by $W^{res} = W - BA$. Even with potential errors in $A$ and $B$, the sum of $W^{res}$ and $BA$ accurately equals $W$.


To utilize the fine-tuned PiSSA modules, simply run the following command:
```python
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto"
)
# Performs SVD again to initialize the residual model and loads the state_dict of the fine-tuned PiSSA modules.
peft_model = PeftModel.from_pretrained(model, "pissa-llama-2-7b")
```

## Advanced Usage

### Access the preprocessed models
We recommend downloading decomposed models directly from the [Hugging Face Collections](https://huggingface.co/collections/fxmeng/pissa-661ce700721235e542a5d7a8) instead of performing SVD every time.
If the existing models do not meet your needs, apply PiSSA initialization to a pre-trained model and store the decomposed model locally:
```bash
python preprocess.py \
--base_model_name_or_path meta-llama/Llama-2-7b-hf \
--init_lora_weights pissa \
--output_dir pissa-llama-2-7b-r32-alpha-32 \
--lora_r 32 \
--lora_alpha 32 \
--lora_dropout 0 \
--bits bf16
```

### Convert PiSSA to LoRA
The main advantage of PiSSA is concentrated during the training phase. For a trained PiSSA adapter, we recommend converting it equivalently to the LoRA adapter for using and sharing.
```python
# The fine-tuned matrices $A$ and $B$ in PiSSA adapter is saved and should be combined with the residual model.
peft_model.save_pretrained(output_dir)
# Given the matrices $A_0$ and $B_0$, initialized by PiSSA and untrained, and the trained matrices $A$ and $B$,
# we can convert these to LoRA by setting $\Delta W = A \times B - A_0 \times B_0 = [A \mid A_0] \times [B \mid -B_0]^T = A'B'$.
peft_model.save_pretrained(output_dir, convert_pissa_to_lora="pissa_init")

```
This conversion enables the loading of LoRA on top of a standard base model:

```python
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto"
)
# No SVD is performed during this step, and the base model remains unaltered.
peft_model = PeftModel.from_pretrained(model, "pissa-llama-2-7b-lora")
```
Utilizing the converted LoRA does not require modifying the parameters of the base model. When multiple converted LoRAs are needed simultaneously, each adapter operates independently without interference, allowing for the adapters to be freely deleted or added.



### Fine-tune in 4-bit or 8-bit
If quantization fine-tuning is desired, it is necessary to first decompose the original model at full precision and then reload the residual model in either 4-bit or 8-bit configurations.
```shell
python pissa_finetuning.py \
--residual_model_name_or_path fxmeng/pissa-llama-2-7b-r16-alpha-16 \
--output_dir output/pissa-llama-2-7b-r16-alpha-16-metamath-10k \
--bits nf4 \
--data_path meta-math/MetaMathQA \
--dataset_split train[:100000] \
--dataset_field query response \
--bf16 True \
--num_train_epochs 1 \
--per_device_train_batch_size 32 \
--gradient_accumulation_steps 4 \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 1 \
--logging_steps 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--tf32 True \
--report_to none \
--convert_pissa_to_lora
```

This approach ensures the preservation of high-frequency, out-of-distribution parameters in the low-rank PiSSA modules, resulting in reduced quantization errors during the quantization of the residual model.

## Citation
```
@article{meng2024pissa,
title={PiSSA: Principal Singular Values and Singular Vectors Adaptation of Large Language Models},
author={Meng, Fanxu and Wang, Zhaohui and Zhang, Muhan},
journal={arXiv preprint arXiv:2404.02948},
year={2024}
}
```
156 changes: 156 additions & 0 deletions examples/pissa_finetuning/pissa_finetuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from dataclasses import dataclass, field
from typing import List, Optional

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
from trl import SFTTrainer

from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training


@dataclass
class TrainingArguments(TrainingArguments):
# model configs
base_model_name_or_path: Optional[str] = field(
default=None, metadata={"help": "The name or path of the fp32/16 base model."}
)
residual_model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "The name or path of the fp32/16 residual model. (`['fxmeng/pissa-llama-2-7b-r16-alpha-16']`)"
},
)
bits: str = field(default="fp32", metadata={"help": "(`['fp4', 'nf4', 'int8', 'bf16', 'fp16', fp32]`)"})
init_lora_weights: str = field(default="pissa", metadata={"help": "(`['gaussian', 'pissa', 'pissa_niter_4']`)"})
lora_r: int = field(default=16)
lora_alpha: int = field(default=16)
lora_dropout: float = field(default=0)
convert_pissa_to_lora: bool = field(default=False)
merge_and_save: bool = field(default=False)
# dataset configs
data_path: str = field(default="imdb", metadata={"help": "Path to the training data."})
dataset_split: str = field(default="train[:1%]", metadata={"help": "(`['train', 'test', 'eval']`):"})
dataset_field: List[str] = field(default=None, metadata={"help": "Fields of dataset input and output."})
max_seq_length: int = field(
default=512,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)


parser = HfArgumentParser(TrainingArguments)
script_args = parser.parse_args_into_dataclasses()[0]
print(script_args)

print(f"Load pre-processed residual model in {script_args.bits} bits.")
if script_args.bits in ["nf4", "fp4", "int8"]:
quantization_config = BitsAndBytesConfig(
load_in_4bit=(script_args.bits == "nf4" or script_args.bits == "fp4"),
load_in_8bit=script_args.bits == "int8",
bnb_4bit_quant_type=script_args.bits,
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
res_model = AutoModelForCausalLM.from_pretrained(
script_args.residual_model_name_or_path, quantization_config=quantization_config, low_cpu_mem_usage=True
)
res_model = prepare_model_for_kbit_training(res_model)
print("Wrapping the residual model with PiSSA.")
peft_model = PeftModel.from_pretrained(
res_model, script_args.residual_model_name_or_path, subfolder="pissa_init", is_trainable=True
)
tokenizer = AutoTokenizer.from_pretrained(script_args.residual_model_name_or_path)

elif script_args.residual_model_name_or_path is not None:
res_model = AutoModelForCausalLM.from_pretrained(
script_args.residual_model_name_or_path,
torch_dtype=(
torch.float16
if script_args.bits == "fp16"
else (torch.bfloat16 if script_args.bits == "bf16" else torch.float32)
),
device_map="auto",
)
print("Wrapping the residual model with PiSSA.")
peft_model = PeftModel.from_pretrained(
res_model, script_args.residual_model_name_or_path, subfolder="pissa_init", is_trainable=True
)
tokenizer = AutoTokenizer.from_pretrained(script_args.residual_model_name_or_path)

elif script_args.base_model_name_or_path is not None:
print(
f"No available pre-processed model, manually initialize a PiSSA using {script_args.base_model_name_or_path}."
)
model = AutoModelForCausalLM.from_pretrained(
script_args.base_model_name_or_path,
torch_dtype=(
torch.float16
if script_args.bits == "fp16"
else (torch.bfloat16 if script_args.bits == "bf16" else torch.float32)
),
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name_or_path)
tokenizer.pad_token_id = tokenizer.eos_token_id
lora_config = LoraConfig(
r=script_args.lora_r,
lora_alpha=script_args.lora_alpha,
init_lora_weights=script_args.init_lora_weights,
lora_dropout=script_args.lora_dropout,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
bias="none",
task_type="CAUSAL_LM",
)
peft_model = get_peft_model(model, lora_config)

print(peft_model)
fxmeng marked this conversation as resolved.
Show resolved Hide resolved
peft_model.print_trainable_parameters()

print(f"Training PiSSA with trl on the {script_args.data_path}[{script_args.dataset_split}] dataset.")
dataset = load_dataset(script_args.data_path, split=script_args.dataset_split)
dataset = dataset.map(
lambda example: {
"text": f"### USER: {example[script_args.dataset_field[0]]}\n### ASSISTANT: {example[script_args.dataset_field[1]]}"
}
)

trainer = SFTTrainer(
model=peft_model,
args=script_args,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=script_args.max_seq_length,
tokenizer=tokenizer,
)
trainer.train()
trainer.save_state()
############################## Upon training completion, convert and save PiSSA in LoRA format ##############################
if script_args.convert_pissa_to_lora:
peft_model.save_pretrained(
os.path.join(script_args.output_dir, "pissa_lora"),
convert_pissa_to_lora=os.path.join(script_args.residual_model_name_or_path, "pissa_init"),
)
else:
peft_model.save_pretrained(
os.path.join(script_args.output_dir, "pissa_ft"),
)

if script_args.merge_and_save:
model = peft_model.merge_and_unload()
model.save_pretrained(os.path.join(script_args.output_dir, "pissa_merged"))
tokenizer.save_pretrained(os.path.join(script_args.output_dir, "pissa_merged"))
67 changes: 67 additions & 0 deletions examples/pissa_finetuning/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from peft import LoraConfig, get_peft_model


parser = argparse.ArgumentParser(
description="Merge Adapter to Base Model", help="The name or path of the fp32/16 base model."
)
parser.add_argument("--base_model_name_or_path", type=str, default="bf16")
parser.add_argument("--bits", type=str, default="bf16", choices=["bf16", "fp16", "fp32"])
parser.add_argument(
"--init_lora_weights", type=str, default="pissa", help="(`['pissa', 'pissa_niter_[number of iters]']`)"
)
parser.add_argument("--lora_r", type=int, default=128)
parser.add_argument("--lora_alpha", type=int, default=128)
parser.add_argument("--lora_dropout", type=int, default=0)
script_args = parser.parse_args()
print(script_args)

model = AutoModelForCausalLM.from_pretrained(
script_args.base_model_name_or_path,
torch_dtype=(
torch.float16
if script_args.bits == "fp16"
else (torch.bfloat16 if script_args.bits == "bf16" else torch.float32)
),
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name_or_path)
tokenizer.pad_token_id = tokenizer.eos_token_id
lora_config = LoraConfig(
r=script_args.lora_r,
lora_alpha=script_args.lora_alpha,
init_lora_weights=script_args.init_lora_weights,
lora_dropout=script_args.lora_dropout,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
bias="none",
task_type="CAUSAL_LM",
)
peft_model = get_peft_model(model, lora_config)

# Save PiSSA modules:
peft_model.peft_config["default"].init_lora_weights = True
peft_model.save_pretrained(os.path.join(script_args.output_dir, "pissa_init"))
# Save residual model:
peft_model = peft_model.unload()
peft_model.save_pretrained(script_args.output_dir)
# Save the tokenizer:
tokenizer.save_pretrained(script_args.output_dir)
Loading
Loading