-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Use flash-attn for decoding #3648
[Kernel] Use flash-attn for decoding #3648
Conversation
04f8c75
to
0e45f5d
Compare
We will try to get this in best effort by tomorrow, if not, this will be slated for next release |
@skrider Thanks for the great work! Can I directly fix this PR for faster integration? |
vllm/attention/ops/flash_attn.py
Outdated
max_subquery_len: int, | ||
alibi_slopes: Optional[torch.Tensor], | ||
) -> torch.Tensor: | ||
raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@skrider Could you elaborate on why it is tricky to implement prefix-enabled attention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The flash API does not support attention to dense KV and paged KV in the same launch. We can either cache the dense KV before invoking forward_prefix or compute attention separately and merge with the online softmax trick. Not sure which would be best.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@skrider isn't this possible with flash_attn_varlen_func
? It seems like it supports paged kv cache (https://github.com/Dao-AILab/flash-attention/pull/831/files).
IIUC, you can just pass k cache and v cache to k and v for flash_attn_varlen_func
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this is correct, however we need to compute attention with k_cache and k in one operation, which the kernel cannot do without first copying k,v into the kv cache.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can either cache the dense KV before invoking forward_prefix
@skrider This is what our current Triton kernel is doing. Can't we do the same thing with flash-attn? Please take a look at the change I made in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My mistake - we should be good then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: Do we really need a new backend implementation? Why don't we just adding env var to existing flash attn impl's decoding path?
vllm/attention/ops/flash_attn.py
Outdated
max_subquery_len: int, | ||
alibi_slopes: Optional[torch.Tensor], | ||
) -> torch.Tensor: | ||
raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@skrider isn't this possible with flash_attn_varlen_func
? It seems like it supports paged kv cache (https://github.com/Dao-AILab/flash-attention/pull/831/files).
IIUC, you can just pass k cache and v cache to k and v for flash_attn_varlen_func
?
@rkooo567 I think that because the KV cache layout is different it makes sense to have a different backend. |
@skrider I just edited this PR: 1) I removed dependency on your FlashAttention repo (Let's add it in the next PR), 2) I enabled the prefix-attention, and 3) I moved this to |
Unittest is needed. |
vllm/worker/model_runner.py
Outdated
if self.attn_backend.get_name() == "flash-attn": | ||
block_table = seq_group_metadata.block_tables[seq_id] | ||
else: | ||
block_table = computed_block_nums |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this suggests to me we should use some abstraction here instead of if...else branching. A method on attn_backend
perhaps
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Agreed. For now, I added a note on why we need this if
, and added TODO asking for a better abstraction.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the simplest solution here is to fix the Triton kernel so that it aligns with the other backends' APIs. I will leave this for future PR though.
oh yay! I will run the benchmark on a100 today |
This reverts commit 1356df5.
Lora 3 & 4 test seems to have illegal memory access failure after this commit; [2024-05-14 23:51:18,182 E 22 22] logging.cc:101: Unhandled exception: N3c105ErrorE. what(): CUDA error: an illegal memory access was encountered <br class="Apple-interchange-newline"> Exmaple: https://buildkite.com/vllm/ci/builds/7382#018f793d-1527-4e1c-ab59-c3a34ec55241 This reverts commit 1356df5. FILL IN THE PR DESCRIPTION HERE FIX #xxxx (link existing issues this PR will resolve)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
…lm-project#4820) Lora 3 & 4 test seems to have illegal memory access failure after this commit; [2024-05-14 23:51:18,182 E 22 22] logging.cc:101: Unhandled exception: N3c105ErrorE. what(): CUDA error: an illegal memory access was encountered <br class="Apple-interchange-newline"> Exmaple: https://buildkite.com/vllm/ci/builds/7382#018f793d-1527-4e1c-ab59-c3a34ec55241 This reverts commit 1356df5. FILL IN THE PR DESCRIPTION HERE FIX #xxxx (link existing issues this PR will resolve)
Great! Thanks for the PR. @skrider It seems that this PR has introduced the following constraint. I've been using vllm for testing the model with head size not listed in the above list. Then from now on, can't I run my model on vllm with the flash_attn option? I look forward to hearing from you. |
I think if this head size is not supported, you cannot use the flash attn (it is the limitation of flash attn). |
(to make it work, you should probably make flash attn work with unspecified head sizes) |
@rkooo567 Oh, I didn't know that flash attn has this limitation. Thanks for the information! |
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
…lm-project#4820) Lora 3 & 4 test seems to have illegal memory access failure after this commit; [2024-05-14 23:51:18,182 E 22 22] logging.cc:101: Unhandled exception: N3c105ErrorE. what(): CUDA error: an illegal memory access was encountered <br class="Apple-interchange-newline"> Exmaple: https://buildkite.com/vllm/ci/builds/7382#018f793d-1527-4e1c-ab59-c3a34ec55241 This reverts commit 1356df5. FILL IN THE PR DESCRIPTION HERE FIX #xxxx (link existing issues this PR will resolve)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
…lm-project#4820) Lora 3 & 4 test seems to have illegal memory access failure after this commit; [2024-05-14 23:51:18,182 E 22 22] logging.cc:101: Unhandled exception: N3c105ErrorE. what(): CUDA error: an illegal memory access was encountered <br class="Apple-interchange-newline"> Exmaple: https://buildkite.com/vllm/ci/builds/7382#018f793d-1527-4e1c-ab59-c3a34ec55241 This reverts commit 1356df5. FILL IN THE PR DESCRIPTION HERE FIX #xxxx (link existing issues this PR will resolve)
Vendors flash-attention from Dao-AILab/flash-attention#824, prunes out the backward pass operator for faster compile times, adds reshape and cache kernel for flash attention kv cache layout, adds logic for selecting kv cache manager / attention backend based on temporary environment variable VLLM_TEMP_USE_FLASH_DECODE. Tested for single GPU on opt-125m, llama-7b