-
Notifications
You must be signed in to change notification settings - Fork 964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FSDP + bf16 mixed precision on Llama2-7b shows weird behavior. #2127
Comments
@tmabraham can you provide a full repr with:
|
I did
This is the output of
transformers was version 4.35.0.dev0 Let me know if you need any more info... |
Hello,
Above, both the code snippets are same 😅. I observed the same behaviour couple days ago. Don't pass As you have passed
|
@tmabraham, there are following bugs in the code you shared:
other changes:
Code: https://github.com/pacman100/hf_fsdp
Dataset: 1% of 7M sample => smangrul/pubmed_smol output logs:
|
Unless I'm missing something, you're getting much worse loss curve than @tmabraham here -- you've divided the loss by 2; @tmabraham had loss down beneath 0.3 after making the same adjustment. |
Hello @jph00, are you using the exact same code, command and setup? I am using 8 GPUs with 2 gradient accumulation steps. |
For reference, I get the exact same loss curve as above using the Commands:
|
Both these issues are clarified. Feel free to close the issue. |
I also see the same loss curve using Transformers FA2 integration instead of my FA2 monkey-patch (cc @younesbelkada). Changes in the ...
- replace_llama_attn_with_flash_attn()
+ #replace_llama_attn_with_flash_attn()
# Create fresh LlamaForCausalLM model
model = LlamaForCausalLM.from_pretrained(
MODEL_NAME,
use_cache=False,
+ use_flash_attention_2=True,
) |
Sorry I think one of us is missing something here @pacman100 -- you've shown you're getting a loss curve which is not training correctly. As @tmabraham showed, the loss should be reducing 3x from the initial batch to the 500th. Your loss flattens out after 50-100 steps and isn't showing much reduction at all. Your model doesn't seem to be training correctly AFAICT. |
That was on 0.1% data that too repeated multiple times, so the loss going down as it sees the same sample again is expected. I am not repeating any data and taking 1% subset. Can they please run it again with the exact steps that I have done and get back with results? I have tried accelerate integration as well as llama recipe by Meta and get the same results. Also, note that the perplexity would already be 2-3 as per my runs, hard to imagine it going down further given that the largest llama model had the training perplexity of e^1.5 (4.481). It would help to get the exact code, command, distributed setup and the library versions. |
Got it - yup I'll work with @tmabraham to make that happen. |
To further add, in the offline information Zach provided regarding this issue, there was a plot that showed loss going down from 1.45 to 1.3 in 8000 steps. So, there is no proper information on what is being actually run. |
Hello @pacman100, thank you for looking into this! I am trying the changes you have made and will report back soon. Note that I observed issues even without FlashAttention 2.0, as you can see here where I am using BetterTransformer which IIUC implements Flash Attention 1.0. did try FA2 through Transformers, and a separate FA2 Llama2 implementation from Together, all resulting in same behavior. I will try again though with these other changes. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
FSDP with mixed precision shows weird behavior. For example, if in the model definition we have:
this performs worse than
Note that this is on 0.1% of the dataset I am using. Another weird issue is just increasing from 0.1% to 1% causes it to fail to train (loss stays effectively constant), even the one with decreasing loss.
Code for reproduction: https://github.com/tmabraham/hf_fsdp/
The text was updated successfully, but these errors were encountered: