Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding PiSSA as an optional initialization method of LoRA #1626

Merged
merged 112 commits into from
May 15, 2024

Conversation

fxmeng
Copy link
Contributor

@fxmeng fxmeng commented Apr 7, 2024

In paper "https://arxiv.org/pdf/2404.02948.pdf", we introduce a parameter-efficient fine-tuning (PEFT) method, Principal Singular values and Singular vectors Adaptation (PiSSA), which optimizes a significantly reduced parameter space while achieving or surpassing the performance of full-parameter fine-tuning.
image

PiSSA is inspired by Intrinsic SAID, which suggests that pre-trained, over-parametrized models inhabit a space of low intrinsic dimension. Consequently, PiSSA represents a matrix $W\in\mathbb{R}^{m\times n}$ within the model by the product of two trainable matrices $A \in \mathbb{R}^{m\times r}$ and $B \in \mathbb{R}^{r\times n}$, where $r \ll \min(m, n)$, plus a residual matrix $W^{res}\in\mathbb{R}^{m\times n}$ for error correction. Singular value decomposition (SVD) is employed to factorize $W$, and the principal singular values and vectors of $W$ are utilized to initialize $A$ and $B$. The residual singular values and vectors initialize the residual matrix $W^{res}$, which keeps frozen during fine-tuning. Notably, PiSSA shares the same architecture with Low-Rank Adaptation (LoRA), which hypothesizes that changes in model parameters $\Delta W$ form a low-rank matrix. However, LoRA approximates $\Delta W$ through the product of two matrices, $A$, initialized with Gaussian noise, and (B), initialized with zeros, while PiSSA initializes $A$ and $B$ with principal singular values and singular vectors of the original matrix $W$. Given that the principal singular values and vectors capture the essence of a low-rank matrix, PiSSA can better approximate the outcomes of full-parameter fine-tuning at the beginning by changing the essential parts while freezing the "noisy" parts. In comparison, LoRA freezes the original matrix and updates the "noise". This distinction enables PiSSA to convergence much faster than LoRA and also achieve better performance in the end. On five common benchmarks, PiSSA outperforms LoRA on all of them using exactly the same setups except for a different initialization. On GSM8K, Mistral-7B fine-tuned with PiSSA achieves an accuracy of 72.86%, outperforming LoRA's 67.7% by 5.16%.
image
image

Due to the same architecture, PiSSA inherits many of LoRA's advantages, such as parameter efficiency and compatibility with quantization. Leveraging a fast SVD method, the initialization of PiSSA takes only a few seconds, inducing negligible cost of switching LoRA to PiSSA.

@wtmlon
Copy link

wtmlon commented Apr 8, 2024

你好,发现你的Pissa工作中提交的 PR 和原论文有一处不太相同的地方,https://github.com/huggingface/peft/blob/ec15cafd929bef508412848fc4e3bfdba46355d7/src/peft/tuners/lora/layer.py#L178,请问这里的 lora A,lora B 计算为啥和原论文对不上,论文中 lora A 矩阵是 Ur@Sr,但是在代码中却变成了 Sr @ Vr,想知道论文中的实验是基于哪一种方式计算的 Lora A 和 B?

@fxmeng
Copy link
Contributor Author

fxmeng commented Apr 8, 2024

你好,发现你的Pissa工作中提交的 PR 和原论文有一处不太相同的地方,https://github.com/huggingface/peft/blob/ec15cafd929bef508412848fc4e3bfdba46355d7/src/peft/tuners/lora/layer.py#L178,请问这里的 lora A,lora B 计算为啥和原论文对不上,论文中 lora A 矩阵是 Ur@Sr,但是在代码中却变成了 Sr @ Vr,想知道论文中的实验是基于哪一种方式计算的 Lora A 和 B?

您好,这里是因为线性层torch.nn.Linear(in_channel, out_channel)的矩阵维度实际上是转置过的,即W的维度实际上是out_channel X in_channel。正常情况下,需要对W进行转置,进行奇异值分解并对AB初始化后,再进行转置
才能赋值给新插入的线性层。
但是如果把Ur和Vhr的顺序调换一下,就可以避免SVD分解前后的转置操作了。
这两种计算方法是等价的,但是后者的效率更高。

@BenjaminBossan
Copy link
Member

Let me know when this is ready for review. Also, please run make style on the code.

@fxmeng
Copy link
Contributor Author

fxmeng commented Apr 9, 2024

Let me know when this is ready for review. Also, please run make style on the code.

I've run make style on the code, following your advice, and believe it's now ready for review. Please let me know if there's anything else needed.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for providing this useful method. The results from the paper are very promising, so I would be happy to have this added to PEFT.

I added a couple of comments to the PR. Please check them out.

There is one very big issue still, which is about saving and loading the model (check out my last comment). This is a very similar issue to what we have with LoftQ initialization (which itself is very similar to PiSSA IIUC). That has to be addressed most urgently or else users cannot correctly load LoRA weights trained with PiSSA.

Edit: It seems I have been mistaken here. But let's add a test that involves saving and loading to ensure that this works correctly.

Let's also add some documentation. This is important so that users can discover this new method and understand what it does and when to use it. For this, could you please:

  1. Add a section to the LoRA docs here: . Give a short explanation with a link to the paper and provide a code snippet. Don't forget to mention the possibility to pass multiple iterations.
  2. The docstring of the LoraConfig here. Extend the type annotation to include Literal["gaussian", "loftq", "pissa"] and add a sentence or two to the description. Don't forget to mention the possibility to pass multiple iterations.
  3. The help of LoraConfig: You can use the same explanation as above.

Furthermore, let's add some testing in this file. Specifically, let's check init_lora_weights="pissa" and also having multiple iterations, with bnb, as well as checking the errors. If you need help with writing the tests, let me know.

src/peft/tuners/lora/layer.py Outdated Show resolved Hide resolved
src/peft/tuners/lora/layer.py Outdated Show resolved Hide resolved
src/peft/tuners/lora/layer.py Outdated Show resolved Hide resolved
src/peft/tuners/lora/layer.py Outdated Show resolved Hide resolved
src/peft/tuners/lora/layer.py Outdated Show resolved Hide resolved
@BenjaminBossan
Copy link
Member

Hey @fxmeng, after some internal discussion, we had some concerns about this line:

https://github.com/huggingface/peft/pull/1626/files#diff-24a141c266b7b714ae8fcc470f31bc283f7b0f5a671bbf6d5f092741fc374104R194

The issue here is that the model base weights are modified when initializing with PiSSA. This can have side-effects for the user. For example, when they disable all adapters, they would normally expect the model output to be the same as the base model, but here it's not the case. Or when a user loads a PiSSA-LoRA adapter and another LoRA adapter, that other adapter will not work correctly because it was trained on the unmodified base weight.

It would be possible to add a lot of checks everywhere and raise errors if we detect that PiSSA is used and a user wants to disable the adapter or switch to another adapter. But that's very complicated and error prone, and at the end of the day also not very user friendly. What I wonder is: How much performance would we lose if we keep the base weights unmodified? If this works almost as well, maybe we can keep the base weights and not have to add all those complications. Did you run experiments to test that?

@fxmeng
Copy link
Contributor Author

fxmeng commented Apr 11, 2024

Hey @fxmeng, after some internal discussion, we had some concerns about this line:

https://github.com/huggingface/peft/pull/1626/files#diff-24a141c266b7b714ae8fcc470f31bc283f7b0f5a671bbf6d5f092741fc374104R194

The issue here is that the model base weights are modified when initializing with PiSSA. This can have side-effects for the user. For example, when they disable all adapters, they would normally expect the model output to be the same as the base model, but here it's not the case. Or when a user loads a PiSSA-LoRA adapter and another LoRA adapter, that other adapter will not work correctly because it was trained on the unmodified base weight.

It would be possible to add a lot of checks everywhere and raise errors if we detect that PiSSA is used and a user wants to disable the adapter or switch to another adapter. But that's very complicated and error prone, and at the end of the day also not very user friendly. What I wonder is: How much performance would we lose if we keep the base weights unmodified? If this works almost as well, maybe we can keep the base weights and not have to add all those complications. Did you run experiments to test that?

Hi @BenjaminBossan,
It's really a good question.
In fact, we can convert a trained PiSSA into LoRA without any loss in performance, allowing the sharing of the transformed LoRA to enjoy the training efficiency improvements brought by PiSSA without the need for any special checks. We provide a function for this conversion in this code (https://github.com/fxmeng/peft/blob/c679a504d0fe581b0ea213f121f4918c875c8c43/examples/pissa_finetuning/convert_pissa_to_lora.py), along with a complete process example. We are compiling more tips on using PiSSA into a document.

@BenjaminBossan
Copy link
Member

In fact, we can convert a trained PiSSA into LoRA without any loss in performance, allowing the sharing of the transformed LoRA to enjoy the training efficiency improvements brought by PiSSA without the need for any special checks.

Oh nice, thanks, I think it would be great to integrate this functionality into PEFT. To be sure I understand: We first load the base model, then initialize the PEFT model with PiSSA turned on, then train the PiSSA-LoRA adapter, then we can convert it to a normal LoRA adapters and share it with others. When someone loads this converted PiSSA-LoRA adapter, it works like a normal LoRA adapter, so no need to adjust the base model weights. This means we can disable it, combine it with other LoRA adapters, etc. Is that right?

Regarding the linked script, can you explain this line (or refer to the part of the paper that explains it):

https://github.com/fxmeng/peft/blob/c679a504d0fe581b0ea213f121f4918c875c8c43/examples/pissa_finetuning/convert_pissa_to_lora.py#L26

We are compiling more tips on using PiSSA into a document.

Looking forward to this.

@fxmeng
Copy link
Contributor Author

fxmeng commented Apr 12, 2024

In fact, we can convert a trained PiSSA into LoRA without any loss in performance, allowing the sharing of the transformed LoRA to enjoy the training efficiency improvements brought by PiSSA without the need for any special checks.

Oh nice, thanks, I think it would be great to integrate this functionality into PEFT. To be sure I understand: We first load the base model, then initialize the PEFT model with PiSSA turned on, then train the PiSSA-LoRA adapter, then we can convert it to a normal LoRA adapters and share it with others. When someone loads this converted PiSSA-LoRA adapter, it works like a normal LoRA adapter, so no need to adjust the base model weights. This means we can disable it, combine it with other LoRA adapters, etc. Is that right?

Regarding the linked script, can you explain this line (or refer to the part of the paper that explains it):

https://github.com/fxmeng/peft/blob/c679a504d0fe581b0ea213f121f4918c875c8c43/examples/pissa_finetuning/convert_pissa_to_lora.py#L26

We are compiling more tips on using PiSSA into a document.

Looking forward to this.

We have explained the line you mentioned on https://github.com/fxmeng/peft/blob/7fabf84375092cc9b2d870188953602a02b9d8db/examples/pissa_finetuning/convert_pissa_to_lora.py#L26. We will include detailed instructions for converting PiSSA to LoRA in our document and the next draft of the paper. Additionally, we have fixed a bug and conducted tests on the combination of the converted LoRA and the base model to ensure its accuracy.

@BenjaminBossan
Copy link
Member

When I run your test above, the values I get the same or very similar values, except for T5 + 8bit:

(tensor(0.1253, device='cuda:0'), tensor(0.0223, device='cuda:0'), tensor(0.1440, device='cuda:0'), tensor(0.0288, device='cuda:0'))
(tensor(1.6214, device='cuda:0'), tensor(3.5510, device='cuda:0'), tensor(0.6988, device='cuda:0'), tensor(0.7377, device='cuda:0'))
(tensor(7.4336e-05, device='cuda:0'), tensor(8.8446e-09, device='cuda:0'), tensor(2.3471e-05, device='cuda:0'), tensor(8.9277e-10, device='cuda:0'))
(tensor(0.0004, device='cuda:0'), tensor(2.2412e-07, device='cuda:0'), tensor(0.0003, device='cuda:0'), tensor(1.3223e-07, device='cuda:0'))

Not sure why that is, perhaps it's best to just remove that specific test (in that case, add a comment that this combination may fail on some machines).

I used the nuclear norm of the error matrix between each layer of the quantized pissa model and the base model to evaluate the magnitude of the quantization error in my paper. This implementation is not affected by factors such as random seeds, and the error calculated for each model is a fixed value. If the test_t5_pissa_8bit test still cannot pass on your local machine, how do you feel about replacing this test with the one used in the paper?

I admit that calculating MAE/MSE of logits is a bit flawed as a measure, this was chosen more from a practical viewpoint. I don't know this measure that you proposed and would need to read a bit more, but if you think it's superior, feel free to use it instead. But as mentioned, it would also be fine to remove this one specific test.

It is quite strange that I can pass the make style test locally

Maybe it's the ruff version? The version that the CI uses is ruff-0.2.2. If this doesn't solve the issue for you, let me know and I can send you a patch.

bnb.nn.Params4bit requires the use of CUDA, so where should I put

Ah yes, good point, then it should go to tests/test_gpu_examples.py. You could add a comment that references the other test in test_initialization.py so that we know that the two belong together.

@fxmeng
Copy link
Contributor Author

fxmeng commented May 6, 2024

When I run your test above, the values I get the same or very similar values, except for T5 + 8bit:

(tensor(0.1253, device='cuda:0'), tensor(0.0223, device='cuda:0'), tensor(0.1440, device='cuda:0'), tensor(0.0288, device='cuda:0'))
(tensor(1.6214, device='cuda:0'), tensor(3.5510, device='cuda:0'), tensor(0.6988, device='cuda:0'), tensor(0.7377, device='cuda:0'))
(tensor(7.4336e-05, device='cuda:0'), tensor(8.8446e-09, device='cuda:0'), tensor(2.3471e-05, device='cuda:0'), tensor(8.9277e-10, device='cuda:0'))
(tensor(0.0004, device='cuda:0'), tensor(2.2412e-07, device='cuda:0'), tensor(0.0003, device='cuda:0'), tensor(1.3223e-07, device='cuda:0'))

Not sure why that is, perhaps it's best to just remove that specific test (in that case, add a comment that this combination may fail on some machines).

I used the nuclear norm of the error matrix between each layer of the quantized pissa model and the base model to evaluate the magnitude of the quantization error in my paper. This implementation is not affected by factors such as random seeds, and the error calculated for each model is a fixed value. If the test_t5_pissa_8bit test still cannot pass on your local machine, how do you feel about replacing this test with the one used in the paper?

I admit that calculating MAE/MSE of logits is a bit flawed as a measure, this was chosen more from a practical viewpoint. I don't know this measure that you proposed and would need to read a bit more, but if you think it's superior, feel free to use it instead. But as mentioned, it would also be fine to remove this one specific test.

It is quite strange that I can pass the make style test locally

Maybe it's the ruff version? The version that the CI uses is ruff-0.2.2. If this doesn't solve the issue for you, let me know and I can send you a patch.

bnb.nn.Params4bit requires the use of CUDA, so where should I put

Ah yes, good point, then it should go to tests/test_gpu_examples.py. You could add a comment that references the other test in test_initialization.py so that we know that the two belong together.

I have changed the method for measuring quantization errors from calculating the MAE/MSE from a practical viewpoint to calculating the nuclear norm of all error matrices. This method has a fixed error for each model and has passed tests in my local environment.

@BenjaminBossan
Copy link
Member

@fxmeng Did you check your local ruff version?

@fxmeng
Copy link
Contributor Author

fxmeng commented May 7, 2024

Hi @BenjaminBossan,
I have installed ruff==0.2.2 and formatted the files.
Next, I installed hf-doc-builder==0.5.0.
When executing the command: doc-builder style src/peft tests docs/source --max_len 119 --check_only,
it raises a ValueError: 3 files should be restyled!
How can I find out which files need to be restyled, and how should I restyle these three files?
Here is the complete error message:

$ make
ruff src tests examples docs scripts docker
ruff format --check src tests examples docs scripts docker
157 files already formatted
doc-builder style src/peft tests docs/source --max_len 119 --check_only
Traceback (most recent call last):
  File "/home/ubuntu/miniconda3/bin/doc-builder", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/ubuntu/miniconda3/lib/python3.12/site-packages/doc_builder/commands/doc_builder_cli.py", line 47, in main
    args.func(args)
  File "/home/ubuntu/miniconda3/lib/python3.12/site-packages/doc_builder/commands/style.py", line 28, in style_command
    raise ValueError(f"{len(changed)} files should be restyled!")
ValueError: 3 files should be restyled!
make: *** [Makefile:11:quality] Error 1

@fxmeng
Copy link
Contributor Author

fxmeng commented May 8, 2024

Hi @BenjaminBossan, I have installed ruff==0.2.2 and formatted the files. Next, I installed hf-doc-builder==0.5.0. When executing the command: doc-builder style src/peft tests docs/source --max_len 119 --check_only, it raises a ValueError: 3 files should be restyled! How can I find out which files need to be restyled, and how should I restyle these three files? Here is the complete error message:

$ make
ruff src tests examples docs scripts docker
ruff format --check src tests examples docs scripts docker
157 files already formatted
doc-builder style src/peft tests docs/source --max_len 119 --check_only
Traceback (most recent call last):
  File "/home/ubuntu/miniconda3/bin/doc-builder", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/ubuntu/miniconda3/lib/python3.12/site-packages/doc_builder/commands/doc_builder_cli.py", line 47, in main
    args.func(args)
  File "/home/ubuntu/miniconda3/lib/python3.12/site-packages/doc_builder/commands/style.py", line 28, in style_command
    raise ValueError(f"{len(changed)} files should be restyled!")
ValueError: 3 files should be restyled!
make: *** [Makefile:11:quality] Error 1

Could you please run make style to make the CI pass?

I had misunderstood your suggestion to use make style to organize the code; instead, I was using the reformat feature in VSCode.
By running the make style command, the three files were restyled. Now, the entire project has finally passed all the tests locally!

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for running ruff and for adjusting the test. There are a few issues that resulted from moving the test, but they should be easy fix, please check my comments.

tests/test_gpu_examples.py Outdated Show resolved Hide resolved
tests/test_gpu_examples.py Outdated Show resolved Hide resolved
tests/test_gpu_examples.py Outdated Show resolved Hide resolved
tests/test_gpu_examples.py Outdated Show resolved Hide resolved
tests/test_gpu_examples.py Outdated Show resolved Hide resolved
@fxmeng
Copy link
Contributor Author

fxmeng commented May 8, 2024

Thanks for running ruff and for adjusting the test. There are a few issues that resulted from moving the test, but they should be easy fix, please check my comments.

Thank you for your comments. I have fixed these points. If there are any other issues, please let me know.

@fxmeng fxmeng requested a review from BenjaminBossan May 8, 2024 15:56
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic, just one small issue that made the new test fail for me.

tests/test_gpu_examples.py Outdated Show resolved Hide resolved
Co-authored-by: Benjamin Bossan <[email protected]>
@fxmeng fxmeng requested a review from BenjaminBossan May 9, 2024 02:35
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this excellent work @fxmeng. This now all looks good. I'll ask if one of the co-maintainers also wants to review before merging.

I also got email notifications with comments you added about experimenting without adjusting the base model weights, thanks for running these tests. However, when I click the links in the emails, the comments are not shown. Maybe because the original comment is on outdated code? Anyway, could you be so kind to paste those comments into the main thread of this PR so that everyone who's interested in your results can read them?

@fxmeng
Copy link
Contributor Author

fxmeng commented May 10, 2024

Hi @BenjaminBossan,
Thank you for your professional, patient, and friendly support throughout this process. We believe that with your help, the code quality of PiSSA has significantly improved and become more easily to use.

Out of curiosity, did you run experiments where the base weights were not updated?

Here is the experiment using the PiSSA adapter in conjunction with the base model for training.
Due to PiSSA altering the functionality of the base model, the initial steps showed very high training loss and grad_norm, with initial values around 0.7 and 7, respectively, in the PiSSA+residual model. However, after 100 iterations, the loss surprisingly converged to around 5.

1-10 steps:

{'loss': 13.4825, 'grad_norm': 11509.810546875, 'learning_rate': 8.333333333333333e-07, 'epoch': 0.0}                                                                                                                                                                         
{'loss': 13.4192, 'grad_norm': 16469.279296875, 'learning_rate': 1.6666666666666667e-06, 'epoch': 0.0}                                                                                                                                                                        
{'loss': 13.457, 'grad_norm': 15556.3232421875, 'learning_rate': 2.5e-06, 'epoch': 0.0}                                                                                                                                                                                       
{'loss': 13.3041, 'grad_norm': 3012.66943359375, 'learning_rate': 3.3333333333333333e-06, 'epoch': 0.01}                                                                                                                                                                      
{'loss': 12.9197, 'grad_norm': 1206.873291015625, 'learning_rate': 4.166666666666667e-06, 'epoch': 0.01}                                                                                                                                                                      
{'loss': 12.8678, 'grad_norm': 953.5204467773438, 'learning_rate': 5e-06, 'epoch': 0.01}                                                                                                                                                                                      
{'loss': 12.6281, 'grad_norm': 2949.16162109375, 'learning_rate': 5.833333333333334e-06, 'epoch': 0.01}                                                                                                                                                                       
{'loss': 11.5261, 'grad_norm': 2978.842529296875, 'learning_rate': 6.666666666666667e-06, 'epoch': 0.01}                                                                                                                                                                      
{'loss': 10.7645, 'grad_norm': 1403.374755859375, 'learning_rate': 7.500000000000001e-06, 'epoch': 0.01}                                                                                                                                                                      
{'loss': 17.2494, 'grad_norm': 2027.790283203125, 'learning_rate': 8.333333333333334e-06, 'epoch': 0.01}   

90-100 steps:

{'loss': 10.1056, 'grad_norm': 14.725829124450684, 'learning_rate': 1.961591468063112e-05, 'epoch': 0.12}                                                                                                                                                                     
{'loss': 5.5941, 'grad_norm': 9.888932228088379, 'learning_rate': 1.9604440636594863e-05, 'epoch': 0.12}                                                                                                                                                                      
{'loss': 5.6372, 'grad_norm': 27.609088897705078, 'learning_rate': 1.9592801175825435e-05, 'epoch': 0.12}                                                                                                                                                                     
{'loss': 8.1496, 'grad_norm': 13.072076797485352, 'learning_rate': 1.9580996498788602e-05, 'epoch': 0.12}                                                                                                                                                                     
{'loss': 5.4474, 'grad_norm': 10.157075881958008, 'learning_rate': 1.9569026808795647e-05, 'epoch': 0.12}                                                                                                                                                                     
{'loss': 5.5057, 'grad_norm': 16.48522186279297, 'learning_rate': 1.955689231199986e-05, 'epoch': 0.12}                                                                                                                                                                       
{'loss': 5.3535, 'grad_norm': 11.122950553894043, 'learning_rate': 1.954459321739298e-05, 'epoch': 0.12}                                                                                                                                                                      
{'loss': 5.4009, 'grad_norm': 16.74958038330078, 'learning_rate': 1.9532129736801616e-05, 'epoch': 0.13}                                                                                                                                                                      
{'loss': 5.2926, 'grad_norm': 12.336542129516602, 'learning_rate': 1.9519502084883585e-05, 'epoch': 0.13}                                                                                                                                                                     
{'loss': 5.3891, 'grad_norm': 25.544628143310547, 'learning_rate': 1.9506710479124212e-05, 'epoch': 0.13} 

The final results comparison is shown in the table below. As we can observe, there is a significant drop in performance when using PiSSA without updating the base weights.

loss gsm8k math
base model+PiSSA 1.776114435263083 0.33965125094768767 0.0538
residual model+PiSSA 0.2408599260521912 0.7149355572403336 0.2416

@BenjaminBossan BenjaminBossan merged commit b5acf5d into huggingface:main May 15, 2024
14 checks passed
@BenjaminBossan
Copy link
Member

Thanks a lot @fxmeng for this great PR. After internal discussion, we decided this is good to be merged.

@fxmeng
Copy link
Contributor Author

fxmeng commented May 16, 2024

I truly appreciate the constructive feedback and effort from all your team members. I am grateful to contribute to the valuable PEFT project.

@Con6924
Copy link

Con6924 commented May 24, 2024

Thanks for your great PR!
I'm trying out this new feature and find it has not been applicable for Conv2d and Embedding layers. Is it possible to support these two layers with Pissa?

@skyshine102
Copy link

Hi @fxmeng. May I know if pissa initialization compatible to deepspeed ZeRO?

@fxmeng
Copy link
Contributor Author

fxmeng commented Jun 17, 2024

Thanks for your great PR! I'm trying out this new feature and find it has not been applicable for Conv2d and Embedding layers. Is it possible to support these two layers with Pissa?

Hi @Con6924,
In this code, PiSSA is implemented in Conv2d and Embedding: https://github.com/fxmeng/peft/blob/main/src/peft/tuners/lora/layer.py
which will also be submitting a PR to the peft repository soon.

@BenjaminBossan
Copy link
Member

which will also be submitting a PR to the peft repository soon.

Cool, looking forward to it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants