-
Notifications
You must be signed in to change notification settings - Fork 10.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Copy minja from google/minja@58f0ca6 * Add --jinja and --chat-template-file flags * Add missing <optional> include * Avoid print in get_hf_chat_template.py * No designated initializers yet * Try and work around msvc++ non-macro max resolution quirk * Update test_chat_completion.py * Wire LLM_KV_TOKENIZER_CHAT_TEMPLATE_N in llama_model_chat_template * Refactor test-chat-template * Test templates w/ minja * Fix deprecation * Add --jinja to llama-run * Update common_chat_format_example to use minja template wrapper * Test chat_template in e2e test * Update utils.py * Update test_chat_completion.py * Update run.cpp * Update arg.cpp * Refactor common_chat_* functions to accept minja template + use_jinja option * Attempt to fix linkage of LLAMA_CHATML_TEMPLATE * Revert LLAMA_CHATML_TEMPLATE refactor * Normalize newlines in test-chat-templates for windows tests * Forward decl minja::chat_template to avoid eager json dep * Flush stdout in chat template before potential crash * Fix copy elision warning * Rm unused optional include * Add missing optional include to server.cpp * Disable jinja test that has a cryptic windows failure * minja: fix vigogne (google/minja#22) * Apply suggestions from code review Co-authored-by: Xuan Son Nguyen <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]> * Finish suggested renamings * Move chat_templates inside server_context + remove mutex * Update --chat-template-file w/ recent change to --chat-template * Refactor chat template validation * Guard against missing eos/bos tokens (null token otherwise throws in llama_vocab::impl::token_get_attr) * Warn against missing eos / bos tokens when jinja template references them * rename: common_chat_template[s] * reinstate assert on chat_templates.template_default * Update minja to google/minja@b8437df * Update minja to google/minja#25 * Update minja from google/minja#27 * rm unused optional header --------- Co-authored-by: Xuan Son Nguyen <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
- Loading branch information
1 parent
e28245f
commit 6171c9d
Showing
22 changed files
with
3,563 additions
and
133 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,249 @@ | ||
/* | ||
Copyright 2024 Google LLC | ||
Use of this source code is governed by an MIT-style | ||
license that can be found in the LICENSE file or at | ||
https://opensource.org/licenses/MIT. | ||
*/ | ||
// SPDX-License-Identifier: MIT | ||
#pragma once | ||
|
||
#include "minja.hpp" | ||
#include <json.hpp> | ||
#include <string> | ||
#include <vector> | ||
|
||
using json = nlohmann::ordered_json; | ||
|
||
namespace minja { | ||
|
||
class chat_template { | ||
public: | ||
|
||
private: | ||
bool supports_tools_ = true; | ||
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object. | ||
// Most other templates (and OpenAI's API) expect the arguments object to be stringified. | ||
bool requires_object_arguments_ = false; | ||
bool supports_system_role_ = true; | ||
bool supports_parallel_tool_calls_ = false; | ||
std::string source_; | ||
std::string bos_token_; | ||
std::string eos_token_; | ||
std::shared_ptr<minja::TemplateNode> template_root_; | ||
|
||
std::string try_render( | ||
const nlohmann::ordered_json & messages, | ||
const nlohmann::ordered_json & tools, | ||
bool add_generation_prompt, | ||
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const | ||
{ | ||
try { | ||
auto prompt = apply(messages, tools, add_generation_prompt, extra_context); | ||
// fprintf(stderr, "Prompt: %s\n", prompt.c_str()); | ||
return prompt; | ||
} catch (const std::exception & e) { | ||
// fprintf(stderr, "Error: %s\n", e.what()); | ||
return ""; | ||
} | ||
} | ||
|
||
public: | ||
chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token) | ||
: source_(source), bos_token_(bos_token), eos_token_(eos_token) | ||
{ | ||
template_root_ = minja::Parser::parse(source_, { | ||
/* .trim_blocks = */ true, | ||
/* .lstrip_blocks = */ true, | ||
/* .keep_trailing_newline = */ false, | ||
}); | ||
supports_tools_ = source.find("tools") != std::string::npos; | ||
|
||
auto renders_string_arguments = | ||
try_render({ | ||
{ | ||
{"role", "user"}, | ||
{"content", "Hey"} | ||
}, | ||
{ | ||
{"role", "assistant"}, | ||
{"tool_calls", json::array({ | ||
{ | ||
{"id", "call_1___"}, | ||
{"type", "function"}, | ||
{"function", { | ||
{"arguments", "{\"code\": \"print('Hello, World!')\"}"}, | ||
{"name", "ipython"}, | ||
}}, | ||
}, | ||
})}, | ||
} | ||
}, {}, false).find("{\"code\": \"print") != std::string::npos; | ||
if (!renders_string_arguments) { | ||
auto renders_object_arguments = | ||
try_render({ | ||
{ | ||
{"role", "user"}, | ||
{"content", "Hey"} | ||
}, | ||
{ | ||
{"role", "assistant"}, | ||
{"tool_calls", json::array({ | ||
{ | ||
{"id", "call_1___"}, | ||
{"type", "function"}, | ||
{"function", { | ||
{"arguments", { | ||
{"code", "print('Hello, World!')"}, | ||
}}, | ||
{"name", "ipython"}, | ||
}}, | ||
}, | ||
})}, | ||
} | ||
}, {}, false).find("{\"code\": \"print") != std::string::npos; | ||
requires_object_arguments_ = renders_object_arguments; | ||
} | ||
supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos; | ||
|
||
supports_system_role_ = try_render({ | ||
{{"role", "system"}, {"content", "<System Needle>"}}, | ||
{{"role", "user"}, {"content", "Hey"}} | ||
}, {}, false).find("<System Needle>") != std::string::npos; | ||
} | ||
|
||
const std::string & source() const { return source_; } | ||
const std::string & bos_token() const { return bos_token_; } | ||
const std::string & eos_token() const { return eos_token_; } | ||
bool supports_tools() const { return supports_tools_; } | ||
bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; } | ||
|
||
std::string apply( | ||
const nlohmann::ordered_json & messages, | ||
const nlohmann::ordered_json & tools, | ||
bool add_generation_prompt, | ||
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const | ||
{ | ||
json actual_messages; | ||
|
||
// First, "fix" messages so they have a chance to be rendered correctly by the template | ||
|
||
if (requires_object_arguments_ || !supports_system_role_ || !supports_tools_) { | ||
actual_messages = json::array(); | ||
|
||
std::string pending_system; | ||
auto flush_sys = [&]() { | ||
if (!pending_system.empty()) { | ||
actual_messages.push_back({ | ||
{"role", "user"}, | ||
{"content", pending_system}, | ||
}); | ||
pending_system.clear(); | ||
} | ||
}; | ||
for (const auto & message_ : messages) { | ||
auto message = message_; | ||
if (!message.contains("role") || !message.contains("content")) { | ||
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump()); | ||
} | ||
std::string role = message.at("role"); | ||
|
||
if (message.contains("tool_calls")) { | ||
if (requires_object_arguments_ || !supports_tools_) { | ||
for (auto & tool_call : message.at("tool_calls")) { | ||
if (tool_call["type"] == "function") { | ||
auto & function = tool_call.at("function"); | ||
std::string arguments = function.at("arguments"); | ||
function["arguments"] = json::parse(arguments); | ||
} | ||
} | ||
} | ||
if (!supports_tools_) { | ||
auto content = message.at("content"); | ||
auto tool_calls = json::array(); | ||
for (const auto & tool_call : message.at("tool_calls")) { | ||
if (tool_call.at("type") != "function") { | ||
continue; | ||
} | ||
const auto & function = tool_call.at("function"); | ||
auto tc = json { | ||
{"name", function.at("name")}, | ||
{"arguments", function.at("arguments")}, | ||
}; | ||
if (tool_call.contains("id")) { | ||
tc["id"] = tool_call["id"]; | ||
} | ||
tool_calls.push_back(tc); | ||
} | ||
auto obj = json { | ||
{"tool_calls", tool_calls}, | ||
}; | ||
if (!content.is_null() && content != "") { | ||
obj["content"] = content; | ||
} | ||
message["content"] = obj.dump(2); | ||
message.erase("tool_calls"); | ||
} | ||
} | ||
if (!supports_tools_ && role == "tool") { | ||
message["role"] = "user"; | ||
auto obj = json { | ||
{"tool_response", { | ||
{"tool", message.at("name")}, | ||
{"content", message.at("content")}, | ||
}}, | ||
}; | ||
if (message.contains("tool_call_id")) { | ||
obj["tool_response"]["tool_call_id"] = message.at("tool_call_id"); | ||
} | ||
message["content"] = obj.dump(2); | ||
message.erase("name"); | ||
} | ||
|
||
if (!message["content"].is_null() && !supports_system_role_) { | ||
std::string content = message.at("content"); | ||
if (role == "system") { | ||
if (!pending_system.empty()) pending_system += "\n"; | ||
pending_system += content; | ||
continue; | ||
} else { | ||
if (role == "user") { | ||
if (!pending_system.empty()) { | ||
message["content"] = pending_system + (content.empty() ? "" : "\n" + content); | ||
pending_system.clear(); | ||
} | ||
} else { | ||
flush_sys(); | ||
} | ||
} | ||
} | ||
actual_messages.push_back(message); | ||
} | ||
flush_sys(); | ||
} else { | ||
actual_messages = messages; | ||
} | ||
|
||
auto context = minja::Context::make(json({ | ||
{"messages", actual_messages}, | ||
{"add_generation_prompt", add_generation_prompt}, | ||
{"bos_token", bos_token_}, | ||
{"eos_token", eos_token_}, | ||
})); | ||
|
||
if (!tools.is_null()) { | ||
auto tools_val = minja::Value(tools); | ||
context->set("tools", tools_val); | ||
} | ||
if (!extra_context.is_null()) { | ||
for (auto & kv : extra_context.items()) { | ||
minja::Value val(kv.value()); | ||
context->set(kv.key(), val); | ||
} | ||
} | ||
|
||
return template_root_->render(context); | ||
} | ||
}; | ||
|
||
} // namespace minja |
Oops, something went wrong.