Skip to content

Commit

Permalink
Fix chat templates with slices, add tokenizer config for `mistralai/M…
Browse files Browse the repository at this point in the history
…istral-7B-Instruct-v0.1` (#648)
  • Loading branch information
yatarkan authored Jul 25, 2024
1 parent a769b33 commit e449ffe
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 48 deletions.
63 changes: 28 additions & 35 deletions src/cpp/src/tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "utils.hpp"
#include <jinja2cpp/template.h>
#include <jinja2cpp/template_env.h>
#include <jinja2cpp/user_callable.h>
#include "tokenizers_path.hpp"
#include "circular_buffer_queue.hpp"
#include <fstream>
Expand Down Expand Up @@ -368,40 +369,32 @@ class Tokenizer::TokenizerImpl {
bool add_generation_prompt,
const std::string& chat_template) const {
auto chat_tpl = chat_template.empty() ? m_chat_template : chat_template;
// Jinja2Cpp does not support slicing, e.g. [1:].
// In templates slicing is used typically in the header to find system prompt.
// If header containts that typical expression we update template and
// extract system message manually from ChatHistory.
std::string header_with_slice = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}";
std::string replacement_string = "{% if false %}{% set placeholder = false %}";

std::string system_message = "";
size_t pos = chat_tpl.find(header_with_slice);
if (pos != std::string::npos) {
chat_tpl.replace(pos, header_with_slice.length(), replacement_string);

if (!history.empty() && history[0].at("role") == "system") {
system_message = history[0].at("content");
history.erase(history.begin());
}

// Jinja2Cpp does not support Python-style slicing, e.g. [1:].
// If chat template contains such slicing, we replace it with custom function `slice()` (user-defined callable)
// that is defined below and does the same list slicing logic.
std::string slice_string = "messages[1:]";
std::string replacement_slice_string = "slice(messages, 1)";
size_t slice_pos = chat_tpl.find(slice_string);
if (slice_pos != std::string::npos) {
chat_tpl.replace(slice_pos, slice_string.length(), replacement_slice_string);
}

// Jinja2Cpp accepts system_message only as a string and incorrectly handles it as a bool.
// Both this patters are found frequently in chat templates, replace so that jinja2cpp
// will not stumble on them.
std::pair<std::string, std::string> replace_str_map[] = {
{"{% set system_message = false %}", ""},
{"system_message != false", "true"},
};
if (!system_message.empty()) {
for (const auto& [from, to] : replace_str_map) {
size_t pos = 0;
while ((pos = chat_tpl.find(from, pos)) != std::string::npos) {
chat_tpl.replace(pos, from.size(), to);
pos += to.size();
jinja2::UserCallable slice_callable = jinja2::MakeCallable(
[](const jinja2::ValuesList& list, const int64_t start) {
if (list.empty())
return jinja2::Value();
jinja2::ValuesList result;
int64_t stop = list.size();
int64_t step = 1;
for (int64_t i = start; i < stop && i < list.size(); i += step)
{
result.push_back(list.at(i));
}
}
}

return jinja2::Value(result);
},
jinja2::ArgInfo{"list"}, jinja2::ArgInfo{"start"}
);

jinja2::TemplateEnv env;
env.GetSettings().lstripBlocks = true;
Expand All @@ -421,13 +414,13 @@ class Tokenizer::TokenizerImpl {
{"bos_token", m_bos_token},
{"eos_token", m_eos_token},
{"pad_token", m_pad_token},
{"system_message", system_message.empty() ? jinja2::EmptyValue() : jinja2::Value{system_message}},
{"add_generation_prompt", add_generation_prompt},
{"slice", slice_callable},
};

try {
return tpl.RenderAsString(params).value();
} catch (const std::bad_alloc& error) {
} catch (const std::exception& error) {
OPENVINO_THROW("Chat template for the current model is not supported by Jinja2Cpp. "
"Please apply template manually to your prompt before calling generate. "
"For exmaple: <start_of_turn>user{user_prompt}<end_of_turn><start_of_turn>model");
Expand Down
23 changes: 10 additions & 13 deletions tests/python_tests/ov_genai_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,38 +99,35 @@ def get_chat_templates():
# TODO: Need to support chat templates in more models: CVS-145963
# Either ov_genai is unable to parse chat_template or results do not match with HF.
"meta-llama/Meta-Llama-3-8B-Instruct",
"databricks/dbrx-instruct",
"databricks/dbrx-instruct", # Chat template is not supported by Jinja2Cpp
"mosaicml/mpt-30b-chat",
"deepseek-ai/deepseek-coder-6.7b-instruct",
"maldv/winter-garden-7b-alpha",
"ishorn5/RTLCoder-Deepseek-v1.1",
"deepseek-ai/deepseek-coder-6.7b-instruct", # Chat template is not supported by Jinja2Cpp
"maldv/winter-garden-7b-alpha", # Chat template is not supported by Jinja2Cpp
"ishorn5/RTLCoder-Deepseek-v1.1", # Chat template is not supported by Jinja2Cpp
"openchat/openchat-3.5-0106",
"casperhansen/llama-3-70b-instruct-awq",
"TheBloke/deepseek-coder-33B-instruct-GPTQ",
"AI-Sweden-Models/gpt-sw3-356m-instruct",
"google/gemma-7b-it",
"THUDM/cogvlm2-llama3-chat-19B",
"KnutJaegersberg/internlm-20b-llama",
"alpindale/WizardLM-2-8x22B",
"maywell/Synatra-Mixtral-8x7B",
"MediaTek-Research/Breeze-7B-Instruct-v1_0",
"bofenghuang/vigostral-7b-chat",
"meetkai/functionary-small-v2.5",
"nvidia/Llama3-ChatQA-1.5-8B",
"meetkai/functionary-small-v2.5", # Chat template is not supported by Jinja2Cpp
"openchat/openchat-3.6-8b-20240522",
"tenyx/TenyxChat-7B-v1",
"LoneStriker/TinyLlama-1.1B-32k-Instruct-3.0bpw-h6-exl2",
"yam-peleg/Hebrew-Gemma-11B-V2",
"shenzhi-wang/Llama3-8B-Chinese-Chat",
"shenzhi-wang/Llama3-8B-Chinese-Chat", # AssertionError
"nlpai-lab/KULLM3",
"HuggingFaceH4/zephyr-7b-gemma-sft-v0.1",
"MediaTek-Research/Breeze-7B-Instruct-v0_1",
"shanchen/llama3-8B-slerp-biomed-chat-chinese",
"MediaTek-Research/Breeze-7B-Instruct-v0_1",
"shanchen/llama3-8B-slerp-biomed-chat-chinese", # AssertionError
"MLP-KTLim/llama-3-Korean-Bllossom-8B",
"lucyknada/microsoft_WizardLM-2-7B",
"aloobun/CosmicBun-8B",
"aloobun/CosmicBun-8B", # Chat template is not supported by Jinja2Cpp
"codellama/CodeLlama-70b-Instruct-hf",
"gorilla-llm/gorilla-openfunctions-v2",
"gorilla-llm/gorilla-openfunctions-v2", # Chat template is not supported by Jinja2Cpp
"BramVanroy/Llama-2-13b-chat-dutch"
}
from tokenizer_configs import get_tokenizer_configs
Expand Down
7 changes: 7 additions & 0 deletions tests/python_tests/tokenizer_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,5 +980,12 @@ def get_tokenizer_configs():
"pad_token": None,
"unk_token": "<unk>",
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{%set system_message = 'Je bent een behulpzame, respectvolle en eerlijke assistent. Antwoord altijd zo behulpzaam mogelijk. Je antwoorden mogen geen schadelijke, onethische, racistische, seksistische, gevaarlijke of illegale inhoud bevatten. Zorg ervoor dat je antwoorden sociaal onbevooroordeeld en positief van aard zijn.\n\nAls een vraag nergens op slaat of feitelijk niet coherent is, leg dan uit waarom in plaats van iets niet correct te antwoorden. Als je het antwoord op een vraag niet weet, deel dan geen onjuiste informatie.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\n' + system_message + '\n<</SYS>>\n\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\n' + content.strip() + '\n<</SYS>>\n\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"
},
"mistralai/Mistral-7B-Instruct-v0.1": {
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": None,
"unk_token": "<unk>",
"chat_template": "{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + eos_token}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n"
}
}

0 comments on commit e449ffe

Please sign in to comment.