-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Performance] Support MQA/GQA in decode phase by using FlashAttention #2744
base: main
Are you sure you want to change the base?
Conversation
It appears that the throughput test for qwen-7b has a negative improvement, and it seems that the GPU blocks have decreased significantly compared to before. Maybe there is a bug? No Flash Attention: With Flash Attention: |
@zhaoyang-star you may be interested in FlashInfer's PagedAttention which is 3x faster than vLLM's version. It supports GQA https://github.com/flashinfer-ai/flashinfer/ |
@ttbachyinsda Thanks for your feedback. As paged KV cache block size in Flash Attention must be divisible by 256, the block size is set to 256 by default in https://github.com/vllm-project/vllm/pull/2744/files#diff-ea8b8ff63961713ccb62d78e53e96404b587b7828cb9fee08a9e5576bf563673R54 So GPU Blocks will decrease significantly compared to before. It is not a bug. I have no Qwen model on my hands. Below is the throughput benchmark under CodeLLaMA-7B on A100-40GB-PCIE and speedup is ~1.07x. Could you please use this model to benchmark throughput? Note that this PR is mainly for MQA/GQA model. Qwen and CodeLLaMA-7B/13B both are MHA so they will not gain much speedup based on this PR.
|
Thanks for your valuable feedback. Yes, I am very interested in kernel optimization and I will dive on FlashInfer soon. |
Thank you for the guidance. I was testing with an RTX 3090, which might not be suitable for the changes in this PR. I will try to test the throughput benchmark under codellama-7b next. |
FlashAttention and FlashInfer are both SOTA solutions to speedup the decode phase in vLLM. We may do more research on it to decide which one to use. @WoosukKwon @zhuohan123 @casper-hansen I am glad to hear your opinions? |
In the current stage of development, it looks like #2772 will be 44.5% faster than the main branch, probably due to FlashInfer being better than FlashAttention for PagedAttention. |
@skrider is working on supporting small page sizes in FlashAttention. The block_size will be 16 after #824 is merged in FlashAttention. @zhuohan123 |
Hello! Yes, I have been working on that in flash attention, it is almost ready to be merged, just one small issue to deal with (fused RoPE). It could be vendored right now. However flashinfer is still slightly faster than flash attention 2, and then there is the issue of the fp8 kv cache and Turing support. I have talked to @simon-mo and the plan is to use flashinfer. |
@skrider @simon-mo Thanks for you informantion. Glad to see FlashInfer has better performance than FA. I noticed #2772 is working on this, but still has some issues on it. I think it is still helpful to merge this pr as an option of high performance cuda kernel. Once FlashInfer is ready we could set FlashInfer as default. |
Some comments (from PR #3010):
|
This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you! |
This pull request has merge conflicts that must be resolved before it can be |
As shown in issue #1880, vLLM's current paged attention kernel does not leverage the benefits of MQA/GQA. FlashAttention supports MQA/GQA at kernel level and has supported paged kv cache since v2.5.0.
To enjoy the benefit of MQA/GQA, this PR replaces
_paged_attention
withflash_attn_with_kvcache
.Notes:
cache_ops.cache
will cost more time than original version (block_size=16). So this PR is mainly for small batch size cases.Kernel latency
Env:
Below are results by running
benchmark_paged_attention.py
.context_len=1024 :
context_len=4096 :
E2E latency
E2E throughput
Env:
python benchmarks/benchmark_throughput.py --input-len 512 --output-len 512 --model /CodeLlama-34b-hf/ --tokenizer /CodeLlama-34b-hf/ --trust-remote-code --tensor-parallel-size 4 --enforce-eager
python benchmarks/benchmark_throughput.py --input-len 512 --output-len 512 --model /CodeLlama-34b-hf/ --tokenizer /CodeLlama-34b-hf/ --trust-remote-code --tensor-parallel-size 4 --enforce-eager --use-flash-attn
Below are results by running
benchmark_throughput.py
.