-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH Remove redundant initialization layer calls #887
ENH Remove redundant initialization layer calls #887
Conversation
The documentation is not available anymore as the PR was closed or merged. |
Short update, I tested this on a small model, "bloomz-1b7". On main, creating the PEFT model takes 16 sec, on this branch 13 sec. No huge difference, maybe it's too small for a more noticeable difference or there is another bottleneck. |
in my tests this PR makes Lora init much faster than |
Also, adjust a comment that was added in previous commit.
@poedator We decided to proceed with this approach to avoid unnecessary inits altogether. Still, you made a huge contribution, so I would like to add you as a co-author if you like to. If you provide me your credentials in the co-author format, I will add it to the commit message (not sure how to figure that out on my own). |
@BenjaminBossan, I like your approach even better than my own, support picking this PR. Now it makes sense to expand by adding other layer types. I'd be glad to contribute by running benchmarks or any other way. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @poedator and @BenjaminBossan for working on this, impactful! 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @BenjaminBossan !
Partly resolves huggingface#872 Description After getting faster initialization of the LoRA Linear layer, initialization of Conv2D and Embedding is now sped up. Implementation The approach of how to achieve the speed up has slightly changed compared to last time. To refresh memory, in huggingface#887, we avoided the unnecessary initialization of the full weight matrix by completely skipping nn.Linear.__init__. Although it is possible to do the same for Embedding and Conv2d, we run into some trouble here. The issue is that the __init__ methods of these classes have quite a lot more arguments and some custom logic (i.e. not only self.foo = foo but more on top). If we wanted to skip __init__ entirely, we would have to basically copy all of that into our code. Although that is possible, it is brittle (e.g. the logic could be different for different PyTorch versions or change over time). For that reason, I opted to implement this differently, using a suggestion we had discussed earlier. The approach is to call __init__ of the parent class but enforce empty weights (this is what torch.nn.utils.skip_init does, although we cannot use that function directly). This way, we can avoid having to copy the __init__ code while still avoiding expensive initialization of the weights. I did not change the code for Linear to also use this approach because the logic inside of Linear.__init__ is quite simple (at least for now), so we are good here with the existing approach. However, I was curious how changing the approach for Linear would affect the initialization speed. Therefore, I ran the script from huggingface#872 again, 3 times each. Current approach: test 1 with model bert-base took 0.021 sec. test 1 with model bert-base took 0.020 sec. test 1 with model bert-base took 0.020 sec. test 2 with model bloomz-1b7 took 0.030 sec. test 2 with model bloomz-1b7 took 0.030 sec. test 2 with model bloomz-1b7 took 0.030 sec. New approach if applied to Linear: test 1 with model bert-base took 0.038 sec. test 1 with model bert-base took 0.039 sec. test 1 with model bert-base took 0.038 sec. test 2 with model bloomz-1b7 took 0.072 sec. test 2 with model bloomz-1b7 took 0.048 sec. test 2 with model bloomz-1b7 took 0.048 sec. This shows that the new approach is indeed a bit slower than the existing one, though still a lot faster than what we had before. IMHO, I think we're safe to leave the code inside of Linear as is and benefit from the slightly better performance at the cost of slightly more fragile code. But please let me know if you prefer: 1. The new approach should also be applied to Linear 2. The existing approach should also be applied to Embedding and Conv2d
Partly resolves #872 Description After getting faster initialization of the LoRA Linear layer, initialization of Conv2D and Embedding is now sped up. Implementation The approach of how to achieve the speed up has slightly changed compared to last time. To refresh memory, in #887, we avoided the unnecessary initialization of the full weight matrix by completely skipping nn.Linear.__init__. Although it is possible to do the same for Embedding and Conv2d, we run into some trouble here. The issue is that the __init__ methods of these classes have quite a lot more arguments and some custom logic (i.e. not only self.foo = foo but more on top). If we wanted to skip __init__ entirely, we would have to basically copy all of that into our code. Although that is possible, it is brittle (e.g. the logic could be different for different PyTorch versions or change over time). For that reason, I opted to implement this differently, using a suggestion we had discussed earlier. The approach is to call __init__ of the parent class but enforce empty weights (this is what torch.nn.utils.skip_init does, although we cannot use that function directly). This way, we can avoid having to copy the __init__ code while still avoiding expensive initialization of the weights. I did not change the code for Linear to also use this approach because the logic inside of Linear.__init__ is quite simple (at least for now), so we are good here with the existing approach. However, I was curious how changing the approach for Linear would affect the initialization speed. Therefore, I ran the script from #872 again, 3 times each. Current approach: test 1 with model bert-base took 0.021 sec. test 1 with model bert-base took 0.020 sec. test 1 with model bert-base took 0.020 sec. test 2 with model bloomz-1b7 took 0.030 sec. test 2 with model bloomz-1b7 took 0.030 sec. test 2 with model bloomz-1b7 took 0.030 sec. New approach if applied to Linear: test 1 with model bert-base took 0.038 sec. test 1 with model bert-base took 0.039 sec. test 1 with model bert-base took 0.038 sec. test 2 with model bloomz-1b7 took 0.072 sec. test 2 with model bloomz-1b7 took 0.048 sec. test 2 with model bloomz-1b7 took 0.048 sec. This shows that the new approach is indeed a bit slower than the existing one, though still a lot faster than what we had before. IMHO, I think we're safe to leave the code inside of Linear as is and benefit from the slightly better performance at the cost of slightly more fragile code.
See discussion in #872
Description
This is an attempt at speeding up initialization of LoRA layers. As mentioned in this comment, it seems that we currently create the weights of the linear layer that is to be adapted with LoRA twice (!), only to override it later by the original weights of the target layer. This PR attempts to remove that redundancy.
Before this can be merged, there should be a thorough study of potential implications, since this could be BC breaking if we miss something, and the impact on initialization speed should be measured.
Road not taken
I think an even cleaner design would be to pass the target module to the LoRA layer, so that the LoRA layer can hold a reference to it. During
forward
, it would callself.original_module.forward(x)
, then add the LoRA stuff on top. When unloading, we would just return theoriginal_module
instead of having to create a completely newnn.Linear
.Note that this approach is basically already implemented for
QuantLinear
. However, I think making that change onLinear
would indeed break BC because it would not allow to load existing weights (correct me if I'm wrong), so I didn't explore that approach further.