-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Masking + biasing #17
Comments
We plan to support additive biases (e.g. ALiBi). The bias should ideally take linear memory and not quadratic memory. Re FP32: We use tensor cores for matrix multiply, which can support fp16 (and bf16 in the near future). The xformers team have some implementation of memory-efficient attention in fp32 (here and here). |
Thanks for the speedy response! I'm thinking of applying FlashAttention to our implementation of AlphaFold 2, which has a number of different attention modules with different biases for the pre-softmax quadratic attention matrix BTW, is there a reason the existing |
Can you point me to the code / paper on how these biases are applied? Are they just added to
|
See e.g. our implementation of what DeepMind calls "triangular attention." Pseudocode on page 19 here. It uses both a fixed mask and a trainable bias ( |
Thanks for the pointers! |
In this particular case, the |
Sorry I forgot to ask yesterday---do you have an approximate ETA for the Thanks! |
Do you mean a key padding mask (of shape [B, 1, 1, N]), or an arbitrary attention mask of shape [1, 1, N, N]? |
I tried playing around with the |
Oh awesome. Thanks! |
We want customizable masking & biasing as well! Adding these two features would make FlashAttention suitable for a lot more models. |
Hello @tridao, |
Yes, it'll be there eventually. I just haven't had as munch bandwidth to work on it recently (conference deadline). By inference, do you mean text generation? Would Q have seqlen 1 most of the time, due to the |
refer to #57 |
Hi @tridao, congratulations on your great work! I have the same issue for using flash attention in Swin-Transformer, especially in shifted window attention. To be more specific, the shapes in the window attention are: The final attn weights (before softmax) should be: |
refer to #76 |
Hello @tridao, flash-attention is amazing! Thank you so much for making it! If possible, I'd also like to request a fully customisable attention bias (i.e. shape Thank you for all of your hard work! |
Actually You can simply use the triton version implmentation to achieve this. But for backward you have to modify the origin code .And the performance will decrease because you have to save bias grad in hbm during backward. |
I did try this at some point, but I was getting errors (I'm not sure whether it was my code being wrong, or just triton bugs). I'll probably try again once triton is more stable |
Where can we find the example? |
You can look at our BERT implementation. |
If convert a sparse mask to cu_seqlens and max_seqlen using unpad_input, we will get incorrect results. The results of
How should we solve this problem?Thanks! |
If the mask is in the middle of the sequence, that's not supported right now. |
Can the |
How difficult would it be to modify the kernels to support arbitrary masks/additive biases for the attention logits, especially w/ support for tensor broadcasting? Is there any fundamental reason why that wouldn't work? I noticed that the FlashAttention class has an "attn_mask" parameter, but it doesn't let you specify it.
On an unrelated note, would it be worth adding FP32 support for inference?
The text was updated successfully, but these errors were encountered: