Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Deepspeed ZeRO-3] Broken model save on fresh Transformers branch #10789

Closed
exelents opened this issue Mar 18, 2021 · 23 comments
Closed

[Deepspeed ZeRO-3] Broken model save on fresh Transformers branch #10789

exelents opened this issue Mar 18, 2021 · 23 comments
Assignees

Comments

@exelents
Copy link

exelents commented Mar 18, 2021

I have my own model, which utilize two T5 encoders, and I train it via DeepSpeed. It has it's own save_pretrained() and from_pretrained() methods, which makes a custom load/save logic:
https://github.com/exelents/try_t5_siamese/blob/4140194978ac113c45e7370f40b3d9b932d0b35b/siamese_model.py#L80

When I run training and trainer starts to save checkpoint, there are going something strange: weights file for every saved encoder is going to be e few kilobytes - weights are not going to be saved.
On the start of training trainer tries to load checkpoint using model.load_checkpoint(), but it seems this function has it's own loading logic, because it cannot exec my load model logic and throws an error:
ValueError: [deepspeed] failed to resume from checkpoint ./templates/siamese-t5-small-v1_1-template
I can comment this code, which loads checkpoint, but then I got described before problem with saving checkpoint...

What should I do to make save my own custom model properly? It worked a month ago, but today I refreshed my Transformers repo and everything has broken.

@samsontmr
Copy link

I'm getting a similar problem after training BERT with MLM using DeepSpeed where all the saved weights are of size 1. The same run_mlm script worked as expected if I didn't use DeepSpeed.

RuntimeError: Error(s) in loading state_dict for BertForSequenceClassification: size mismatch for bert.embeddings.word_embeddings.weight: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([119547, 768]). size mismatch for bert.embeddings.position_embeddings.weight: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([512, 768]). size mismatch for bert.encoder.layer.0.attention.self.query.weight: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([768, 768]).

@LysandreJik
Copy link
Member

Since this is using DeepSpeed, maybe @stas00 has an idea?

@samsontmr
Copy link

samsontmr commented Mar 18, 2021

Just tried loading a model trained with sharded_ddp and got a different error:

Traceback (most recent call last):
  File "/export/proj/code/transformers/src/transformers/modeling_utils.py", line 1057, i
n from_pretrained
    state_dict = torch.load(resolved_archive_file, map_location="cpu")
  File "/export/proj/env_cuda11_1/lib/python3.7/site-packages/torch/serialization.py", line 593, in load
    return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
  File "/export/proj/env_cuda11_1/lib/python3.7/site-packages/torch/serialization.py", line 762, in _legacy_load
    magic_number = pickle_module.load(f, **pickle_load_args)
EOFError: Ran out of input

It seems the model saving might not be happening properly for these two integrations? I also noticed that only the config and weights were being saved when using --sharded_ddp.

UPDATE: It's actually the checkpoint saving getting stuck that's causing this issue. Started another run to confirm and it got stuck while saving as well.

UPDATE 2: This only happens with zero_dp_2 and zero_dp_3. simple appears to work fine. For DeepSpeed, using stage 2 appears to fix the problem (I was previously using stage 3).

@exelents
Copy link
Author

exelents commented Mar 18, 2021

@samsontmr I have changed DeepSpeed stage to 2 and it seems works well - checkpoints are saved properly. I also used DeepSpeed stage 3 before.

It seems problems are in Stage 3 integration. Maybe @stas00 could help, he did previous integration of DeepSpeed into trainer.

@stas00
Copy link
Contributor

stas00 commented Mar 18, 2021

DeepSpeed Stage 3 integration is not finished yet, a wip PR is here if you'd like to try it - though it has a ton of debug statements still and a few more features are still missing.
#10753

Make sure you are using the latest deepspeed since zero3 had problems with saving checkpoint but the 0.3.13 release should be good.

But I am pretty sure the issue is different, as I literally merged the code that generated the error you quoted 2 days ago:
If it worked before please roll back to any sha before #10760 and let me know if it works.

The problem with DeepSpeed is that it doesn't currently have a way to save a fp32 checkpoint that can be loaded normally and not via DeepSpeed, microsoft/DeepSpeed#800 so when you save a model you only get an fp16 version. However its special checkpoint (see e.g. global-step10 folder in the checkpoint folder) contains all the right data and thus if you want to load deepspeed model you need to train(resume_from_checkpoint) instead.

So if you want to resume training you can't use from_pretrained() at the moment, unless fp16 weights are sufficient for your work. And it sounds that it's broken at the moment.

Let me know if any of this makes sense and let's see how we can make your code work with what we have.

I'd be happy to adapt my recent changes to meet your needs.

@samsontmr
Copy link

Thanks for the detailed reply @stas00! Is the issue with the fp32 checkpoint saving only happening with zero3 or also with stage 2? My fine-tuning step started with no issues when I used the checkpoint from a stage 2 training run (hasn't completed yet so I'm not sure how it'll end up).

@stas00
Copy link
Contributor

stas00 commented Mar 18, 2021

Is the issue with the fp32 checkpoint saving only happening with zero3 or also with stage 2?

It's an issue with any zero stage under deepspeed.

Are you saying that the problem emerged once switching to zero3 config? I'm not at all sure it can resume from zero2 checkpoint to zero3 config - those are quite different setups. So we really need to get the fp32 saving sorted out

Let's see if we can ask to make this a higher priority at #10789

@samsontmr
Copy link

Are you saying that the problem emerged once switching to zero3 config? I'm not at all sure it can resume from zero2 checkpoint to zero3 config - those are quite different setups. So we really need to get the fp32 saving sorted out

Yup, I didn't try going from zero2 to zero3; I just restarted my training using zero2, then fine-tuned the model without deepspeed... which somehow managed to load just by using .from_pretrained

@stas00
Copy link
Contributor

stas00 commented Mar 18, 2021

As I tried to explain you were getting only fp16 weights when using from from_pretrained which may or may not be good enough for your needs. It mostly should be OK. Except some metrics or feature may break under fp16 if they weren't coded for it.
e.g. #10674

So let's lay out a test that I need to work on to reproduce your issues. Could you please lay out a sequence of events - ideally in code but pseudo-code will work too and then I will try to see where the breakage is.

The PR I referred to includes several save/resume tests, so the saving is normal, and resume uses train(resume_from_checkpoint) and it works too. Though I need to add zero3 test as well. Only tested zero2 so far. The resume test is here:

def test_can_resume_training(self):

You shouldn't get:

ValueError: [deepspeed] failed to resume from checkpoint ./templates/siamese-t5-small-v1_1-template

if you're not trying to do train(resume_from_checkpoint), you can see where it gets triggered:

if resume_from_checkpoint is not None: # and os.path.isdir(resume_from_checkpoint):

@exelents
Copy link
Author

exelents commented Mar 18, 2021

As for me: I fixed my problem with unnessesary checkpoint load, where I get load error, but it still has an save error on DeepSpeed stage 3 mode. If you @stas00 could help me, I would appreciate.

Here is steps to reproduce my error with model save:

  • Clone this repo:
    https://github.com/exelents/try_t5_siamese

  • Extract folder "qasc" from this archive:
    https://drive.google.com/file/d/1gwvFiPzWW0JLr0XLS25PuG2Br5S4fPbR/view?usp=sharing

  • Go to clonned repo folder and run ./create-siamese-template.sh - it will create siamese NN from two t5-small encoders in folder ./templates/siamese-t5-small-template

  • then you can run ./run-siamese-small.sh - you will see normal behaviour, in folder ./siamese_train_deepspeed/output_dir/ you will find there will be stored checkpoints every 3 steps? and you will can see a sight that weights are stored:
    weights files like ./siamese_train_deepspeed/output_dir/checkpoint-6/left/pytorch_model.bin will have size around hundred megabytes.

  • Then to see a problem open ./run-siamese-small.sh and change "ds_config.json" to "ds_config_stage3.json" and rerun training. You will see that weights files, like ./siamese_train_deepspeed/output_dir/checkpoint-6/left/pytorch_model.bin will have size for a few kilobytes, and you couldn't load model from that checkpoint. There is a probleb, and it appears only if I turn on "stage 3" mode in config.

@stas00
Copy link
Contributor

stas00 commented Mar 18, 2021

Thank you for the detailed instructions, @exelents.

Let me adapt the existing test first to zero3 so I am sure it's working and then will try your sequence. I will keep you posted.

@stas00
Copy link
Contributor

stas00 commented Mar 18, 2021

I can reproduce the saved model size problem. pytorch_model.bin with:

  • zero2 135M
  • zero3 38K

but as I mentioned currently Deepspeed doesn't provide a proper way to save a model on its own.

It saves the model state in its own sub-folder, e.g., in your case:

ls -l output_dir/checkpoint-6/global_step6/
total 809M
-rw-rw-r-- 1 stas stas  53K Mar 18 14:03 zero_pp_rank_0_mp_rank_00_model_states.pt
-rw-rw-r-- 1 stas stas 809M Mar 18 14:03 zero_pp_rank_0_mp_rank_00_optim_states.pt

as you can see the optimizer states dict has everything in it. So you should be able to resume from it.

Your script is a bit old and based on an old example - so it doesn't support the current mechanism of doing resume from command line using https://github.com/huggingface/transformers/blob/master/examples/README.md#resuming-training

So for resume to currently work, you either need to bring your script up-to-date, by probably checking the latest version of the example you used as a base for your work.

The key is train(resume_from_checkpoint) if you passed this as output_dir/checkpoint-6 deepspeed reloads where it left on and continues on its merry way.

To help you think the new script in your case is this and I pointed to where the critical part is:

train_result = trainer.train(resume_from_checkpoint=checkpoint)

(this is on master)

So if you could bring your script up-to-date with the current way it'd automatically work, or you can adapt it manually as I suggested above.

If any of my comments are unclear please don't hesitate to ask for clarifications.

Meanwhile I will investigate why the model state_dict is almost empty under zero3 - this looks like a bug - making it work might help you move on w/o needing you to change your code.

I will get back to you.

@stas00
Copy link
Contributor

stas00 commented Mar 18, 2021

I investigated and model.state_dict() returns some sort of placeholder with tensor([1.], for each weights and no real data, that's why pytorch_model.bin is tiny. Filed a request: microsoft/DeepSpeed#872

So until we find a way to reconstruct it, I suggest to stick to zero2 otherwise you will remain locked in into DeepSpeed data files, that is you should be able to continue training but not being able to use it w/o deepspeed.

@stas00 stas00 self-assigned this Mar 21, 2021
@stas00 stas00 changed the title Broken model save on fresh Transformers branch [Deepspeed ZeRO-3] Broken model save on fresh Transformers branch Mar 21, 2021
@stas00
Copy link
Contributor

stas00 commented Mar 21, 2021

While the Deepspeed team is sorting the addition of a method to extract model weights from its checkpoint, here is an update for you.

Deepspeed stores the model weights in its checkpoint file (a file per gpu) which at the moment can only be loaded via its deepspeed.load_checkpoint. Therefore please adapt your code to rely on that to save and resume your custom models. Do not rely on save_pretrained and then expect from_pretrained to work, since the model weights won't be there.

The new method we are discussing will be able to convert the deepspeed checkpoint into consolidated from multiple gpus model weights. This is quite expensive so it shouldn't happen on each checkpoint saving and definitely shouldn't be the default because there might not be enough memory to do the consolidation (e.g. a model spread out over dozens of gpus).

Bottom line, should you choose to use deepspeed zero-3 things aren't as straightforward. And we will work out a solution in this case.

I suppose it's a similar story with fairscale Sharded DDP, but I am working on DeepSpeed only at the moment and can't comment on the former. Unless @sgugger who did the initial integration of fairscale beats me to it I will be able to look at it once I complete the integration of DeepSpeed ZeRO-3, which is coming along nicely but requires changes on the DeepSpeed side - so it'll take some time.

@stas00
Copy link
Contributor

stas00 commented Mar 21, 2021

@exelents, here is how to solve your specific problem of:

class T5Siamese(T5PreTrainedModel):
[....]
    def init_from_base_t5_model(model_name_or_path='t5-base', output_root='./'):
        [...]
        model_left = T5EncoderModel.from_pretrained(MODEL)
        model_right = T5EncoderModel.from_pretrained(MODEL)

with DeepSpeed zero-3.

If you don't mind continuing training and not being to retrieve the final weights until microsoft/DeepSpeed#872 is addressed, here is what you can do immediately to be able to move forward:

Do the above only when starting "cold", but when resuming from a checkpoint don't do that and let instead T5Siamese be restored from the deepspeed checkpoint at once.

Once we get the method to extract the model weights out of the DeepSpeed checkpoint, you can then recover both sub-model weights if you want to upload them to the hub or to take them elsewhere.

Please let me know if this solution resonates with you. Or if you run into any hiccups I haven't considered.

Note that currently under zero-2 you're only recovering fp16 weights, so it is also not ideal either. So you want to use this solution for both cases.

@stas00
Copy link
Contributor

stas00 commented Mar 21, 2021

@samsontmr, would you kindly open a separate issue since while this is related the use-case is quite different. Please tag me and we will work on solving your use case there. Thank you!

p.s. also when you test please make sure you are using the transformers and deepspeeed master since there are constant fixes merged into it.

@exelents
Copy link
Author

@stas00 Thank you for the explanation. So, to load stage-3 checkpoint I should make "cold load" from original T5 weights, and then load actual weights via deepspeed.load_checkpoint . The question is: is it possible to use this model in usual jupyter notebook, or usual python script, if I load model weights using deepspeed function? Or if I trained model via deepspeed once, I will be bound to it's runner forever?

@stas00
Copy link
Contributor

stas00 commented Mar 21, 2021

So, to load stage-3 checkpoint I should make "cold load" from original T5 weights, and then load actual weights via deepspeed.load_checkpoint .

I haven't tested it, but I can't think of any reason why it won't work. If you run into problems that I haven't considered please let me know.

The question is: is it possible to use this model in usual jupyter notebook, or usual python script, if I load model weights using deepspeed function?

Yes, of course.

Just note that if you use the notebook directly and don't launch an external process which launches the distributed environment, you will be limited to 1 gpu and you will have to emulate the distributed environment like so:

import os
dist_env_1_gpu = dict(MASTER_ADDR="localhost", MASTER_PORT="10999", RANK="0", LOCAL_RANK="0", WORLD_SIZE="1")
for k,v in dist_env_1_gpu.items():
    os.environ[k] = v

and please make sure you're on the master or very recent transformers version for this to work.

But if you just use the notebook to open a shell with the deepspeed launcher then you have no limitation of one gpu, e.g. see: https://github.com/stas00/porting/blob/master/transformers/deepspeed/DeepSpeed_on_colab_CLI.ipynb

Or if I trained model via deepspeed once, I will be bound to it's runner forever?

I'm not sure what you ask here, as I don't know whether you refer to the deepspeed launcher, or something else.

  1. The deepspeed launcher is a more elaborate equivalent of python -m torch.distributed.launch. In simple cases of a single node you can use the latter. Here all DeepSpeed needs is to have a dedicated process per gpu and the distributed env set up (even in the case of one gpu).

  2. If you're asking whether your data will be locked into the deepspeed checkpoints, then at the moment the answer is yes.
    Once [zero3] how to get the model reconstructed for saving? microsoft/DeepSpeed#872 is resolved you will be able to recover the consolidated weights and use them in any way you want.

@exelents
Copy link
Author

Ok, thank you for the explanation. I'm not sure if I could test these changes on my code soon, but I'll do it sooner or later.

@stas00
Copy link
Contributor

stas00 commented Mar 21, 2021

I just proposed yet another API in microsoft/DeepSpeed#872:

being able to call deepspeed.consolidate_weights() in the rank0 process which would give users full weights back (perhaps with a bool arg of whether they want the fp16 or fp32 version). So now they can just save the model as they do with any other pytorch tools. This would only be practical for small-ish models. The key here is that while this would be somewhat costly they will be able to use their code almost w/o any change if they train in various ways and not just with deepspeed.

So if that was added then your current code would also work with just adding this newly proposed API. Let's see.

@samsontmr
Copy link

@stas00 thanks! My problem is solved for now since I'm also using fp16 during fine-tuning so the current stage2 saves are good enough for me.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@exelents
Copy link
Author

Hello, @stas00. I have created an issue due to problems with converting model to fp32. Can you say something about it?
microsoft/DeepSpeed#1009

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants