Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding additive attention mask support to Triton FlashAttention #307

Open
janEbert opened this issue Jul 13, 2023 · 7 comments
Open

Adding additive attention mask support to Triton FlashAttention #307

janEbert opened this issue Jul 13, 2023 · 7 comments

Comments

@janEbert
Copy link
Contributor

janEbert commented Jul 13, 2023

Hey, I was working on adding optional additive attention masking support to Triton's FlashAttention kernel on Triton main to support the latest version and its features. For example, the Triton version now supports both causal and full attention, uses the new block pointer API in the forward pass, is faster than before, and adds sequence parallelism to the backward pass (taken from here).
Triton upstream really changed a lot, which the implementation in here does not reflect yet. So I'm proposing to either merge the two implementations or copy-paste upstream in here, depending on how much work should be invested. A short pro/con list of the approaches is at the end.

As a warning, I haven't compared the two versions in detail since you probably did some heavy performance optimizations for the Triton FlashAttention version in here, including adding several new features yourself.

New Triton implementation

I'd like to highlight the addition of the MODE flag to Triton upstream, which is used to switch between full/causal attention (with two implementations for the causal calculation for efficiency depending on the sequence length). My masking implementation transparently handles this MODE flag as well, meaning two things:

  1. If causal attention is selected, adding a mask on top will never make it non-causal.
  2. Similar to q/k/v, we only load mask blocks where necessary (i.e., we ignore the upper triangle for causal attention).

I'm also very happy about a broadcasting feature, where missing or 1-valued dimensions are expanded so the mask is of shape [batch, nhead, seq_len_q, seq_len_k].

Finally, just like here, usage of a mask is completely optional, so code will run as fast as before when a mask is not given while still supporting causality. The difference between the bias implementation and mask implementation is that masking has to handle -∞ (which also means placing the IS_CAUSAL query after the addition to handle the tranparent mode handling I mentioned above).

Way forward

What do you think of updating the Triton FlashAttention kernel here to make use of the newest Triton upstream features (block pointers really add to readability) and to integrate attention mask support (which implies attention bias support). I'm not sure what the best proposal here is, but I see two main options:

  1. Mostly copy-paste the Triton PR version. This would lose:

    • cross-attention
    • arbitrary seq lens
    • speedups from your side
    • support for old Triton versions

    It would win:

    • not much work to do
    • remove the current hacky workarounds (e.g. TMP buffer)
    • usage of new Triton APIs → speed, maintainability, readability
    • optimizations from the Triton side
    • support for current Triton version
  2. Integrate new upstream features piece-wise. This would lose:

    • a lot of work to do

    It would win:

    • anything from upstream that you decide to integrate (IMO most important: handling of -∞ in attention bias, support for current Triton, new API usage, and speedups from the Triton side)
    • still able to remove old baggage
    • ability to retain Triton backward-compatibility
    • careful considerations for each new integration → not much work if only considering few, small integrations
@tridao
Copy link
Contributor

tridao commented Jul 13, 2023

Wonderful work on the Triton implementation, and very thoughtful suggestions here. Thanks @janEbert!
Yes, I'd love to stay up to date with upstream Triton, I just haven't had time to update the Triton implementation in this repo.

Ideally some of the features implemented here should be upstreamed (option 2), but I'm not sure I'll have the bandwidth to do that. Fortunately some of the features here have gradually been implemented in upstream Triton (causal & non-causal attention, parallelize the backward pass across seqlen_k, and hopefully attention bias soon with your PR). I hope folks will contribute by sending PRs to upstream Triton.

@janEbert
Copy link
Contributor Author

janEbert commented Jul 13, 2023

I can take care of some of the integrations from upstream to here if you're fine with losing backward-compatibility. However I'm a bit worried of having to figure out the workarounds that had to be implemented here. Were they necessary to support the arbitrary sequence lengths or were they used to fight around compiler bugs which may have been fixed upstream?

The attention mask/bias will probably not be integrated upstream due to the added complexity.

@janEbert
Copy link
Contributor Author

Alas, yours will have to be the featureful Triton FlashAttention kernel we need. ;)

@tridao
Copy link
Contributor

tridao commented Jul 13, 2023

I can take care of some of the integrations from upstream to here if you're fine with losing backward-compatibility. The attention mask/bias will probably not be integrated upstream due to the added complexity.

That would be wonderful. Backward-compatibility with old Triton version isn't a concern for me.

@janEbert
Copy link
Contributor Author

Sorry, I've just edited the post above: My only worry is having to figure out the workarounds that had to be implemented here. Were they necessary to support the arbitrary sequence lengths or were they used to fight around compiler bugs which may have been fixed upstream?

Other than that thanks for the okay!

@tridao
Copy link
Contributor

tridao commented Jul 13, 2023

Sorry, I've just edited the post above: My only worry is having to figure out the workarounds that had to be implemented here. Were they necessary to support the arbitrary sequence lengths or were they used to fight around compiler bugs which may have been fixed upstream?

I had to add a bunch of tl.debug_barrier() to deal with compiler bugs, which might have been fixed upstream (I haven't checked). To support arbitrary seqlen ideally one just needs to set a mask when calling tl.load.

@skejriwal44
Copy link

Hi @tridao, I am noticing that even the triton fwd call is slower than the flash-cuda implementation when the head dimension is 128. I am using https://github.com/triton-lang/triton/blob/main/python/tutorials/06-fused-attention.py for profiling. I am return the time instead of flops. These are the results I get for head dimension 128 -

fused-attention-batch4-head32-d128-fwd-causal=True: (Time in ms)

     N_CTX  Triton [FP16]  Triton [FP8]    Flash-2
0   1024.0       0.285561      0.447643   0.241351
1   2048.0       0.894371      1.475016   0.778228
2   4096.0       3.183826      5.124862   2.779694
3   8192.0      12.040288     19.008390  10.662884
4  16384.0      46.750145     73.111710  41.799503

This issue is not present when the head dimension is 64. These are the results for head dimension 64 -

fused-attention-batch4-head32-d64-fwd-causal=True: (Time in ms)

     N_CTX  Triton [FP16]  Triton [FP8]    Flash-2
0   1024.0       0.148868      0.186288   0.146833
1   2048.0       0.450962      0.601676   0.457882
2   4096.0       1.598519      2.149870   1.628497
3   8192.0       6.023782      8.119765   6.196142
4  16384.0      23.513256     31.372011  24.149960

Since the triton fwd call is said to be faster than the flash-cuda implementation wanted to know if that is only for a head dimension of 64 or if I am missing some optimal setting of the hyperparameters (BLOCK_M, BLOCK_N, warps, stages).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants