-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FP16 overflow with GPT-Neo when using sequence lengths of 2048. #11076
Comments
Thank you for the report, @LouisCastricato I think it's pretty safe to take DeepSpeed out of the equation for now, since as you're saying the problem is due to mixed precision, so let's deal with AMP first. How was GPT-Neo pre-trained? Is it by chance the case that GPT-Neo was pre-trained with |
The 1.3b model was pretrained on TPUs in mesh-tf using fp16. |
You mean mixed precision fp16, correct? As I haven't used mesh-tf - what would be the equivalent of this setup in the pytorch land? Since if we find the exact equivalent and the model was ported correctly and is used under the same setup - this problem shouldn't exist. does it make sense? So let's find out what is different here (assuming the porting was done correctly). |
OK so I just wrote about it today: https://discuss.huggingface.co/t/mixed-precision-for-bfloat16-pretrained-models/5315 I will try to look at it tomorrow, this is probably the same story as t5/mt5 then. |
It'd help to save time if you had a ready way to reproduce the problem, I tried:
It hardly fits onto a 24GB card with a tiny block size, and fp16 OOMs right away. I don't suppose you have a smaller model to experiment with? Straightforward
Thanks. |
We are working on producing a minimal example for you currently. After checking our internal documents we realized that 1.3b is bfp16 where as 2.7b is fp32 |
If you need an A100 to test on, let us know. |
Hi! As we're doing a few changes to the implementation to make it cleaner over in #10985, we ran a quick training to ensure that the model could still train. We leveraged @Xirider's script detailed in https://github.com/Xirider/finetune-gpt2xl in order to fine-tune the 1.3B checkpoint, and we did see a decrease in the loss over this small sample: We didn't investigate further, but this allows to fine-tune the 1.3B variant on a single V100 GPU. cc @patil-suraj |
That was sequence length 2048? |
It's 1024 on wikitext |
Hm... Maybe our project is just cursed then. Thanks for the pointer, I'll go through installations and see if anything is weird. |
I ran the fine-tuning on the recent branch so I thought this might be it; but I just tested on |
I'm running this on 24GB rtx-3090 and while it's not converging it's not getting NaNs:
|
It looks like the version of DeepSpeed we are running (0.3.11) prevents us from running that example on our hardware. We are in the process of updating DeepSpeed to a newer version (>0.3.12) so that it is not caught by line 287 of |
I'm able to reproduce
around step 24/174. except I'm using 2 uncommitted branches mentioned in #11044 I will try to reduce it something smaller. p.s. for reproducibility purpose here is the config I used:
|
Thank you! |
Does that happen with zero-2? |
Oddly enough it's fine with zero-2 in this particular setup, but the configurations aren't the same so we aren't comparing the same things. But also if I do the same zero-3 training on one gpu there no nan either. But that doesn't matter, as long as we have a way to reproduce nans it's good enough to start working on understanding the cause and then fixing it. @samyam from DeepSpeed suggested an idea to try, so I'm going to go back to the mt5 which gets a NaN on the very first step and experiment with it first, since it's much faster than dealing with step 21 of this really heavy model. And then if it works will come back to gpt-neo. If meanwhile you find a much faster way to get to NaNs that would be helpful. |
Also I don't know if this is somehow related, but this message looks alarming:
I think there is a bug somewhere, but it might be unrelated. edit: I verified this is just a misplaced warning, not a problem |
@stas00 an update: We were able to run your code on both 125M and 1.3B models without issue. The loss goes down, we get Shakespearean language, all is good. Unfortunately, we cannot use your code for our task. We are seeking to train a dual objective model with two complete different datasets. We have two datasets that we are mashing together and trying to train via contrastive loss. Unfortunately, it appears that using the HF trainer class makes that more or less impossible. Is there a better way to do the pipelining, so we can evade whatever the bug we are running into is? We tried to run it on sequence length 1024, but it ended up eventually going to NaN anyways after a thousand steps or so. |
The The model was trained in one numerical range and you're trying to run it in a different range that it wasn't trained for - there is not too much that can be done here. It's the same problem for any bfloat16-pretrained model. Which includes t5/mt5/pegasus to name a few. The fine-tuning/inference should be done in the same environment it was trained in or an environment that the model can numerically translate to. This is not the case with bf16 vs fp16 - please refer to my commentary at https://discuss.huggingface.co/t/mixed-precision-for-bfloat16-pretrained-models/5315 What we are trying to do now is to find a workaround that will not provide a full mixed precision regime, but a partial one. For that we need to find which operations are safe to run in fp16 and which aren't. And unfortunately as you can see some of these "runaways" happen after thousands of steps.
Oh, fantastic! Just discovered https://huggingface.co/EleutherAI/gpt-neo-125M after you mentioned it - it'd be much easier to debug with. Thank you for that! Which "your code" are you referring to? Trainer + run_clm.py? I hear you that the HF Trainer is not suitable for your task. But if you have your own Trainer that works, why won't you use that instead? On other words how can we support you in this situation? |
What I'm going to do next is:
This proved to be not possible at the moment
|
The dual objective code we refer to can be found here: https://github.com/EleutherAI/visual-grounding And ok sounds good. The offer for A100s still stands btw, fp32 might be a nightmare on an RTX 3090. |
Thank you for your offer, @LouisCastricato - You're very likely correct - I may take you up on that offer at a later time. I'm not planning on finetuning gpt-neo in fp32 on rtx-3090, but just to test that deepspeed can even run in fp32 on a small model. Because if it works you could at least do that.
Yes, but I'm not sure what to do with this information. My guess is that you developed your own trainer and you're trying to integrate deepspeed into it and are running into issues? What is it specifically that you need to move forward with your project or what is blocking you? |
Oh apologies. I shared it so that you could see the configuration we're using. I think I might have accidentally deleted that part though (bigs thumbs and touchscreens) Yes, we're trying to integrate DeepSpeed directly with our training code. Both ds_config.json and amp_config.json produce the same NaN error strictly on autoregressive batches- before the forward step. We have not seen the NaN error on the backwards step. Therefore, since we do not see it on the other component of our dual objective (in this case is Google's WIT dataset) which has sequence lengths at most 128 tokens. We can see NaNs beginning to appear at sequence length 768 and once we get to 2048 its every batch that has NaNs. |
Thank you for clarifying that, @LouisCastricato Understood. I will have to get to know this model as I have never worked with it. So I will comment once I had a chance to sit with it after I install all kinds of debug hooks into it. wrt your config, it looks good.
these might be too small for an efficient operation. You want these to be in 2e8 to 5e8 range according to Samyam. I also recommend you switch to the e-notation - it's too easy to miss a zero. In zero3 they have a param with 14 zeros! You may want to enable cpu-offload if you have extra RAM. Otherwise there isn't that much to configure in zero-2. There is a lot more to tune up in zero-3. As I mentioned a few comments up, Deepspeed makes an efficient use of hardware, but if the model itself is an issue there is not much that changing Deepspeed configuration can do. |
Hi, I was curious if there was any update on this? |
I was busy working on the DeepSpeed ZeRO-3 integration with If I knew it was a quick fix I'd have done it right away, but this kind of a problem is a long process so I need to have uninterrupted time to work on it. Moreover, fixing it in AMP won't necessarily fix it in DeepSpeed (but it'd surely help). I started working on the checklist, I'm aware that this is a big problem for you guys and I thought that perhaps at least you could run DeepSpeed in fp32, but, alas, currently it's not possible - you can disable I doubt the DeepSpeed developers will take care of this any time soon as they have no resources to do so, so if you want to help that could be one task that might help to move things forward a bit - making Deepspeed work with fp32. Then the next stage would be to optimize the parts that can be done in fp16 w/o overflow leaving most of it in fp32. Samyam suggested the matmuls in FF layers would be the best part to do in fp16 as I mentioned some comments earlier. Just to give you an idea, the process goes like this: I find something that doesn't work or is missing for Let's hope they manage to expand their team with the recent job openings they posted and have more resources to support the projects that integrate their work. I also asked them and all the models they have been working with were trained in mixed fp16 precision, so had no reason to sort out bfloat16 (yet). So priorities-wise, will having full fp32-support be useful to you or not really? |
Yes, but large logits are a potential symptom of what's going on in the network. I've just created a new debug tool that helps diagnosing the activation overflow issue, just waiting for review to complete, but if you want to try it sooner please grab this branch: #11274 and add |
which could indicate an issue in the model design I think. If the original model can't even do the math in Or alternatively perhaps those were high precision numbers that |
|
Let me know if I can provide any more traces for you. |
So this looks more like an underflow, rather than overflow, as activations are tiny and you got In this case I will modify the code to print abs_min as well - we are probably going to see tiny-tiny numbers there. How do I reproduce this? |
We are trying to find a minimal example for you that can be ran on a 3090. |
The lowest we could make the memory requirement was 32GB. We sent you a login for an instance with 6x A100s. The command you need to run, under ~/visual-grounding/Training/ is
It should output NaN information after the first batch. It is a (semi) minimal example that uses a custom AR trainer, but it crashes before the first optimizer step. The code is (relatively) easy to follow without reading through any of the custom data loaders. We've already confirmed it works with GPT2-XL. Transformers was not installed from source as editable but I assume you wanted to use a custom branch for this so I just installed it from pypi for you. |
Thank you, @LouisCastricato! I needed to install my own branch, but I was able to reproduce with the updated detector, which now gives a much better picture. So with your custom code getting:
|
I run the detector under the fp32 and the deepspeed/fp16 mode and as I suspected we are having an underflow here - a serious underflow. Attached 2 traces: frames_overflow.txt from the very first forward, we have a problem with embeddings: fp32:
fp16:
As you can see some weights are immediately As shown here: https://github.com/stas00/ml-ways/blob/master/numbers/bfloat16-vs-float16-study.ipynb fp16 can barely handle
Deepspeed runs in |
OK, trying to force fp32 mode in deepspeed by editing its engine to skip
Your I stopped it after 5482 steps (2h). |
Yeah the minimal example removes all evaluation. In FP32 it does work though. I tested a checkpoint the other day. |
I'm asking Deepspeed devs if they have some ideas on how to overcome this, I will keep you posted if we find a good intermediary solution. But at the very least we now know why the model fails under fp16. I wonder if pre-training processes targeted for mixed precision use should have a loss penalty component that forces the model to remain within fp16 dynamic range, both upper and lower. |
microsoft/DeepSpeed#974 (comment) This could be relevant. |
OP is asking to support It'd be awesome for deepspeed to support |
I meant the changes they recommended making could also help resolve our FP16 issues. They outlined what would need to be changed for bf16 |
Currently the very first problem is deepspeed calling
Therefore I can't see how any of the suggestions directed to support Chances are that deepspeed will need a new mode, which is not all- |
OK, please have a look at the current setup on your instance, try:
It also currently requires a hardcoded change:
which is already applied under so this is zero3. zero2 still needs some work. |
But looking closer at your code, I see now that we have been trying to solve the wrong problem all along. Why is your code using "EleutherAI/gpt-neo-2.7B", when one of you said earlier was pre-trained in full fp32? how could you possibly expect it to train or eval in fp16? or did you just want deepspeed in fp32 mode? Please clarify. One of you said it's 1.3B checkpoint that was trained in |
OK, zero2 now works too.
So Samyam explained that this new deepspeed branch enables full FP32 mode. But since your setup is running on A100, pytorch uses TF32, so you're getting an equivalent speed to fp16 on V100. RTX-3090 should also be able to get this performance. All kudos go to @samyam. |
We've been having the nan issue with both the bf16 1.3B checkpoint and the fp32 2.7B checkpoint; we were under the assumption that as both have the same dynamic range, both would have the same under/overflow problems. I'm also pretty sure that the bf16 1.3B checkpoint was trained with bf16 activations with fp32 master weights quantized to bf16 (the quantization was a mistake by one of our devs). Our main problem is that with fp32, 1.3B, and no deepspeed, we can't even fit a single full batch without OOM, and we can't turn on any deepspeed optimizations without fp16 being on (interestingly, it seems the OOM doesn't happen with Samyam's branch). Of course, we would like to train our model using mixed-precision (using fp32 for the parts that are underflowing) for the obvious memory savings, so we thought it would be much easier to just make our model work with mixed-precision and also get those memory savings than to make deepspeed work with fp32. We would also be fine with making deepspeed work with fp32 or bf16 if it's significantly easier. Thanks for all your time in helping us with this issue. |
In general if you want users to be able to use fp16 mixed precision for fine-tuning and inference you need to pre-train the model using this mode. For some models we find certain workarounds that localize switching to fp32 for specific submodules, that lead to underflow/overflow under fp16, but often users still get NaNs during long training. Bottom line, if you pre-train in bf16 be prepared to tell users to use fp32 or bf16 in their fine-tuning/inference processes. As the new hardware supporting bf16/tf32 formats emerges (rtx-3090 + a100) this will be come the simple go-to solution in the future. Now that deepspeed will have a full-fp32 mode this is great. So to summarize, at this moment with Samyam's branch if you use:
|
How would one use this special fp32 mode without zero? |
You mean w/o deepspeed (or fairscale)? Just don't enable mixed precision in the training. i.e. in Unless you ask how to use deepspeed w/o zero - why would you want to do that? ZeRO is the core of deepspeed and if you are not using it, you don't really need deepspeed. If I misunderstood your question please clarify. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
You wrote: Did you mean to write fp16 or bf16? According to the detector tool I'm working on it is most likely fp16. It'd be super helpful if you could check on how it was trained. Thank you! If you have other published model checkpoints and their dtype that would be very helpful too, as I'm trying to gather that information. |
Talked to Stella and she confirmed Louis meant to write bf16 for 1.3B model. |
Environment info
transformers
version: 4.5.0.dev0Who can help
@stas00
Models:
Library:
Information
Model I am using (Bert, XLNet ...):
The problem arises when using:
The tasks I am working on is:
To reproduce
Steps to reproduce the behavior:
Also reproducible using AMP or DeepSpeed. It seems like there is code to circumvent this outlined in the GPT-Neo implementation where q,k,v are casted to fp32 in the attention block.
When the max_length is shorter (512) this overflow does not occur.
Expected behavior
I expected no overflows.
Aside
I'm reaching out on behalf of EleutherAI, Lysandre told us to create an issue about this.
The text was updated successfully, but these errors were encountered: