-
Notifications
You must be signed in to change notification settings - Fork 9.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cuda : improve text-generation and batched decoding performance #3776
Conversation
cuBLAS should still be optional since it increases memory usage significantly, and reduces the number of layers that can be offloaded. What is the reason for adding a new function instead of just using the same |
If this PR does what I think it does I very much do not agree with it. Half the motivation behind mmq was to reduce the VRAM usage by avoiding the allocation of a temporary buffer for the dequantized weight matrix. Even on relatively recent hardware where the speed is now comparatively worse this can be worthwhile if VRAM is tight. Also as far as I can tell there is no check for CUDA architectures without tensor cores where mmq should be universally faster or for AMD GPUs (where I don't know which one is faster). |
Likely no reason, but I wanted to eliminate the convoluted logic of
I will try to add an elegant way to accommodate old video cards and fallback to MMQ (should be easy), but for sure we are not going to hold back performance on modern hardware where there is plenty of VRAM just to be able to offload a few more layers. So this change in one way or another is going to make it to |
Let me be frank: I currently do not have the time to work on llama.cpp but I would consider this unacceptable. Even on modern hardware VRAM is a considerable bottleneck, especially on consumer products. If and when I return to working on llama.cpp this would be enough reason for me to fork the project. |
What do you propose? With this change: We gain:
We lose:
|
Firstly I propose keeping some version of the current mmq behavior, regardless of whatever is the default. Secondly the mmq functions are templates. So I suggest compiling additional variants that are optimized for small batch sizes and switching between the two variants at runtime based on batch size. |
In practice, it is probably worse than this. In theory, we should only need enough memory to store a copy of the largest tensor in F16. When offloading the output tensor this requires at least 260MB for 7B and 327MB for 13B, excluding the activations. However, the CUDA memory pool is far from optimal and will waste a lot of memory in some cases. We could improve on this somewhat. Ultimately, the best solution would be to store activations as F16 and implementing our own MMQ kernels using tensor cores. This would remove the need for these buffers entirely. |
I agree that long-term this would be the best solution. However, using tensor cores on their own does not seem to be sufficient. Several weeks ago I did a prototype mmq implementation using tensor cores and there was no speedup because the tensor cores were dramatically underutilized. In fact, even without the use of tensor cores the ALU using the current mmq utilization is underutilized. The problem is most likely related to mmq not utilizing asynchronous data loading (available on Ampere or newer). For an mmq implementation that outperforms FP16 cuBLAS for quantized models I think the following changes would be necessary:
Currently it looks like I will be busy until the end of December. Afterwards I wanted to start working on this unless someone else does something similar in the meantime. |
I agree that this is the right direction and we will do it eventually. It's just that this change is so easy (if you remove the duplicated code it's just a few lines) that I don't see any reason to not get the benefits from it.
I just realized the output tensor currently does not go through the new cublas GEMM branch because I use @JohannesGaessler When I wrote earlier that I will try to accommodate older cards and fallback to MMQ, I had something similar in mind to what you propose. |
Using |
I built a custom version of llama-cpp-python using this PR branch and re-did the tests in this post for
This is a +58% increase in the tokens/second during evaluation, and a +70% increase in the tokens/second during processing, |
Also, if you are measuring VRAM by what is printed to console that is not reflective of the actual VRAM usage. The print only includes the VRAM allocated initially and does not account for the additional VRAM potentially needed later when it may be possible to dequantize a weight matrix to either FP16 or FP32 (i.e. when mmq is not used). Instead something like nvidia-smi should be used to monitor the VRAM usage. |
I will give my input here. Mind you, this is before I got the time to test this PR: With MMQ I am able to run 7B Q4K_S models with GQA like Mistral fully offloaded on my RTX 2060 with 6 GB VRAM and at 4K context. And I still have some VRAM left to spare for other applications. With cuBLAS on master, it will slow down fast as it overflows into system ram. So for me, the VRAM savings MMQ in master provides are essential. I hope MMQ will continue to stay VRAM efficient. I do hope tensor cores can be used for it eventually and be super VRAM efficient at the same time. cuBLAS is noticeably faster (when the VRAM is not spilling over) |
I did one more llama-cpp-python build with the master branch instead of the PR branch, to better separate the changes in this PR from other changes in llama.cpp over the past few weeks (the llama.cpp in llama-cpp-python 0.2.11 is a bit outdated). These are the results:
So the bulk of the speed gains are due to this PR.
@JohannesGaessler all these VRAM values come from |
The VRAM increase will only happen once a weight matrix is actually dequantized, i.e. when the model is being evaluated with a large enough batch size (32 I think it was with this PR). Also I forgot: unless someone else has worked on this the VRAM allocated for temporary buffers is not freed until the process exits which may be relevant for measurements. |
There are tons of inference backends with super fast ampere+ support. A big draw of llama.cpp is the wide HW compatibility, especially low end. It's the last one left with decent pascal speeds. Stuff like falcon for under $800 of GPUs is IMO more worth it than a few extra t/s on already well supported platforms. |
@Ph0rk0z you have one use case, but to seriously use this software in a business context, good batched performance is necessary for quite a few use cases. I do hope they maintain good compatibility though. |
I was testing this PR and I didn't notice a difference in terms of VRAM usage compared to MMQ in master with a 3600 token prompt. While being indeed a lot faster. (batch size was set to the default, so 512) I'm very happy with it. From my point of view, this PR is safe to merge. Great job! |
To clarify my perspective: if the increase in VRAM usage is ~1% as previously suggested I am completely fine with this PR and do not think a compilation option for mmq only is necessary. However, to my understanding mmq is currently still used for the output tensor which is by far the largest and therefore requires the most VRAM. So prior to merging I would like there to be a final VRAM measurement. Also the multi GPU performance should be checked. Currently with mmq the hidden state is converted to q8_1 prior to being distributed from the main GPU to the other GPUs. This significantly reduces latency and bandwidth and therefore improves performance. So cuBLAS may still be slower in multi GPU settings even with the presence of tensor cores. Although for batched decoding in a setting with one server and multiple clients a different parallelization scheme where the GPUs run sequentially rather than in parallel would be much more efficient anyways (it is extremely unlikely that I will implement this because it is not a use case that I care about). |
It hobbles a lot of accessible hardware that people invested money into. I'm not the only one. There are no cheaper 24G cards available. Enjoy running a lot of tiny ineffectual models really really fast I guess. The v100/A100 folks will be using vllm and TGI as they currently do. You could say stay with the old version if the format didn't change so often but that hasn't been the case. So much for good Ml being accessible rather than big business oriented. |
Nvlink isn't supported on 4090. It will use faux nvlink via PCIE though. |
RTX 4090s to my knowledge do not have NVLink support.
It's due to the overhead when transferring data between GPUs. Also note that when the batch size increases the computation time per token decreases but the data transfer time per token does not nearly as much. So as the batch size increases the generation becomes increasingly bottlenecked by the interconnect speed/latency. Also comparatively faster GPUs are more bottlenecked, especially when using small models. There is no easy fix but one thing you could do is convert |
How much speedup could I expect from NVLink? I posted some more results in #3814 for 2x RTX A6000 with PCIe AFAICT. Also, it seems TG at 7B F16 now actually benefits slightly from x2 cards compared to one. Even with just this PR (#3776). It's small, but at least it is not slower now as it was before. For quantized model there is not much difference - 2x GPUs are still slower than 1 |
Well for this.. with my Nvlink-ed 3090s I went from 18.86 tokens/s to 17.5t/s so this pr is not a net benefit. That's for 70b. Setting the flag returns the old performance. I have yet to test on long contexts (this is only 22 tokens) though. Didn't test what happens with the P40s or mixed archs. I find when using CPP vs exllama, cpp beats it at first but then it falls short once you get up to 2-3k. As for nvlink, it gives a gain from .5t/s to 5t/s depending on implementation. Once peer access was enabled my t/s went up and I still see people post lower speeds. For under $100 it's worth it, for more probably not. If you train, it will have larger gains there. |
* master: (350 commits) speculative : ensure draft and target model vocab matches (ggerganov#3812) llama : correctly report GGUFv3 format (ggerganov#3818) simple : fix batch handling (ggerganov#3803) cuda : improve text-generation and batched decoding performance (ggerganov#3776) server : do not release slot on image input (ggerganov#3798) batched-bench : print params at start log : disable pid in log filenames server : add parameter -tb N, --threads-batch N (ggerganov#3584) (ggerganov#3768) server : do not block system prompt update (ggerganov#3767) sync : ggml (conv ops + cuda MSVC fixes) (ggerganov#3765) cmake : add missed dependencies (ggerganov#3763) cuda : add batched cuBLAS GEMM for faster attention (ggerganov#3749) Add more tokenizer tests (ggerganov#3742) metal : handle ggml_scale for n%4 != 0 (close ggerganov#3754) Revert "make : add optional CUDA_NATIVE_ARCH (ggerganov#2482)" issues : separate bug and enhancement template + no default title (ggerganov#3748) Update special token handling in conversion scripts for gpt2 derived tokenizers (ggerganov#3746) llama : remove token functions with `context` args in favor of `model` (ggerganov#3720) Fix baichuan convert script not detecing model (ggerganov#3739) make : add optional CUDA_NATIVE_ARCH (ggerganov#2482) ...
I got around to testing this PR on my 6GB RTX2060, and it's a mixed bag. For models that I'm able to fully offload, it is indeed an improvement in speed - for 7B Q4_K_S I am able to get a decent boost in prompt processing speed. However, when testing a 13B Q4_K_M model, I must now offload 1 layer less than before - and speeds are slightly slower. Philosophically, I do agree with the idea that llama.cpp should cater more towards hobbyist hardware like my crappy 6GB VRAM card. It fills an excellent niche for the home user with old/inferior cards, since all the alternatives are unfeasible. Supporting modern hardware is good, yes, but vllm & TGI already cater to high end cards, llama.cpp should play to it's unique strengths. |
The change from this PR combined with the idea from #3457 will most certainly improve the performance with 1 layer less compared to what was on So long term, I'm pretty sure that users with old/inferior cards will be very happy by this change. |
Don't use Q4_K_M, use K_S instead. The perplexity is almost the same and the speed is noticeably better. |
Token generation results on my Tesla P40:
@ggerganov How can this change be adapted to not cripple cards that don't have tensor cores? |
Building with |
So I'm not sure this PR is at fault, but given how much has changed it likely is the culprit. Previously, I was just testing fully offloaded 7b models on my RTX 2060. But running 13B with 25 layers offloaded, generation speed is absolutely atrocious, while prompt processing is performing as expected.
Normally I would get a generation speed of around 250 ms per token, not 1300. There is also no RAM swapping involved (I've disabled this behavior with the new driver and there's enough VRAM left anyway). |
@slaren @ggerganov Can you please test if you can reproduce the significant slowdown using partial offloading? Generation speed using 25 layers on a 13B model is 5 x slower in builds after this PR compared to ones before this PR. And my GPU has tensor cores. This is related to #3860 All testing in this PR was done using full GPU offloading (ngl=999), so it might be possible this slipped under the radar. Please always test partial offloading as well. |
@Dampfinchen what model are you using, quant, etc? |
This one. https://huggingface.co/TheBloke/Mythical-Destroyer-V2-L2-13B-GGUF/tree/main Q4_K_S. |
@slaren Since LostRuin's hardware, which is the same as mine, is unaffected, I suspect there might be an incompatibility with the latest driver 546.01 and partial offloading at play. While previously I mentioned I didn't had that issue in earlier builds, do notice I meant koboldcpp with that which is built on llama.cpp. However, after testing various llama.cpp builds, old and new I can confirm the issue is not due to tensor core support or batched CUDA processing nor cublas in general as it happens with FORCE_MMQ as well. Could you perhaps test the latest driver 546.01 in combination with partial offloading? Even if I'm using partial offloading with a 7B model using 28 layers, it's much, much slower than expected. However, full GPU offloading performs as expected. |
I am already using the driver 546.01, but I can't reproduce this under WSL. If you are building it yourself, are you sure that you are building with AVX? It is not done by default anymore. Otherwise, can you bisect exactly the commit that caused your issue? |
@slaren You are right, I did not compile it with AVX2 support. I did not notice it was compiling without it anymore, I just ticked cublas and nothing else. I've now compiled with AVX2 and speed is exactly as expected again. Thank you! |
…ganov#3776) * cuda : prints wip * cuda : new cublas gemm branch for multi-batch quantized src0 * cuda : add F32 sgemm branch * cuda : fine-tune >= VOLTA params + use MMQ only for small batches * cuda : remove duplicated cuBLAS GEMM code * cuda : add CUDA_USE_TENSOR_CORES and GGML_CUDA_FORCE_MMQ macros * build : add compile option to force use of MMQ kernels
…ce (ggerganov#3776)" This commit introduces a performance regression on my Tesla P40.
ref #3479
ref #3771
Description
This PR should improve significantly the text-generation, prompt processing and batched decoding speed for all models for NVIDIA cards with tensor cores (i.e. VOLTA, AMPERE, etc).
Prompt processing
By default
llama.cpp
usesMMQ=1
which means that the matrix-matrix multiplications for quantized models are performed with custom kernel for integer multiplications. Recently (#3412), we found out that for large batch dimension (which is the case when processing prompts),MMQ=0
offers significant performance boost by first dequantizingsrc0
to F16 and performing the GEMM using cublas. This PR essentially enables the same optimization forMMQ=1
by not using the custom kernel for batch size > 32.Batched decoding
In this mode, the batch size is larger than 1, but typically small (for example not more than 32). In #3545 we found out that the currently used constants
MMQ_X
,MMQ_Y
andNWARPS
are not optimal for small batch sizes. Probably they have been optimized for prompt processing. However, since we now fallback to cuBLAS for prompt processing, the constants can be adjusted for small batch sizes.Text-generation
So far, for the KV cache related ops (
KQ
andKQV
) we have been using custom matrix-vector kernels. For small sequence lengths (~128) and no prompt, these kernels are quite efficient. However, as the KV cache grows with the sequence length it is more efficient to use the tensor cores via cuBLAS GEMM. This PR applies this change to achieve TG improvements for all models when the context is bigIn summary, we now have the following strategy for matrix multiplications:
src0
: use custom matrix-vector kernelResults
RTX 3090
LLAMA_CUBLAS=1 make -j batched batched-bench && ./batched-bench ./models/codellama-7b/ggml-model-q4_0.gguf 8704 1 99 1 512 128 1,2,3,4,5,6,7,8,16,32,64
master
main: n_kv_max = 8704, is_pp_shared = 1, n_gpu_layers = 99, mmq = 1
PR
main: n_kv_max = 8704, is_pp_shared = 1, n_gpu_layers = 99, mmq = 1
LLAMA_CUBLAS=1 make -j batched-bench && ./batched-bench ./models/codellama-13b/ggml-model-q4_k.gguf 4096 1 99 1 512,3200 128,800 1
master
main: n_kv_max = 4096, is_pp_shared = 1, n_gpu_layers = 99, mmq = 1
PR
master
PR
RTX 4090
LLAMA_CUBLAS=1 make -j batched batched-bench && ./batched-bench ./models/codellama-7b/ggml-model-q4_0.gguf 8704 1 99 1 512 128 1,2,3,4,5,6,7,8,16,32,64
master
main: n_kv_max = 8704, is_pp_shared = 1, n_gpu_layers = 99, mmq = 1
PR
main: n_kv_max = 8704, is_pp_shared = 1, n_gpu_layers = 99, mmq = 1
LLAMA_CUBLAS=1 make -j batched-bench && ./batched-bench ./models/codellama-13b/ggml-model-q4_k.gguf 4096 1 99 1 512,3200 128,800 1
master
PR
main: n_kv_max = 4096, is_pp_shared = 1, n_gpu_layers = 99, mmq = 1
master
PR
V100
LLAMA_CUBLAS=1 make -j batched-bench && ./batched-bench ./models/openllama-7b-v2/ggml-model-q4_k.gguf 4096 1 99 1 512,3200 128,800 1
master
main: n_kv_max = 4096, is_pp_shared = 1, n_gpu_layers = 99, mmq = 1
PR
main: n_kv_max = 4096, is_pp_shared = 1, n_gpu_layers = 99, mmq = 1
A100 80GB
LLAMA_CUBLAS=1 make -j batched batched-bench && ./batched-bench ./models/codellama-7b/ggml-model-q4_0.gguf 8704 1 99 1 512 128 1,2,3,4,5,6,7,8,16,32,64
master
main: n_kv_max = 8704, is_pp_shared = 1, n_gpu_layers = 99, mmq = 1
PR
main: n_kv_max = 8704, is_pp_shared = 1, n_gpu_layers = 99, mmq = 1
LLAMA_CUBLAS=1 make -j batched-bench && ./batched-bench ./models/codellama-13b/ggml-model-q4_k.gguf 4096 1 99 1 512,3200 128,800 1
master
main: n_kv_max = 4096, is_pp_shared = 1, n_gpu_layers = 99, mmq = 1
PR
main: n_kv_max = 4096, is_pp_shared = 1, n_gpu_layers = 99, mmq = 1
PR
TODO
BACKEND_SPLIT
support forsrc0