-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Peft Training with Zero.Init() and Zero3 will increase GPU memory every forward step #3002
Comments
I have also try the tohtana/nested_zero_init branch, which did not fix it. |
@dumpmemory $ git diff
diff --git a/src/peft/tuners/lora.py b/src/peft/tuners/lora.py
index 1d1680d..97f0a4e 100644
--- a/src/peft/tuners/lora.py
+++ b/src/peft/tuners/lora.py
@@ -484,7 +484,7 @@ class Linear(nn.Linear, LoraLayer):
self.unmerge()
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
elif self.r[self.active_adapter] > 0 and not self.merged:
- result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
+ result = torch.matmul(x, transpose(self.weight, not self.fan_in_fan_out)) + self.bias
x = x.to(self.lora_A[self.active_adapter].weight.dtype) Although Zero3 sets an empty tensor to |
I will test this, thanks for your help. I will update result later |
It wokred ! With peft commit 10a2a6db5dc9cabb63a36c0fb489aeb2b9a1e433 and modification above , deepspeed 0.9.1 and torch 2.0. Thanks for your help.
|
I will try to make a pr following your idea on peft . Thanks again. |
@dumpmemory This memory leak can also be fixed by setting By default, DeepSpeed replaces PyTorch's linear with a different implementation. This might cause the memory leak. I will investigate what the memory_efficient_linear does. |
Thanks for your work ! |
@dumpmemory, can you please try PR #3413 created by @tohtana? Thanks! |
Yes i can. Can i test it after my holiday ? Thanks |
@dumpmemory, of course! By the way, the PR is merged so you can use the master branch when you are ready. Happy holidays to you! Thanks for your help. |
It worked with peft(commit 10a2a6db5dc9cabb63a36c0fb489aeb2b9a1e433 ) and peft 3.0 |
Describe the bug
when i using Peft LoRA to train a gpt2 model, the gpu memory increase with every forward step with Zero3 adn zero.init function. when i disable zero.init, it worked as normal.
To Reproduce
Expected behavior
run with no gpu memory increasing
ds_report output
Please run
ds_report
to give us details about your setup.Screenshots
If applicable, add screenshots to help explain your problem.
System info (please complete the following information):
The text was updated successfully, but these errors were encountered: