Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Flash attention #725

Merged
merged 5 commits into from
Dec 10, 2022
Merged

Conversation

VHellendoorn
Copy link
Contributor

This PR adds Tri Dao's Flash Attention as an optional backend for the global attention operation, enabled by setting the attention_config to [[["flash"], ...]. I've tested the changes in my own environment and consistently see a 2x boost for 4K sequence lengths in models ranging from 100M - 3B parameters.

Maybe relevant: @tridao @lucidrains

@Quentin-Anthony
Copy link
Member

Quentin-Anthony commented Nov 30, 2022

@VHellendoorn -- Thanks so much for adding this support! I'll run some tests myself sometime this week. I'll report back here whether I can reproduce your speedup.

@dashstander
Copy link
Contributor

Here is a side by side comparison of a 1.3B model on an 80GB A100 with and without Flash attention. So far seems pretty stable at 180 vs 130 TFLOPS, respectively.

@tridao
Copy link

tridao commented Dec 7, 2022

This is awesome, thanks @VHellendoorn for the integration and for @dashstander for the comparison runs!

@Quentin-Anthony
Copy link
Member

Thanks for testing this out @dashstander!

@VHellendoorn -- I'll review and approve this once you have run our pre-commit checks. Can you run them and commit the formatting updates with:

pip install pre-commit
cd /path/to/gpt-neox
pre-commit install
pre-commit run --all-files # You will probably have to run twice so that formatting changes can be automatically applied. Make sure all checks pass.
<git commit and push>

@StellaAthena
Copy link
Member

That is quite impressive! Let’s make sure to leave the runs going and do some downstream evals once they’ve finished training @dashstander, but this looks compelling enough to me to merge.

@VHellendoorn
Copy link
Contributor Author

Agree, thanks @dashstander! Out of curiosity, what sequence length did you run this experiment with?

@Quentin-Anthony: I ran the precommit script and pushed the corresponding fixes. It made changes to quite a few files outside of the ones I changed; hope you don't mind that I left those out.

@StellaAthena
Copy link
Member

@VHellendoorn you can find the configs for the run here. The sequence length is 2048, which is EleutherAI’s default value.

@VHellendoorn
Copy link
Contributor Author

Great, thanks! FWIW, I've been training with both 4K and 8K sequence lengths and it's really very fast there too -- the 8K model is barely slower than the 4K one for a 2.7B model. No formal benchmarks to share, but just noting that for others looking for info. All props to @tridao for making this possible, of course!

@tridao
Copy link

tridao commented Dec 7, 2022

Great, thanks! FWIW, I've been training with both 4K and 8K sequence lengths and it's really very fast there too -- the 8K model is barely slower than the 4K one for a 2.7B model. No formal benchmarks to share, but just noting that for others looking for info. All props to @tridao for making this possible, of course!

Ha I've been training 8k models as well (GPT 1.3B and 2.7B on the Pile)! Will put out some speed numbers this week in a blogpost.

@StellaAthena
Copy link
Member

@VHellendoorn @tridao this is very interesting info! Do you have the experience that longer sequence lengths increases performance in areas you care about noticeably? My recollection (though I would have to go hunt down papers) is that many people have found that 2048 tokens is sufficient for most NLP tasks.

@VHellendoorn
Copy link
Contributor Author

I can't speak for regular NLP tasks, but in my current experiments on code, the difference between 2K and 4K (and above) is surprisingly large, at least in terms of loss. In several of my runs, the larger context window reduces loss by 10-20%, especially past 100B tokens or so, at least for smaller models. Bigger model sweeps are still going, but as a first datapoint, the 8K model I mention above has just passed the loss of a concurrently running 2K context-window equivalent in less than half the steps (~23B and ~52B tokens in, resp.). Seems on track to converge quite a bit lower still.

This might be a code-specific phenomenon (though @tridao can hopefully tell us otherwise!). Code files are typically quite large -- 4K tokens is not a bad guess for the median in many languages. I'll also add the caveat that I'm still working on this sweep, but the results have been pretty consistent so far. Definitely makes me see why Codex was trained with 4K tokens.

@tridao
Copy link

tridao commented Dec 8, 2022

I think most current NLP benchmark tasks don't have long sequences.
Right now I'm interested in how to equip models with memory (e.g., ChatGPT claims to "remember what user said earlier in the conversation" and is rumored to have context length 8k).
I've also been working with some collaborators on multi-turn user interaction and they said they need/want 8k or even 16k context length.

@StellaAthena
Copy link
Member

@Quentin-Anthony are there any concerns about this PR? It looks good to merge to me.

@Quentin-Anthony Quentin-Anthony merged commit efd5911 into EleutherAI:main Dec 10, 2022
@StellaAthena StellaAthena added this to the Release V2 milestone Dec 20, 2022
@chuanli11
Copy link
Contributor

chuanli11 commented Jan 3, 2023

Thanks to adding this support. Just tried it out and in my case (a 8xA100 40GB server) i was able to get 15 - 50% improvement on TFLOPs with flash attention, depending on the model sizes.

Also in my tests the memory usage seems to be roughly the same with/without flash attention. This study seems to suggest there is about 15% memory reduction for some of the models.

Wonder if anyone can shed light on how much improvement to expect by adding flash attention to GPT-neox models.

@dashstander
Copy link
Contributor

@chuanli11 what model sizes did you use for your tests? I ran some tests at 2.7B and 6.7B parameters. For 2.7B there was a ~20% improvement in peak memory usage, but for 6.7B there wasn't. I'm not sure why that might be though.

Also pinging @Quentin-Anthony for visibility.

@chuanli11
Copy link
Contributor

chuanli11 commented Jan 6, 2023

Hey, @dashstander these are some of the tests I ran (all with 8xA100 40GB)

Model non-flash (activation checkpoint = false) non-flash (activation checkpoint = true) flash pipeline parallel model parallel micro_bs
19M_pythia OOM 80TFLOPs/35GB 90TFLOPs/30GB 1 1 32
2.7B OOM 96.2TFLOPS/31GB 146.2TFLOPS/31GB 1 1 4
6.7B OOM 67.4TFLOPS/40GB 84TFLOPS/40GB 2 1 4

activation checkpoint has huge impact on memory usage, see comment here.

Further data points for non-flash + activation checkpoint = false with different micro_bs_per_gpu:

Model micro_bs = 1 micro_bs = 2 micro_bs = 4 micro_bs = 8 micro_bs = 16 micro_bs = 32 pipeline parallel model parallel
19M_pythia 48TFLOS/5GB 74TFLOPS/7GB 102FLOPS/9.5GB 121.5TFLOPS/15GB 64.6TFLOPS / 27GB OOM 1 1
2.7B 90TFLOS/40GB OOM OOM OOM OOM OOM 1 1
6.7B OOM OOM OOM OOM OOM OOM 2 1

@tridao
Copy link

tridao commented Jan 6, 2023

Do you have activation checkpointing on? In that case, for the non-flash runs, even though the attention matrix is materialized, it's not saved for the backward (but instead recomputed). So you're trading off compute to reduce memory.
With FlashAttention we found we didn't need to do activation checkpointing (e.g. for 2.7B or 6.7B models on 80GB cards).

@chuanli11
Copy link
Contributor

Do you have activation checkpointing on? In that case, for the non-flash runs, even though the attention matrix is materialized, it's not saved for the backward (but instead recomputed). So you're trading off compute to reduce memory. With FlashAttention we found we didn't need to do activation checkpointing (e.g. for 2.7B or 6.7B models on 80GB cards).

Thanks for the tip! Indeed activation checkpointing was on and that had huge impact. Update my original post.

@dashstander
Copy link
Contributor

@tridao and this is why gradient checkpointing is obviated by FlashAttention?
image

@tridao
Copy link

tridao commented Jan 6, 2023

I'd say FlashAttention reduces the need for gradient checkpointing, since attention now doesn't take up quadratic memory. Ofc there are cases (GPU with small memory, large models) where one would still need gradient checkpointing regardless.

@greg1232
Copy link

Now that this code is available.

Are there any plans to train pythia/etc with longer sequence lengths, e.g. 32k (using sparse flash attention)?

@Quentin-Anthony
Copy link
Member

Now that this code is available.

Are there any plans to train pythia/etc with longer sequence lengths, e.g. 32k (using sparse flash attention)?

We intend to finetune some of the Pythia models on longer context lengths as part of the INCITE grant: https://twitter.com/BlancheMinerva/status/1593725723352539136?s=20&t=E-NvaQqiMS7IgmN3uS-zpg

If you're interested in contributing, please join the discord at https://discord.gg/hrcJTaSDeC

@guozhiyao
Copy link

Hey, @dashstander these are some of the tests I ran (all with 8xA100 40GB)

Model non-flash (activation checkpoint = false) non-flash (activation checkpoint = true) flash pipeline parallel model parallel micro_bs
19M_pythia OOM 80TFLOPs/35GB 90TFLOPs/30GB 1 1 32
2.7B OOM 96.2TFLOPS/31GB 146.2TFLOPS/31GB 1 1 4
6.7B OOM 67.4TFLOPS/40GB 84TFLOPS/40GB 2 1 4
activation checkpoint has huge impact on memory usage, see comment here.

Further data points for non-flash + activation checkpoint = false with different micro_bs_per_gpu:

Model micro_bs = 1 micro_bs = 2 micro_bs = 4 micro_bs = 8 micro_bs = 16 micro_bs = 32 pipeline parallel model parallel
19M_pythia 48TFLOS/5GB 74TFLOPS/7GB 102FLOPS/9.5GB 121.5TFLOPS/15GB 64.6TFLOPS / 27GB OOM 1 1
2.7B 90TFLOS/40GB OOM OOM OOM OOM OOM 1 1
6.7B OOM OOM OOM OOM OOM OOM 2 1

@chuanli11 Hi. Does this test enable fp16 or bf16? Is zero used?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants