From e449ffed5e8d23f2bb442da2a9a6faf71caf55f7 Mon Sep 17 00:00:00 2001 From: Yaroslav Tarkan Date: Thu, 25 Jul 2024 14:25:32 +0300 Subject: [PATCH] Fix chat templates with slices, add tokenizer config for `mistralai/Mistral-7B-Instruct-v0.1` (#648) --- src/cpp/src/tokenizer.cpp | 63 ++++++++++------------- tests/python_tests/ov_genai_test_utils.py | 23 ++++----- tests/python_tests/tokenizer_configs.py | 7 +++ 3 files changed, 45 insertions(+), 48 deletions(-) diff --git a/src/cpp/src/tokenizer.cpp b/src/cpp/src/tokenizer.cpp index b1e36033ee..c6039d87bd 100644 --- a/src/cpp/src/tokenizer.cpp +++ b/src/cpp/src/tokenizer.cpp @@ -6,6 +6,7 @@ #include "utils.hpp" #include #include +#include #include "tokenizers_path.hpp" #include "circular_buffer_queue.hpp" #include @@ -368,40 +369,32 @@ class Tokenizer::TokenizerImpl { bool add_generation_prompt, const std::string& chat_template) const { auto chat_tpl = chat_template.empty() ? m_chat_template : chat_template; - // Jinja2Cpp does not support slicing, e.g. [1:]. - // In templates slicing is used typically in the header to find system prompt. - // If header containts that typical expression we update template and - // extract system message manually from ChatHistory. - std::string header_with_slice = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}"; - std::string replacement_string = "{% if false %}{% set placeholder = false %}"; - - std::string system_message = ""; - size_t pos = chat_tpl.find(header_with_slice); - if (pos != std::string::npos) { - chat_tpl.replace(pos, header_with_slice.length(), replacement_string); - - if (!history.empty() && history[0].at("role") == "system") { - system_message = history[0].at("content"); - history.erase(history.begin()); - } + + // Jinja2Cpp does not support Python-style slicing, e.g. [1:]. + // If chat template contains such slicing, we replace it with custom function `slice()` (user-defined callable) + // that is defined below and does the same list slicing logic. + std::string slice_string = "messages[1:]"; + std::string replacement_slice_string = "slice(messages, 1)"; + size_t slice_pos = chat_tpl.find(slice_string); + if (slice_pos != std::string::npos) { + chat_tpl.replace(slice_pos, slice_string.length(), replacement_slice_string); } - - // Jinja2Cpp accepts system_message only as a string and incorrectly handles it as a bool. - // Both this patters are found frequently in chat templates, replace so that jinja2cpp - // will not stumble on them. - std::pair replace_str_map[] = { - {"{% set system_message = false %}", ""}, - {"system_message != false", "true"}, - }; - if (!system_message.empty()) { - for (const auto& [from, to] : replace_str_map) { - size_t pos = 0; - while ((pos = chat_tpl.find(from, pos)) != std::string::npos) { - chat_tpl.replace(pos, from.size(), to); - pos += to.size(); + jinja2::UserCallable slice_callable = jinja2::MakeCallable( + [](const jinja2::ValuesList& list, const int64_t start) { + if (list.empty()) + return jinja2::Value(); + jinja2::ValuesList result; + int64_t stop = list.size(); + int64_t step = 1; + for (int64_t i = start; i < stop && i < list.size(); i += step) + { + result.push_back(list.at(i)); } - } - } + + return jinja2::Value(result); + }, + jinja2::ArgInfo{"list"}, jinja2::ArgInfo{"start"} + ); jinja2::TemplateEnv env; env.GetSettings().lstripBlocks = true; @@ -421,13 +414,13 @@ class Tokenizer::TokenizerImpl { {"bos_token", m_bos_token}, {"eos_token", m_eos_token}, {"pad_token", m_pad_token}, - {"system_message", system_message.empty() ? jinja2::EmptyValue() : jinja2::Value{system_message}}, {"add_generation_prompt", add_generation_prompt}, + {"slice", slice_callable}, }; - + try { return tpl.RenderAsString(params).value(); - } catch (const std::bad_alloc& error) { + } catch (const std::exception& error) { OPENVINO_THROW("Chat template for the current model is not supported by Jinja2Cpp. " "Please apply template manually to your prompt before calling generate. " "For exmaple: user{user_prompt}model"); diff --git a/tests/python_tests/ov_genai_test_utils.py b/tests/python_tests/ov_genai_test_utils.py index edfadb0988..ad5b7254cd 100644 --- a/tests/python_tests/ov_genai_test_utils.py +++ b/tests/python_tests/ov_genai_test_utils.py @@ -99,11 +99,11 @@ def get_chat_templates(): # TODO: Need to support chat templates in more models: CVS-145963 # Either ov_genai is unable to parse chat_template or results do not match with HF. "meta-llama/Meta-Llama-3-8B-Instruct", - "databricks/dbrx-instruct", + "databricks/dbrx-instruct", # Chat template is not supported by Jinja2Cpp "mosaicml/mpt-30b-chat", - "deepseek-ai/deepseek-coder-6.7b-instruct", - "maldv/winter-garden-7b-alpha", - "ishorn5/RTLCoder-Deepseek-v1.1", + "deepseek-ai/deepseek-coder-6.7b-instruct", # Chat template is not supported by Jinja2Cpp + "maldv/winter-garden-7b-alpha", # Chat template is not supported by Jinja2Cpp + "ishorn5/RTLCoder-Deepseek-v1.1", # Chat template is not supported by Jinja2Cpp "openchat/openchat-3.5-0106", "casperhansen/llama-3-70b-instruct-awq", "TheBloke/deepseek-coder-33B-instruct-GPTQ", @@ -111,26 +111,23 @@ def get_chat_templates(): "google/gemma-7b-it", "THUDM/cogvlm2-llama3-chat-19B", "KnutJaegersberg/internlm-20b-llama", - "alpindale/WizardLM-2-8x22B", "maywell/Synatra-Mixtral-8x7B", "MediaTek-Research/Breeze-7B-Instruct-v1_0", "bofenghuang/vigostral-7b-chat", - "meetkai/functionary-small-v2.5", - "nvidia/Llama3-ChatQA-1.5-8B", + "meetkai/functionary-small-v2.5", # Chat template is not supported by Jinja2Cpp "openchat/openchat-3.6-8b-20240522", "tenyx/TenyxChat-7B-v1", "LoneStriker/TinyLlama-1.1B-32k-Instruct-3.0bpw-h6-exl2", "yam-peleg/Hebrew-Gemma-11B-V2", - "shenzhi-wang/Llama3-8B-Chinese-Chat", + "shenzhi-wang/Llama3-8B-Chinese-Chat", # AssertionError "nlpai-lab/KULLM3", "HuggingFaceH4/zephyr-7b-gemma-sft-v0.1", - "MediaTek-Research/Breeze-7B-Instruct-v0_1", - "shanchen/llama3-8B-slerp-biomed-chat-chinese", + "MediaTek-Research/Breeze-7B-Instruct-v0_1", + "shanchen/llama3-8B-slerp-biomed-chat-chinese", # AssertionError "MLP-KTLim/llama-3-Korean-Bllossom-8B", - "lucyknada/microsoft_WizardLM-2-7B", - "aloobun/CosmicBun-8B", + "aloobun/CosmicBun-8B", # Chat template is not supported by Jinja2Cpp "codellama/CodeLlama-70b-Instruct-hf", - "gorilla-llm/gorilla-openfunctions-v2", + "gorilla-llm/gorilla-openfunctions-v2", # Chat template is not supported by Jinja2Cpp "BramVanroy/Llama-2-13b-chat-dutch" } from tokenizer_configs import get_tokenizer_configs diff --git a/tests/python_tests/tokenizer_configs.py b/tests/python_tests/tokenizer_configs.py index eb83f50836..4e8197ff5f 100644 --- a/tests/python_tests/tokenizer_configs.py +++ b/tests/python_tests/tokenizer_configs.py @@ -980,5 +980,12 @@ def get_tokenizer_configs(): "pad_token": None, "unk_token": "", "chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{%set system_message = 'Je bent een behulpzame, respectvolle en eerlijke assistent. Antwoord altijd zo behulpzaam mogelijk. Je antwoorden mogen geen schadelijke, onethische, racistische, seksistische, gevaarlijke of illegale inhoud bevatten. Zorg ervoor dat je antwoorden sociaal onbevooroordeeld en positief van aard zijn.\n\nAls een vraag nergens op slaat of feitelijk niet coherent is, leg dan uit waarom in plaats van iets niet correct te antwoorden. Als je het antwoord op een vraag niet weet, deel dan geen onjuiste informatie.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\n' + system_message + '\n<>\n\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\n' + content.strip() + '\n<>\n\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}" + }, + "mistralai/Mistral-7B-Instruct-v0.1": { + "bos_token": "", + "eos_token": "", + "pad_token": None, + "unk_token": "", + "chat_template": "{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + eos_token}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n" } } \ No newline at end of file