-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Deepspeed ZeRO-3] Broken model save on fresh Transformers branch #10789
Comments
I'm getting a similar problem after training BERT with MLM using DeepSpeed where all the saved weights are of size 1. The same
|
Since this is using DeepSpeed, maybe @stas00 has an idea? |
Just tried loading a model trained with
It seems the model saving might not be happening properly for these two integrations? I also noticed that only the config and weights were being saved when using UPDATE: It's actually the checkpoint saving getting stuck that's causing this issue. Started another run to confirm and it got stuck while saving as well. UPDATE 2: This only happens with |
@samsontmr I have changed DeepSpeed stage to 2 and it seems works well - checkpoints are saved properly. I also used DeepSpeed stage 3 before. It seems problems are in Stage 3 integration. Maybe @stas00 could help, he did previous integration of DeepSpeed into trainer. |
DeepSpeed Stage 3 integration is not finished yet, a wip PR is here if you'd like to try it - though it has a ton of debug statements still and a few more features are still missing. Make sure you are using the latest deepspeed since zero3 had problems with saving checkpoint but the 0.3.13 release should be good. But I am pretty sure the issue is different, as I literally merged the code that generated the error you quoted 2 days ago: The problem with DeepSpeed is that it doesn't currently have a way to save a fp32 checkpoint that can be loaded normally and not via DeepSpeed, microsoft/DeepSpeed#800 so when you save a model you only get an fp16 version. However its special checkpoint (see e.g. So if you want to resume training you can't use Let me know if any of this makes sense and let's see how we can make your code work with what we have. I'd be happy to adapt my recent changes to meet your needs. |
Thanks for the detailed reply @stas00! Is the issue with the fp32 checkpoint saving only happening with zero3 or also with stage 2? My fine-tuning step started with no issues when I used the checkpoint from a stage 2 training run (hasn't completed yet so I'm not sure how it'll end up). |
It's an issue with any zero stage under deepspeed. Are you saying that the problem emerged once switching to zero3 config? I'm not at all sure it can resume from zero2 checkpoint to zero3 config - those are quite different setups. So we really need to get the fp32 saving sorted out Let's see if we can ask to make this a higher priority at #10789 |
Yup, I didn't try going from zero2 to zero3; I just restarted my training using zero2, then fine-tuned the model without deepspeed... which somehow managed to load just by using |
As I tried to explain you were getting only fp16 weights when using from So let's lay out a test that I need to work on to reproduce your issues. Could you please lay out a sequence of events - ideally in code but pseudo-code will work too and then I will try to see where the breakage is. The PR I referred to includes several save/resume tests, so the saving is normal, and resume uses
You shouldn't get:
if you're not trying to do transformers/src/transformers/integrations.py Line 452 in 008672e
|
As for me: I fixed my problem with unnessesary checkpoint load, where I get load error, but it still has an save error on DeepSpeed stage 3 mode. If you @stas00 could help me, I would appreciate. Here is steps to reproduce my error with model save:
|
Thank you for the detailed instructions, @exelents. Let me adapt the existing test first to zero3 so I am sure it's working and then will try your sequence. I will keep you posted. |
I can reproduce the saved model size problem.
but as I mentioned currently Deepspeed doesn't provide a proper way to save a model on its own. It saves the model state in its own sub-folder, e.g., in your case:
as you can see the optimizer states dict has everything in it. So you should be able to resume from it. Your script is a bit old and based on an old example - so it doesn't support the current mechanism of doing resume from command line using https://github.com/huggingface/transformers/blob/master/examples/README.md#resuming-training So for resume to currently work, you either need to bring your script up-to-date, by probably checking the latest version of the example you used as a base for your work. The key is To help you think the new script in your case is this and I pointed to where the critical part is:
(this is on master) So if you could bring your script up-to-date with the current way it'd automatically work, or you can adapt it manually as I suggested above. If any of my comments are unclear please don't hesitate to ask for clarifications. Meanwhile I will investigate why the model state_dict is almost empty under zero3 - this looks like a bug - making it work might help you move on w/o needing you to change your code. I will get back to you. |
I investigated and So until we find a way to reconstruct it, I suggest to stick to zero2 otherwise you will remain locked in into DeepSpeed data files, that is you should be able to continue training but not being able to use it w/o deepspeed. |
While the Deepspeed team is sorting the addition of a method to extract model weights from its checkpoint, here is an update for you. Deepspeed stores the model weights in its checkpoint file (a file per gpu) which at the moment can only be loaded via its The new method we are discussing will be able to convert the deepspeed checkpoint into consolidated from multiple gpus model weights. This is quite expensive so it shouldn't happen on each checkpoint saving and definitely shouldn't be the default because there might not be enough memory to do the consolidation (e.g. a model spread out over dozens of gpus). Bottom line, should you choose to use deepspeed zero-3 things aren't as straightforward. And we will work out a solution in this case. I suppose it's a similar story with fairscale Sharded DDP, but I am working on DeepSpeed only at the moment and can't comment on the former. Unless @sgugger who did the initial integration of fairscale beats me to it I will be able to look at it once I complete the integration of DeepSpeed ZeRO-3, which is coming along nicely but requires changes on the DeepSpeed side - so it'll take some time. |
@exelents, here is how to solve your specific problem of:
with DeepSpeed zero-3. If you don't mind continuing training and not being to retrieve the final weights until microsoft/DeepSpeed#872 is addressed, here is what you can do immediately to be able to move forward: Do the above only when starting "cold", but when resuming from a checkpoint don't do that and let instead Once we get the method to extract the model weights out of the DeepSpeed checkpoint, you can then recover both sub-model weights if you want to upload them to the hub or to take them elsewhere. Please let me know if this solution resonates with you. Or if you run into any hiccups I haven't considered. Note that currently under zero-2 you're only recovering fp16 weights, so it is also not ideal either. So you want to use this solution for both cases. |
@samsontmr, would you kindly open a separate issue since while this is related the use-case is quite different. Please tag me and we will work on solving your use case there. Thank you! p.s. also when you test please make sure you are using the |
@stas00 Thank you for the explanation. So, to load stage-3 checkpoint I should make "cold load" from original T5 weights, and then load actual weights via |
I haven't tested it, but I can't think of any reason why it won't work. If you run into problems that I haven't considered please let me know.
Yes, of course. Just note that if you use the notebook directly and don't launch an external process which launches the distributed environment, you will be limited to 1 gpu and you will have to emulate the distributed environment like so:
and please make sure you're on the master or very recent But if you just use the notebook to open a shell with the
I'm not sure what you ask here, as I don't know whether you refer to the
|
Ok, thank you for the explanation. I'm not sure if I could test these changes on my code soon, but I'll do it sooner or later. |
I just proposed yet another API in microsoft/DeepSpeed#872:
So if that was added then your current code would also work with just adding this newly proposed API. Let's see. |
@stas00 thanks! My problem is solved for now since I'm also using fp16 during fine-tuning so the current stage2 saves are good enough for me. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Hello, @stas00. I have created an issue due to problems with converting model to fp32. Can you say something about it? |
I have my own model, which utilize two T5 encoders, and I train it via DeepSpeed. It has it's own save_pretrained() and from_pretrained() methods, which makes a custom load/save logic:
https://github.com/exelents/try_t5_siamese/blob/4140194978ac113c45e7370f40b3d9b932d0b35b/siamese_model.py#L80
When I run training and trainer starts to save checkpoint, there are going something strange: weights file for every saved encoder is going to be e few kilobytes - weights are not going to be saved.
On the start of training trainer tries to load checkpoint using model.load_checkpoint(), but it seems this function has it's own loading logic, because it cannot exec my load model logic and throws an error:
ValueError: [deepspeed] failed to resume from checkpoint ./templates/siamese-t5-small-v1_1-template
I can comment this code, which loads checkpoint, but then I got described before problem with saving checkpoint...
What should I do to make save my own custom model properly? It worked a month ago, but today I refreshed my Transformers repo and everything has broken.
The text was updated successfully, but these errors were encountered: