From f8079d067d11e36a5d24ec162762e83e9d8f1f02 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 16 Dec 2023 10:52:41 -0800 Subject: [PATCH 1/8] UI: save the sent chat message on "no model is loaded" error --- modules/chat.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 7a44c03e2e..613cae1b9d 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -210,10 +210,6 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess output = copy.deepcopy(history) output = apply_extensions('history', output) state = apply_extensions('state', state) - if shared.model_name == 'None' or shared.model is None: - logger.error("No model is loaded! Select one in the Model tab.") - yield output - return visible_text = None stopping_strings = get_stopping_strings(state) @@ -252,6 +248,9 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess 'internal': output['internal'] } + if shared.model_name == 'None' or shared.model is None: + raise ValueError("No model is loaded! Select one in the Model tab.") + # Generate the prompt kwargs = { '_continue': _continue, From 0087dca2866fe38f075df7e686bec687988c7f5c Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 16 Dec 2023 12:28:51 -0800 Subject: [PATCH 2/8] Update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d35ebe04aa..4b65254d83 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github. * Dropdown menu for quickly switching between different models. * Large number of extensions (built-in and user-contributed), including Coqui TTS for realistic voice outputs, Whisper STT for voice inputs, translation, [multimodal pipelines](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal), vector databases, Stable Diffusion integration, and a lot more. See [the wiki](https://github.com/oobabooga/text-generation-webui/wiki/07-%E2%80%90-Extensions) and [the extensions directory](https://github.com/oobabooga/text-generation-webui-extensions) for details. * [Chat with custom characters](https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab#character). -* Precise chat templates for instruction-following models, including Llama-2-chat, Alpaca, Vicuna, Mistral, and many others. +* Precise chat templates for instruction-following models, including Llama-2-chat, Alpaca, Vicuna, Mistral. * LoRA: train new LoRAs with your own data, load/unload LoRAs on the fly for generation. * Transformers library integration: load models in 4-bit or 8-bit precision through bitsandbytes, use llama.cpp with transformers samplers (`llamacpp_HF` loader), CPU inference in 32-bit precision using PyTorch. * OpenAI-compatible API server with Chat and Completions endpoints -- see the [examples](https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#examples). From d2ed0a06bf0fa4956d1a338ce0a4944570130918 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 16 Dec 2023 16:34:15 -0800 Subject: [PATCH 3/8] Bump ExLlamav2 to 0.0.11 (adds Mixtral support) --- requirements.txt | 18 +++++++++--------- requirements_amd.txt | 2 +- requirements_amd_noavx2.txt | 2 +- requirements_apple_intel.txt | 2 +- requirements_apple_silicon.txt | 2 +- requirements_cpu_only.txt | 2 +- requirements_cpu_only_noavx2.txt | 2 +- requirements_noavx2.txt | 18 +++++++++--------- requirements_nowheels.txt | 2 +- 9 files changed, 25 insertions(+), 25 deletions(-) diff --git a/requirements.txt b/requirements.txt index 8e3f09cde4..a7d87a92f1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ accelerate==0.25.* colorama datasets einops -exllamav2==0.0.10; platform_system != "Darwin" and platform_machine != "x86_64" +exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64" gradio==3.50.* markdown numpy==1.24.* @@ -53,14 +53,14 @@ https://github.com/jllllll/exllama/releases/download/0.0.18/exllama-0.0.18+cu121 https://github.com/jllllll/exllama/releases/download/0.0.18/exllama-0.0.18+cu121-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" https://github.com/jllllll/exllama/releases/download/0.0.18/exllama-0.0.18+cu121-cp39-cp39-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.9" https://github.com/jllllll/exllama/releases/download/0.0.18/exllama-0.0.18+cu121-cp38-cp38-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.8" -https://github.com/turboderp/exllamav2/releases/download/v0.0.10/exllamav2-0.0.10+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/turboderp/exllamav2/releases/download/v0.0.10/exllamav2-0.0.10+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" -https://github.com/turboderp/exllamav2/releases/download/v0.0.10/exllamav2-0.0.10+cu121-cp39-cp39-win_amd64.whl; platform_system == "Windows" and python_version == "3.9" -https://github.com/turboderp/exllamav2/releases/download/v0.0.10/exllamav2-0.0.10+cu121-cp38-cp38-win_amd64.whl; platform_system == "Windows" and python_version == "3.8" -https://github.com/turboderp/exllamav2/releases/download/v0.0.10/exllamav2-0.0.10+cu121-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/turboderp/exllamav2/releases/download/v0.0.10/exllamav2-0.0.10+cu121-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" -https://github.com/turboderp/exllamav2/releases/download/v0.0.10/exllamav2-0.0.10+cu121-cp39-cp39-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.9" -https://github.com/turboderp/exllamav2/releases/download/v0.0.10/exllamav2-0.0.10+cu121-cp38-cp38-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.8" +https://github.com/turboderp/exllamav2/releases/download/v0.0.11/exllamav2-0.0.11+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/turboderp/exllamav2/releases/download/v0.0.11/exllamav2-0.0.11+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" +https://github.com/turboderp/exllamav2/releases/download/v0.0.11/exllamav2-0.0.11+cu121-cp39-cp39-win_amd64.whl; platform_system == "Windows" and python_version == "3.9" +https://github.com/turboderp/exllamav2/releases/download/v0.0.11/exllamav2-0.0.11+cu121-cp38-cp38-win_amd64.whl; platform_system == "Windows" and python_version == "3.8" +https://github.com/turboderp/exllamav2/releases/download/v0.0.11/exllamav2-0.0.11+cu121-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/turboderp/exllamav2/releases/download/v0.0.11/exllamav2-0.0.11+cu121-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" +https://github.com/turboderp/exllamav2/releases/download/v0.0.11/exllamav2-0.0.11+cu121-cp39-cp39-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.9" +https://github.com/turboderp/exllamav2/releases/download/v0.0.11/exllamav2-0.0.11+cu121-cp38-cp38-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.8" https://github.com/jllllll/flash-attention/releases/download/v2.3.4/flash_attn-2.3.4+cu121torch2.1cxx11abiFALSE-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" https://github.com/jllllll/flash-attention/releases/download/v2.3.4/flash_attn-2.3.4+cu121torch2.1cxx11abiFALSE-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" https://github.com/jllllll/flash-attention/releases/download/v2.3.4/flash_attn-2.3.4+cu121torch2.1cxx11abiFALSE-cp39-cp39-win_amd64.whl; platform_system == "Windows" and python_version == "3.9" diff --git a/requirements_amd.txt b/requirements_amd.txt index 4f4c936984..899dac530c 100644 --- a/requirements_amd.txt +++ b/requirements_amd.txt @@ -2,7 +2,7 @@ accelerate==0.25.* colorama datasets einops -exllamav2==0.0.10 +exllamav2==0.0.11 gradio==3.50.* markdown numpy==1.24.* diff --git a/requirements_amd_noavx2.txt b/requirements_amd_noavx2.txt index 915444e1ea..8ad52eab51 100644 --- a/requirements_amd_noavx2.txt +++ b/requirements_amd_noavx2.txt @@ -2,7 +2,7 @@ accelerate==0.25.* colorama datasets einops -exllamav2==0.0.10 +exllamav2==0.0.11 gradio==3.50.* markdown numpy==1.24.* diff --git a/requirements_apple_intel.txt b/requirements_apple_intel.txt index f57f054fc5..1abbc98649 100644 --- a/requirements_apple_intel.txt +++ b/requirements_apple_intel.txt @@ -2,7 +2,7 @@ accelerate==0.25.* colorama datasets einops -exllamav2==0.0.10 +exllamav2==0.0.11 gradio==3.50.* markdown numpy==1.24.* diff --git a/requirements_apple_silicon.txt b/requirements_apple_silicon.txt index bb97db70a5..e4bf88a654 100644 --- a/requirements_apple_silicon.txt +++ b/requirements_apple_silicon.txt @@ -2,7 +2,7 @@ accelerate==0.25.* colorama datasets einops -exllamav2==0.0.10 +exllamav2==0.0.11 gradio==3.50.* markdown numpy==1.24.* diff --git a/requirements_cpu_only.txt b/requirements_cpu_only.txt index 2925ba7788..81d10f772f 100644 --- a/requirements_cpu_only.txt +++ b/requirements_cpu_only.txt @@ -2,7 +2,7 @@ accelerate==0.25.* colorama datasets einops -exllamav2==0.0.10 +exllamav2==0.0.11 gradio==3.50.* markdown numpy==1.24.* diff --git a/requirements_cpu_only_noavx2.txt b/requirements_cpu_only_noavx2.txt index cc6fe912cb..0f82ef6152 100644 --- a/requirements_cpu_only_noavx2.txt +++ b/requirements_cpu_only_noavx2.txt @@ -2,7 +2,7 @@ accelerate==0.25.* colorama datasets einops -exllamav2==0.0.10 +exllamav2==0.0.11 gradio==3.50.* markdown numpy==1.24.* diff --git a/requirements_noavx2.txt b/requirements_noavx2.txt index 33a14d2ade..9a3e574528 100644 --- a/requirements_noavx2.txt +++ b/requirements_noavx2.txt @@ -2,7 +2,7 @@ accelerate==0.25.* colorama datasets einops -exllamav2==0.0.10; platform_system != "Darwin" and platform_machine != "x86_64" +exllamav2==0.0.11; platform_system != "Darwin" and platform_machine != "x86_64" gradio==3.50.* markdown numpy==1.24.* @@ -53,14 +53,14 @@ https://github.com/jllllll/exllama/releases/download/0.0.18/exllama-0.0.18+cu121 https://github.com/jllllll/exllama/releases/download/0.0.18/exllama-0.0.18+cu121-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" https://github.com/jllllll/exllama/releases/download/0.0.18/exllama-0.0.18+cu121-cp39-cp39-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.9" https://github.com/jllllll/exllama/releases/download/0.0.18/exllama-0.0.18+cu121-cp38-cp38-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.8" -https://github.com/turboderp/exllamav2/releases/download/v0.0.10/exllamav2-0.0.10+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" -https://github.com/turboderp/exllamav2/releases/download/v0.0.10/exllamav2-0.0.10+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" -https://github.com/turboderp/exllamav2/releases/download/v0.0.10/exllamav2-0.0.10+cu121-cp39-cp39-win_amd64.whl; platform_system == "Windows" and python_version == "3.9" -https://github.com/turboderp/exllamav2/releases/download/v0.0.10/exllamav2-0.0.10+cu121-cp38-cp38-win_amd64.whl; platform_system == "Windows" and python_version == "3.8" -https://github.com/turboderp/exllamav2/releases/download/v0.0.10/exllamav2-0.0.10+cu121-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" -https://github.com/turboderp/exllamav2/releases/download/v0.0.10/exllamav2-0.0.10+cu121-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" -https://github.com/turboderp/exllamav2/releases/download/v0.0.10/exllamav2-0.0.10+cu121-cp39-cp39-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.9" -https://github.com/turboderp/exllamav2/releases/download/v0.0.10/exllamav2-0.0.10+cu121-cp38-cp38-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.8" +https://github.com/turboderp/exllamav2/releases/download/v0.0.11/exllamav2-0.0.11+cu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" +https://github.com/turboderp/exllamav2/releases/download/v0.0.11/exllamav2-0.0.11+cu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" +https://github.com/turboderp/exllamav2/releases/download/v0.0.11/exllamav2-0.0.11+cu121-cp39-cp39-win_amd64.whl; platform_system == "Windows" and python_version == "3.9" +https://github.com/turboderp/exllamav2/releases/download/v0.0.11/exllamav2-0.0.11+cu121-cp38-cp38-win_amd64.whl; platform_system == "Windows" and python_version == "3.8" +https://github.com/turboderp/exllamav2/releases/download/v0.0.11/exllamav2-0.0.11+cu121-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" +https://github.com/turboderp/exllamav2/releases/download/v0.0.11/exllamav2-0.0.11+cu121-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" +https://github.com/turboderp/exllamav2/releases/download/v0.0.11/exllamav2-0.0.11+cu121-cp39-cp39-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.9" +https://github.com/turboderp/exllamav2/releases/download/v0.0.11/exllamav2-0.0.11+cu121-cp38-cp38-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.8" https://github.com/jllllll/flash-attention/releases/download/v2.3.4/flash_attn-2.3.4+cu121torch2.1cxx11abiFALSE-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11" https://github.com/jllllll/flash-attention/releases/download/v2.3.4/flash_attn-2.3.4+cu121torch2.1cxx11abiFALSE-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10" https://github.com/jllllll/flash-attention/releases/download/v2.3.4/flash_attn-2.3.4+cu121torch2.1cxx11abiFALSE-cp39-cp39-win_amd64.whl; platform_system == "Windows" and python_version == "3.9" diff --git a/requirements_nowheels.txt b/requirements_nowheels.txt index 9007bf169e..66b7afc166 100644 --- a/requirements_nowheels.txt +++ b/requirements_nowheels.txt @@ -2,7 +2,7 @@ accelerate==0.25.* colorama datasets einops -exllamav2==0.0.10 +exllamav2==0.0.11 gradio==3.50.* markdown numpy==1.24.* From 41424907b15e62195d72fa6a9db45b77a45bc929 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 16 Dec 2023 16:35:36 -0800 Subject: [PATCH 4/8] Update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 4b65254d83..57b7afeea1 100644 --- a/README.md +++ b/README.md @@ -377,7 +377,7 @@ text-generation-webui └── llama-2-13b-chat.Q4_K_M.gguf ``` -* Other models (like 16-bit transformers models and GPTQ models) are made of several files and must be placed in a subfolder. Example: +* The remaining model types (like 16-bit transformers models and GPTQ models) are made of several files and must be placed in a subfolder. Example: ``` text-generation-webui From 7a84d7b2da2b1ffb1d68c7b52af458d1b98755f1 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 16 Dec 2023 22:16:26 -0300 Subject: [PATCH 5/8] Instruct style improvements (#4951) --- css/html_instruct_style.css | 49 +++++++++++++++++++++++-------------- css/main.css | 2 +- js/main.js | 14 +++++++++++ 3 files changed, 46 insertions(+), 19 deletions(-) diff --git a/css/html_instruct_style.css b/css/html_instruct_style.css index 1908f879f7..acff04f196 100644 --- a/css/html_instruct_style.css +++ b/css/html_instruct_style.css @@ -1,10 +1,18 @@ +.chat { + background: var(--block-background-fill); + padding: 24px 19px; + padding-right: 19px !important; + border: 1px solid var(--block-border-color); + border-radius: 8px; +} + .message { display: grid; grid-template-columns: 60px 1fr; padding-bottom: 25px; font-size: 15px; font-family: 'Noto Sans', Helvetica, Arial, sans-serif; - line-height: 22px; + line-height: 24px; } .username { @@ -13,11 +21,16 @@ .message-body p, .message-body li { font-size: 15px !important; - line-height: 22.5px !important; + line-height: 24px !important; + list-style-position: outside; } .message-body p, .chat .message-body ul, .chat .message-body ol { - margin-bottom: 23.4375px !important; + margin-bottom: 16px !important; +} + +.chat .message-body ul, .chat .message-body ol { + padding-inline-start: 2em; } .message-body p:last-child, .chat .message-body ul:last-child, .chat .message-body ol:last-child { @@ -34,34 +47,34 @@ .gradio-container .chat .assistant-message { padding: 20px; - border-radius: 20px; - background-color: #0000000f; - margin-top: 9px !important; - margin-bottom: 18px !important; + background: var(--background-fill-secondary); + margin-top: 12px !important; + margin-bottom: 24px !important; + margin-right: 16px; + border-radius: 22px; + border-bottom-left-radius: 0; + border: 1px solid var(--border-color-primary); } .gradio-container .chat .user-message { padding: 20px; + background-color: var(--color-accent-soft); border-radius: 20px; - margin-bottom: 9px !important; + margin-bottom: 12px !important; + margin-left: 16px; + border-radius: 22px; + border-bottom-right-radius: 0; + border: 1px solid var(--border-color-accent-subdued); } .gradio-container .chat .assistant-message:last-child, .gradio-container .chat .user-message:last-child { margin-bottom: 0 !important; } -.dark .chat .assistant-message { - background-color: #1f2937; -} - -.dark .chat .user-message { - background-color: transparent; -} - code { - background-color: white !important; + background-color: #f3f4f6 !important; } .dark code { - background-color: #0e1321 !important; + background-color: #1f2937 !important; } \ No newline at end of file diff --git a/css/main.css b/css/main.css index a3480fe034..a53f99d025 100644 --- a/css/main.css +++ b/css/main.css @@ -332,7 +332,7 @@ div.svelte-362y77>*, div.svelte-362y77>.form>* { margin-left: auto; margin-right: auto; max-width: 880px; - height: 100%; + min-height: var(--chat-height); overflow-y: auto; padding-right: 15px; display: flex; diff --git a/js/main.js b/js/main.js index 1e50e14742..5c05b394c8 100644 --- a/js/main.js +++ b/js/main.js @@ -123,6 +123,8 @@ targetElement.addEventListener("scroll", function() { // Create a MutationObserver instance const observer = new MutationObserver(function(mutations) { mutations.forEach(function(mutation) { + updateChatHeight(); + if(!isScrolled) { targetElement.scrollTop = targetElement.scrollHeight; } @@ -373,3 +375,15 @@ function toggleBigPicture() { } } +//------------------------------------------------ +// Define the --chat-height global CSS variable to +// the height of the chat parent +//------------------------------------------------ +function updateChatHeight() { + const chatContainer = document.getElementById('chat').parentNode.parentNode.parentNode; + const newChatHeight = `${chatContainer.clientHeight}px`; + + document.documentElement.style.setProperty('--chat-height', newChatHeight); +} + +window.addEventListener('resize', updateChatHeight); From aa200f872311f298a0c369adf6c8ef4938616913 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 16 Dec 2023 19:37:25 -0800 Subject: [PATCH 6/8] UI: remove no longer necessary js in Default/Notebook tabs --- js/main.js | 50 -------------------------------------------------- 1 file changed, 50 deletions(-) diff --git a/js/main.js b/js/main.js index 5c05b394c8..9f10915d83 100644 --- a/js/main.js +++ b/js/main.js @@ -155,56 +155,6 @@ const config = { // Start observing the target element observer.observe(targetElement, config); -//------------------------------------------------ -// Notebook box scrolling -//------------------------------------------------ -const notebookElement = document.querySelector("#textbox-notebook textarea"); -let notebookScrolled = false; - -notebookElement.addEventListener("scroll", function() { - let diff = notebookElement.scrollHeight - notebookElement.clientHeight; - if(Math.abs(notebookElement.scrollTop - diff) <= 10 || diff == 0) { - notebookScrolled = false; - } else { - notebookScrolled = true; - } -}); - -const notebookObserver = new MutationObserver(function(mutations) { - mutations.forEach(function(mutation) { - if(!notebookScrolled) { - notebookElement.scrollTop = notebookElement.scrollHeight; - } - }); -}); - -notebookObserver.observe(notebookElement.parentNode.parentNode.parentNode, config); - -//------------------------------------------------ -// Default box scrolling -//------------------------------------------------ -const defaultElement = document.querySelector("#textbox-default textarea"); -let defaultScrolled = false; - -defaultElement.addEventListener("scroll", function() { - let diff = defaultElement.scrollHeight - defaultElement.clientHeight; - if(Math.abs(defaultElement.scrollTop - diff) <= 10 || diff == 0) { - defaultScrolled = false; - } else { - defaultScrolled = true; - } -}); - -const defaultObserver = new MutationObserver(function(mutations) { - mutations.forEach(function(mutation) { - if(!defaultScrolled) { - defaultElement.scrollTop = defaultElement.scrollHeight; - } - }); -}); - -defaultObserver.observe(defaultElement.parentNode.parentNode.parentNode, config); - //------------------------------------------------ // Add some scrollbars //------------------------------------------------ From 12690d3ffc8052f3379eeb60298095ddc46255b0 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 17 Dec 2023 02:01:23 -0300 Subject: [PATCH 7/8] Better HF grammar implementation (#4953) --- grammars/japanese.gbnf | 7 - grammars/json.gbnf | 23 +- grammars/json_arr.gbnf | 34 -- grammars/json_w_trailing_space.gbnf | 14 + grammars/list.gbnf | 6 +- grammars/simple_arithmetic.gbnf | 7 + modules/grammar.py | 33 -- modules/grammar/grammar_utils.py | 687 ++++++++++++++++++++++++++++ modules/grammar/logits_process.py | 104 +++++ modules/text_generation.py | 13 +- requirements.txt | 2 - requirements_amd.txt | 2 - requirements_amd_noavx2.txt | 2 - requirements_apple_intel.txt | 2 - requirements_apple_silicon.txt | 2 - requirements_cpu_only.txt | 2 - requirements_cpu_only_noavx2.txt | 2 - requirements_noavx2.txt | 2 - requirements_nowheels.txt | 2 - 19 files changed, 830 insertions(+), 116 deletions(-) delete mode 100644 grammars/japanese.gbnf delete mode 100644 grammars/json_arr.gbnf create mode 100644 grammars/json_w_trailing_space.gbnf create mode 100644 grammars/simple_arithmetic.gbnf delete mode 100644 modules/grammar.py create mode 100644 modules/grammar/grammar_utils.py create mode 100644 modules/grammar/logits_process.py diff --git a/grammars/japanese.gbnf b/grammars/japanese.gbnf deleted file mode 100644 index 43f25ab598..0000000000 --- a/grammars/japanese.gbnf +++ /dev/null @@ -1,7 +0,0 @@ -# A probably incorrect grammar for Japanese -root ::= jp-char+ ([ \t\n] jp-char+)* -jp-char ::= hiragana | katakana | punctuation | cjk -hiragana ::= [ぁ-ゟ] -katakana ::= [ァ-ヿ] -punctuation ::= [、-〾] -cjk ::= [一-鿿] diff --git a/grammars/json.gbnf b/grammars/json.gbnf index a9537cdf9f..108c6fb75f 100644 --- a/grammars/json.gbnf +++ b/grammars/json.gbnf @@ -1,25 +1,14 @@ root ::= object -value ::= object | array | string | number | ("true" | "false" | "null") ws -object ::= - "{" ws ( - string ":" ws value - ("," ws string ":" ws value)* - )? "}" ws +object ::= "{" ws ( string ":" ws value ("," ws string ":" ws value)* )? "}" + +value ::= object | array | string | number | ("true" | "false" | "null") ws -array ::= - "[" ws ( - value - ("," ws value)* - )? "]" ws +array ::= "[" ws ( value ("," ws value)* )? "]" ws -string ::= - "\"" ( - [^"\\] | - "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes - )* "\"" ws +string ::= "\"" ( [a-zA-Z0-9] )* "\"" ws number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws -# Optional space: by convention, applied in this grammar after literal chars when allowed + ws ::= ([ \t\n] ws)? diff --git a/grammars/json_arr.gbnf b/grammars/json_arr.gbnf deleted file mode 100644 index ef53e77a0b..0000000000 --- a/grammars/json_arr.gbnf +++ /dev/null @@ -1,34 +0,0 @@ -# This is the same as json.gbnf but we restrict whitespaces at the end of the root array -# Useful for generating JSON arrays - -root ::= arr -value ::= object | array | string | number | ("true" | "false" | "null") ws - -arr ::= - "[\n" ws ( - value - (",\n" ws value)* - )? "]" - -object ::= - "{" ws ( - string ":" ws value - ("," ws string ":" ws value)* - )? "}" ws - -array ::= - "[" ws ( - value - ("," ws value)* - )? "]" ws - -string ::= - "\"" ( - [^"\\] | - "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes - )* "\"" ws - -number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws - -# Optional space: by convention, applied in this grammar after literal chars when allowed -ws ::= ([ \t\n] ws)? diff --git a/grammars/json_w_trailing_space.gbnf b/grammars/json_w_trailing_space.gbnf new file mode 100644 index 0000000000..21549e0c78 --- /dev/null +++ b/grammars/json_w_trailing_space.gbnf @@ -0,0 +1,14 @@ +root ::= object + +object ::= "{" ws ( string ":" ws value ("," ws string ":" ws value)* )? "}" ws + +value ::= object | array | string | number | ("true" | "false" | "null") ws + +array ::= "[" ws ( value ("," ws value)* )? "]" ws + +string ::= "\"" ( [a-zA-Z0-9] )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + + +ws ::= ([ \t\n] ws)? diff --git a/grammars/list.gbnf b/grammars/list.gbnf index 51e6c9c4b0..fd5aea10bb 100644 --- a/grammars/list.gbnf +++ b/grammars/list.gbnf @@ -1,4 +1,2 @@ -root ::= item+ - -# Excludes various line break characters -item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n" +root ::= "1. " paragraph "\n" ([0-9] [0-9]? ". " paragraph "\n")+ +paragraph ::= [a-zA-Z'.,; ]+ \ No newline at end of file diff --git a/grammars/simple_arithmetic.gbnf b/grammars/simple_arithmetic.gbnf new file mode 100644 index 0000000000..b4451da212 --- /dev/null +++ b/grammars/simple_arithmetic.gbnf @@ -0,0 +1,7 @@ +root ::= (expr "=" ws term "\n")+ +expr ::= term ([-+*/] term)* +term ::= num | "(" ws expr ")" ws +num ::= [0-9]+ ws +ws ::= [ \t\n]* +# this is a comment + diff --git a/modules/grammar.py b/modules/grammar.py deleted file mode 100644 index 5f6ad3a637..0000000000 --- a/modules/grammar.py +++ /dev/null @@ -1,33 +0,0 @@ -from torch_grammar import GrammarSampler -from transformers.generation.logits_process import LogitsProcessor - -from modules import shared - -sampler = None -grammar = None -grammar_string = '' - - -class GrammarLogitsProcessor(LogitsProcessor): - def __init__(self, string): - - global sampler, grammar, grammar_string - - if string != grammar_string: - grammar_string = string - if string.strip() != '': - string = string.strip() + '\n' - sampler = GrammarSampler(string, 'root', shared.tokenizer) - else: - sampler = None - - if sampler is not None: - grammar = sampler.logits_processor() - else: - grammar = None - - def __call__(self, input_ids, scores): - if grammar is not None: - scores = grammar(input_ids, scores) - - return scores diff --git a/modules/grammar/grammar_utils.py b/modules/grammar/grammar_utils.py new file mode 100644 index 0000000000..4b37c24a7a --- /dev/null +++ b/modules/grammar/grammar_utils.py @@ -0,0 +1,687 @@ +''' +This file has been 100% copied from this PR to the Transformers library: +https://github.com/huggingface/transformers/pull/27557 + +Author: Saibo-creator +Author GitHub: https://github.com/Saibo-creator + +All credits go to the author. +''' + +import logging +import re +import time +from abc import ABC +from functools import lru_cache +from typing import Dict, List + +import torch + +from modules import shared + +logger = logging.getLogger(__name__) + + +######################## +# EBNF Grammar Parsing # +######################## + +END_OF_ALTERNATE_MARKER = 0 +END_OF_RULE_MARKER = 0 +TO_BE_FILLED_MARKER = 0 +REF_RULE_MARKER = 1 +LITERAL_MARKER = 2 + + +class ParseState: + def __init__(self): + self.symbol_ids = {} + self.grammar_encoding = [] # old name: out_grammar + + +def get_symbol_id(state, src): + if src not in state.symbol_ids: + state.symbol_ids[src] = len(state.symbol_ids) + return state.symbol_ids[src] + + +def generate_symbol_id(state, base_name): + next_id = len(state.symbol_ids) + state.symbol_ids[base_name + "_" + str(next_id)] = next_id + return next_id + + +def is_word_char(c): + return c.isalnum() or c == "-" or c == "_" + + +def hex_to_int(c): + if c.isdigit(): + return int(c) + elif "a" <= c.lower() <= "f": + return ord(c.lower()) - ord("a") + 10 + return -1 + + +def remove_leading_white_space(src, newline_ok): + """ + Skips over whitespace and comments in the input string. + This function processes the input string, skipping over any spaces, tabs, + and content following a '#' character, which denotes a comment. The parsing + of a comment continues until the end of the line (denoted by newline characters + '\r' or '\n'). If the 'newline_ok' parameter is set to False, the function + will stop processing and return the remaining string upon encountering a + newline character, otherwise it will skip over newline characters as well. + Parameters: + src (str): The input string to be processed. + newline_ok (bool): A flag indicating whether encountering a newline character + should stop the parsing (False) or if it should be skipped (True). + Returns: + str: The remaining portion of the input string after skipping whitespace and comments. + """ + pos = 0 + while pos < len(src) and (src[pos].isspace() or src[pos] == "#"): + if src[pos] == "#": + while pos < len(src) and src[pos] not in ("\r", "\n"): + pos += 1 + else: + if not newline_ok and src[pos] in ("\r", "\n"): + break + pos += 1 + return src[pos:] + + +def parse_name(src): + pos = 0 + while pos < len(src) and is_word_char(src[pos]): + pos += 1 + if pos == 0: + raise RuntimeError("expecting name at " + src) + return src[:pos], src[pos:] + + +def parse_char(src): + """ + parse the leading char from the input string + :param src: + :return: char, remaining_src + """ + + # if we have a backslash, it's maybe an escape + if src[0] == "\\": + esc = src[1] + if esc == "x": + first = hex_to_int(src[2]) + if first > -1: + second = hex_to_int(src[3]) + if second > -1: + return (first << 4) + second, src[4:] + raise RuntimeError("expecting \\xNN at " + src) + elif esc in ('"', "[", "]"): + return esc, src[2:] + elif esc == "r": + return "\r", src[2:] + elif esc == "n": + return "\n", src[2:] + elif esc == "t": + return "\t", src[2:] + raise RuntimeError("unknown escape at " + src) + elif src: + return src[0], src[1:] + raise RuntimeError("unexpected end of input") + + +def parse_sequence(state, src, rule_name, outbuf, is_nested): + out_start_pos = len(outbuf) + + # sequence size, will be replaced at end when known + outbuf.append(TO_BE_FILLED_MARKER) + + last_sym_start = len(outbuf) + remaining_src = src + while remaining_src: + if remaining_src[0] == '"': # literal string + remaining_src = remaining_src[1:] + last_sym_start = len(outbuf) + while remaining_src[0] != '"': + char, remaining_src = parse_char(remaining_src) + + # each char of a literal is encoded as a "range" of char - char + outbuf.append(LITERAL_MARKER) + outbuf.append(ord(char)) + outbuf.append(ord(char)) + remaining_src = remove_leading_white_space(remaining_src[1:], is_nested) + elif remaining_src[0] == "[": # char range(s) + remaining_src = remaining_src[1:] + last_sym_start = len(outbuf) + # num chars in range - replaced at end of loop + outbuf.append(TO_BE_FILLED_MARKER) + while remaining_src[0] != "]": + char, remaining_src = parse_char(remaining_src) + + outbuf.append(ord(char)) + if remaining_src[0] == "-" and remaining_src[1] != "]": + endchar_pair, remaining_src = parse_char(remaining_src[1:]) + outbuf.append(ord(endchar_pair)) + else: + # chars that aren't part of a c1-c2 range are just doubled (i.e., c-c) + outbuf.append(ord(char)) + # replace num chars with actual + outbuf[last_sym_start] = len(outbuf) - last_sym_start - 1 + remaining_src = remove_leading_white_space(remaining_src[1:], is_nested) + elif is_word_char(remaining_src[0]): # rule reference + name, remaining_src = parse_name(remaining_src) + ref_rule_id = get_symbol_id(state, name) + remaining_src = remove_leading_white_space(remaining_src, is_nested) + last_sym_start = len(outbuf) + outbuf.append(REF_RULE_MARKER) + outbuf.append(ref_rule_id) + elif remaining_src[0] == "(": # grouping + # parse nested alternates into synthesized rule + remaining_src = remove_leading_white_space(remaining_src[1:], True) + sub_rule_id = generate_symbol_id(state, rule_name) + remaining_src = parse_alternates(state, remaining_src, rule_name, sub_rule_id, True) + last_sym_start = len(outbuf) + # output reference to synthesized rule + outbuf.append(REF_RULE_MARKER) + outbuf.append(sub_rule_id) + if remaining_src[0] != ")": + raise RuntimeError("expecting ')' at " + remaining_src) + remaining_src = remove_leading_white_space(remaining_src[1:], is_nested) + elif remaining_src[0] in ("*", "+", "?"): # repetition operator + if len(outbuf) - out_start_pos - 1 == 0: + raise RuntimeError("expecting preceeding item to */+/? at " + remaining_src) + out_grammar = state.grammar_encoding + + # apply transformation to previous symbol (last_sym_start - + # end) according to rewrite rules: + # S* --> S' ::= S S' | + # S+ --> S' ::= S S' | S + # S? --> S' ::= S | + sub_rule_id = generate_symbol_id(state, rule_name) + out_grammar.append(sub_rule_id) + sub_rule_start = len(out_grammar) + # placeholder for size of 1st alternate + out_grammar.append(TO_BE_FILLED_MARKER) + # add preceding symbol to generated rule + out_grammar.extend(outbuf[last_sym_start:]) + if remaining_src[0] in ("*", "+"): + # cause generated rule to recurse + out_grammar.append(REF_RULE_MARKER) + out_grammar.append(sub_rule_id) + # apply actual size + out_grammar[sub_rule_start] = len(out_grammar) - sub_rule_start + # mark end of 1st alternate + out_grammar.append(END_OF_ALTERNATE_MARKER) + sub_rule_start = len(out_grammar) + # placeholder for size of 2nd alternate + out_grammar.append(TO_BE_FILLED_MARKER) + if remaining_src[0] == "+": + # add preceding symbol as alternate only for '+' + out_grammar.extend(outbuf[last_sym_start:]) + # apply actual size of 2nd alternate + out_grammar[sub_rule_start] = len(out_grammar) - sub_rule_start + # mark end of 2nd alternate, then end of rule + out_grammar.append(END_OF_ALTERNATE_MARKER) + out_grammar.append(END_OF_RULE_MARKER) + + # in original rule, replace previous symbol with reference to generated rule + outbuf[last_sym_start:] = [1, sub_rule_id] + + remaining_src = remove_leading_white_space(remaining_src[1:], is_nested) + else: + break + # apply actual size of this alternate sequence + outbuf[out_start_pos] = len(outbuf) - out_start_pos + # mark end of alternate + outbuf.append(END_OF_ALTERNATE_MARKER) + return remaining_src + + +def parse_alternates(state, src, rule_name, rule_id, is_nested): + outbuf = [] + remaining_src = parse_sequence(state, src, rule_name, outbuf, is_nested) + while remaining_src and remaining_src[0] == "|": + remaining_src = remove_leading_white_space(remaining_src[1:], True) + remaining_src = parse_sequence(state, remaining_src, rule_name, outbuf, is_nested) + + state.grammar_encoding.append(rule_id) + state.grammar_encoding.extend(outbuf) + state.grammar_encoding.append(0) + return remaining_src + + +def parse_rule(state, src): + name, remaining_src = parse_name(src) + remaining_src = remove_leading_white_space(remaining_src, False) + rule_id = get_symbol_id(state, name) + + if remaining_src[:3] != "::=": + raise RuntimeError("expecting ::= at " + remaining_src) + remaining_src = remove_leading_white_space(remaining_src[3:], True) + + remaining_src = parse_alternates(state, remaining_src, name, rule_id, False) + + if remaining_src and remaining_src[0] == "\r": + remaining_src = remaining_src[2:] if remaining_src[1] == "\n" else remaining_src[1:] + elif remaining_src and remaining_src[0] == "\n": + remaining_src = remaining_src[1:] + elif remaining_src: + raise RuntimeError("expecting newline or end at " + remaining_src) + return remove_leading_white_space(remaining_src, True) + + +def parse_ebnf(src): + try: + state = ParseState() + grammar_repr = remove_leading_white_space(src, True) + last_grammar_repr = "" + while grammar_repr: + if last_grammar_repr: + last_parsed_rule_len = len(last_grammar_repr) - len(grammar_repr) + logger.debug(f"last_parsed_rule: {last_grammar_repr[:last_parsed_rule_len]}") + last_grammar_repr = grammar_repr + grammar_repr = parse_rule(state, grammar_repr) + state.grammar_encoding.append(0xFFFF) + return state + except RuntimeError as err: + logger.warning("error parsing grammar:", err) + return ParseState() + + +def print_rule(file, grammar_encoding, index, symbol_id_names): + rule_id = grammar_encoding[index] + print(f"<{index}>{symbol_id_names[rule_id]} ::=", end=" ", file=file) + pos = index + 1 + while grammar_encoding[pos]: + if pos - 1 > index: + print("|", end=" ", file=file) + pos += 1 # sequence size, not needed here + while grammar_encoding[pos]: + if grammar_encoding[pos] == REF_RULE_MARKER: + ref_rule_id = grammar_encoding[pos + 1] + print( + f"<{pos}>{symbol_id_names[ref_rule_id]}", + end=" ", + file=file, + ) + pos += 2 + else: + print("<{}>[".format(pos), end="", file=file) + num_chars = grammar_encoding[pos] + pos += 1 + + for i in range(0, num_chars, 2): + print("{}-".format(chr(grammar_encoding[pos + i])), end="", file=file) + if i + 1 < num_chars: + print("{}".format(chr(grammar_encoding[pos + i + 1])), end="", file=file) + print("]", end=" ", file=file) + pos += num_chars + pos += 1 + print(file=file) + return pos + 1 + + +def print_grammar(file, state): + pos = 0 + symbol_id_names = {v: k for k, v in state.symbol_ids.items()} + print("Grammar Rules:", file=file) + + while state.grammar_encoding[pos] != 0xFFFF: + pos = print_rule(file, state.grammar_encoding, pos, symbol_id_names) + pos = 0 + print("\nBinary representation:", file=file) + while state.grammar_encoding[pos] != 0xFFFF: + print(f"{state.grammar_encoding[pos]:04x}", end=" ", file=file) + pos += 1 + print("ffff\n") + + +################################### +# EBNF Grammar Parsing ends here # +################################### + + +class GrammarConstraint(ABC): + def __init__(self, grammar_str, start_rule_name, tokenizer): + self.tt = 0 + self.nt = 0 + state = parse_ebnf(grammar_str) + grammar_encoding = state.grammar_encoding + self.start_rule_id = state.symbol_ids.get(start_rule_name) + + self.eos_token_id = tokenizer.eos_token_id + self.token_trie = TokenTrie(tokenizer) + self.tokenizer = tokenizer + self.grammar_encoding = grammar_encoding + + pos = 0 + rules: Dict[int, int] = {} + + while grammar_encoding[pos] != 0xFFFF: + rule_id = grammar_encoding[pos] + + # Store the current position in the 'rules' list at the index corresponding to rule_id. + # This effectively maps each rule_id to its position in the grammar encoding. + rules[rule_id] = pos + pos += 1 + + # Continue to the next rule in the encoding. + # The loop advances by the size indicated at the current position (grammar_encoding[pos]) + # plus one for the size field itself. + while grammar_encoding[pos]: + pos += 1 + grammar_encoding[pos] + # Now we're at the end of the rule, + # so advance to the next rule by skipping the 0, which means 'end of rule'. + pos += 1 + + self.start_rule_pos = rules[self.start_rule_id] + self.rules_pos_dict: Dict[int, int] = rules + + def init_stacks(self): + # suppose the start rule position is 0, then grammar_encoding[0] = rule_id + # grammar_encoding[1] = rule_size + # grammar_encoding[2] = rule_type + # this is why we need to add 2 to the start rule position + stack = [self.start_rule_pos + 2] + # convert to tuple for caching(immutable) + return self.advance_stack(tuple(stack)) + + # For each stack, resolve rules to find the actual characters that are + # accepted by this stack (not the set of sub-rules). + # This is where the parsing happens. + # The parsing is a top-down, left-to-right, depth-first traversal of the + # grammar. + @lru_cache(maxsize=32768) + def advance_stack(self, stack): + stack = list(stack) + # If the stack is empty, we're done. Because no more tokens should be accepted. + if len(stack) == 0: + return [stack] + + # Get the top of the stack. + pos = stack[-1] + + # If the stack head is a terminal(literal), we can resolve it immediately. + # literal is marked with 2 in the grammar encoding. + if self.grammar_encoding[pos] > 1: + return [stack] + + # The stack head is a nonterminal (a rule reference, 1 in the grammar encoding). + # Resolving this rule gives a set of one or more possible positions + # (e.g. two in `a ::= b | c`) + # We pop the current rule off the stack and, for each option, push: + # - the symbol following this symbol in the current rule; then + # - the first symbol of the resolved rule. + referenced_rule_id = self.grammar_encoding[pos + 1] + + # subpos should points to the size of the subrule + subpos = self.rules_pos_dict[referenced_rule_id] + 1 + stacks: List[List[int]] = [] + + # do depth-first search to find all possible rules and check the next terminal + # When this value is non-zero, it indicates that subpos is not yet at the end of the rule, so we can continue. + # here subpos is a pointer, and the value in the rule encoding can never be 0 except for the end of the rule. + while self.grammar_encoding[subpos]: + new_stack = stack[:-1] + if self.grammar_encoding[pos + 2]: + # check if there is a next symbol in the current rule, e.g. `a ::= b c | d` + # if yes, push the pos to rule_size to the stack + new_stack.append(pos + 2) + + # if the type of the next symbol is not "empty", push the first symbol of the resolved rule to the stack + if self.grammar_encoding[subpos + 1]: + new_stack.append(subpos + 1) + stacks.extend(self.advance_stack(tuple(new_stack))) + # The increment subpos += self.grammar_encoding[subpos] + 1 + # moves subpos forward in the grammar encoding array to the next alternative in the current rule. + subpos += self.grammar_encoding[subpos] + 1 + return stacks + + def accept_char(self, *args, **kwargs): + """Process a byte according to the grammar rules.""" + raise NotImplementedError + + def accept_token_id(self, *args, **kwargs): + """Process a token according to the grammar rules.""" + raise NotImplementedError + + def filter_vocab(self, *args, **kwargs): + raise NotImplementedError + + +class IncrementalGrammarConstraint(GrammarConstraint): + def __init__(self, grammar_str, start_rule_name, tokenizer): + super().__init__(grammar_str, start_rule_name, tokenizer) + + def accept_char(self, byte, stacks): + new_stacks = [] + for stack in stacks: + # stack is empty + if not stack: + continue + + pos = stack[-1] + num_chars = self.grammar_encoding[pos] + + # to make pos point to the size of the char range rule + pos += 1 + found = False + for i in range(0, num_chars, 2): + if self.grammar_encoding[pos + i] <= byte and byte <= self.grammar_encoding[pos + i + 1]: + found = True + break + if not found: + continue + + pos += num_chars + new_stack = stack[:-1] + if self.grammar_encoding[pos]: + new_stack.append(pos) + new_stacks.extend(self.advance_stack(tuple(new_stack))) + + return new_stacks + + def accept_string(self, string: str, stacks: List[List[int]]): + _bytes = bytes(string, "utf-8") + for byte in _bytes: + stacks = self.accept_char(byte, stacks) + return stacks + + def accept_token_id(self, token_id: int, stacks: List[List[int]]): + if token_id == self.eos_token_id: + if stacks and all(len(stack) != 0 for stack in stacks): + raise Exception( + f"At least one of the stack should be empty when EOS is reached. However, " + f"the stacks are {stacks}" + ) + return [] + + for byte in self.token_trie.id2str(token_id): + stacks = self.accept_char(byte, stacks) + # check updated stacks + # TODO, I commented this out because it will fail when the stack is empty + # empty stack means the end of the grammar + # assert stacks != [] + + return stacks + + def accept_token_ids(self, token_ids: List[int], stacks: List[List[int]], as_string=True): + if as_string: + string = self.tokenizer.decode(token_ids) + stacks = self.accept_string(string, stacks) + else: + for token_id in token_ids: + stacks = self.accept_token_id(token_id, stacks) + return stacks + + def batch_filter_vocab(self, batch_stacks, device): + batch_acceptance = [] + for stacks in batch_stacks: + batch_acceptance.append(self.filter_vocab(stacks, device)) + return torch.stack(batch_acceptance) + + def filter_vocab(self, stacks, device): + if not stacks: # Check if stacks is empty + # Handle the empty case: for example, return a tensor of False + # The size of the tensor should match the size of your vocabulary + vocab_size = len(self.token_trie) + logger.debug(f"sum of acceptance: {0}") + return torch.zeros(vocab_size, dtype=torch.bool, device=device) + + acceptance_matrix = torch.cat([self.token_acceptance_for_stack(tuple(stack), device) for stack in stacks]) + # Merge stacks: any True => True + acceptance = acceptance_matrix.reshape(len(stacks), -1).any(dim=0) + logger.debug(f"sum of acceptance: {acceptance.sum()}") + return acceptance + + # For each sub-rule in the grammar, cache whether each byte is accepted. + @lru_cache(maxsize=None) + def pos_char_acceptance(self, pos): + acceptance = [False] * 256 + num_chars = self.grammar_encoding[pos] + pos += 1 + for i in range(0, num_chars, 2): + start = self.grammar_encoding[pos + i] + end = self.grammar_encoding[pos + i + 1] + for j in range(start, end + 1): + acceptance[j] = True + return acceptance + + # Probably this should be configurable. If the grammar has an exceedingly + # large number of states, the correct setting is a tradeoff between GPU + # RAM usage and recomputation time. + # + # The main variable that pushes usage up here is number of states in the + # grammar. + @lru_cache(maxsize=32768) + def token_acceptance_for_stack(self, stack, device): + st = time.time() + stack = list(stack) # needs to come in as a tuple for lru_cache + + accepts = [False] * len(self.token_trie) + accepts[self.eos_token_id] = len(stack) == 0 + if len(stack) == 0: + logger.debug("empty stack") + + def traverse_trie(trie, stacks): + for byte, next_trie in trie.items(): + if byte == LEAF: + token_id = next_trie + if token_id != self.eos_token_id: + accepts[token_id] = bool(stacks) + continue + + new_stacks = [] + for stk in stacks: + if not stk: + continue + + pos = stk[-1] + num_chars = self.grammar_encoding[pos] + + if not self.pos_char_acceptance(pos)[byte]: + continue + + pos += num_chars + 1 + new_stack = stk[:-1] + if self.grammar_encoding[pos]: + new_stack.append(pos) + new_stacks.extend(self.advance_stack(tuple(new_stack))) + + if new_stacks: + traverse_trie(next_trie, new_stacks) + + traverse_trie(self.token_trie.trie, [stack]) + + et = time.time() - st + x = torch.tensor(accepts, dtype=torch.bool, device=device) + self.tt += et + self.nt += 1 + return x + + +class StaticGrammarConstraint(GrammarConstraint): + def __init__(self, grammar_str, start_rule_name, tokenizer): + super().__init__(grammar_str, start_rule_name, tokenizer) + + def accept_char(self): + raise NotImplementedError + + +################# +# DATA STRUCTURES +################# + + +LEAF = -1 + + +class TokenTrie: + def __init__(self, tokenizer): + self.eos_token_id = tokenizer.eos_token_id + self.tokens = [] + self.trie = {} + self.load_tokens(tokenizer) + + def id2str(self, token_id): + return self.tokens[token_id] + + def __len__(self): + return len(self.tokens) + + def load_tokens(self, tokenizer): + def replace_hex(match): + hex_value = match.group(1) + return chr(int(hex_value, 16)) + + if "gpt2" in tokenizer.__class__.__name__.lower(): + special = tokenizer.additional_special_tokens_ids + + # Here, the decoder does a string replace on a bunch of sequences + # like ' .' for '.'. This interferes with our assumptions, where a + # token should always have exactly one representation. + # Fortunately(?) text-generation-inference doesn't seem to run this + # cleanup, so we get extraneous spaces. So, in order to generate + # the right token set for TGI, we have to skip the space trimming. + # See: + # https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3588-L3600 + def fmt_token(id): + if id in special: + return None + return bytes(tokenizer.decode([id], clean_up_tokenization_spaces=False), "utf-8") + + elif "llama" in tokenizer.__class__.__name__.lower(): + + def fmt_token(id): + token = tokenizer.convert_ids_to_tokens(id) + token = re.sub(r"<0x([0-9a-fA-F]{2})>", replace_hex, token) + token = token.replace("▁", " ") + return bytes(token, "utf-8") + + else: + print("Warning: unrecognized tokenizer: using default token formatting") + + def fmt_token(id): + token = tokenizer.convert_ids_to_tokens(id) + return bytes(token, "utf-8") + + # note: vocab_size doesn't work here because there are also + # get_added_vocab() tokens + self.tokens = [fmt_token(i) for i in range(len(tokenizer.get_vocab()))] + for token_id, token_bytes in enumerate(self.tokens): + if token_bytes is not None: + self.insert_into_trie(self.trie, token_bytes, token_id) + + def insert_into_trie(self, trie, token_bytes, token_id): + current = trie + for byte in token_bytes: + if byte not in current: + current[byte] = {} + current = current[byte] + current[LEAF] = token_id + + +@lru_cache(maxsize=5) +def initialize_grammar(grammar_string): + return IncrementalGrammarConstraint(grammar_string.strip(), start_rule_name="root", tokenizer=shared.tokenizer) diff --git a/modules/grammar/logits_process.py b/modules/grammar/logits_process.py new file mode 100644 index 0000000000..85e8df687b --- /dev/null +++ b/modules/grammar/logits_process.py @@ -0,0 +1,104 @@ +''' +This file has been 100% copied from this PR to the Transformers library: +https://github.com/huggingface/transformers/pull/27557 + +Author: Saibo-creator +Author GitHub: https://github.com/Saibo-creator + +All credits go to the author. +''' + +import math + +import torch +from transformers.generation.logits_process import LogitsProcessor +from transformers.utils import add_start_docstrings + +LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`): + Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam + search or log softmax for each vocabulary token when using beam search + + Return: + `torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores. + +""" + + +class GrammarConstrainedLogitsProcessor(LogitsProcessor): + def __init__(self, grammar_constraint): + self.last_size = None + self.grammar_constraint = grammar_constraint + self.batch_stacks = None + + def filter_logits(self, logits, device): + # resolve each stack to a tensor of True/False for each token + # indicating acceptance + # acceptance = self.grammar_acceptor.filter_vocab(self.stacks, device) + acceptance = self.grammar_constraint.batch_filter_vocab(self.batch_stacks, device) + # logger.debug(acceptance) + # Logits to -inf where False + logits[~acceptance] = -math.inf + + # TODO: batching + def process_logits(self, input_ids, scores, parse_start_index=None): + """ + :param input_ids: + :param scores: + :param parse_start_index: default None, which means generate from scratch. Set to 0 to parse all input_ids + :return: + """ + # we dynamically create stacks at the first call, so that we know the batch size and beam size + if self.batch_stacks is None: + self.batch_stacks = [self.grammar_constraint.init_stacks() for _ in range(len(input_ids))] + + # if self.last_size is not set (which would be the case when processing the first token). + # In this case, do nothing. + if self.last_size is None: + prefix_to_parse = [ + single_input_ids[parse_start_index:] if parse_start_index is not None else [] + for single_input_ids in input_ids + ] + # self.grammar_acceptor.accept_token_ids(prefix_to_parse, self.stacks) + self.batch_stacks = [ + self.grammar_constraint.accept_token_ids(prefix, stack) + for prefix, stack in zip(prefix_to_parse, self.batch_stacks) + ] + # if the length of the current input IDs (input_ids[0]) is exactly one more than self.last_size. + # This is expected in a scenario where inputs are processed incrementally, one token at a time. + elif len(input_ids[0]) == self.last_size + 1: + # self.stacks = self.grammar_acceptor.accept_token_id(input_ids[0][-1], self.stacks) + self.batch_stacks = [ + self.grammar_constraint.accept_token_id(single_input_ids[-1], stack) + for single_input_ids, stack in zip(input_ids, self.batch_stacks) + ] + # ensure that the input size is consistent with the expected incremental processing + # (i.e., one token at a time). + else: + # here we check if the input_ids are one token longer than the last time we processed + # but we don't check if input_ids are actually valid. + # Imagine a scenario where we generate 10 tokens, then we replace the 10 generated tokens with 10 new tokens. + # In this case, the input_ids will be consistent with the last_size, but the input_ids are not valid. + # However, should we really check if the input_ids are valid here? + # If we do, then we need to reparse the whole input_ids at each call, which is not efficient. + # Maybe we should just trust the user to provide valid input_ids? + # The conclusion is that, we assume the input_ids are valid, and our generation will be correct. + # If the input_ids are not valid, then the generation result will be wrong and we don't take responsibility for that. + raise RuntimeError( + "Input ID's length is inconsistent with the current state of " + "the GrammarConstrainedLogitsProcessor. If you want to process " + "another input sequence, please instantiate a new " + "GrammarConstrainedLogitsProcessor." + ) + + self.filter_logits(scores, scores.device) + + self.last_size = len(input_ids[0]) + return scores + + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + return self.process_logits(input_ids, scores) diff --git a/modules/text_generation.py b/modules/text_generation.py index f576ba839b..72ccf99600 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -18,7 +18,8 @@ _StopEverythingStoppingCriteria ) from modules.extensions import apply_extensions -from modules.grammar import GrammarLogitsProcessor +from modules.grammar.grammar_utils import initialize_grammar +from modules.grammar.logits_process import GrammarConstrainedLogitsProcessor from modules.html_generator import generate_4chan_html, generate_basic_html from modules.logging_colors import logger from modules.models import clear_torch_cache, local_rank @@ -317,11 +318,17 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings generate_params['stopping_criteria'] = transformers.StoppingCriteriaList() generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria()) + # Logits processor processor = state.get('logits_processor', LogitsProcessorList([])) - # In case a processor is passed by itself. if not isinstance(processor, LogitsProcessorList): processor = LogitsProcessorList([processor]) - processor.append(GrammarLogitsProcessor(state['grammar_string'])) + + # Grammar + if state['grammar_string'].strip() != '': + grammar = initialize_grammar(state['grammar_string']) + grammar_processor = GrammarConstrainedLogitsProcessor(grammar) + processor.append(grammar_processor) + apply_extensions('logits_processor', processor, input_ids) generate_params['logits_processor'] = processor diff --git a/requirements.txt b/requirements.txt index a7d87a92f1..827e7654ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,8 +20,6 @@ transformers==4.36.* tqdm wandb -git+https://github.com/oobabooga/torch-grammar.git - # bitsandbytes bitsandbytes==0.41.1; platform_system != "Windows" https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows" diff --git a/requirements_amd.txt b/requirements_amd.txt index 899dac530c..bd8ccbd623 100644 --- a/requirements_amd.txt +++ b/requirements_amd.txt @@ -20,8 +20,6 @@ transformers==4.36.* tqdm wandb -git+https://github.com/oobabooga/torch-grammar.git - # bitsandbytes bitsandbytes==0.38.1; platform_system != "Windows" https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.38.1-py3-none-win_amd64.whl; platform_system == "Windows" diff --git a/requirements_amd_noavx2.txt b/requirements_amd_noavx2.txt index 8ad52eab51..d7e517066a 100644 --- a/requirements_amd_noavx2.txt +++ b/requirements_amd_noavx2.txt @@ -20,8 +20,6 @@ transformers==4.36.* tqdm wandb -git+https://github.com/oobabooga/torch-grammar.git - # bitsandbytes bitsandbytes==0.38.1; platform_system != "Windows" https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.38.1-py3-none-win_amd64.whl; platform_system == "Windows" diff --git a/requirements_apple_intel.txt b/requirements_apple_intel.txt index 1abbc98649..f0ed23411c 100644 --- a/requirements_apple_intel.txt +++ b/requirements_apple_intel.txt @@ -20,8 +20,6 @@ transformers==4.36.* tqdm wandb -git+https://github.com/oobabooga/torch-grammar.git - # bitsandbytes bitsandbytes==0.41.1; platform_system != "Windows" https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows" diff --git a/requirements_apple_silicon.txt b/requirements_apple_silicon.txt index e4bf88a654..201a55a89c 100644 --- a/requirements_apple_silicon.txt +++ b/requirements_apple_silicon.txt @@ -20,8 +20,6 @@ transformers==4.36.* tqdm wandb -git+https://github.com/oobabooga/torch-grammar.git - # bitsandbytes bitsandbytes==0.41.1; platform_system != "Windows" https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows" diff --git a/requirements_cpu_only.txt b/requirements_cpu_only.txt index 81d10f772f..7bd9da9e0c 100644 --- a/requirements_cpu_only.txt +++ b/requirements_cpu_only.txt @@ -20,8 +20,6 @@ transformers==4.36.* tqdm wandb -git+https://github.com/oobabooga/torch-grammar.git - # bitsandbytes bitsandbytes==0.41.1; platform_system != "Windows" https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows" diff --git a/requirements_cpu_only_noavx2.txt b/requirements_cpu_only_noavx2.txt index 0f82ef6152..d9b73ef9e1 100644 --- a/requirements_cpu_only_noavx2.txt +++ b/requirements_cpu_only_noavx2.txt @@ -20,8 +20,6 @@ transformers==4.36.* tqdm wandb -git+https://github.com/oobabooga/torch-grammar.git - # bitsandbytes bitsandbytes==0.41.1; platform_system != "Windows" https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows" diff --git a/requirements_noavx2.txt b/requirements_noavx2.txt index 9a3e574528..a193967dc1 100644 --- a/requirements_noavx2.txt +++ b/requirements_noavx2.txt @@ -20,8 +20,6 @@ transformers==4.36.* tqdm wandb -git+https://github.com/oobabooga/torch-grammar.git - # bitsandbytes bitsandbytes==0.41.1; platform_system != "Windows" https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows" diff --git a/requirements_nowheels.txt b/requirements_nowheels.txt index 66b7afc166..4c1161f985 100644 --- a/requirements_nowheels.txt +++ b/requirements_nowheels.txt @@ -20,8 +20,6 @@ transformers==4.36.* tqdm wandb -git+https://github.com/oobabooga/torch-grammar.git - # bitsandbytes bitsandbytes==0.41.1; platform_system != "Windows" https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl; platform_system == "Windows" From f1f2c4c3f4e478ab0ff86261c23ce6f2fe2750dc Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 17 Dec 2023 12:08:33 -0300 Subject: [PATCH 8/8] Add --num_experts_per_token parameter (ExLlamav2) (#4955) --- README.md | 1 + modules/exllamav2.py | 1 + modules/exllamav2_hf.py | 1 + modules/loaders.py | 30 ++++++++++++++++-------------- modules/shared.py | 1 + modules/ui.py | 1 + modules/ui_model_menu.py | 1 + 7 files changed, 22 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 57b7afeea1..ad8087ee60 100644 --- a/README.md +++ b/README.md @@ -274,6 +274,7 @@ List of command-line flags |`--cfg-cache` | ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama. | |`--no_flash_attn` | Force flash-attention to not be used. | |`--cache_8bit` | Use 8-bit cache to save VRAM. | +|`--num_experts_per_token NUM_EXPERTS_PER_TOKEN` | Number of experts to use for generation. Applies to MoE models like Mixtral. | #### AutoGPTQ diff --git a/modules/exllamav2.py b/modules/exllamav2.py index d755a36a31..2cf4a03971 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -48,6 +48,7 @@ def from_pretrained(self, path_to_model): config.scale_pos_emb = shared.args.compress_pos_emb config.scale_alpha_value = shared.args.alpha_value config.no_flash_attn = shared.args.no_flash_attn + config.num_experts_per_token = int(shared.args.num_experts_per_token) model = ExLlamaV2(config) diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index 30e3fe4888..944c39dd29 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -165,5 +165,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P config.scale_pos_emb = shared.args.compress_pos_emb config.scale_alpha_value = shared.args.alpha_value config.no_flash_attn = shared.args.no_flash_attn + config.num_experts_per_token = int(shared.args.num_experts_per_token) return Exllamav2HF(config) diff --git a/modules/loaders.py b/modules/loaders.py index c7e7653e04..9f1c70d121 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -65,24 +65,25 @@ 'logits_all', 'llamacpp_HF_info', ], - 'ExLlama_HF': [ + 'ExLlamav2_HF': [ 'gpu_split', 'max_seq_len', + 'cfg_cache', + 'no_flash_attn', + 'num_experts_per_token', + 'cache_8bit', 'alpha_value', - 'rope_freq_base', 'compress_pos_emb', - 'cfg_cache', 'trust_remote_code', 'no_use_fast', ], - 'ExLlamav2_HF': [ + 'ExLlama_HF': [ 'gpu_split', 'max_seq_len', - 'cfg_cache', - 'no_flash_attn', - 'cache_8bit', 'alpha_value', + 'rope_freq_base', 'compress_pos_emb', + 'cfg_cache', 'trust_remote_code', 'no_use_fast', ], @@ -123,22 +124,23 @@ 'no_use_fast', 'gptq_for_llama_info', ], - 'ExLlama': [ + 'ExLlamav2': [ 'gpu_split', 'max_seq_len', + 'no_flash_attn', + 'num_experts_per_token', + 'cache_8bit', 'alpha_value', - 'rope_freq_base', 'compress_pos_emb', - 'exllama_info', + 'exllamav2_info', ], - 'ExLlamav2': [ + 'ExLlama': [ 'gpu_split', 'max_seq_len', - 'no_flash_attn', - 'cache_8bit', 'alpha_value', + 'rope_freq_base', 'compress_pos_emb', - 'exllamav2_info', + 'exllama_info', ], 'ctransformers': [ 'n_ctx', diff --git a/modules/shared.py b/modules/shared.py index adebe62d34..edd74af132 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -125,6 +125,7 @@ parser.add_argument('--cfg-cache', action='store_true', help='ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama.') parser.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.') parser.add_argument('--cache_8bit', action='store_true', help='Use 8-bit cache to save VRAM.') +parser.add_argument('--num_experts_per_token', type=int, default=2, help='Number of experts to use for generation. Applies to MoE models like Mixtral.') # AutoGPTQ parser.add_argument('--triton', action='store_true', help='Use triton.') diff --git a/modules/ui.py b/modules/ui.py index 8bfc949108..285e2fc3c6 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -73,6 +73,7 @@ def list_model_elements(): 'disable_exllamav2', 'cfg_cache', 'no_flash_attn', + 'num_experts_per_token', 'cache_8bit', 'threads', 'threads_batch', diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index 7242d117c3..7f81ca2d1b 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -129,6 +129,7 @@ def create_ui(): shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn, info='Force flash-attention to not be used.') shared.gradio['cache_8bit'] = gr.Checkbox(label="cache_8bit", value=shared.args.cache_8bit, info='Use 8-bit cache to save VRAM.') shared.gradio['no_use_fast'] = gr.Checkbox(label="no_use_fast", value=shared.args.no_use_fast, info='Set use_fast=False while loading the tokenizer.') + shared.gradio['num_experts_per_token'] = gr.Number(label="Number of experts per token", value=shared.args.num_experts_per_token, info='Only applies to MoE models like Mixtral.') shared.gradio['gptq_for_llama_info'] = gr.Markdown('Legacy loader for compatibility with older GPUs. ExLlama_HF or AutoGPTQ are preferred for GPTQ models when supported.') shared.gradio['exllama_info'] = gr.Markdown("ExLlama_HF is recommended over ExLlama for better integration with extensions and more consistent sampling behavior across loaders.") shared.gradio['exllamav2_info'] = gr.Markdown("ExLlamav2_HF is recommended over ExLlamav2 for better integration with extensions and more consistent sampling behavior across loaders.")