Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Use flash-attn for decoding #3648

Merged
merged 98 commits into from
May 13, 2024

Conversation

skrider
Copy link
Contributor

@skrider skrider commented Mar 26, 2024

Vendors flash-attention from Dao-AILab/flash-attention#824, prunes out the backward pass operator for faster compile times, adds reshape and cache kernel for flash attention kv cache layout, adds logic for selecting kv cache manager / attention backend based on temporary environment variable VLLM_TEMP_USE_FLASH_DECODE. Tested for single GPU on opt-125m, llama-7b

@skrider skrider force-pushed the flash-attention-decode branch from 04f8c75 to 0e45f5d Compare March 27, 2024 06:26
@WoosukKwon WoosukKwon added the release-blocker This PR/issue blocks the next release, therefore deserves highest priority label Mar 27, 2024
@simon-mo
Copy link
Collaborator

We will try to get this in best effort by tomorrow, if not, this will be slated for next release

@WoosukKwon WoosukKwon self-assigned this Mar 27, 2024
@WoosukKwon
Copy link
Collaborator

@skrider Thanks for the great work! Can I directly fix this PR for faster integration?

max_subquery_len: int,
alibi_slopes: Optional[torch.Tensor],
) -> torch.Tensor:
raise NotImplementedError
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@skrider Could you elaborate on why it is tricky to implement prefix-enabled attention?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The flash API does not support attention to dense KV and paged KV in the same launch. We can either cache the dense KV before invoking forward_prefix or compute attention separately and merge with the online softmax trick. Not sure which would be best.

Copy link
Collaborator

@rkooo567 rkooo567 Mar 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@skrider isn't this possible with flash_attn_varlen_func? It seems like it supports paged kv cache (https://github.com/Dao-AILab/flash-attention/pull/831/files).

IIUC, you can just pass k cache and v cache to k and v for flash_attn_varlen_func?

Copy link
Contributor Author

@skrider skrider Mar 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is correct, however we need to compute attention with k_cache and k in one operation, which the kernel cannot do without first copying k,v into the kv cache.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can either cache the dense KV before invoking forward_prefix

@skrider This is what our current Triton kernel is doing. Can't we do the same thing with flash-attn? Please take a look at the change I made in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My mistake - we should be good then

Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Do we really need a new backend implementation? Why don't we just adding env var to existing flash attn impl's decoding path?

max_subquery_len: int,
alibi_slopes: Optional[torch.Tensor],
) -> torch.Tensor:
raise NotImplementedError
Copy link
Collaborator

@rkooo567 rkooo567 Mar 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@skrider isn't this possible with flash_attn_varlen_func? It seems like it supports paged kv cache (https://github.com/Dao-AILab/flash-attention/pull/831/files).

IIUC, you can just pass k cache and v cache to k and v for flash_attn_varlen_func?

@skrider
Copy link
Contributor Author

skrider commented Mar 28, 2024

@rkooo567 I think that because the KV cache layout is different it makes sense to have a different backend.

@WoosukKwon
Copy link
Collaborator

WoosukKwon commented Mar 28, 2024

@skrider I just edited this PR: 1) I removed dependency on your FlashAttention repo (Let's add it in the next PR), 2) I enabled the prefix-attention, and 3) I moved this to FlashAttentionBackend as @rkooo567 suggested, since the backend actually existed for this integration. The current implementation was actually a placeholder. Could you please take a look?

@zhaoyang-star
Copy link
Contributor

Unittest is needed.

Comment on lines 269 to 272
if self.attn_backend.get_name() == "flash-attn":
block_table = seq_group_metadata.block_tables[seq_id]
else:
block_table = computed_block_nums
Copy link
Collaborator

@Yard1 Yard1 May 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this suggests to me we should use some abstraction here instead of if...else branching. A method on attn_backend perhaps

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Agreed. For now, I added a note on why we need this if, and added TODO asking for a better abstraction.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the simplest solution here is to fix the Triton kernel so that it aligns with the other backends' APIs. I will leave this for future PR though.

@WoosukKwon WoosukKwon merged commit 1356df5 into vllm-project:main May 13, 2024
43 of 51 checks passed
@rkooo567
Copy link
Collaborator

oh yay! I will run the benchmark on a100 today

rkooo567 added a commit to rkooo567/vllm that referenced this pull request May 15, 2024
rkooo567 added a commit that referenced this pull request May 15, 2024
Lora 3 & 4 test seems to have illegal memory access failure after this commit;

[2024-05-14 23:51:18,182 E 22 22] logging.cc:101: Unhandled exception: N3c105ErrorE. what(): CUDA error: an illegal memory access was encountered
<br class="Apple-interchange-newline">
Exmaple: https://buildkite.com/vllm/ci/builds/7382#018f793d-1527-4e1c-ab59-c3a34ec55241

This reverts commit 1356df5.

FILL IN THE PR DESCRIPTION HERE

FIX #xxxx (link existing issues this PR will resolve)
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request May 19, 2024
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request May 19, 2024
…lm-project#4820)

Lora 3 & 4 test seems to have illegal memory access failure after this commit;

[2024-05-14 23:51:18,182 E 22 22] logging.cc:101: Unhandled exception: N3c105ErrorE. what(): CUDA error: an illegal memory access was encountered
<br class="Apple-interchange-newline">
Exmaple: https://buildkite.com/vllm/ci/builds/7382#018f793d-1527-4e1c-ab59-c3a34ec55241

This reverts commit 1356df5.

FILL IN THE PR DESCRIPTION HERE

FIX #xxxx (link existing issues this PR will resolve)
@wooyeonlee0
Copy link
Contributor

Great! Thanks for the PR. @skrider

It seems that this PR has introduced the following constraint.
_SUPPORTED_HEAD_SIZES = [32, 64, 96, 128, 160, 192, 224, 256] link

I've been using vllm for testing the model with head size not listed in the above list.

Then from now on, can't I run my model on vllm with the flash_attn option?
Does vllm before this PR (v0.4.2) work properly with the model with head size not listed in the list?

I look forward to hearing from you.
Thank you in advance!

@rkooo567
Copy link
Collaborator

I think if this head size is not supported, you cannot use the flash attn (it is the limitation of flash attn).

@rkooo567
Copy link
Collaborator

(to make it work, you should probably make flash attn work with unspecified head sizes)

@wooyeonlee0
Copy link
Contributor

@rkooo567 Oh, I didn't know that flash attn has this limitation. Thanks for the information!

dtrifiro pushed a commit to dtrifiro/vllm that referenced this pull request May 21, 2024
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
dtrifiro pushed a commit to dtrifiro/vllm that referenced this pull request May 21, 2024
…lm-project#4820)

Lora 3 & 4 test seems to have illegal memory access failure after this commit;

[2024-05-14 23:51:18,182 E 22 22] logging.cc:101: Unhandled exception: N3c105ErrorE. what(): CUDA error: an illegal memory access was encountered
<br class="Apple-interchange-newline">
Exmaple: https://buildkite.com/vllm/ci/builds/7382#018f793d-1527-4e1c-ab59-c3a34ec55241

This reverts commit 1356df5.

FILL IN THE PR DESCRIPTION HERE

FIX #xxxx (link existing issues this PR will resolve)
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
…lm-project#4820)

Lora 3 & 4 test seems to have illegal memory access failure after this commit;

[2024-05-14 23:51:18,182 E 22 22] logging.cc:101: Unhandled exception: N3c105ErrorE. what(): CUDA error: an illegal memory access was encountered
<br class="Apple-interchange-newline">
Exmaple: https://buildkite.com/vllm/ci/builds/7382#018f793d-1527-4e1c-ab59-c3a34ec55241

This reverts commit 1356df5.

FILL IN THE PR DESCRIPTION HERE

FIX #xxxx (link existing issues this PR will resolve)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants