From 611b631b529ed8399be37f3fd4cb5e31bf08104e Mon Sep 17 00:00:00 2001 From: ochafik Date: Tue, 4 Feb 2025 09:51:43 +0000 Subject: [PATCH 1/3] Add test-polyfills --- tests/CMakeLists.txt | 19 ++- tests/test-polyfills.cpp | 354 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 372 insertions(+), 1 deletion(-) create mode 100644 tests/test-polyfills.cpp diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 6a34dd4..89e3dc8 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -17,8 +17,25 @@ target_link_libraries(test-syntax PRIVATE gtest_main gmock ) + +add_executable(test-polyfills test-polyfills.cpp) +target_compile_features(test-polyfills PUBLIC cxx_std_17) +if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + target_compile_definitions(test-polyfills PUBLIC _CRT_SECURE_NO_WARNINGS) + target_compile_options(gtest PRIVATE -Wno-language-extension-token) +endif() +target_link_libraries(test-polyfills PRIVATE + nlohmann_json::nlohmann_json + gtest_main + gmock +) +if (NOT CMAKE_CROSSCOMPILING) + gtest_discover_tests(test-syntax) +endif() + if (NOT CMAKE_CROSSCOMPILING) gtest_discover_tests(test-syntax) + gtest_discover_tests(test-polyfills) endif() add_executable(test-capabilities test-capabilities.cpp) @@ -54,7 +71,7 @@ set(MODEL_IDS # minja implementation on the same template and context, and compare the output with the golden. # # For Gated models, you'll need to run `huggingface-cli login` (and be granted access) to download their template. - + abacusai/Fewshot-Metamath-OrcaVicuna-Mistral bofenghuang/vigogne-2-70b-chat CohereForAI/c4ai-command-r-plus # Gated diff --git a/tests/test-polyfills.cpp b/tests/test-polyfills.cpp new file mode 100644 index 0000000..d0fcbfc --- /dev/null +++ b/tests/test-polyfills.cpp @@ -0,0 +1,354 @@ +/* + Copyright 2024 Google LLC + + Use of this source code is governed by an MIT-style + license that can be found in the LICENSE file or at + https://opensource.org/licenses/MIT. +*/ +// SPDX-License-Identifier: MIT +#include "minja.hpp" +#include +#include + +#include +#include +#include +#include "chat-template.hpp" + +using namespace minja; + +#define TEMPLATE_CHATML \ + "{%- for message in messages -%}\n" \ + " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \ + "{%- endfor -%}\n" \ + "{%- if add_generation_prompt -%}\n" \ + " {{- '<|im_start|>assistant\n' -}}\n" \ + "{%- endif -%}" + + +#define TEMPLATE_CHATML_NO_SYSTEM \ + "{%- for message in messages -%}\n" \ + " {%- if message.role == 'system' -%}\n" \ + " {{- raise_exception('System role not supported') -}}\n" \ + " {%- endif -%}\n" \ + " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \ + "{%- endfor -%}\n" \ + "{%- if add_generation_prompt -%}\n" \ + " {{- '<|im_start|>assistant\n' -}}\n" \ + "{%- endif -%}" + + +#define TEMPLATE_DUMMY \ + "{%- for tool in tools -%}\n" \ + " {{- 'tool: ' + (tool | tojson(indent=2)) + '\n' -}}\n" \ + "{%- endfor -%}\n" \ + "{%- for message in messages -%}\n" \ + " {{- 'message: ' + (message | tojson(indent=2)) + '\n' -}}\n" \ + "{%- endfor -%}\n" \ + "{%- if add_generation_prompt -%}\n" \ + " {{- 'message: ' -}}\n" \ + "{%- endif -%}" + + +const json message_user_text { + { "role", "user" }, + { "content", "I need help" }, +}; +const json message_assistant_text { + { "role", "assistant" }, + { "content", "Hello, world!" }, +}; +const json message_system { + { "role", "system" }, + { "content", "I am The System!" }, +}; +const json tool_calls = json::array({{ + { "type", "function" }, + { "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } }, +}}); + +const json message_assistant_call { + { "role", "assistant"}, + { "content", {}}, + { "tool_calls", { + { + { "type", "function" }, + { "function", { + { "name", "special_function" }, + { "arguments", "{\"arg1\": 1}" }, + }}, + }, + }}, +}; +const json message_assistant_call_id { + { "role", "assistant"}, + { "content", {}}, + { "tool_calls", { + { + { "type", "function" }, + { "function", { + { "name", "special_function" }, + { "arguments", "{\"arg1\": 1}" }, + }}, + {"id", "123456789"}, + }, + }}, + { "role", "assistant" }, + { "content", {} }, + { "tool_calls", tool_calls } +}; +const json message_assistant_call_idx { + { "role", "assistant"}, + { "content", {}}, + { "tool_plan", "I'm not so sure"}, + { "tool_calls", { + { + { "type", "function" }, + { "function", { + { "name", "special_function" }, + { "arguments", "{\"arg1\": 1}" }, + }}, + {"id", "0"}, + }, + }}, + { "role", "assistant" }, + { "content", {} }, + { "tool_calls", tool_calls } +}; +const json message_tool { + { "role", "tool"}, + { "content", { + {"result", 123}, + }}, +}; + +const auto special_function_tool = json::parse(R"({ + "type": "function", + "function": { + "name": "special_function", + "description": "I'm special", + "parameters": { + "type": "object", + "properties": { + "arg1": { + "type": "integer", + "description": "The arg." + } + }, + "required": ["arg1"] + } + } +})"); + +auto ThrowsWithSubstr = [](const std::string & expected_substr) { + return testing::Throws(Property(&std::runtime_error::what, testing::HasSubstr(expected_substr))); +}; + +static chat_template_options options_no_polyfills() { + chat_template_options opts; + opts.apply_polyfills = false; + opts.polyfill_system_role = false; + opts.polyfill_tools = false; + opts.polyfill_tool_call_examples = false; + opts.polyfill_tool_calls = false; + opts.polyfill_tool_responses = false; + opts.polyfill_object_arguments = false; + opts.polyfill_typed_content = false; + return opts; +}; + +TEST(PolyfillTest, NoPolyFill) { + chat_template tmpl(TEMPLATE_CHATML, "<|im_end|>", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_user_text}); + + EXPECT_EQ( + "<|im_start|>user\n" + "I need help<|im_end|>\n" + "<|im_start|>assistant\n", + tmpl.apply(inputs, options_no_polyfills())); + + inputs.add_generation_prompt = false; + EXPECT_EQ( + "<|im_start|>user\n" + "I need help<|im_end|>\n", + tmpl.apply(inputs, options_no_polyfills())); + + inputs.messages = json::array({message_user_text, message_assistant_text}); + EXPECT_EQ( + "<|im_start|>user\n" + "I need help<|im_end|>\n" + "<|im_start|>assistant\n" + "Hello, world!<|im_end|>\n", + tmpl.apply(inputs, options_no_polyfills())); +} + +TEST(PolyfillTest, SystemRoleSupported) { + chat_template chatml(TEMPLATE_CHATML, "<|im_end|>", ""); + chat_template dummy(TEMPLATE_DUMMY, "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_system, message_user_text}); + + EXPECT_EQ( + "<|im_start|>system\n" + "I am The System!<|im_end|>\n" + "<|im_start|>user\n" + "I need help<|im_end|>\n" + "<|im_start|>assistant\n", + chatml.apply(inputs)); + EXPECT_EQ( + "message: {\n" + " \"role\": \"system\",\n" + " \"content\": \"I am The System!\"\n" + "}\n" + "message: {\n" + " \"role\": \"user\",\n" + " \"content\": \"I need help\"\n" + "}\n" + "message: ", + dummy.apply(inputs)); +} + +TEST(PolyfillTest, SystemRolePolyfill) { + chat_template tmpl(TEMPLATE_CHATML_NO_SYSTEM, "<|im_end|>", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_system, message_user_text}); + + EXPECT_THAT( + [&]() { tmpl.apply(inputs, options_no_polyfills()); }, + ThrowsWithSubstr("System role not supported")); + + EXPECT_EQ( + "<|im_start|>user\n" + "I am The System!\n" + "I need help<|im_end|>\n" + "<|im_start|>assistant\n", + tmpl.apply(inputs)); +} + +TEST(PolyfillTest, ToolCallSupported) { + chat_template tmpl(TEMPLATE_DUMMY, "<|im_end|>", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_user_text, message_assistant_call_id}); + + EXPECT_EQ( + "message: {\n" + " \"role\": \"user\",\n" + " \"content\": \"I need help\"\n" + "}\n" + "message: {\n" + " \"role\": \"assistant\",\n" + " \"content\": null,\n" + " \"tool_calls\": [\n" + " {\n" + " \"type\": \"function\",\n" + " \"function\": {\n" + " \"name\": \"special_function\",\n" + " \"arguments\": {\n" + " \"arg1\": 1\n" + " }\n" + " },\n" + " \"id\": \"123456789\"\n" + " }\n" + " ]\n" + "}\n" + "message: ", + tmpl.apply(inputs)); +} + +TEST(PolyfillTest, ToolCallPolyfill) { + chat_template tmpl(TEMPLATE_CHATML_NO_SYSTEM, "<|im_end|>", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_user_text, message_assistant_call_id}); + + EXPECT_EQ( + "<|im_start|>user\n" + "I need help<|im_end|>\n" + "<|im_start|>assistant\n" + "{\n" + " \"tool_calls\": [\n" + " {\n" + " \"name\": \"special_function\",\n" + " \"arguments\": {\n" + " \"arg1\": 1\n" + " },\n" + " \"id\": \"123456789\"\n" + " }\n" + " ]\n" + "}<|im_end|>\n" + "<|im_start|>assistant\n", + tmpl.apply(inputs)); +} + +TEST(PolyfillTest, ToolsPolyfill) { + chat_template tmpl(TEMPLATE_CHATML_NO_SYSTEM, "<|im_end|>", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_user_text}); + inputs.tools = json::array({special_function_tool}); + + EXPECT_EQ( + "<|im_start|>user\n" + "You can call any of the following tools to satisfy the user's requests: [\n" + " {\n" + " \"type\": \"function\",\n" + " \"function\": {\n" + " \"name\": \"special_function\",\n" + " \"description\": \"I'm special\",\n" + " \"parameters\": {\n" + " \"type\": \"object\",\n" + " \"properties\": {\n" + " \"arg1\": {\n" + " \"type\": \"integer\",\n" + " \"description\": \"The arg.\"\n" + " }\n" + " },\n" + " \"required\": [\n" + " \"arg1\"\n" + " ]\n" + " }\n" + " }\n" + " }\n" + "]\n" + "\n" + "Example tool call syntax:\n" + "\n" + "{\n" + " \"tool_calls\": [\n" + " {\n" + " \"name\": \"tool_name\",\n" + " \"arguments\": {\n" + " \"arg1\": \"some_value\"\n" + " },\n" + " \"id\": \"call_1___\"\n" + " }\n" + " ]\n" + "}<|im_end|>\n" + "\n" + "I need help<|im_end|>\n" + "<|im_start|>assistant\n", + tmpl.apply(inputs)); +} + +TEST(PolyfillTest, ToolPolyfill) { + chat_template tmpl(TEMPLATE_CHATML_NO_SYSTEM, "<|im_end|>", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|im_start|>user\n{\n" + " \"tool_response\": {\n" + " \"content\": {\n" + " \"result\": 123\n" + " }\n" + " }\n" + "}<|im_end|>\n" + "<|im_start|>assistant\n", + tmpl.apply(inputs)); +} \ No newline at end of file From a12dc8a9d09fbd1e14c897d5812f8e96f5b93a1d Mon Sep 17 00:00:00 2001 From: ochafik Date: Tue, 4 Feb 2025 09:56:27 +0000 Subject: [PATCH 2/3] Update test-polyfills.cpp --- tests/test-polyfills.cpp | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/test-polyfills.cpp b/tests/test-polyfills.cpp index d0fcbfc..0b52836 100644 --- a/tests/test-polyfills.cpp +++ b/tests/test-polyfills.cpp @@ -328,13 +328,30 @@ TEST(PolyfillTest, ToolsPolyfill) { " \"id\": \"call_1___\"\n" " }\n" " ]\n" - "}<|im_end|>\n" + "}<|im_end|>\n" // TODO: fix this "\n" "I need help<|im_end|>\n" "<|im_start|>assistant\n", tmpl.apply(inputs)); } +TEST(PolyfillTest, ToolSupported) { + chat_template tmpl(TEMPLATE_DUMMY, "<|im_end|>", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "message: {\n" + " \"role\": \"tool\",\n" + " \"content\": {\n" + " \"result\": 123\n" + " }\n" + "}\n" + "message: ", + tmpl.apply(inputs)); +} + TEST(PolyfillTest, ToolPolyfill) { chat_template tmpl(TEMPLATE_CHATML_NO_SYSTEM, "<|im_end|>", ""); From b9c335e493a5ea082770f2aaaa350f9205e802d3 Mon Sep 17 00:00:00 2001 From: ochafik Date: Tue, 4 Feb 2025 10:20:18 +0000 Subject: [PATCH 3/3] Fix tool call example optional final eos elision --- include/minja/chat-template.hpp | 11 +++++------ scripts/fetch_templates_and_goldens.py | 2 +- tests/test-polyfills.cpp | 22 +++++++++++----------- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 0e88fb3..2efb69a 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -249,11 +249,10 @@ class chat_template { inputs.add_generation_prompt = false; full = apply(inputs); } - - if (full.find(prefix) != 0) { - if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) { - prefix = prefix.substr(0, prefix.size() - eos_token_.size()); - } + auto eos_pos_last = full.rfind(eos_token_); + if (eos_pos_last == prefix.size() - eos_token_.size() || + (full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) { + full = full.substr(0, eos_pos_last); } if (full.find(prefix) != 0) { fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n"); @@ -363,7 +362,7 @@ class chat_template { if (polyfill_tools) { adjusted_messages = add_system(inputs.messages, "You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) + - (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_)); + (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n")); } else { adjusted_messages = inputs.messages; } diff --git a/scripts/fetch_templates_and_goldens.py b/scripts/fetch_templates_and_goldens.py index 66238bc..5637b3e 100644 --- a/scripts/fetch_templates_and_goldens.py +++ b/scripts/fetch_templates_and_goldens.py @@ -278,7 +278,7 @@ def apply(self, context): if has_tools and not caps.supports_tools: add_system(context['messages'], f"You can call any of the following tools to satisfy the user's requests: {json.dumps(context['tools'], indent=2)}" + - ("\n\nExample tool call syntax:\n\n" + self.tool_call_example if self.tool_call_example is not None else "")) + ("\n\nExample tool call syntax:\n\n" + self.tool_call_example + "\n\n" if self.tool_call_example is not None else "")) for message in context['messages']: if 'tool_calls' in message: diff --git a/tests/test-polyfills.cpp b/tests/test-polyfills.cpp index 0b52836..9f9d1f7 100644 --- a/tests/test-polyfills.cpp +++ b/tests/test-polyfills.cpp @@ -158,7 +158,7 @@ static chat_template_options options_no_polyfills() { }; TEST(PolyfillTest, NoPolyFill) { - chat_template tmpl(TEMPLATE_CHATML, "<|im_end|>", ""); + chat_template tmpl(TEMPLATE_CHATML, "", ""); auto inputs = chat_template_inputs(); inputs.messages = json::array({message_user_text}); @@ -185,7 +185,7 @@ TEST(PolyfillTest, NoPolyFill) { } TEST(PolyfillTest, SystemRoleSupported) { - chat_template chatml(TEMPLATE_CHATML, "<|im_end|>", ""); + chat_template chatml(TEMPLATE_CHATML, "", ""); chat_template dummy(TEMPLATE_DUMMY, "", ""); auto inputs = chat_template_inputs(); @@ -212,7 +212,7 @@ TEST(PolyfillTest, SystemRoleSupported) { } TEST(PolyfillTest, SystemRolePolyfill) { - chat_template tmpl(TEMPLATE_CHATML_NO_SYSTEM, "<|im_end|>", ""); + chat_template tmpl(TEMPLATE_CHATML_NO_SYSTEM, "", ""); auto inputs = chat_template_inputs(); inputs.messages = json::array({message_system, message_user_text}); @@ -230,7 +230,7 @@ TEST(PolyfillTest, SystemRolePolyfill) { } TEST(PolyfillTest, ToolCallSupported) { - chat_template tmpl(TEMPLATE_DUMMY, "<|im_end|>", ""); + chat_template tmpl(TEMPLATE_DUMMY, "", ""); auto inputs = chat_template_inputs(); inputs.messages = json::array({message_user_text, message_assistant_call_id}); @@ -261,7 +261,7 @@ TEST(PolyfillTest, ToolCallSupported) { } TEST(PolyfillTest, ToolCallPolyfill) { - chat_template tmpl(TEMPLATE_CHATML_NO_SYSTEM, "<|im_end|>", ""); + chat_template tmpl(TEMPLATE_CHATML, "", ""); auto inputs = chat_template_inputs(); inputs.messages = json::array({message_user_text, message_assistant_call_id}); @@ -286,14 +286,14 @@ TEST(PolyfillTest, ToolCallPolyfill) { } TEST(PolyfillTest, ToolsPolyfill) { - chat_template tmpl(TEMPLATE_CHATML_NO_SYSTEM, "<|im_end|>", ""); + chat_template tmpl(TEMPLATE_CHATML, "", "<|im_end|>"); auto inputs = chat_template_inputs(); inputs.messages = json::array({message_user_text}); inputs.tools = json::array({special_function_tool}); EXPECT_EQ( - "<|im_start|>user\n" + "<|im_start|>system\n" "You can call any of the following tools to satisfy the user's requests: [\n" " {\n" " \"type\": \"function\",\n" @@ -328,15 +328,15 @@ TEST(PolyfillTest, ToolsPolyfill) { " \"id\": \"call_1___\"\n" " }\n" " ]\n" - "}<|im_end|>\n" // TODO: fix this - "\n" + "}\n\n<|im_end|>\n" + "<|im_start|>user\n" "I need help<|im_end|>\n" "<|im_start|>assistant\n", tmpl.apply(inputs)); } TEST(PolyfillTest, ToolSupported) { - chat_template tmpl(TEMPLATE_DUMMY, "<|im_end|>", ""); + chat_template tmpl(TEMPLATE_DUMMY, "", ""); auto inputs = chat_template_inputs(); inputs.messages = json::array({message_tool}); @@ -353,7 +353,7 @@ TEST(PolyfillTest, ToolSupported) { } TEST(PolyfillTest, ToolPolyfill) { - chat_template tmpl(TEMPLATE_CHATML_NO_SYSTEM, "<|im_end|>", ""); + chat_template tmpl(TEMPLATE_CHATML_NO_SYSTEM, "", ""); auto inputs = chat_template_inputs(); inputs.messages = json::array({message_tool});