-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow more granular KV cache settings #6561
Conversation
Merge dev branch
Merge dev branch
Merge dev branch
Looks good, a single flag for all kv cache types and a dropdown menu are an ideal solution It would be good to change if shared.args.cache_4bit and shared.args.loader.lower() in ['exllamav2', 'exllamav2_hf']:
shared.args.kv_cache_type = 'q4'
... Feel free to continue expanding this PR, I'll merge it when you say it's ready. |
Yeah this is why I had marked it as draft; I figured we'd want to work out the little details before I got too far. Sure it makes sense to transform the legacy KV cache quant options elsewhere. Should that go in I'll make this change a little later today. While I have you: are you interested in getting the Transformers Quanto loader going too as in #6126 or should I pass on that for now and focus on Exllama and llama.cpp? |
That's what I would personally do. Something simple and explicit, as this is temporary code (although I will probaby keep the old flags for a long while, given that they have existed for many months)
If the same flag can be reused, that would be a great addition, yes. Does that work out of the box with transformers or does a new requirement have to be added? |
A very good question. I don't get to work with Transformers much since I'm GPU-poor, but I'll evaluate this on a smaller model today and report back. My current leaning is: if there are no additional requirements I'll probably hammer it in, otherwise we'll circle back and pick it up in another pass. |
modules/llamacpp_model.py
Outdated
@@ -11,6 +11,32 @@ | |||
from modules.text_generation import get_max_prompt_length | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure all of these quant types work with llama.cpp?
I did some search here: #6168
From what I saw in llama.cpp list of supported cache types is shorter than list of all quant types (there were no K-quants in supported cache quantizations).
I checked llama.cpp code a bit more now and found this: common.cpp
From what I understand this function kv_cache_type_from_str determines type for KV cache for llama.cpp.
And it seems to allow less types, it allows only: "f32", "f16", "bf16", "q8_0", "q4_0", "q4_0", "q4_1", "iq4_nl", "q5_0", "q5_1".
In other cases it fails with "Unsupported cache type".
Did you check if all the cache types for llama.cpp added in PR work? Especially those not on the list in llama.cpp code (like q6_k or q4_k)?
EDIT: somehow I misclicked and comment was added above the relevant line in code...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, thanks for pointing that out, part of finalizing this PR is checking to ensure the KV cache matrix is correct. Additionally, many of these are only supported if you build with GGML_CUDA_FA_ALL_QUANTS
, which I'm not sure if we have enabled on our wheel. I'll ensure this is all to code before I mark ready.
UPDATE:
Unsupported KV type combination for head_size 128.
Supported combinations:
- K == q4_0, V == q4_0, 4.50 BPV
- K == q8_0, V == q8_0, 8.50 BPV
- K == f16, V == f16, 16.00 BPV
Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.
So it looks like we're clamping these values to either q4_0 or q8_0, and disallowing mixing types. Bummer.
An update: transformers throws with
|
OK that'll do it for the last round of checks. I'm opening this for real review, since everything seems to be humming along nicely now. A note: I don't love the command line arguments, but llamacpp and exllama have pretty different options for kv cache quant. Happy to make changes to those as needed. |
You know, thinking on this for the day I think I'm going to collapse this down to a single command line argument: EDIT ok done I think this is ready to go. |
d41d2a6
to
31641ef
Compare
About |
Got it I’ll roll it back tomorrow. Thanks for the feedback! |
31641ef
to
037caec
Compare
@oobabooga ok I've backed out the parameter unification patch and moved back to per-loader string-based quantization specification. I've confirmed everything is still cooking nicely. Thanks again for your patience! |
b9d1128
to
9f9c13f
Compare
…port. Disallow type mixing.
9f9c13f
to
82ced8e
Compare
Thanks @dinerburger, I have made some small final changes:
Do things look right after those changes? @dinerburger @GodEmperor785 |
Best of all worlds! LGTM! ✌️ |
Looks good to me too |
Thanks @dinerburger @GodEmperor785! @dinerburger if you feel like implementing quantized cache for Transformers, that would be a nice addition. I assume the additional requirement is pure Python/PyTorch, without compiled wheels, so it should be an easy addition. |
@oobabooga I actually implemented it in this (now out of date) branch. Problem was: it was unusably bad. Generation was pure garbage. Not sure if it was because I missed setting |
Interesting, I see that you tried even the HQQ one. I found this source, which you have probably seen already int2 and int4 precision is probably round to nearest and too low precision. The HQQ int8 one would be the most likely to perform well. |
It’s funny, I actually forgot to try the HQQ one at all favoring Quanto. I’ll get this branch rebased tonight and give it a shot. |
Checklist:
This PR adds more granular support for KV cache settings, allowing:
This PR should be expanded to allow for new Quanto types as mentioned in #6126, but before I go too far I wanted to make sure this structure was appropriate.
NOTE: This should probably supersede or compliment #6280 somehow