Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : improve batched decoding performance #3479

Closed
ggerganov opened this issue Oct 4, 2023 · 12 comments · Fixed by #3749
Closed

llama : improve batched decoding performance #3479

ggerganov opened this issue Oct 4, 2023 · 12 comments · Fixed by #3749
Labels
Nvidia GPU Issues specific to Nvidia GPUs performance Speed related topics

Comments

@ggerganov
Copy link
Owner

ggerganov commented Oct 4, 2023

Based on info from the following post, vLLM can achieve the following speeds for parallel decoding on A100 GPU:

https://docs.mystic.ai/docs/mistral-ai-7b-vllm-fast-inference-guide

Batch size Tokens/s
1 46
10 400
60 1.8k

(thanks to @wsxiaoys for bringing my attention to this)

Even though llama.cpp's single batch inference is faster (~72 t/s) we currently don't seem to scale well with batch size. At batch size 60 for example, the performance is roughly x5 slower than what is reported in the post above.

We should understand where is the bottleneck and try to optimize the performance.

# batch size 1
./parallel -m ~/f16.gguf -t 1 -ngl 100 -c 8192 -b 512 -s 1 -np 1 -ns 128 -n 100 -cb

# batch size 10
./parallel -m ~/f16.gguf -t 1 -ngl 100 -c 8192 -b 512 -s 1 -np 10 -ns 128 -n 100 -cb

# batch size 60
./parallel -m ~/f16.gguf -t 1 -ngl 100 -c 8192 -b 512 -s 1 -np 60 -ns 128 -n 100 -cb

As discussed with @slaren, the discrepancy is likely due to lack of Flash Attention and CUDA tensor core utilization in llama.cpp. Still, I wouldn't be surprised if there is some low-hanging fruit that would improve the performance similar to #3412.

At the very least, we should profile things and have a better understanding where to focus in the future.


Here are some results with llama.cpp on A100 (48edda3) using OpenLLaMA 7B F16

To measure this, I've remove the system prompt from the parallel example to match better the vllm test above.
We count both the prompt and the generated tokens.

Batch size Tokens/s
1 108.29
8 247.30
10 296.58
16 368.59
32 422.33
60 489.99
64 481.83
# single batch
LLAMA_CUBLAS=1 make -j && CUDA_VISIBLE_DEVICES=5 ./parallel -m models/openllama-7b/ggml-model-f16.gguf -t 1 -ngl 100 -c 8192 -b 512 -s 1 -np 1 -ns 128 -n 100 -cb

Total prompt tokens:   2011, speed: 53.51 t/s
Total gen tokens:      2059, speed: 54.79 t/s
Total speed (AVG):           speed: 108.29 t/s
main: clearing the KV cache
Client   0, seq  126, started decoding ...
Client   0, seq  126, prompt   18 t, response   13 t, time  0.25 s, speed 126.04 t/s, cache miss 0  

Input:    If you could have any superpower, what would it be?
Response: If you could have any superpower, what would it be?

main: clearing the KV cache
Client   0, seq  127, started decoding ...
Client   0, seq  127, prompt   23 t, response   23 t, time  0.40 s, speed 113.95 t/s, cache miss 0  

Input:    Are you familiar with the Special Theory of Relativity and can you explain it to me?
Response: I have a question. Are you familiar with the Special Theory of Relativity and can you explain it to me?

main: clearing the KV cache


Total prompt tokens:   2011, speed: 53.51 t/s
Total gen tokens:      2059, speed: 54.79 t/s
Total speed (AVG):           speed: 108.29 t/s
Cache misses:             0



llama_print_timings:        load time =  3377.87 ms
llama_print_timings:      sample time =  1735.54 ms /  2187 runs   (    0.79 ms per token,  1260.13 tokens per second)
llama_print_timings: prompt eval time =  5227.17 ms /  2011 tokens (    2.60 ms per token,   384.72 tokens per second)
llama_print_timings:        eval time = 29932.81 ms /  2060 runs   (   14.53 ms per token,    68.82 tokens per second)
llama_print_timings:       total time = 37582.41 ms
# n_parallel = 8
LLAMA_CUBLAS=1 make -j && CUDA_VISIBLE_DEVICES=5 ./parallel -m models/openllama-7b/ggml-model-f16.gguf -t 1 -ngl 100 -c 8192 -b 512 -s 1 -np 8 -ns 128 -n 100 -cb

Total prompt tokens:   2011, speed: 124.95 t/s
Total gen tokens:      1969, speed: 122.34 t/s
Total speed (AVG):           speed: 247.30 t/s
Client   7, seq  119, prompt   12 t, response   38 t, time  2.34 s, speed 21.33 t/s, cache miss 0  

Input:    What is the meaning of life?
Response: Hello. This is the United States Army, and we need your help! You’ve been drafted to fight in a war against an army of zombies that have taken over the world.

Client   3, seq  117, prompt   15 t, response   46 t, time  2.82 s, speed 21.66 t/s, cache miss 0  

Input:    Tell me an interesting fact about llamas.
Response: I don't know of any interesting facts about llamas, so I searched for "interesting facts about llama" on the internet. (Search engine). I found a couple of websites and read some of them.

Client   6, seq  120, prompt   13 t, response   44 t, time  2.47 s, speed 23.06 t/s, cache miss 0  

Input:    How to get a job at Google?
Response: The job is to make sure that Google search works as intended by organizing and maintaining the database. They are also responsible for making sure that everything is running smoothly, updating the website and keeping it up-to-date.

main: clearing the KV cache


Total prompt tokens:   2011, speed: 124.95 t/s
Total gen tokens:      1969, speed: 122.34 t/s
Total speed (AVG):           speed: 247.30 t/s
Cache misses:             0



llama_print_timings:        load time =  3436.27 ms
llama_print_timings:      sample time =  1684.62 ms /  2097 runs   (    0.80 ms per token,  1244.79 tokens per second)
llama_print_timings: prompt eval time = 13690.16 ms /  3975 tokens (    3.44 ms per token,   290.35 tokens per second)
llama_print_timings:        eval time =    94.53 ms /     6 runs   (   15.75 ms per token,    63.47 tokens per second)
llama_print_timings:       total time = 16093.98 ms
# n_parallel = 10
LLAMA_CUBLAS=1 make -j && CUDA_VISIBLE_DEVICES=5 ./parallel -m models/openllama-7b/ggml-model-f16.gguf -t 1 -ngl 100 -c 8192 -b 512 -s 1 -np 10 -ns 128 -n 100 -cb

Total prompt tokens:   2011, speed: 153.91 t/s
Total gen tokens:      1864, speed: 142.66 t/s
Total speed (AVG):           speed: 296.58 t/s
Client   7, seq  127, prompt   23 t, response   19 t, time  1.06 s, speed 39.77 t/s, cache miss 0  

Input:    Are you familiar with the Special Theory of Relativity and can you explain it to me?
Response: We can try! If we go back in time, everything will be the same, right?

Client   5, seq  112, prompt   13 t, response   59 t, time  3.26 s, speed 22.08 t/s, cache miss 0  

Input:    How to get a job at Google?
Response: “I’ve been with Google for seven years. I started as a summer intern and have worked in a variety of roles, including Search Ads Product Marketing Manager and now Senior Manager of Product Management, Search Ads Strategy. For me, the most memorable aspect of working at Google is the people.

main: clearing the KV cache


Total prompt tokens:   2011, speed: 153.91 t/s
Total gen tokens:      1864, speed: 142.66 t/s
Total speed (AVG):           speed: 296.58 t/s
Cache misses:             0



llama_print_timings:        load time =  3420.25 ms
llama_print_timings:      sample time =  1693.70 ms /  1992 runs   (    0.85 ms per token,  1176.12 tokens per second)
llama_print_timings: prompt eval time = 10678.86 ms /  3870 tokens (    2.76 ms per token,   362.40 tokens per second)
llama_print_timings:        eval time =    96.14 ms /     6 runs   (   16.02 ms per token,    62.41 tokens per second)
llama_print_timings:       total time = 13064.91 ms
# n_parallel = 16
LLAMA_CUBLAS=1 make -j && CUDA_VISIBLE_DEVICES=5 ./parallel -m models/openllama-7b/ggml-model-f16.gguf -t 1 -ngl 100 -c 8192 -b 512 -s 1 -np 16 -ns 128 -n 100 -cb

Total prompt tokens:   2011, speed: 181.94 t/s
Total gen tokens:      2063, speed: 186.65 t/s
Total speed (AVG):           speed: 368.59 t/s

Input:    What is the best way to learn a new language?
Response: The easiest way to learn any language is to live with someone that speaks that language. However, if that isn’t an option, the best way to learn any language is to use a program that uses a combination of verbal learning and verbal reinforcement to help you learn. When I first started studying Russian, I used programs like Rosetta Stone (which is great for beginners), but what worked best for me was a method

Client   9, seq   90, prompt   15 t, response   71 t, time  4.76 s, speed 18.08 t/s, cache miss 0  

Input:    What is the best way to cook a steak?
Response: The best way to cook a steak is to first preheat your oven to 425 degrees. Then, lightly season both sides of the steak with salt and pepper. Put it on a baking sheet lined with aluminum foil, drizzle with olive oil, and bake it in the oven for 10 minutes, or until medium-rare.

Client  13, seq  111, prompt   15 t, response   58 t, time  3.22 s, speed 22.69 t/s, cache miss 0  

Input:    I want to learn how to play the piano.
Response: I think you are a good piano player and I can teach you all about the piano. You will learn how to play all the songs that you like on the piano in no time. I can teach you how to improve your piano playing so that you can become an even better piano player.

main: clearing the KV cache


Total prompt tokens:   2011, speed: 181.94 t/s
Total gen tokens:      2063, speed: 186.65 t/s
Total speed (AVG):           speed: 368.59 t/s
Cache misses:             0



llama_print_timings:        load time =  3391.46 ms
llama_print_timings:      sample time =  1843.20 ms /  2191 runs   (    0.84 ms per token,  1188.69 tokens per second)
llama_print_timings: prompt eval time =  8358.01 ms /  4063 tokens (    2.06 ms per token,   486.12 tokens per second)
llama_print_timings:        eval time =   200.03 ms /    12 runs   (   16.67 ms per token,    59.99 tokens per second)
llama_print_timings:       total time = 11052.24 ms
# n_parallel = 32
LLAMA_CUBLAS=1 make -j && CUDA_VISIBLE_DEVICES=5 ./parallel -m models/openllama-7b/ggml-model-f16.gguf -t 1 -ngl 100 -c 8192 -b 512 -s 1 -np 32 -ns 128 -n 100 -cb

Total prompt tokens:   2011, speed: 186.50 t/s
Total gen tokens:      2543, speed: 235.83 t/s
Total speed (AVG):           speed: 422.33 t/s
Input:    How to get a job at Google?
Response: Job Description. As an assistant, you will support the people who work at Google and our partners. This includes supporting some of the most senior leaders as they run their teams. You will have a wide variety of responsibilities, including scheduling meetings, booking travel and supporting senior leadership in planning events.

Client  19, seq   87, prompt   13 t, response   87 t, time  7.09 s, speed 14.11 t/s, cache miss 0  

Input:    How to get a job at Google?
Response: Google is a search engine for the Internet and one of the most visited sites on the Internet. However, it has not been easy to work at Google since its creation, as it has taken more than ten years to find it. At the beginning, Larry Page and Sergey Brin were looking for employees who were as intelligent as possible. They did not really understand how to work well or where to search for good workers. They simply thought

Client  25, seq  127, prompt   23 t, response   75 t, time  4.29 s, speed 22.83 t/s, cache miss 0  

Input:    Are you familiar with the Special Theory of Relativity and can you explain it to me?
Response: Yes. The Special Theory of Relativity (SR) is a theory that, in essence, says that the speed of light is constant for all observers. For example, if you have three observers at rest with respect to one another, who are moving towards each other and who have different speeds, the three will measure the same speed of light for any object that they view.

main: clearing the KV cache


Total prompt tokens:   2011, speed: 186.50 t/s
Total gen tokens:      2543, speed: 235.83 t/s
Total speed (AVG):           speed: 422.33 t/s
Cache misses:             0



llama_print_timings:        load time =  3420.38 ms
llama_print_timings:      sample time =  2267.36 ms /  2671 runs   (    0.85 ms per token,  1178.02 tokens per second)
llama_print_timings: prompt eval time =  7318.15 ms /  4535 tokens (    1.61 ms per token,   619.69 tokens per second)
llama_print_timings:        eval time =   412.22 ms /    20 runs   (   20.61 ms per token,    48.52 tokens per second)
llama_print_timings:       total time = 10782.65 ms
# n_parallel = 60
LLAMA_CUBLAS=1 make -j && CUDA_VISIBLE_DEVICES=5 ./parallel -m models/openllama-7b/ggml-model-f16.gguf -t 1 -ngl 100 -c 8192 -b 512 -s 1 -np 60 -ns 128 -n 100 -cb

Total prompt tokens:   2011, speed: 235.90 t/s
Total gen tokens:      2166, speed: 254.09 t/s
Total speed (AVG):           speed: 489.99 t/s
Client  33, seq   78, prompt   13 t, response   72 t, time  6.70 s, speed 12.69 t/s, cache miss 0  

Input:    How to get a job at Google?
Response: Assistant role at Google is one of the most important jobs in the organization. The job requires candidates who are passionate, enthusiastic and well-versed with the latest technology in the market. The candidates must be passionate and able to understand and solve problems on their own. They should also be able to collaborate with others, communicate effectively, and have a strong work ethic

Client  26, seq   77, prompt   23 t, response   77 t, time  7.00 s, speed 14.29 t/s, cache miss 0  

Input:    Are you familiar with the Special Theory of Relativity and can you explain it to me?
Response: “No, sir. It is not my specialty. I only know that the theory was first put forward by Einstein, it was quite an influential theory of his and it has been used in a lot of scientific experiments and measurements. There is a whole bunch of experiments that have been done to prove it, but I cannot explain them to you. You should speak to one of my colleagues.”

Client  29, seq  102, prompt   16 t, response   79 t, time  6.41 s, speed 14.83 t/s, cache miss 0  

Input:    What is the best way to learn a new language?
Response: Well I do know that you have to know the grammar, you have to know vocabulary, and you have to get a feel for the sounds and the way it is pronounced. You also have to know the culture of where the language is spoken. And you also have to have friends that are natives of the country to practice with, and that’s really the best way to do it.

main: clearing the KV cache


Total prompt tokens:   2011, speed: 235.90 t/s
Total gen tokens:      2166, speed: 254.09 t/s
Total speed (AVG):           speed: 489.99 t/s
Cache misses:             0



llama_print_timings:        load time =  3407.33 ms
llama_print_timings:      sample time =  1923.99 ms /  2294 runs   (    0.84 ms per token,  1192.31 tokens per second)
llama_print_timings: prompt eval time =  5760.76 ms /  4170 tokens (    1.38 ms per token,   723.86 tokens per second)
llama_print_timings:        eval time =   159.77 ms /     8 runs   (   19.97 ms per token,    50.07 tokens per second)
llama_print_timings:       total time =  8524.06 ms
# n_parallel = 64
LLAMA_CUBLAS=1 make -j && CUDA_VISIBLE_DEVICES=5 ./parallel -m models/openllama-7b/ggml-model-f16.gguf -t 1 -ngl 100 -c 8192 -b 512 -s 1 -np 64 -ns 128 -n 100 -cb

Total prompt tokens:   2011, speed: 228.04 t/s
Total gen tokens:      2238, speed: 253.78 t/s
Total speed (AVG):           speed: 481.83 t/s
Client  61, seq   61, prompt   23 t, response   77 t, time  8.09 s, speed 12.36 t/s, cache miss 0  

Input:    Are you familiar with the Special Theory of Relativity and can you explain it to me?
Response: Sure. The Special Theory of Relativity is very simply understood by the layman. It concerns the speed of light and how to measure distance. You can imagine a room with a large light bulb at one end, a meter stick on the floor and a tape measure, a ruler, etc. at the other end of the room. When we go to that far end of the room

Client  15, seq   82, prompt   23 t, response   74 t, time  7.03 s, speed 13.79 t/s, cache miss 0  

Input:    Are you familiar with the Special Theory of Relativity and can you explain it to me?
Response: Yes, you can ask me about the Special Theory of Relativity. This theory states that the speed of light in vacuum is constant and independent of the source or the observer in a coordinate system moving relative to the source. Einstein's relativity theory also states that gravity is not a force but that it can be described as the curvature of space-time.

Client  47, seq  127, prompt   23 t, response   77 t, time  5.48 s, speed 18.24 t/s, cache miss 0  

Input:    Are you familiar with the Special Theory of Relativity and can you explain it to me?
Response: I’m sure you have heard about the Special Theory of Relativity by now, although it is not very often brought up in the classroom. It is a theory developed by the famous physicist Albert Einstein that explains how space and time are interrelated. For example, if you travel fast enough across space, you would experience time as speeding up. On the other hand, in general rel

main: clearing the KV cache


Total prompt tokens:   2011, speed: 228.04 t/s
Total gen tokens:      2238, speed: 253.78 t/s
Total speed (AVG):           speed: 481.83 t/s
Cache misses:             0



llama_print_timings:        load time =  3401.75 ms
llama_print_timings:      sample time =  1976.50 ms /  2366 runs   (    0.84 ms per token,  1197.06 tokens per second)
llama_print_timings: prompt eval time =  5806.75 ms /  4234 tokens (    1.37 ms per token,   729.15 tokens per second)
llama_print_timings:        eval time =   335.70 ms /    16 runs   (   20.98 ms per token,    47.66 tokens per second)
llama_print_timings:       total time =  8817.67 ms
@ggerganov ggerganov added performance Speed related topics Nvidia GPU Issues specific to Nvidia GPUs labels Oct 4, 2023
@ggerganov ggerganov moved this to Todo in ggml : roadmap Oct 4, 2023
@slaren
Copy link
Collaborator

slaren commented Oct 4, 2023

I mentioned flash attention in the context of prompt processing speed, but I think that batch processing may have different requirements. I suspect that we waste quite a few FLOPS during attention multiplying with masked KV blocks. If that's the case, we should look into paged attention.

@ggerganov
Copy link
Owner Author

ggerganov commented Oct 4, 2023

I would be very surprised if the masked KV blocks are the reason for the low performance. If I remember correctly from my experiments, even when running without continuous batching, but with many batches, the performance was still subpar, even though in this mode there are no wasted KV blocks.

My understanding is paged attention only improves memory utilization. I don't think it can improve the performance, correct?

Edit: At least it cannot improve the performance so much to amount to a factor of x5 or more. Maybe a few percent (<10%) at most

@bobqianic
Copy link
Contributor

bobqianic commented Oct 4, 2023

only improves memory utilization

Maybe a few percent (<10%) at most

This is not true.

When we use KV cache, the transformer rapidly becomes memory-bound. With finer control over memory, we can ensure that the most frequently accessed data remains on the chip (eg. Registers or SRAM) rather than in global memory (eg. GDDR or HBM).

image

@bobqianic
Copy link
Contributor

bobqianic commented Oct 4, 2023

It seems that some people in this community found out about this method months ago, but didn't give it much attention.

@JohannesGaessler said this
image

image

@vikigenius said this

image

@bobqianic
Copy link
Contributor

bobqianic commented Oct 4, 2023

It is quite understandable that vLLM provides much better performance.

  1. Less memory wasted → Smaller KV cache → Reduced stress on global memory
  2. KV cache is paged → Fine grained memory control → More on-chip cache hits → Further relief on global memory

See this https://huggingface.co/docs/text-generation-inference/conceptual/paged_attention

@ggerganov

@JohannesGaessler
Copy link
Collaborator

At least it cannot improve the performance so much to amount to a factor of x5 or more. Maybe a few percent (<10%) at most

This is not true.

Just measure the actual runtime taken up by the KV cache with nsys profile, that will automatically give you an upper bound for how much you can reduce the runtime via any sort of optimization. For a batch size of 1 it definitely is not much but for larger batch sizes it will be more if the KV caches are not the same.

@ggerganov
Copy link
Owner Author

@bobqianic The test case discussed in this issue is not memory-bound so paged attention does not bring any benefit in this situation. All KV cache data for up to a batch size of 60 and sequence length of 100 fits comfortably in an A100. If anything, paged attention is more likely to reduce the performance in this case

@slaren
Copy link
Collaborator

slaren commented Oct 5, 2023

I have been looking a bit more into how paged attention works, and what their kernel does is compute attention for each sequence separately, fetching each block as needed without requiring them to be contiguous in memory. So each KQ and KQV multiplications are only done with the KV blocks relevant to each sequence. What we do instead is calculate attention for all the blocks of all the sequences at the same time, and throw away what we don't need with the attention mask. I still suspect that is the cause of the difference in performance. Whether this is memory bound or compute bound is the minor issue.

Paged attention does not reduce memory utilization compared to our implementation, only compared to implementations where all the sequences must be contiguous in the KV cache, so for example they reserve n_ctx KV blocks for each sequence in advance, which can be very inefficient when most sequences are shorter than that.

It's also worth pointing that vLLM themselves attribute their performance to paged attention, and the paged attention paper shows a 2-4x throughput improvement over the previous SotA implementations.

@cmp-nct
Copy link
Contributor

cmp-nct commented Oct 5, 2023

My latest status is a few months back (basically the ggllm codebase).
One significant slowdown I notice is as soon as I use more than one GPU, I think it's not optimized for multiple good GPUs.
In my case I've 4090+3090 and as soon as the second GPU is active in any way the performance drops significantly.
batched cuda uses loops which all have synchronization in between.
So multi GPU by splitting tensors reduces performance significantly, instead of boosting it. That might be different if one of the cards is very old.

I've done a few broadcasting "dry tests", so I did not fully implement that but tested the performance.
The result was that if I'd do the K/V calculations broadcasted on cuda instead of CPU I'd have magnitudes slower performance.
The "current" op() routine (for batched processing as well as for broadcasting) would do thousands of loops per multiplication (128 * batch), each loop synchronizes, transforms data, etc.
So in ggml-cuda a totally different routine for that would be needed to do that quicker.

One more thing that's currently not possible to easily add into ggml-cuda is cublas based batched gemm (because the batching and broadcasting is done in the combined cuda operation wrapper.
I really like the new kernels Johannes made, those are a significant step up but when it comes to batched processing I'll go to fp8 (or old cards fp16) cublas in batched mode. Currently lack the time to do that but the performance should be significantly better that way.

@wsxiaoys
Copy link
Contributor

wsxiaoys commented Oct 6, 2023

Adding some slides for vllm implementation shared in their meetup

IMG_7117

prompt is processed by fast attention, generation is done with PagedAttention

Full deck: https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit#slide=id.p

@ggerganov
Copy link
Owner Author

After #3749 the performance is now comparable

@ggerganov ggerganov moved this from In Progress to Done in ggml : roadmap Oct 24, 2023
@lxrite
Copy link

lxrite commented Oct 25, 2023

Batched decoding run on 2 GPUs is slower than on 1 GPU.

Environment and Context

GPU: NVIDIA A10 * 2
Model: Mistral-7B F16

$ nvidia-smi
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.141.03   Driver Version: 470.141.03   CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A10          On   | 00000000:8E:00.0 Off |                    0 |
|  0%   34C    P8    20W / 150W |      0MiB / 22731MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A10          On   | 00000000:8F:00.0 Off |                    0 |
|  0%   34C    P8    21W / 150W |      0MiB / 22731MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

1 GPU

$ CUDA_VISIBLE_DEVICES=0 ./batched-bench mistral-7b-instruct-v0.1.f16.gguf 18432 0 99 1 100 128 1,2,3,4,5,6,7,8,16,32,64
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
100 128 1 228 0.056 1793.05 4.109 31.15 4.165 54.75
100 128 2 456 0.085 2340.63 4.489 57.02 4.575 99.68
100 128 3 684 0.137 2187.56 4.551 84.38 4.688 145.91
100 128 4 912 0.169 2373.21 4.611 111.05 4.779 190.83
100 128 5 1140 0.199 2506.53 4.697 136.25 4.897 232.81
100 128 6 1368 0.250 2404.76 4.774 160.87 5.024 272.31
100 128 7 1596 0.259 2707.85 4.895 183.06 5.153 309.72
100 128 8 1824 0.291 2746.52 4.977 205.74 5.268 346.21
100 128 16 3648 0.654 2447.28 6.386 320.69 7.040 518.18
100 128 32 7296 1.584 2020.73 8.131 503.78 9.714 751.07
100 128 64 14592 4.586 1395.54 17.639 464.43 22.225 656.56

2 GPUs

$ CUDA_VISIBLE_DEVICES=0,1 ./batched-bench mistral-7b-instruct-v0.1.f16.gguf 18432 0 99 1 100 128 1,2,3,4,5,6,7,8,16,32,64
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
100 128 1 228 0.259 386.04 3.250 39.39 3.509 64.98
100 128 2 456 0.559 357.86 11.837 21.63 12.396 36.79
100 128 3 684 0.807 371.73 12.022 31.94 12.829 53.32
100 128 4 912 1.016 393.83 12.123 42.23 13.139 69.41
100 128 5 1140 0.982 509.41 12.556 50.97 13.538 84.21
100 128 6 1368 1.516 395.84 12.627 60.82 14.143 96.73
100 128 7 1596 1.444 484.87 13.095 68.42 14.539 109.77
100 128 8 1824 1.606 498.25 13.447 76.15 15.053 121.17
100 128 16 3648 3.479 459.93 15.962 128.30 19.441 187.65
100 128 32 7296 6.871 465.75 20.175 203.03 27.045 269.77
100 128 64 14592 14.141 452.58 36.306 225.64 50.447 289.25

And I also tested vLLM, It's faster on 2 GPUs than on 1 GPU.

1 GPU

$ cd vllm/benchmarks
$ python3 benchmark_throughput.py --model Mistral-7B-Instruct-v0.1 --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 32 --tensor-parallel-size 1
Throughput: 1.03 requests/s, 422.29 tokens/s
$ python3 benchmark_throughput.py --model Mistral-7B-Instruct-v0.1 --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 64 --tensor-parallel-size 1
Throughput: 1.81 requests/s, 824.27 tokens/s

2 GPUs

$ python3 benchmark_throughput.py --model Mistral-7B-Instruct-v0.1 --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 32 --tensor-parallel-size 2
Throughput: 1.40 requests/s, 572.78 tokens/s
$ python3 benchmark_throughput.py --model Mistral-7B-Instruct-v0.1 --dataset ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 64 --tensor-parallel-size 2
Throughput: 2.24 requests/s, 1019.82 tokens/s

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Nvidia GPU Issues specific to Nvidia GPUs performance Speed related topics
Projects
Status: Done
Development

Successfully merging a pull request may close this issue.

7 participants