Skip to content

Commit

Permalink
Add cache_8bit option
Browse files Browse the repository at this point in the history
  • Loading branch information
oobabooga committed Nov 2, 2023
1 parent 42f8163 commit c065547
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 5 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ Optionally, you can use the following command-line flags:
|`--max_seq_len MAX_SEQ_LEN` | Maximum sequence length. |
|`--cfg-cache` | ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama. |
|`--no_flash_attn` | Force flash-attention to not be used. |
|`--cache_8bit` | Use 8-bit cache to save VRAM. |

#### AutoGPTQ

Expand Down
7 changes: 6 additions & 1 deletion modules/exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Config,
ExLlamaV2Tokenizer
)
Expand Down Expand Up @@ -57,7 +58,11 @@ def from_pretrained(self, path_to_model):
model.load(split)

tokenizer = ExLlamaV2Tokenizer(config)
cache = ExLlamaV2Cache(model)
if shared.args.cache_8bit:
cache = ExLlamaV2Cache_8bit(model)
else:
cache = ExLlamaV2Cache(model)

generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)

result = self()
Expand Down
20 changes: 16 additions & 4 deletions modules/exllamav2_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from typing import Any, Dict, Optional, Union

import torch
from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Config
)
from torch.nn import CrossEntropyLoss
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
Expand Down Expand Up @@ -40,11 +45,18 @@ def __init__(self, config: ExLlamaV2Config):
self.generation_config = GenerationConfig()
self.loras = None

self.ex_cache = ExLlamaV2Cache(self.ex_model)
self.past_seq = None
if shared.args.cache_8bit:
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model)
else:
self.ex_cache = ExLlamaV2Cache(self.ex_model)

self.past_seq = None
if shared.args.cfg_cache:
self.ex_cache_negative = ExLlamaV2Cache(self.ex_model)
if shared.args.cache_8bit:
self.ex_cache_negative = ExLlamaV2Cache_8bit(self.ex_model)
else:
self.ex_cache_negative = ExLlamaV2Cache(self.ex_model)

self.past_seq_negative = None

def _validate_model_class(self):
Expand Down
4 changes: 4 additions & 0 deletions modules/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
'gpu_split',
'max_seq_len',
'cfg_cache',
'no_flash_attn',
'cache_8bit',
'alpha_value',
'compress_pos_emb',
'use_fast',
Expand All @@ -56,6 +58,8 @@
'ExLlamav2': [
'gpu_split',
'max_seq_len',
'no_flash_attn',
'cache_8bit',
'alpha_value',
'compress_pos_emb',
],
Expand Down
1 change: 1 addition & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
parser.add_argument('--max_seq_len', type=int, default=2048, help='Maximum sequence length.')
parser.add_argument('--cfg-cache', action='store_true', help='ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama.')
parser.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.')
parser.add_argument('--cache_8bit', action='store_true', help='Use 8-bit cache to save VRAM.')

# AutoGPTQ
parser.add_argument('--triton', action='store_true', help='Use triton.')
Expand Down
2 changes: 2 additions & 0 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def list_model_elements():
'no_use_cuda_fp16',
'disable_exllama',
'cfg_cache',
'no_flash_attn',
'cache_8bit',
'threads',
'threads_batch',
'n_batch',
Expand Down
2 changes: 2 additions & 0 deletions modules/ui_model_menu.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def create_ui():
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='To enable this option, start the web UI with the --trust-remote-code flag. It is necessary for some models.', interactive=shared.args.trust_remote_code)
shared.gradio['use_fast'] = gr.Checkbox(label="use_fast", value=shared.args.use_fast, info='Set use_fast=True while loading the tokenizer. May trigger a conversion that takes several minutes.')
shared.gradio['disable_exllama'] = gr.Checkbox(label="disable_exllama", value=shared.args.disable_exllama, info='Disable ExLlama kernel.')
shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn, info='Force flash-attention to not be used.')
shared.gradio['cache_8bit'] = gr.Checkbox(label="cache_8bit", value=shared.args.cache_8bit, info='Use 8-bit cache to save VRAM.')
shared.gradio['gptq_for_llama_info'] = gr.Markdown('GPTQ-for-LLaMa support is currently only kept for compatibility with older GPUs. AutoGPTQ or ExLlama is preferred when compatible. GPTQ-for-LLaMa is installed by default with the webui on supported systems. Otherwise, it has to be installed manually following the instructions here: [instructions](https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md#installation-1).')
shared.gradio['exllama_info'] = gr.Markdown('For more information, consult the [docs](https://github.com/oobabooga/text-generation-webui/blob/main/docs/ExLlama.md).')
shared.gradio['exllama_HF_info'] = gr.Markdown('ExLlama_HF is a wrapper that lets you use ExLlama like a Transformers model, which means it can use the Transformers samplers. It\'s a bit slower than the regular ExLlama.')
Expand Down

0 comments on commit c065547

Please sign in to comment.