Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v2] Attention Masking #352

Open
MikeynJerry opened this issue Jul 20, 2023 · 20 comments
Open

[v2] Attention Masking #352

MikeynJerry opened this issue Jul 20, 2023 · 20 comments

Comments

@MikeynJerry
Copy link

MikeynJerry commented Jul 20, 2023

Is any plan to add attention masking support? PyTorch's version of flash attention v1 included the ability to provide an attention mask in their implementation and it would be very useful to have this feature in v2.

@leizhao1234
Copy link

In fact, when you send an attention mask to PyTorch's implementation, flash attention didn't work.

@balachandarsv
Copy link

Yes, facing the same issue. @tridao Can you please take a look at this and respond when you are available?

@tridao
Copy link
Contributor

tridao commented Jul 21, 2023

Attention mask isn't supported (either in v1 or v2). I might implement it at some point but there are other priorities now.

@PeterL1n
Copy link

PeterL1n commented Aug 3, 2023

I thought masking is supported through flash_attn_varlen_func

https://github.com/Dao-AILab/flash-attention/blob/d30f2e1cd50185c98ed88c0684b4a603f15bee37/flash_attn/flash_attn_interface.py#L454C21-L454C21

@zhipeng93
Copy link

I have tested v1.0.7 and v2.0.4. The result turns out that none of them supports attention mask ---

  • A: using flash attention with attention mask
  • B: not using flash attention, with attention mask

The results of A and B are different.

@samvanstroud
Copy link

@defei-coder
Copy link

@tridao Hello, I plan to add a bias mask in flashattention2. I noticed that in order to integrate the scale and add operations scale_apply_exp2 ,the scale is delayed until after the maximum value is calculated. I plan to support bias mask in the apply_mask_causal function, I think if a bias mask is supported, it seems that ffma optimization in scale_apply_exp2 can be cancelled. Using scale and bias can still benefit from FFMA, do you have any suggestions?

@zhangyipin
Copy link

zhangyipin commented Nov 6, 2023

flash_attn/flash_attn_triton.py support bias input
you can use bias=-inf

@wehos
Copy link

wehos commented Feb 29, 2024

flash_attn/flash_attn_triton.py support bias input you can use bias=-inf

This is a good point but the example itself is not working with pytorch2.0+ (<==triton2.0+) 😭

@jaanli
Copy link

jaanli commented Mar 6, 2024

Anyone have tips on custom masks with flash attention for training?

(I need this to train encoder-decoder models with variable-length sequences using non-causal masks.)

This came up in a recent article: https://www.yitay.net/blog/training-great-llms-entirely-from-ground-zero-in-the-wilderness

The other striking thing is how little support these codebases have for large scale encoder-decoder training or even prefixLM training. To that end, even flash attention has consistently declined to provide support for prefixLM training (i.e., custom masks) despite reasonable demand on their github issues for whatever reason.

Curious what this would take or if it is still out of scope for the flash attention library?

Really grateful that this exists!! Just posting for visibility in case others have solved this problem :)

@tridao
Copy link
Contributor

tridao commented Mar 7, 2024

Curious what this would take or if it is still out of scope for the flash attention library?

Not out of scope, it's just someone needs to go implement it :D

@jaanli
Copy link

jaanli commented Mar 7, 2024

Understood — thank you!! Will try using the varlen functions for now :)

@ardagoreci
Copy link

I was wondering if there has been any updates on this? AlphaFold3 uses a lot of attention pair biasing and it would be tremendously useful to computational biology if flash attention supported attention biasing!

@tridao
Copy link
Contributor

tridao commented May 26, 2024

I was wondering if there has been any updates on this? AlphaFold3 uses a lot of attention pair biasing and it would be tremendously useful to computational biology if flash attention supported attention biasing!

Right, we still need someone to implement it.

@alexzhang13
Copy link

@tridao Was wondering, what needs to be done for this to be implemented (I'm assuming efficiently? otw it seems quite simple)

I need a similar feature (arbitrary attention masks) but I figured I might take a stab at just implementing it if it still needs to be done.

@alexzhang13
Copy link

I've implemented a version of custom masking for FA2 in Triton: https://github.com/alexzhang13/flashattention2-custom-mask

It suffices for my use case, but if something comes up where it's necessary to touch the FA3 code I may re-visit this.

@amyxlu
Copy link

amyxlu commented Aug 21, 2024

Attention mask isn't supported (either in v1 or v2). I might implement it at some point but there are other priorities now.

Seems like the FlashAttention class does take in a key_padding_mask argument in its forward method. What would be the difference between this and the attention mask to be implemented? Cc @tridao. Thanks!

@tridao
Copy link
Contributor

tridao commented Aug 21, 2024

As you can see in the code, key_padding_mask just removes elements from keys and values before passing to the flash attention kernel. There's no attention mask passed to the kernel.

@krejciadam
Copy link

As you can see in the code, key_padding_mask just removes elements from keys and values before passing to the flash attention kernel. There's no attention mask passed to the kernel.

Is there any plan to support key_padding_mask in MHA in v2 ? My understanding is that this was supported in v1 (in flash_attn.flash_attention.FlashMHA), but in v2, one can only use key_padding_mask when use_flash_attn is False (in flash_attn.modules.mha.MHA). Thank you.

@agshar96
Copy link

Hi Everyone,
Recently I published a paper in ENLSP Workshop@NEURips 2024, to address this problem, the paper can be found here: https://arxiv.org/pdf/2409.15097

I have the code, but its in a private repository currently, as I am still cleaning up the code. If someone wants to access this repo just send a mail to: [email protected]

Meanwhile, I realised that pytorch team already implemented a change which pretty much uses same method which I used. (I came up with my method independently for a university project the pytorch blog came around half a month after my university project).

Anyways, TL:DR - pytorch has now enabled custom masking of flash attention, you can find it here: https://pytorch.org/blog/flexattention/
(And, I am sad man, as my method will never be used)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests