Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Encoder Decoder] Add flash_attn kernel support for encoder-decoder models #9559

Merged
merged 104 commits into from
Nov 2, 2024

Conversation

sroy745
Copy link
Collaborator

@sroy745 sroy745 commented Oct 21, 2024

This PR adds support for flash attention kernel for encoder decoder models. For encoder-decoder models with dtype=bfloat16 the default backend choice is now FlashAttention instead of XFormers. However for llama-3.2-11b-vision-instruct we still use the Xformers backend even with dtype=bfloat16 because the model implementation (models/mllama.py) has dependency on PagedAttention.

For adding this support, we make the following changes in this pr

  1. Updated flash_attn.py to add support for encoder-decoder models. Also updated the tests in tests/kernels/test_encoder_decoder.py to test FlashAttention backend along with the existing XFormers backend.
  2. Updated test_bart.py , test_florence2.py and encoder_decoder/test_e2e_correctness.py to run with both backends.
  3. Moved some methods from xformers.py to backend/utils.py so that they can be reused in both xformers.py and flash_attn.py
  4. Updated the checks in worker/enc_dec_model_runner.py to now check that the backend is either FlashAttention or XFormers instead of only XFormers as we do currently.
  5. Updated models/bart.py to invoke attention.forward with query of shape [num_tokens, hidden_size]. Currently it was invoking the forward with a query of shape [num_tokens, num_heads, head_size] which is not default.

#7366

sroy745 added 30 commits May 28, 2024 20:39
@sroy745
Copy link
Collaborator Author

sroy745 commented Nov 1, 2024

Thanks for the review. Addressed comments. PTAL

Copy link

mergify bot commented Nov 1, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @sroy745 please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 1, 2024
Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for your hardwork on this. Looking forward for the follow-up PRs for test_encoder_decoder_attention and mllama support.

Also CC @WoosukKwon. You may need to sync this PR to v1 later.

@mergify mergify bot removed the needs-rebase label Nov 1, 2024
@sroy745
Copy link
Collaborator Author

sroy745 commented Nov 1, 2024

@ywang96 PTAL when you get a chance. PR has been LG'ed by @heheda12345 , is synced to head and all tests are passing.

Copy link

mergify bot commented Nov 1, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @sroy745 please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 1, 2024
@mergify mergify bot removed the needs-rebase label Nov 1, 2024
Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this great work!

@ywang96 ywang96 merged commit a78dd33 into vllm-project:main Nov 2, 2024
61 of 62 checks passed
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Nov 4, 2024
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Nov 4, 2024
richardsliu pushed a commit to richardsliu/vllm that referenced this pull request Nov 4, 2024
bigPYJ1151 pushed a commit to bigPYJ1151/vllm that referenced this pull request Nov 5, 2024
JC1DA pushed a commit to JC1DA/vllm that referenced this pull request Nov 11, 2024
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants