Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash attention #173

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

oliverdutton
Copy link

@oliverdutton oliverdutton commented Apr 21, 2024

Implements FlashAttention similarly to google-deepmind/alphafold#931

For a 759 residue protein and model_5 this improves runtime 2.2x on an L4 (37.3 $\rightarrow$ 16.9 seconds [with minibatching of 256 for non-flash attention to avoid OOM])

Here's a colab link showing runtime improvement and no significant change in prediction output by visual inspection. I didn't want to rerun all the input prep so I've used a colab with alphafold input preparation and done fixes for colabdesign.

Notes

Key variations from a reference flash attention kernel are:

  • Attention logit biasing supported
  • Gating supported
  • Some heads have only 8 channels, they’re padded up to 16 within kernel (this is a requirement of pl.dot, we still see performance improvement relative to non-flash attn and keeps overall AlphaFold2 linear in memory requirements)
  • Broadcasted masks in batch, q and head dimensions supported (they’re often size 1 and implicitly broadcasted in AlphaFold2 einsums)

There's guards against kernel being called for short sequence lengths less than block sizes specified in q and k which exits to reference kernel.

Comments

  • I think the runtime improvement is benefitting from the triangular fusion you've previously implemented, as on an A100 I saw with flash attention and bfloat16 that starts to be significant.
  • I haven't done correctness/speed checks with multimer models or models using templates. If you have a test suite to do that, that'd be wonderful.
  • When you said 'fused attention' you meant shifting the mask to a bias so XLA lowers it to a fused kernel, right? I've moved that mask $\rightarrow$ bias conversion into the Attention module itself and kept it in the reference_kernel (so now reference_kernel differs from the one in google-deepmind/alphafold#931). So with use_flash_attention=False I haven't changed behaviour: here's a colab showing same 37.3s runtime from the main branch.
  • fix for use_dgram which seemed to access the wrong config keys
  • fix for models not containing pae head

(+)
fix: fix access to global config
fix: allow lack of predicted_aligned_error head
@oliverdutton
Copy link
Author

@sokrypton I think this is ready for merging.

It's still strictly opt-in (as Pallas with Triton is only available for Ampere architecture GPUs and up)

You could improve performance a bit more by tuning block sizes and the number of warps on an input shape dependent manner, and similarly the 'subbatch_size` global config setting could be split into a default heuristic of memory usage where it selects subbatch sizes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant