-
Notifications
You must be signed in to change notification settings - Fork 9.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : improve batched decoding performance #3479
Comments
I mentioned flash attention in the context of prompt processing speed, but I think that batch processing may have different requirements. I suspect that we waste quite a few FLOPS during attention multiplying with masked KV blocks. If that's the case, we should look into paged attention. |
I would be very surprised if the masked KV blocks are the reason for the low performance. If I remember correctly from my experiments, even when running without continuous batching, but with many batches, the performance was still subpar, even though in this mode there are no wasted KV blocks. My understanding is paged attention only improves memory utilization. I don't think it can improve the performance, correct? Edit: At least it cannot improve the performance so much to amount to a factor of x5 or more. Maybe a few percent (<10%) at most |
This is not true. When we use KV cache, the transformer rapidly becomes memory-bound. With finer control over memory, we can ensure that the most frequently accessed data remains on the chip (eg. Registers or SRAM) rather than in global memory (eg. GDDR or HBM). |
It seems that some people in this community found out about this method months ago, but didn't give it much attention. @JohannesGaessler said this @vikigenius said this |
It is quite understandable that
See this https://huggingface.co/docs/text-generation-inference/conceptual/paged_attention |
Just measure the actual runtime taken up by the KV cache with |
@bobqianic The test case discussed in this issue is not memory-bound so paged attention does not bring any benefit in this situation. All KV cache data for up to a batch size of 60 and sequence length of 100 fits comfortably in an A100. If anything, paged attention is more likely to reduce the performance in this case |
I have been looking a bit more into how paged attention works, and what their kernel does is compute attention for each sequence separately, fetching each block as needed without requiring them to be contiguous in memory. So each KQ and KQV multiplications are only done with the KV blocks relevant to each sequence. What we do instead is calculate attention for all the blocks of all the sequences at the same time, and throw away what we don't need with the attention mask. I still suspect that is the cause of the difference in performance. Whether this is memory bound or compute bound is the minor issue. Paged attention does not reduce memory utilization compared to our implementation, only compared to implementations where all the sequences must be contiguous in the KV cache, so for example they reserve It's also worth pointing that vLLM themselves attribute their performance to paged attention, and the paged attention paper shows a 2-4x throughput improvement over the previous SotA implementations. |
My latest status is a few months back (basically the ggllm codebase). I've done a few broadcasting "dry tests", so I did not fully implement that but tested the performance. One more thing that's currently not possible to easily add into ggml-cuda is cublas based batched gemm (because the batching and broadcasting is done in the combined cuda operation wrapper. |
Adding some slides for vllm implementation shared in their meetup
Full deck: https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit#slide=id.p |
After #3749 the performance is now comparable |
Batched decoding run on 2 GPUs is slower than on 1 GPU. Environment and ContextGPU: NVIDIA A10 * 2
1 GPU$ CUDA_VISIBLE_DEVICES=0 ./batched-bench mistral-7b-instruct-v0.1.f16.gguf 18432 0 99 1 100 128 1,2,3,4,5,6,7,8,16,32,64
2 GPUs$ CUDA_VISIBLE_DEVICES=0,1 ./batched-bench mistral-7b-instruct-v0.1.f16.gguf 18432 0 99 1 100 128 1,2,3,4,5,6,7,8,16,32,64
And I also tested vLLM, It's faster on 2 GPUs than on 1 GPU. 1 GPU$ cd vllm/benchmarks
$ python3 benchmark_throughput.py --model Mistral-7B-Instruct-v0.1 --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 32 --tensor-parallel-size 1
Throughput: 1.03 requests/s, 422.29 tokens/s
$ python3 benchmark_throughput.py --model Mistral-7B-Instruct-v0.1 --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 64 --tensor-parallel-size 1
Throughput: 1.81 requests/s, 824.27 tokens/s 2 GPUs$ python3 benchmark_throughput.py --model Mistral-7B-Instruct-v0.1 --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 32 --tensor-parallel-size 2
Throughput: 1.40 requests/s, 572.78 tokens/s
$ python3 benchmark_throughput.py --model Mistral-7B-Instruct-v0.1 --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 64 --tensor-parallel-size 2
Throughput: 2.24 requests/s, 1019.82 tokens/s |
Based on info from the following post, vLLM can achieve the following speeds for parallel decoding on A100 GPU:
https://docs.mystic.ai/docs/mistral-ai-7b-vllm-fast-inference-guide
(thanks to @wsxiaoys for bringing my attention to this)
Even though
llama.cpp
's single batch inference is faster (~72 t/s) we currently don't seem to scale well with batch size. At batch size 60 for example, the performance is roughly x5 slower than what is reported in the post above.We should understand where is the bottleneck and try to optimize the performance.
As discussed with @slaren, the discrepancy is likely due to lack of Flash Attention and CUDA tensor core utilization in
llama.cpp
. Still, I wouldn't be surprised if there is some low-hanging fruit that would improve the performance similar to #3412.At the very least, we should profile things and have a better understanding where to focus in the future.
Here are some results with
llama.cpp
on A100 (48edda3) using OpenLLaMA 7B F16To measure this, I've remove the system prompt from the
parallel
example to match better the vllm test above.We count both the prompt and the generated tokens.
The text was updated successfully, but these errors were encountered: