-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Running out of memory when resume training. #12680
Comments
Thank you for the detailed report, @thies1006 I suspect that at some point we have the model allocated more than once. I will profile the memory usage and get back to you with the findings. I'm glad to hear that meanwhile you have a workaround. |
So first I see our non-deepspeed checkpoint-loading is inefficient CPU memory-wise
As you can see the checkpoint loading takes ~225MB more:
which is exactly the size of the t5-small (230MB) model. That is at some point it keeps 2 full copies of the model in CPU memory. cc: @sgugger So the issue might not be in deepspeed, but will check that next. |
Oh that is weird. At the top of my mind the first culprit could be the |
Yes, that did the trick! It's the same memory usage now. Applied here: #12718 |
So back to the deepspeed side of this Issue. I wasn't able to see the problem with
So it's easy to see that at some point there is a temporary jump by 1.1GB as compared to the normal run - t5-base is about 850MB. Which most likely means there are several copies of it loaded into CPU memory at some point. |
OK, so I did some profiling with an even larger model: t5-large (2.7GB) so it's easier to see what's happening. We need to take into account that Deepspeed needs to load optimizer states, which non-Deepspeed run doesn't do! And that makes a huge difference. So our model has close to 0.75B params:
Now the checkpoint contains 4 bytes for fp32 weights and 8 bytes for optimizer, 12 in total:
Indeed if we check the checkpoint folder:
And this is what accounts for a huge peak CPU RAM that gets temporarily used when the checkpoint is loaded. So as you indeed figured out if you bypass the checkpoint loading and load just the weights you extracted with In general this should be possible to fix, by not allocating the model until the checkpoint loading (see #12274 - which was just made available in pytorch) and probably something similar with the optimizer. But I can't promise you if and when this will happen. This is very important I think! Perhaps a simpler solution until then would be to allocate some swap memory on an nvme drive? Please let me know if this is helpful. |
Thank you very much for the insights @stas00 !! I just wanted to bring this up because the order of magnitude was surprising to me. As I understand you, model and optimizer states are allocating memory twice (model init and checkpoint loading). My checkpoint has the size (for Blenderbot-9B):
I also tried with the Blenderbot-3B, there I get 61GB size of the checkpoint folder and cpu ram consumption peaks at about 330GB (short peak, as you said). So, in summary, I'm still wondering about the numbers. But as I understand you, this is normal and already addressed. I'll try with the nvme btw, thanks for the hint! I think we can close this for now. |
The main issue is loading optimizer states which are 2x bigger than the fp32 model. Actually, I thought of a possible solution last night. This is staggered checkpoint loading. So if you have 4 gpus on a node, now you get the whole checkpoint folder loaded into CPU at once. However what if we loaded one gpu at a time! That would require 1/4th extra CPU memory as when one gpu finished loading it will return the CPU memory back to the pool. I think this approach should solve your limitation. Let me try to implement this on the deepspeed side. |
After trying to implement staggered load, I discovered that each process loads zero checkpoints for all ranks in deepspeed, |
Might be similar problem as #11317, node runs out of cpu memory (512GB).
To reproduce:
(i)
(ii)
Afterwards in order to resume I use the option
--resume_from_checkpoint /tmp/tst-summarization/checkpoint-10
.A workaround is to export the FP32 weights using the script
zero_to_fp32.py
as described in https://huggingface.co/transformers/master/main_classes/deepspeed.html#getting-the-model-weights-out and restart directly frompytorch_model.bin
, nevertheless it would be better to resume directly from the deepspeed checkpoint, if possible.torch: 1.8.1+cu111
transformers: 4.9.0.dev0
deepspeed: 0.4.4+d1a7a55
log: log.txt
@stas00
The text was updated successfully, but these errors were encountered: