-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for Flash attention #725
Conversation
@VHellendoorn -- Thanks so much for adding this support! I'll run some tests myself sometime this week. I'll report back here whether I can reproduce your speedup. |
Here is a side by side comparison of a 1.3B model on an 80GB A100 with and without Flash attention. So far seems pretty stable at 180 vs 130 TFLOPS, respectively. |
This is awesome, thanks @VHellendoorn for the integration and for @dashstander for the comparison runs! |
Thanks for testing this out @dashstander! @VHellendoorn -- I'll review and approve this once you have run our pre-commit checks. Can you run them and commit the formatting updates with:
|
That is quite impressive! Let’s make sure to leave the runs going and do some downstream evals once they’ve finished training @dashstander, but this looks compelling enough to me to merge. |
Agree, thanks @dashstander! Out of curiosity, what sequence length did you run this experiment with? @Quentin-Anthony: I ran the precommit script and pushed the corresponding fixes. It made changes to quite a few files outside of the ones I changed; hope you don't mind that I left those out. |
@VHellendoorn you can find the configs for the run here. The sequence length is 2048, which is EleutherAI’s default value. |
Great, thanks! FWIW, I've been training with both 4K and 8K sequence lengths and it's really very fast there too -- the 8K model is barely slower than the 4K one for a 2.7B model. No formal benchmarks to share, but just noting that for others looking for info. All props to @tridao for making this possible, of course! |
Ha I've been training 8k models as well (GPT 1.3B and 2.7B on the Pile)! Will put out some speed numbers this week in a blogpost. |
@VHellendoorn @tridao this is very interesting info! Do you have the experience that longer sequence lengths increases performance in areas you care about noticeably? My recollection (though I would have to go hunt down papers) is that many people have found that 2048 tokens is sufficient for most NLP tasks. |
I can't speak for regular NLP tasks, but in my current experiments on code, the difference between 2K and 4K (and above) is surprisingly large, at least in terms of loss. In several of my runs, the larger context window reduces loss by 10-20%, especially past 100B tokens or so, at least for smaller models. Bigger model sweeps are still going, but as a first datapoint, the 8K model I mention above has just passed the loss of a concurrently running 2K context-window equivalent in less than half the steps (~23B and ~52B tokens in, resp.). Seems on track to converge quite a bit lower still. This might be a code-specific phenomenon (though @tridao can hopefully tell us otherwise!). Code files are typically quite large -- 4K tokens is not a bad guess for the median in many languages. I'll also add the caveat that I'm still working on this sweep, but the results have been pretty consistent so far. Definitely makes me see why Codex was trained with 4K tokens. |
I think most current NLP benchmark tasks don't have long sequences. |
@Quentin-Anthony are there any concerns about this PR? It looks good to merge to me. |
Thanks to adding this support. Just tried it out and in my case (a 8xA100 40GB server) i was able to get 15 - 50% improvement on TFLOPs with flash attention, depending on the model sizes. Also in my tests the memory usage seems to be roughly the same with/without flash attention. This study seems to suggest there is about 15% memory reduction for some of the models. Wonder if anyone can shed light on how much improvement to expect by adding flash attention to GPT-neox models. |
@chuanli11 what model sizes did you use for your tests? I ran some tests at 2.7B and 6.7B parameters. For 2.7B there was a ~20% improvement in peak memory usage, but for 6.7B there wasn't. I'm not sure why that might be though. Also pinging @Quentin-Anthony for visibility. |
Hey, @dashstander these are some of the tests I ran (all with 8xA100 40GB)
Further data points for non-flash +
|
Do you have activation checkpointing on? In that case, for the non-flash runs, even though the attention matrix is materialized, it's not saved for the backward (but instead recomputed). So you're trading off compute to reduce memory. |
Thanks for the tip! Indeed activation checkpointing was on and that had huge impact. Update my original post. |
@tridao and this is why gradient checkpointing is obviated by FlashAttention? |
I'd say FlashAttention reduces the need for gradient checkpointing, since attention now doesn't take up quadratic memory. Ofc there are cases (GPU with small memory, large models) where one would still need gradient checkpointing regardless. |
Now that this code is available. Are there any plans to train pythia/etc with longer sequence lengths, e.g. 32k (using sparse flash attention)? |
We intend to finetune some of the Pythia models on longer context lengths as part of the INCITE grant: https://twitter.com/BlancheMinerva/status/1593725723352539136?s=20&t=E-NvaQqiMS7IgmN3uS-zpg If you're interested in contributing, please join the discord at https://discord.gg/hrcJTaSDeC |
@chuanli11 Hi. Does this test enable fp16 or bf16? Is zero used? |
This PR adds Tri Dao's Flash Attention as an optional backend for the global attention operation, enabled by setting the
attention_config
to[[["flash"], ...]
. I've tested the changes in my own environment and consistently see a 2x boost for 4K sequence lengths in models ranging from 100M - 3B parameters.Maybe relevant: @tridao @lucidrains