Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running out of memory when resume training. #12680

Closed
thies1006 opened this issue Jul 13, 2021 · 9 comments
Closed

Running out of memory when resume training. #12680

thies1006 opened this issue Jul 13, 2021 · 9 comments
Assignees

Comments

@thies1006
Copy link

thies1006 commented Jul 13, 2021

Might be similar problem as #11317, node runs out of cpu memory (512GB).

To reproduce:

(i)

deepspeed --hostfile myhostfile \ ${_PATH}/examples/pytorch/summarization/run_summarization.py \ 
--model_name_or_path hyunwoongko/blenderbot-9B \ 
--do_train \ 
--do_eval \ 
--dataset_name cnn_dailymail \ 
--dataset_config "3.0.0" \ 
--source_prefix "summarize: " \ 
--output_dir /tmp/tst-summarization \ 
--per_device_train_batch_size 8 \ 
--per_device_eval_batch_size 8 \ 
--deepspeed ${_PATH}/tests/deepspeed/ds_config_zero3.json \ 
--logging_steps 1 \ 
--fp16 \ 
--overwrite_output_dir \ 
--save_steps 10 \ 
--gradient_accumulation_steps 1 \ 
--evaluation_strategy="steps" \ 
--max_train_samples 10024 \ 
--max_eval_samples 32 \ 
--max_source_length 128 
--max_target_length 128 \ 
--eval_steps 5 

(ii)
Afterwards in order to resume I use the option --resume_from_checkpoint /tmp/tst-summarization/checkpoint-10.

A workaround is to export the FP32 weights using the script zero_to_fp32.py as described in https://huggingface.co/transformers/master/main_classes/deepspeed.html#getting-the-model-weights-out and restart directly from pytorch_model.bin, nevertheless it would be better to resume directly from the deepspeed checkpoint, if possible.

torch: 1.8.1+cu111
transformers: 4.9.0.dev0
deepspeed: 0.4.4+d1a7a55

log: log.txt

@stas00

@stas00 stas00 self-assigned this Jul 13, 2021
@stas00
Copy link
Contributor

stas00 commented Jul 13, 2021

Thank you for the detailed report, @thies1006

I suspect that at some point we have the model allocated more than once.

I will profile the memory usage and get back to you with the findings.

I'm glad to hear that meanwhile you have a workaround.

@stas00
Copy link
Contributor

stas00 commented Jul 14, 2021

So first I see our non-deepspeed checkpoint-loading is inefficient CPU memory-wise

# save
export BS=16; rm -r output_dir; PYTHONPATH=src USE_TF=0 CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 examples/pytorch/translation/run_translation.py --model_name_or_path t5-small --output_dir output_dir --adam_eps 1e-06  --do_train --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 500 --max_source_length 128 --max_target_length 128 --val_max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_train_batch_size $BS --predict_with_generate --sortish_sampler --source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" --source_prefix "translate English to Romanian: " --max_train_samples 50 --save_steps 1 --skip_memory_metrics 0

# load:
export BS=16; PYTHONPATH=src USE_TF=0 CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 examples/pytorch/translation/run_translation.py --model_name_or_path t5-small --output_dir output_dir --adam_eps 1e-06  --do_train --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 500 --max_source_length 128 --max_target_length 128 --val_max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_train_batch_size $BS --predict_with_generate --sortish_sampler --source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" --source_prefix "translate English to Romanian: " --max_train_samples 50 --save_steps 1 --skip_memory_metrics 0 --resume_from_checkpoint output_dir/checkpoint-1
# save
***** train metrics *****
  epoch                      =        1.0
  init_mem_cpu_alloc_delta   =     -153MB
  init_mem_cpu_peaked_delta  =      152MB
  init_mem_gpu_alloc_delta   =      230MB
  init_mem_gpu_peaked_delta  =        0MB
  train_loss                 =     2.9967
  train_mem_cpu_alloc_delta  =     1324MB
  train_mem_cpu_peaked_delta =      125MB
  train_mem_gpu_alloc_delta  =      933MB
  train_mem_gpu_peaked_delta =      355MB
  train_runtime              = 0:00:03.47
  train_samples              =         50
  train_samples_per_second   =     14.386
  train_steps_per_second     =      0.575
# load
***** train metrics *****
  epoch                      =        1.0
  init_mem_cpu_alloc_delta   =     -153MB
  init_mem_cpu_peaked_delta  =      152MB
  init_mem_gpu_alloc_delta   =      230MB
  init_mem_gpu_peaked_delta  =        0MB
  train_loss                 =     1.4817
  train_mem_cpu_alloc_delta  =     1552MB
  train_mem_cpu_peaked_delta =      124MB
  train_mem_gpu_alloc_delta  =      931MB
  train_mem_gpu_peaked_delta =      228MB
  train_runtime              = 0:00:03.45
  train_samples              =         50
  train_samples_per_second   =     14.472
  train_steps_per_second     =      0.579

As you can see the checkpoint loading takes ~225MB more:

-  train_mem_cpu_alloc_delta  =     1324MB
+  train_mem_cpu_alloc_delta  =     1552MB

which is exactly the size of the t5-small (230MB) model.

That is at some point it keeps 2 full copies of the model in CPU memory.

cc: @sgugger

So the issue might not be in deepspeed, but will check that next.

@sgugger
Copy link
Collaborator

sgugger commented Jul 14, 2021

Oh that is weird. At the top of my mind the first culprit could be the state_dict we loaded that is not release by the Trainer for some reason. If you add a del state_dict on this line does it release that copy? (Can't fully test right now which is why I'm asking you.)

@stas00
Copy link
Contributor

stas00 commented Jul 14, 2021

Yes, that did the trick! It's the same memory usage now. Applied here: #12718

@stas00
Copy link
Contributor

stas00 commented Jul 15, 2021

So back to the deepspeed side of this Issue. I wasn't able to see the problem with t5-small, but I can see it clearly with t5-base

# save
BS=16; rm -r output_dir; PYTHONPATH=src USE_TF=0 deepspeed --num_gpus 2 examples/pytorch/translation/run_translation.py --model_name_or_path t5-base --output_dir output_dir --overwrite_output_dir --max_source_length 128 --max_target_length 128 --val_max_target_length 128 --do_train --num_train_epochs 1 --per_device_train_batch_size $BS --learning_rate 3e-3 --logging_steps 0  --dataset_name wmt16 --dataset_config ro-en --source_lang en --target_lang ro --source_prefix "translate English to Romanian: " --max_train_samples 50 --deepspeed tests/deepspeed/ds_config_zero3.json --save_steps 1 --skip_memory_metrics 0

# load:
BS=16; PYTHONPATH=src USE_TF=0 deepspeed --num_gpus 2 examples/pytorch/translation/run_translation.py --model_name_or_path t5-base --output_dir output_dir --overwrite_output_dir --max_source_length 128 --max_target_length 128 --val_max_target_length 128 --do_train --num_train_epochs 1 --per_device_train_batch_size $BS --learning_rate 3e-3 --logging_steps 0  --dataset_name wmt16 --dataset_config ro-en --source_lang en --target_lang ro --source_prefix "translate English to Romanian: " --max_train_samples 50 --deepspeed tests/deepspeed/ds_config_zero3.json --save_steps 1 --skip_memory_metrics 0 --resume_from_checkpoint output_dir/checkpoint-1
# save
***** train metrics *****
  train_mem_cpu_alloc_delta  =     5542MB
  train_mem_cpu_peaked_delta =      424MB
  train_mem_gpu_alloc_delta  =     -394MB
  train_mem_gpu_peaked_delta =     1259MB
# load
***** train metrics *****
  train_mem_cpu_alloc_delta  =     5109MB
  train_mem_cpu_peaked_delta =     1944MB
  train_mem_gpu_alloc_delta  =     -394MB
  train_mem_gpu_peaked_delta =      804MB

So it's easy to see that at some point there is a temporary jump by 1.1GB as compared to the normal run - t5-base is about 850MB. Which most likely means there are several copies of it loaded into CPU memory at some point.

@stas00
Copy link
Contributor

stas00 commented Jul 15, 2021

OK, so I did some profiling with an even larger model: t5-large (2.7GB) so it's easier to see what's happening.

We need to take into account that Deepspeed needs to load optimizer states, which non-Deepspeed run doesn't do! And that makes a huge difference.

So our model has close to 0.75B params:

$ python -c 'from transformers import T5ForConditionalGeneration; model = T5ForConditionalGeneration.from_pretrained("t5-large"); print(sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values()))'
737,668,096 # 737M params

Now the checkpoint contains 4 bytes for fp32 weights and 8 bytes for optimizer, 12 in total:

python -c 'print(f"{737668096*12 / 2**30 :0.2f}GB")'
8.24GB

Indeed if we check the checkpoint folder:

du -sh output_dir/checkpoint-1/global_step1/
8.3G    output_dir/checkpoint-1/global_step1/

And this is what accounts for a huge peak CPU RAM that gets temporarily used when the checkpoint is loaded.

So as you indeed figured out if you bypass the checkpoint loading and load just the weights you extracted with zero_to_fp32.py you have no problem with temporarily needing more CPU memory than required to run the normal run.

In general this should be possible to fix, by not allocating the model until the checkpoint loading (see #12274 - which was just made available in pytorch) and probably something similar with the optimizer. But I can't promise you if and when this will happen. This is very important I think!

Perhaps a simpler solution until then would be to allocate some swap memory on an nvme drive?

Please let me know if this is helpful.

@thies1006
Copy link
Author

Thank you very much for the insights @stas00 !! I just wanted to bring this up because the order of magnitude was surprising to me. As I understand you, model and optimizer states are allocating memory twice (model init and checkpoint loading).

My checkpoint has the size (for Blenderbot-9B):

du -sh /tmp/tst-summarization/checkpoint-10/global_step10/
106G	/tmp/tst-summarization/checkpoint-10/global_step10/

I also tried with the Blenderbot-3B, there I get 61GB size of the checkpoint folder and cpu ram consumption peaks at about 330GB (short peak, as you said).

So, in summary, I'm still wondering about the numbers. But as I understand you, this is normal and already addressed. I'll try with the nvme btw, thanks for the hint!

I think we can close this for now.

@stas00
Copy link
Contributor

stas00 commented Jul 15, 2021

The main issue is loading optimizer states which are 2x bigger than the fp32 model.

Actually, I thought of a possible solution last night. This is staggered checkpoint loading.

So if you have 4 gpus on a node, now you get the whole checkpoint folder loaded into CPU at once. However what if we loaded one gpu at a time! That would require 1/4th extra CPU memory as when one gpu finished loading it will return the CPU memory back to the pool.

I think this approach should solve your limitation. Let me try to implement this on the deepspeed side.

@stas00
Copy link
Contributor

stas00 commented Jul 15, 2021

After trying to implement staggered load, I discovered that each process loads zero checkpoints for all ranks in deepspeed,
Let's continue this discussion over at Deepspeed as it's not really a transformers' issue
microsoft/DeepSpeed#1236

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants