-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding additive attention mask support to Triton FlashAttention #307
Comments
Wonderful work on the Triton implementation, and very thoughtful suggestions here. Thanks @janEbert! Ideally some of the features implemented here should be upstreamed (option 2), but I'm not sure I'll have the bandwidth to do that. Fortunately some of the features here have gradually been implemented in upstream Triton (causal & non-causal attention, parallelize the backward pass across seqlen_k, and hopefully attention bias soon with your PR). I hope folks will contribute by sending PRs to upstream Triton. |
I can take care of some of the integrations from upstream to here if you're fine with losing backward-compatibility. However I'm a bit worried of having to figure out the workarounds that had to be implemented here. Were they necessary to support the arbitrary sequence lengths or were they used to fight around compiler bugs which may have been fixed upstream? The attention mask/bias will probably not be integrated upstream due to the added complexity. |
Alas, yours will have to be the featureful Triton FlashAttention kernel we need. ;) |
That would be wonderful. Backward-compatibility with old Triton version isn't a concern for me. |
Sorry, I've just edited the post above: My only worry is having to figure out the workarounds that had to be implemented here. Were they necessary to support the arbitrary sequence lengths or were they used to fight around compiler bugs which may have been fixed upstream? Other than that thanks for the okay! |
I had to add a bunch of |
Hi @tridao, I am noticing that even the triton fwd call is slower than the flash-cuda implementation when the head dimension is 128. I am using https://github.com/triton-lang/triton/blob/main/python/tutorials/06-fused-attention.py for profiling. I am return the time instead of flops. These are the results I get for head dimension 128 - fused-attention-batch4-head32-d128-fwd-causal=True: (Time in ms)
This issue is not present when the head dimension is 64. These are the results for head dimension 64 - fused-attention-batch4-head32-d64-fwd-causal=True: (Time in ms)
Since the triton fwd call is said to be faster than the flash-cuda implementation wanted to know if that is only for a head dimension of 64 or if I am missing some optimal setting of the hyperparameters (BLOCK_M, BLOCK_N, warps, stages). |
Hey, I was working on adding optional additive attention masking support to Triton's FlashAttention kernel on Triton
main
to support the latest version and its features. For example, the Triton version now supports both causal and full attention, uses the new block pointer API in the forward pass, is faster than before, and adds sequence parallelism to the backward pass (taken from here).Triton upstream really changed a lot, which the implementation in here does not reflect yet. So I'm proposing to either merge the two implementations or copy-paste upstream in here, depending on how much work should be invested. A short pro/con list of the approaches is at the end.
As a warning, I haven't compared the two versions in detail since you probably did some heavy performance optimizations for the Triton FlashAttention version in here, including adding several new features yourself.
New Triton implementation
I'd like to highlight the addition of the
MODE
flag to Triton upstream, which is used to switch between full/causal attention (with two implementations for the causal calculation for efficiency depending on the sequence length). My masking implementation transparently handles thisMODE
flag as well, meaning two things:I'm also very happy about a broadcasting feature, where missing or 1-valued dimensions are expanded so the mask is of shape
[batch, nhead, seq_len_q, seq_len_k]
.Finally, just like here, usage of a mask is completely optional, so code will run as fast as before when a mask is not given while still supporting causality. The difference between the bias implementation and mask implementation is that masking has to handle -∞ (which also means placing the
IS_CAUSAL
query after the addition to handle the tranparent mode handling I mentioned above).Way forward
What do you think of updating the Triton FlashAttention kernel here to make use of the newest Triton upstream features (block pointers really add to readability) and to integrate attention mask support (which implies attention bias support). I'm not sure what the best proposal here is, but I see two main options:
Mostly copy-paste the Triton PR version. This would lose:
It would win:
TMP
buffer)Integrate new upstream features piece-wise. This would lose:
It would win:
The text was updated successfully, but these errors were encountered: