-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[v2] Attention Masking #352
Comments
In fact, when you send an attention mask to PyTorch's implementation, flash attention didn't work. |
Yes, facing the same issue. @tridao Can you please take a look at this and respond when you are available? |
Attention mask isn't supported (either in v1 or v2). I might implement it at some point but there are other priorities now. |
I thought masking is supported through |
I have tested v1.0.7 and v2.0.4. The result turns out that none of them supports attention mask ---
The results of A and B are different. |
This paper might be relevant: https://arxiv.org/abs/2306.01160. There are several related issues:
I believe pytorch 2.1 will have a memory efficient attention implementation that supports arbitrary masks: pytorch/pytorch#96099 |
@tridao Hello, I plan to add a bias mask in flashattention2. I noticed that in order to integrate the scale and add operations scale_apply_exp2 ,the scale is delayed until after the maximum value is calculated. I plan to support bias mask in the apply_mask_causal function, I think if a bias mask is supported, it seems that ffma optimization in scale_apply_exp2 can be cancelled. Using scale and bias can still benefit from FFMA, do you have any suggestions? |
flash_attn/flash_attn_triton.py support bias input |
This is a good point but the example itself is not working with pytorch2.0+ (<==triton2.0+) 😭 |
Anyone have tips on custom masks with flash attention for training? (I need this to train encoder-decoder models with variable-length sequences using non-causal masks.) This came up in a recent article: https://www.yitay.net/blog/training-great-llms-entirely-from-ground-zero-in-the-wilderness
Curious what this would take or if it is still out of scope for the flash attention library? Really grateful that this exists!! Just posting for visibility in case others have solved this problem :) |
Not out of scope, it's just someone needs to go implement it :D |
Understood — thank you!! Will try using the varlen functions for now :) |
I was wondering if there has been any updates on this? AlphaFold3 uses a lot of attention pair biasing and it would be tremendously useful to computational biology if flash attention supported attention biasing! |
Right, we still need someone to implement it. |
@tridao Was wondering, what needs to be done for this to be implemented (I'm assuming efficiently? otw it seems quite simple) I need a similar feature (arbitrary attention masks) but I figured I might take a stab at just implementing it if it still needs to be done. |
I've implemented a version of custom masking for FA2 in Triton: https://github.com/alexzhang13/flashattention2-custom-mask It suffices for my use case, but if something comes up where it's necessary to touch the FA3 code I may re-visit this. |
Seems like the FlashAttention class does take in a |
As you can see in the code, key_padding_mask just removes elements from keys and values before passing to the flash attention kernel. There's no attention mask passed to the kernel. |
Is there any plan to support |
Hi Everyone, I have the code, but its in a private repository currently, as I am still cleaning up the code. If someone wants to access this repo just send a mail to: [email protected] Meanwhile, I realised that pytorch team already implemented a change which pretty much uses same method which I used. (I came up with my method independently for a university project the pytorch blog came around half a month after my university project). Anyways, TL:DR - pytorch has now enabled custom masking of flash attention, you can find it here: https://pytorch.org/blog/flexattention/ |
Is any plan to add attention masking support? PyTorch's version of flash attention v1 included the ability to provide an attention mask in their implementation and it would be very useful to have this feature in v2.
The text was updated successfully, but these errors were encountered: