Skip to content

Commit

Permalink
Collapse everything behind 'cache-bits' argument.
Browse files Browse the repository at this point in the history
  • Loading branch information
dinerburger committed Dec 10, 2024
1 parent 037caec commit 31641ef
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 48 deletions.
18 changes: 8 additions & 10 deletions modules/exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,20 @@ def from_pretrained(self, path_to_model):
model.load(split)

# Determine the correct cache type
kv_cache_type = 'fp16'
if shared.args.exl_cache_type:
kv_cache_type = shared.args.exl_cache_type.lower()
cache_bits = 16
if shared.args.cache_bits:
cache_bits = shared.args.cache_bits

if kv_cache_type == 'fp16':
if cache_bits == 16:
cache_type = ExLlamaV2Cache
elif kv_cache_type == 'fp8':
cache_type = ExLlamaV2Cache_8bit
elif kv_cache_type == 'q8':
elif cache_bits == 8:
cache_type = ExLlamaV2Cache_Q8
elif kv_cache_type == 'q6':
elif cache_bits == 6:
cache_type = ExLlamaV2Cache_Q6
elif kv_cache_type == 'q4':
elif cache_bits == 4:
cache_type = ExLlamaV2Cache_Q4
else:
raise ValueError(f"Unknown cache kv type: {kv_cache_type}")
raise ValueError(f"Invalid kv cache bit width: {cache_bits}")

# Use TP if specified
if shared.args.enable_tp:
Expand Down
18 changes: 8 additions & 10 deletions modules/exllamav2_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,20 @@ def __init__(self, config: ExLlamaV2Config):
self.ex_model.load(split)

# Determine the correct cache type
kv_cache_type = 'fp16'
if shared.args.exl_cache_type:
kv_cache_type = shared.args.exl_cache_type.lower()
cache_bits = 16
if shared.args.cache_bits:
cache_bits = shared.args.cache_bits

if kv_cache_type == 'fp16':
if cache_bits == 16:
cache_type = ExLlamaV2Cache
elif kv_cache_type == 'fp8':
cache_type = ExLlamaV2Cache_8bit
elif kv_cache_type == 'q8':
elif cache_bits == 8:
cache_type = ExLlamaV2Cache_Q8
elif kv_cache_type == 'q6':
elif cache_bits == 6:
cache_type = ExLlamaV2Cache_Q6
elif kv_cache_type == 'q4':
elif cache_bits == 4:
cache_type = ExLlamaV2Cache_Q4
else:
raise ValueError(f"Unknown cache kv type: {kv_cache_type}")
raise ValueError(f"Invalid kv cache bit width: {cache_bits}")


# Use TP if specified
Expand Down
8 changes: 4 additions & 4 deletions modules/llamacpp_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from transformers.modeling_outputs import CausalLMOutputWithPast

from modules import shared
from modules.llamacpp_model import get_llamacpp_quant_type_for_string
from modules.llamacpp_model import get_cache_quant_type_for_bitwidth
from modules.llama_cpp_python_hijack import llama_cpp_lib
from modules.logging_colors import logger

Expand Down Expand Up @@ -197,9 +197,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
'flash_attn': shared.args.flash_attn
}

if shared.args.lcpp_cache_type:
params["type_k"] = get_llamacpp_quant_type_for_string(shared.args.lcpp_cache_type)
params["type_v"] = get_llamacpp_quant_type_for_string(shared.args.lcpp_cache_type)
if shared.args.cache_bits:
params["type_k"] = get_cache_quant_type_for_bitwidth(shared.args.cache_bits)
params["type_v"] = get_cache_quant_type_for_bitwidth(shared.args.cache_bits)

Llama = llama_cpp_lib().Llama
model = Llama(**params)
Expand Down
15 changes: 12 additions & 3 deletions modules/llamacpp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@
'bf16': 30,
}

def get_cache_quant_type_for_bitwidth(width: int):
if width == 16:
return llamacpp_quant_mapping['fp16']
elif width == 8:
return llamacpp_quant_mapping['q8_0']
elif width == 4:
return llamacpp_quant_mapping['q4_0']
else:
raise ValueError(f"Unsupported bitwidth: {width}")

def get_llamacpp_quant_type_for_string(quant_type: str):
quant_type = quant_type.lower()
Expand Down Expand Up @@ -103,9 +112,9 @@ def from_pretrained(self, path):
'flash_attn': shared.args.flash_attn
}

if shared.args.lcpp_cache_type:
params["type_k"] = get_llamacpp_quant_type_for_string(shared.args.lcpp_cache_type)
params["type_v"] = get_llamacpp_quant_type_for_string(shared.args.lcpp_cache_type)
if shared.args.cache_bits:
params["type_k"] = get_cache_quant_type_for_bitwidth(shared.args.cache_bits)
params["type_v"] = get_cache_quant_type_for_bitwidth(shared.args.cache_bits)

result.model = Llama(**params)
if cache_capacity > 0:
Expand Down
8 changes: 4 additions & 4 deletions modules/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
'llama.cpp': [
'n_ctx',
'n_gpu_layers',
'lcpp_cache_type',
'cache_bits',
'tensor_split',
'n_batch',
'threads',
Expand All @@ -53,7 +53,7 @@
'llamacpp_HF': [
'n_ctx',
'n_gpu_layers',
'lcpp_cache_type',
'cache_bits',
'tensor_split',
'n_batch',
'threads',
Expand Down Expand Up @@ -85,7 +85,7 @@
'no_xformers',
'no_sdpa',
'num_experts_per_token',
'exl_cache_type',
'cache_bits',
'autosplit',
'enable_tp',
'alpha_value',
Expand All @@ -100,7 +100,7 @@
'no_xformers',
'no_sdpa',
'num_experts_per_token',
'exl_cache_type',
'cache_bits',
'autosplit',
'enable_tp',
'alpha_value',
Expand Down
23 changes: 11 additions & 12 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@
'max_updates_second': 0,
'prompt_lookup_num_tokens': 0,
'custom_stopping_strings': '',
'lcpp_cache_type': 'fp16',
'exl_cache_type': 'fp16',
'cache_bits': '16',
'custom_token_bans': '',
'auto_max_new_tokens': False,
'ban_eos_token': False,
Expand Down Expand Up @@ -127,7 +126,6 @@
group.add_argument('--mlock', action='store_true', help='Force the system to keep the model in RAM.')
group.add_argument('--n-gpu-layers', type=int, default=0, help='Number of layers to offload to the GPU.')
group.add_argument('--tensor_split', type=str, default=None, help='Split the model across multiple GPUs. Comma-separated list of proportions. Example: 60,40.')
group.add_argument('--lcpp_cache_type', type=str, default='fp16', help='KV cache K-quant type. May be one of fp16, q8_0, q4_0.')
group.add_argument('--numa', action='store_true', help='Activate NUMA task allocation for llama.cpp.')
group.add_argument('--logits_all', action='store_true', help='Needs to be set for perplexity evaluation to work. Otherwise, ignore it, as it makes prompt processing slower.')
group.add_argument('--no_offload_kqv', action='store_true', help='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.')
Expand All @@ -146,7 +144,6 @@
group.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.')
group.add_argument('--no_xformers', action='store_true', help='Force xformers to not be used.')
group.add_argument('--no_sdpa', action='store_true', help='Force Torch SDPA to not be used.')
group.add_argument('--exl_cache_type', type=str, default='fp16', help='KV cache type; may be one of FP16, FP8, Q8, Q6 or Q4.')
group.add_argument('--num_experts_per_token', type=int, default=2, help='Number of experts to use for generation. Applies to MoE models like Mixtral.')
group.add_argument('--enable_tp', action='store_true', help='Enable Tensor Parallelism (TP) in ExLlamaV2.')

Expand Down Expand Up @@ -208,6 +205,10 @@
group = parser.add_argument_group('Multimodal')
group.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.')

# Cache Options
group = parser.add_argument_group('Cache Options')
group.add_argument('--cache-bits', type=int, default=16, help='The number of bits to use during KV caching. Defaults to 16. Valid options are 16, 8, 4 for llama.cpp; 16, 8, 6, 4 for exllama')

# Deprecated parameters
group = parser.add_argument_group('Deprecated')
group.add_argument('--model_type', type=str, help='DEPRECATED')
Expand Down Expand Up @@ -305,21 +306,19 @@ def del_key(key, fallback_set):
# prevent as much breakage as possible.
if not loader:
if cache_8bit:
set('lcpp_cache_type', 'q8_0')
set('exl_cache_type', 'fp8')
set('cache_bits', 8)
elif cache_4bit:
set('lcpp_cache_type', 'q4_0')
set('exl_cache_type', 'q4')
set('cache_bits', 4)
elif loader.lower() in ['exllamav2', 'exllamav2_hf']:
if cache_8bit:
set('exl_cache_type', 'fp8')
set('cache_bits', 8)
elif cache_4bit:
set('exl_cache_type', 'q4')
set('cache_bits', 4)
elif loader.lower() in ['llama.cpp', 'llamacpp_hf']:
if cache_4bit:
set('lcpp_cache_type', 'q4_0')
set('cache_bits', 4)
elif cache_8bit:
set('lcpp_cache_type', 'q8_0')
set('cache_bits', 8)

del_key('cache_4bit', False)
del_key('cache_8bit', False)
Expand Down
3 changes: 1 addition & 2 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ def list_model_elements():
'no_xformers',
'no_sdpa',
'num_experts_per_token',
'lcpp_cache_type',
'exl_cache_type',
'cache_bits',
'autosplit',
'enable_tp',
'threads',
Expand Down
9 changes: 6 additions & 3 deletions modules/ui_model_menu.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ def create_ui():
shared.gradio['flash_attn'] = gr.Checkbox(label="flash_attn", value=shared.args.flash_attn, info='Use flash-attention.')
shared.gradio['auto_devices'] = gr.Checkbox(label="auto-devices", value=shared.args.auto_devices)
shared.gradio['tensorcores'] = gr.Checkbox(label="tensorcores", value=shared.args.tensorcores, info='NVIDIA only: use llama-cpp-python compiled with tensor cores support. This may increase performance on newer cards.')
shared.gradio['lcpp_cache_type'] = gr.Dropdown(label="cache_type", value=shared.args.lcpp_cache_type, info='KV cache type', choices=['fp16', 'q8_0', 'q4_0'] )
shared.gradio['exl_cache_type'] = gr.Dropdown(label="cache_type", value=shared.args.exl_cache_type, info='KV cache type', choices=['fp16', 'fp8', 'q8', 'q6', 'q4'])
shared.gradio['cache_bits'] = gr.Dropdown(label="cache_bits", value=shared.args.cache_bits, choices=[16, 8, 6, 4], info='KV cache bit width. Reduces VRAM usage at the cost of accuracy. 8-bit is considered lossless, 4-bit may degrade accuracy.')
shared.gradio['streaming_llm'] = gr.Checkbox(label="streaming_llm", value=shared.args.streaming_llm, info='(experimental) Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
shared.gradio['attention_sink_size'] = gr.Number(label="attention_sink_size", value=shared.args.attention_sink_size, precision=0, info='StreamingLLM: number of sink tokens. Only used if the trimmed prompt doesn\'t share a prefix with the old prompt.')
shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu, info='llama.cpp: Use llama-cpp-python compiled without GPU acceleration. Transformers: use PyTorch in CPU mode.')
Expand Down Expand Up @@ -187,7 +186,11 @@ def create_ui():


def create_event_handlers():
shared.gradio['loader'].change(loaders.make_loader_params_visible, gradio('loader'), gradio(loaders.get_all_params()), show_progress=False)
loaderobj: gr.Dropdown = shared.gradio['loader']
loaderobj.change(loaders.make_loader_params_visible, gradio('loader'), gradio(loaders.get_all_params()), show_progress=False).then(
lambda loader: gr.update(choices=[16, 8, 6, 4] if loader and loader.lower() in ['exllamav2', 'exllamav2_hf'] else [16, 8, 4]),
gradio('loader'), gradio('cache_bits')
)

# In this event handler, the interface state is read and updated
# with the model defaults (if any), and then the model is loaded
Expand Down

0 comments on commit 31641ef

Please sign in to comment.