From 31641ef5890af755c4adcd778cf2d914da06cb35 Mon Sep 17 00:00:00 2001 From: Diner Burger Date: Tue, 10 Dec 2024 15:47:27 -0500 Subject: [PATCH] Collapse everything behind 'cache-bits' argument. --- modules/exllamav2.py | 18 ++++++++---------- modules/exllamav2_hf.py | 18 ++++++++---------- modules/llamacpp_hf.py | 8 ++++---- modules/llamacpp_model.py | 15 ++++++++++++--- modules/loaders.py | 8 ++++---- modules/shared.py | 23 +++++++++++------------ modules/ui.py | 3 +-- modules/ui_model_menu.py | 9 ++++++--- 8 files changed, 54 insertions(+), 48 deletions(-) diff --git a/modules/exllamav2.py b/modules/exllamav2.py index 7b9b2ce94a..5932b31cf4 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -59,22 +59,20 @@ def from_pretrained(self, path_to_model): model.load(split) # Determine the correct cache type - kv_cache_type = 'fp16' - if shared.args.exl_cache_type: - kv_cache_type = shared.args.exl_cache_type.lower() + cache_bits = 16 + if shared.args.cache_bits: + cache_bits = shared.args.cache_bits - if kv_cache_type == 'fp16': + if cache_bits == 16: cache_type = ExLlamaV2Cache - elif kv_cache_type == 'fp8': - cache_type = ExLlamaV2Cache_8bit - elif kv_cache_type == 'q8': + elif cache_bits == 8: cache_type = ExLlamaV2Cache_Q8 - elif kv_cache_type == 'q6': + elif cache_bits == 6: cache_type = ExLlamaV2Cache_Q6 - elif kv_cache_type == 'q4': + elif cache_bits == 4: cache_type = ExLlamaV2Cache_Q4 else: - raise ValueError(f"Unknown cache kv type: {kv_cache_type}") + raise ValueError(f"Invalid kv cache bit width: {cache_bits}") # Use TP if specified if shared.args.enable_tp: diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index 88ec772f17..288ea0e3cb 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -47,22 +47,20 @@ def __init__(self, config: ExLlamaV2Config): self.ex_model.load(split) # Determine the correct cache type - kv_cache_type = 'fp16' - if shared.args.exl_cache_type: - kv_cache_type = shared.args.exl_cache_type.lower() + cache_bits = 16 + if shared.args.cache_bits: + cache_bits = shared.args.cache_bits - if kv_cache_type == 'fp16': + if cache_bits == 16: cache_type = ExLlamaV2Cache - elif kv_cache_type == 'fp8': - cache_type = ExLlamaV2Cache_8bit - elif kv_cache_type == 'q8': + elif cache_bits == 8: cache_type = ExLlamaV2Cache_Q8 - elif kv_cache_type == 'q6': + elif cache_bits == 6: cache_type = ExLlamaV2Cache_Q6 - elif kv_cache_type == 'q4': + elif cache_bits == 4: cache_type = ExLlamaV2Cache_Q4 else: - raise ValueError(f"Unknown cache kv type: {kv_cache_type}") + raise ValueError(f"Invalid kv cache bit width: {cache_bits}") # Use TP if specified diff --git a/modules/llamacpp_hf.py b/modules/llamacpp_hf.py index 15e92f2e16..8824d6b4c2 100644 --- a/modules/llamacpp_hf.py +++ b/modules/llamacpp_hf.py @@ -8,7 +8,7 @@ from transformers.modeling_outputs import CausalLMOutputWithPast from modules import shared -from modules.llamacpp_model import get_llamacpp_quant_type_for_string +from modules.llamacpp_model import get_cache_quant_type_for_bitwidth from modules.llama_cpp_python_hijack import llama_cpp_lib from modules.logging_colors import logger @@ -197,9 +197,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P 'flash_attn': shared.args.flash_attn } - if shared.args.lcpp_cache_type: - params["type_k"] = get_llamacpp_quant_type_for_string(shared.args.lcpp_cache_type) - params["type_v"] = get_llamacpp_quant_type_for_string(shared.args.lcpp_cache_type) + if shared.args.cache_bits: + params["type_k"] = get_cache_quant_type_for_bitwidth(shared.args.cache_bits) + params["type_v"] = get_cache_quant_type_for_bitwidth(shared.args.cache_bits) Llama = llama_cpp_lib().Llama model = Llama(**params) diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py index f6bdc8d526..864253ac76 100644 --- a/modules/llamacpp_model.py +++ b/modules/llamacpp_model.py @@ -30,6 +30,15 @@ 'bf16': 30, } +def get_cache_quant_type_for_bitwidth(width: int): + if width == 16: + return llamacpp_quant_mapping['fp16'] + elif width == 8: + return llamacpp_quant_mapping['q8_0'] + elif width == 4: + return llamacpp_quant_mapping['q4_0'] + else: + raise ValueError(f"Unsupported bitwidth: {width}") def get_llamacpp_quant_type_for_string(quant_type: str): quant_type = quant_type.lower() @@ -103,9 +112,9 @@ def from_pretrained(self, path): 'flash_attn': shared.args.flash_attn } - if shared.args.lcpp_cache_type: - params["type_k"] = get_llamacpp_quant_type_for_string(shared.args.lcpp_cache_type) - params["type_v"] = get_llamacpp_quant_type_for_string(shared.args.lcpp_cache_type) + if shared.args.cache_bits: + params["type_k"] = get_cache_quant_type_for_bitwidth(shared.args.cache_bits) + params["type_v"] = get_cache_quant_type_for_bitwidth(shared.args.cache_bits) result.model = Llama(**params) if cache_capacity > 0: diff --git a/modules/loaders.py b/modules/loaders.py index 73d15948f4..9312a99b76 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -31,7 +31,7 @@ 'llama.cpp': [ 'n_ctx', 'n_gpu_layers', - 'lcpp_cache_type', + 'cache_bits', 'tensor_split', 'n_batch', 'threads', @@ -53,7 +53,7 @@ 'llamacpp_HF': [ 'n_ctx', 'n_gpu_layers', - 'lcpp_cache_type', + 'cache_bits', 'tensor_split', 'n_batch', 'threads', @@ -85,7 +85,7 @@ 'no_xformers', 'no_sdpa', 'num_experts_per_token', - 'exl_cache_type', + 'cache_bits', 'autosplit', 'enable_tp', 'alpha_value', @@ -100,7 +100,7 @@ 'no_xformers', 'no_sdpa', 'num_experts_per_token', - 'exl_cache_type', + 'cache_bits', 'autosplit', 'enable_tp', 'alpha_value', diff --git a/modules/shared.py b/modules/shared.py index a3162ce929..d5b20ae303 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -47,8 +47,7 @@ 'max_updates_second': 0, 'prompt_lookup_num_tokens': 0, 'custom_stopping_strings': '', - 'lcpp_cache_type': 'fp16', - 'exl_cache_type': 'fp16', + 'cache_bits': '16', 'custom_token_bans': '', 'auto_max_new_tokens': False, 'ban_eos_token': False, @@ -127,7 +126,6 @@ group.add_argument('--mlock', action='store_true', help='Force the system to keep the model in RAM.') group.add_argument('--n-gpu-layers', type=int, default=0, help='Number of layers to offload to the GPU.') group.add_argument('--tensor_split', type=str, default=None, help='Split the model across multiple GPUs. Comma-separated list of proportions. Example: 60,40.') -group.add_argument('--lcpp_cache_type', type=str, default='fp16', help='KV cache K-quant type. May be one of fp16, q8_0, q4_0.') group.add_argument('--numa', action='store_true', help='Activate NUMA task allocation for llama.cpp.') group.add_argument('--logits_all', action='store_true', help='Needs to be set for perplexity evaluation to work. Otherwise, ignore it, as it makes prompt processing slower.') group.add_argument('--no_offload_kqv', action='store_true', help='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.') @@ -146,7 +144,6 @@ group.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.') group.add_argument('--no_xformers', action='store_true', help='Force xformers to not be used.') group.add_argument('--no_sdpa', action='store_true', help='Force Torch SDPA to not be used.') -group.add_argument('--exl_cache_type', type=str, default='fp16', help='KV cache type; may be one of FP16, FP8, Q8, Q6 or Q4.') group.add_argument('--num_experts_per_token', type=int, default=2, help='Number of experts to use for generation. Applies to MoE models like Mixtral.') group.add_argument('--enable_tp', action='store_true', help='Enable Tensor Parallelism (TP) in ExLlamaV2.') @@ -208,6 +205,10 @@ group = parser.add_argument_group('Multimodal') group.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.') +# Cache Options +group = parser.add_argument_group('Cache Options') +group.add_argument('--cache-bits', type=int, default=16, help='The number of bits to use during KV caching. Defaults to 16. Valid options are 16, 8, 4 for llama.cpp; 16, 8, 6, 4 for exllama') + # Deprecated parameters group = parser.add_argument_group('Deprecated') group.add_argument('--model_type', type=str, help='DEPRECATED') @@ -305,21 +306,19 @@ def del_key(key, fallback_set): # prevent as much breakage as possible. if not loader: if cache_8bit: - set('lcpp_cache_type', 'q8_0') - set('exl_cache_type', 'fp8') + set('cache_bits', 8) elif cache_4bit: - set('lcpp_cache_type', 'q4_0') - set('exl_cache_type', 'q4') + set('cache_bits', 4) elif loader.lower() in ['exllamav2', 'exllamav2_hf']: if cache_8bit: - set('exl_cache_type', 'fp8') + set('cache_bits', 8) elif cache_4bit: - set('exl_cache_type', 'q4') + set('cache_bits', 4) elif loader.lower() in ['llama.cpp', 'llamacpp_hf']: if cache_4bit: - set('lcpp_cache_type', 'q4_0') + set('cache_bits', 4) elif cache_8bit: - set('lcpp_cache_type', 'q8_0') + set('cache_bits', 8) del_key('cache_4bit', False) del_key('cache_8bit', False) diff --git a/modules/ui.py b/modules/ui.py index 49b2487e54..9ce554fb5b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -87,8 +87,7 @@ def list_model_elements(): 'no_xformers', 'no_sdpa', 'num_experts_per_token', - 'lcpp_cache_type', - 'exl_cache_type', + 'cache_bits', 'autosplit', 'enable_tp', 'threads', diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index 4ea1f5df74..b3785cc7b0 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -118,8 +118,7 @@ def create_ui(): shared.gradio['flash_attn'] = gr.Checkbox(label="flash_attn", value=shared.args.flash_attn, info='Use flash-attention.') shared.gradio['auto_devices'] = gr.Checkbox(label="auto-devices", value=shared.args.auto_devices) shared.gradio['tensorcores'] = gr.Checkbox(label="tensorcores", value=shared.args.tensorcores, info='NVIDIA only: use llama-cpp-python compiled with tensor cores support. This may increase performance on newer cards.') - shared.gradio['lcpp_cache_type'] = gr.Dropdown(label="cache_type", value=shared.args.lcpp_cache_type, info='KV cache type', choices=['fp16', 'q8_0', 'q4_0'] ) - shared.gradio['exl_cache_type'] = gr.Dropdown(label="cache_type", value=shared.args.exl_cache_type, info='KV cache type', choices=['fp16', 'fp8', 'q8', 'q6', 'q4']) + shared.gradio['cache_bits'] = gr.Dropdown(label="cache_bits", value=shared.args.cache_bits, choices=[16, 8, 6, 4], info='KV cache bit width. Reduces VRAM usage at the cost of accuracy. 8-bit is considered lossless, 4-bit may degrade accuracy.') shared.gradio['streaming_llm'] = gr.Checkbox(label="streaming_llm", value=shared.args.streaming_llm, info='(experimental) Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.') shared.gradio['attention_sink_size'] = gr.Number(label="attention_sink_size", value=shared.args.attention_sink_size, precision=0, info='StreamingLLM: number of sink tokens. Only used if the trimmed prompt doesn\'t share a prefix with the old prompt.') shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu, info='llama.cpp: Use llama-cpp-python compiled without GPU acceleration. Transformers: use PyTorch in CPU mode.') @@ -187,7 +186,11 @@ def create_ui(): def create_event_handlers(): - shared.gradio['loader'].change(loaders.make_loader_params_visible, gradio('loader'), gradio(loaders.get_all_params()), show_progress=False) + loaderobj: gr.Dropdown = shared.gradio['loader'] + loaderobj.change(loaders.make_loader_params_visible, gradio('loader'), gradio(loaders.get_all_params()), show_progress=False).then( + lambda loader: gr.update(choices=[16, 8, 6, 4] if loader and loader.lower() in ['exllamav2', 'exllamav2_hf'] else [16, 8, 4]), + gradio('loader'), gradio('cache_bits') + ) # In this event handler, the interface state is read and updated # with the model defaults (if any), and then the model is loaded