Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reconstruction of fp32 weights on stage3 doesn't work #1009

Closed
exelents opened this issue Apr 27, 2021 · 15 comments · Fixed by #1017
Closed

Reconstruction of fp32 weights on stage3 doesn't work #1009

exelents opened this issue Apr 27, 2021 · 15 comments · Fixed by #1017

Comments

@exelents
Copy link

In #892 @stas00 proposed a new script which can consolidate fp32 weights from fp16 model checkpoint on stage 3 training.
Unfortunately I have found? that t5-11b model can't be consolidated due to some error:

└──>$ ./zero_to_fp32.py global_step3250/ pytorch_model1.bin
Processing zero checkpoint 'global_step3250/'
Detected checkpoint of type zero stage 3, world_size: 1
Traceback (most recent call last):
  File "./zero_to_fp32.py", line 151, in <module>
    convert_zero_chkpt_to_fp32_consolid_state_dict(args.checkpoint_dir, args.output_file)
  File "./zero_to_fp32.py", line 122, in convert_zero_chkpt_to_fp32_consolid_state_dict
    tuple(fp32_flat_groups[i].narrow(0,
  File "./zero_to_fp32.py", line 122, in <genexpr>
    tuple(fp32_flat_groups[i].narrow(0,
RuntimeError: start (32899072) + length (16777216) exceeds dimension size (32899072).

Maybe @stas00 could say what is the problem, and how it can be fixed?

@exelents exelents changed the title Reconstruction of Reconstruction of fp32 weights on stage3 doesn't work Apr 27, 2021
@exelents
Copy link
Author

exelents commented Apr 27, 2021

Here is steps to reproduce my error with model reconstruction:

  • Clone this repo:
    https://github.com/exelents/try_t5_siamese

  • Extract folder "qasc" from this archive:
    https://drive.google.com/file/d/1gwvFiPzWW0JLr0XLS25PuG2Br5S4fPbR/view?usp=sharing

  • Go to clonned repo folder and run ./create-siamese-template.sh - it will create siamese NN from two t5-11b encoders in folder ./templates/siamese-t5-11b-template

  • then you can run ./run-siamese-small.sh and wait until any checkpoint would be saved. If you don't want to wait for 1625 steps you can change checkpoint save period in ./run-siamese-small.sh in parameter --save_steps

  • go to ./output_dir/<your_latest_checkpoint> and run
    ./zero_to_fp32.py ./<global_step_path> pytorch_model1.bin

@stas00
Copy link
Collaborator

stas00 commented Apr 28, 2021

Could you please save one checkpoint and share it - on my dev machine I won't be able to train 45GB model.

Alternatively, can you reproduce the same problem with a much smaller model? say t5-small? Or do you get it only with t5-11b?

@exelents
Copy link
Author

Yes, I can reproduce it on t5-small

└──>$ ./zero_to_fp32.py global_step10/ pytorch_model1.bin 
Processing zero checkpoint 'global_step10/'
Detected checkpoint of type zero stage 3, world_size: 1
Traceback (most recent call last):
  File "./zero_to_fp32.py", line 151, in <module>
    convert_zero_chkpt_to_fp32_consolid_state_dict(args.checkpoint_dir, args.output_file)
  File "./zero_to_fp32.py", line 122, in convert_zero_chkpt_to_fp32_consolid_state_dict
    tuple(fp32_flat_groups[i].narrow(0,
  File "./zero_to_fp32.py", line 122, in <genexpr>
    tuple(fp32_flat_groups[i].narrow(0,
RuntimeError: start (16449536) + length (262144) exceeds dimension size (16449536).

Here is checkpoint of siamese model based on t5-small:
https://drive.google.com/file/d/1PjWEPWHXiLw1vKHSE6_r6YgkCfuweGgb/view?usp=sharing

@stas00
Copy link
Collaborator

stas00 commented Apr 28, 2021

This is odd as it works for me:

./zero_to_fp32.py global_step6 out.bin
Processing zero checkpoint 'global_step6'
Detected checkpoint of type zero stage 2, world_size: 1
Saving fp32 state dict to out.bin (total_numel=70661632)

This was generated by ./run-siamese-small.sh

Unrelated, but I had to downgrade transformers to the last official release to make your script work. This is because the deepspeed config in transformers master has changed, please refer to huggingface/transformers#11464 so you're in sync with the changes.

If I download your checkpoint, indeed it fails:

./zero_to_fp32.py global_step10 1
Processing zero checkpoint 'global_step10'
Detected checkpoint of type zero stage 3, world_size: 1
Traceback (most recent call last):
  File "./zero_to_fp32.py", line 151, in <module>
    convert_zero_chkpt_to_fp32_consolid_state_dict(args.checkpoint_dir, args.output_file)
  File "./zero_to_fp32.py", line 122, in convert_zero_chkpt_to_fp32_consolid_state_dict
    tuple(fp32_flat_groups[i].narrow(0,
  File "./zero_to_fp32.py", line 122, in <genexpr>
    tuple(fp32_flat_groups[i].narrow(0,
RuntimeError: start (16449536) + length (262144) exceeds dimension size (16449536).

Any ideas on what might be different in our envs? do you somehow copy the checkpoint via some OS/way that possibly mangles the checkpiont? Can you resume from those checkpoints

I'm on pytorch-nightly, deepspeed master and downgraded to transformers==4.5.1 to be able to run your script.

@stas00
Copy link
Collaborator

stas00 commented Apr 28, 2021

Oddly enough, your t5-small checkpoint looks very different from mine:

Mine:

-rw-rw-r-- 1 stas stas 141357368 Apr 28 10:18 mp_rank_00_model_states.pt
-rw-rw-r-- 1 stas stas 847950410 Apr 28 10:18 zero_pp_rank_0_mp_rank_00_optim_states.pt

Yours:

-rw-rw-r-- 1 stas stas     52397 Apr 28 10:02 zero_pp_rank_0_mp_rank_00_model_states.pt
-rw-rw-r-- 1 stas stas 847983800 Apr 28 10:02 zero_pp_rank_0_mp_rank_00_optim_states.pt

the optim states are slighly different and model states are very different. (though we don't care for model states for this purpose, but it's an indicator that we use different envs).

I used your scripts unmodified so we should be getting the same results.

Could you update your deepspeed and transformers to the latest official versions and try again?

pip install deepspeed  transformers -U

@exelents
Copy link
Author

exelents commented Apr 28, 2021

I have reinstalled deepspeed and transformers to the latest versions from pip repo. Also I installed everything this in fresh python environment with pytorch-nighty:
pip3 install --pre torch torchvision torchaudio -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html
And I still have the same problem.

Sizes of weights files really different from yours.

-rw-rw-r-- 1 fellow fellow     52254 апр 28 21:12 zero_pp_rank_0_mp_rank_00_model_states.pt
-rw-rw-r-- 1 fellow fellow 847983555 апр 28 21:12 zero_pp_rank_0_mp_rank_00_optim_states.pt

Can this problem be caused by CUDA?

@exelents
Copy link
Author

exelents commented Apr 28, 2021

After i installed all packets in the fresh env I got an warnings while training, despite training ends successfuly:

got UNKNOWN type <class 'siamese_model.T5SiameseModelOutput'>
convert output to [tensor(0.4226, device='cuda:0', dtype=torch.float16), tensor([[-0.1069,  0.0889,  0.0464, 
.....

T5SiameseModelOutput is my custom output class. Can it make problems? I even don't know what went wrong.

@exelents
Copy link
Author

exelents commented Apr 28, 2021

I have found an interesting behavior. This checkpoint can be successfully loaded via
model_engine.load_checkpoint(ENC_MODEL)
where model_engine is deepspeed's model engine object. I even can make predictions via it, despite I still haven't done full check with search index using this model.
Here is notebook, how I did this. You should have my repo in directory with it and configure ENC_MODEL (path to checkpoint) in the second cell.
https://drive.google.com/file/d/1u_v5LrVCxg3qWCYj70cvqxcA3M-b4cV2/view?usp=sharing

@stas00
Copy link
Collaborator

stas00 commented Apr 28, 2021

the last one is a buglet in the deepspeed=0.3.15 - they forgot to remove a debug print. So you can upgrade to deepspeed master and it will go away.

I think the slight difference in filesize is OK, since it has all kinds of keys stored. I will try to save just the data entries we want from your checkpoint to compare.

But why your model states are shorted, this is strange. Basically your model states include the placeholders for the model weights and not the model itself. In ZeRO-3 the model is partitioned over gpus and deepspeed uses a placeholder of tensor of Size[1] and reconstructs the weights on the fly just before each forward.

Let me think some more about this strange discrepancy a bit later and I will get back to you.

@stas00
Copy link
Collaborator

stas00 commented Apr 28, 2021

oh, completely missed the crucial point - I get zero2 checkpoint and you sent me zero3 checkpoint. thank you, @tjruwase for noticing that! So we aren't testing the same thing.

ah yes, your ./run-siamese-small.sh is configured to run zero2 checkpoint. that explains it.

@stas00
Copy link
Collaborator

stas00 commented Apr 28, 2021

Yes, OK, now I changed your script to use ZeRO-3 config and I can now reproduce the problem! Will be looking at it later today.

@exelents
Copy link
Author

Oh, that's my fault, I haven't staged changes in *.sh scripts in my repo...

@stas00
Copy link
Collaborator

stas00 commented Apr 28, 2021

That's no problem. The zero2 checkpoint looks significantly different from zero3 one. We will sort it out.

@stas00
Copy link
Collaborator

stas00 commented Apr 29, 2021

OK, I see it stores the weights differently under your model, could you please give a chance to this version of the script?

https://gist.github.com/stas00/479a6ae2fac070e866440d8a5211f5cd

Please ignore all the debug noise, just watch for the final:

Saving fp32 state dict to pytorch_model.bin (total_numel=70661632)

should be around 270MB for the t5-small in your double model.

I did only a quick validation, so I'm not 100% it reconstructs it correctly.

Meanwhile, let me do the same against the staple example script

@stas00
Copy link
Collaborator

stas00 commented Apr 29, 2021

So the problem was that this model had multiple param groups and it was stored differently.

@exelents, please try this PR #1017
You will need to re-run the training to get the script updated, or you can just manually copy it from here

Please let me know the outcome. Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants