diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 2efb69a..882ba41 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -254,10 +254,25 @@ class chat_template { (full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) { full = full.substr(0, eos_pos_last); } - if (full.find(prefix) != 0) { + size_t common_prefix_length = 0; + for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) { + if (prefix[i] != full[i]) { + break; + } + if (prefix[i] == '<') { + // DeepSeek R1's template (as of 20250209) adds a trailing if add_generation_prompt, + // but it removes thinking tags for past messages. + // The prefix and full strings diverge at vs. <|tool▁calls▁begin|>, we avoid consuming the leading <. + continue; + } + common_prefix_length = i + 1; + } + auto example = full.substr(common_prefix_length); + if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) { fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n"); + } else { + tool_call_example_ = example; } - tool_call_example_ = full.substr(prefix.size()); } } catch (const std::exception & e) { fprintf(stderr, "Failed to generate tool call example: %s\n", e.what()); diff --git a/scripts/fetch_templates_and_goldens.py b/scripts/fetch_templates_and_goldens.py index 5637b3e..fd022ba 100644 --- a/scripts/fetch_templates_and_goldens.py +++ b/scripts/fetch_templates_and_goldens.py @@ -43,9 +43,6 @@ def raise_exception(message: str): raise ValueError(message) -def tojson(eval_ctx, value, indent=None): - return json.dumps(value, indent=indent) - TEST_DATE = os.environ.get('TEST_DATE', '2024-07-26') @@ -114,16 +111,22 @@ def try_raw_render(self, messages, *, tools=[], add_generation_prompt=False, ext # print(out, file=sys.stderr) return out except BaseException as e: - # print(f"{template_file}: Error rendering template with messages {messages}: {e}", file=sys.stderr, flush=True) + # print(f"Error rendering template with messages {messages}: {e}", file=sys.stderr, flush=True) return "" - def __init__(self, template, known_eos_tokens, env=None): + def __init__(self, template, env=None, filters=None, global_functions=None): if not env: env = jinja2.Environment( trim_blocks=True, lstrip_blocks=True, extensions=[jinja2.ext.loopcontrols] ) + if filters: + for name, func in filters.items(): + env.filters[name] = func + if global_functions: + for name, func in global_functions.items(): + env.globals[name] = func self.env = env self.template = env.from_string(template) @@ -243,15 +246,24 @@ def make_tool_call(tool_name, arguments): } prefix = self.try_raw_render([user_msg], add_generation_prompt=True) full = self.try_raw_render([user_msg, tool_call_msg], add_generation_prompt=False) - if not full.startswith(prefix): - for known_eos_token in known_eos_tokens: - prefix = prefix.rstrip() - if prefix.endswith(known_eos_token): - prefix = prefix[:-len(known_eos_token)] - break - if not full.startswith(prefix): + + common_prefix_length = 0 + for i in range(min(len(prefix), len(full))): + if prefix[i] != full[i]: + break + if prefix[i] == '<': + # DeepSeek R1's template (as of 20250209) adds a trailing if add_generation_prompt, + # but it removes thinking tags for past messages. + # The prefix and full strings diverge at vs. <|tool▁calls▁begin|>, we avoid consuming the leading <. + continue + common_prefix_length = i + 1 + + example = full[common_prefix_length:] + if "tool_name" not in example and "some_value" not in example: print("Failed to infer a tool call example (possible template bug)", file=sys.stderr) - self.tool_call_example = full[len(prefix):] + else: + self.tool_call_example = example + except Exception as e: print(f"Failed to generate tool call example: {e}", file=sys.stderr) @@ -321,7 +333,11 @@ def apply(self, context): message['content'] = [{"type": "text", "text": message['content']}] try: - return self.template.render(**context) + out = self.template.render(**context) + out = out.replace("\\u0027", "'") + out = out.replace('"', '"') + out = out.replace(''', "'") + return out except Exception as e1: for message in context['messages']: if message.get("content") is None: @@ -350,21 +366,14 @@ async def handle_chat_template(output_folder, model_id, variant, template_src, c async with aiofiles.open(template_file, 'w') as f: await f.write(template_src) - known_eos_tokens = [ - "<|END_OF_TURN_TOKEN|>", - "", - "", - "<|im_end|>", - "<|eom_id|>", - "<|eot_id|>", - "<|end▁of▁sentence|>", - ] - - template = chat_template(template_src, known_eos_tokens) - template.env.filters['safe'] = lambda x: x - template.env.filters['tojson'] = tojson - template.env.globals['raise_exception'] = raise_exception - template.env.globals['strftime_now'] = strftime_now + template = chat_template(template_src, + filters={ + 'safe': lambda x: x, + }, + global_functions={ + 'raise_exception': raise_exception, + 'strftime_now': strftime_now, + }) caps = template.original_caps if not context_files: diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 550515d..47b0900 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -31,11 +31,8 @@ target_link_libraries(test-polyfills PRIVATE ) if (NOT CMAKE_CROSSCOMPILING) gtest_discover_tests(test-syntax) -endif() - -if (NOT CMAKE_CROSSCOMPILING) - gtest_discover_tests(test-syntax) - gtest_discover_tests(test-polyfills) + add_test(NAME test-polyfills COMMAND test-polyfills) + set_tests_properties(test-polyfills PROPERTIES WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) endif() add_executable(test-capabilities test-capabilities.cpp) @@ -82,6 +79,7 @@ set(MODEL_IDS MiniMaxAI/MiniMax-Text-01 indischepartij/MiniCPM-3B-OpenHermes-2.5-v2 mattshumer/Reflection-Llama-3.1-70B + meetkai/functionary-medium-v3.1 meetkai/functionary-medium-v3.2 meta-llama/Llama-3.1-8B-Instruct # Gated meta-llama/Llama-3.2-3B-Instruct # Gated diff --git a/tests/contexts/simple.json b/tests/contexts/simple.json index 560f92f..5e89f22 100644 --- a/tests/contexts/simple.json +++ b/tests/contexts/simple.json @@ -11,5 +11,6 @@ ], "add_generation_prompt": true, "bos_token": "<|startoftext|>", - "eos_token": "<|endoftext|>" + "eos_token": "<|endoftext|>", + "tools_in_user_message": false } diff --git a/tests/contexts/system.json b/tests/contexts/system.json index 4d72972..7cbc5c2 100644 --- a/tests/contexts/system.json +++ b/tests/contexts/system.json @@ -15,5 +15,6 @@ ], "add_generation_prompt": true, "bos_token": "<|startoftext|>", - "eos_token": "<|endoftext|>" + "eos_token": "<|endoftext|>", + "tools_in_user_message": false } diff --git a/tests/contexts/tool_use.json b/tests/contexts/tool_use.json index 4920d19..cca70cb 100644 --- a/tests/contexts/tool_use.json +++ b/tests/contexts/tool_use.json @@ -88,6 +88,7 @@ "add_generation_prompt": true, "bos_token": "<|startoftext|>", "eos_token": "<|endoftext|>", + "tools_in_user_message": false, "builtin_tools": [ "wolfram_alpha", "brave_search" @@ -96,72 +97,72 @@ "todays_date": "2024-09-03", "tools": [ { - "type": "function", "function": { - "name": "ipython", "description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", + "name": "ipython", "parameters": { - "type": "object", "properties": { "code": { - "type": "string", - "description": "The code to run in the ipython interpreter." + "description": "The code to run in the ipython interpreter.", + "type": "string" } }, - "required": ["code"] + "required": ["code"], + "type": "object" } - } + }, + "type": "function" }, { - "type": "function", "function": { - "name": "brave_search", "description": "Executes a web search with Brave.", + "name": "brave_search", "parameters": { - "type": "object", "properties": { "query": { - "type": "string", - "description": "The query to search for." + "description": "The query to search for.", + "type": "string" } }, - "required": ["query"] + "required": ["query"], + "type": "object" } - } + }, + "type": "function" }, { - "type": "function", "function": { - "name": "wolfram_alpha", "description": "Executes a query with Wolfram Alpha.", + "name": "wolfram_alpha", "parameters": { - "type": "object", "properties": { "query": { - "type": "string", - "description": "The query to execute." + "description": "The query to execute.", + "type": "string" } }, - "required": ["query"] + "required": ["query"], + "type": "object" } - } + }, + "type": "function" }, { - "type": "function", "function": { - "name": "test", "description": "Runs a test.", + "name": "test", "parameters": { - "type": "object", "properties": { "condition": { - "type": "boolean", - "description": "The condition to test." + "description": "The condition to test.", + "type": "boolean" } }, - "required": ["condition"] + "required": ["condition"], + "type": "object" } - } + }, + "type": "function" } ] } \ No newline at end of file diff --git a/tests/test-polyfills.cpp b/tests/test-polyfills.cpp index 9f9d1f7..d1c598b 100644 --- a/tests/test-polyfills.cpp +++ b/tests/test-polyfills.cpp @@ -17,6 +17,22 @@ using namespace minja; +static std::string read_file(const std::string &path) +{ + std::ifstream fs(path, std::ios_base::binary); + if (!fs.is_open()) + { + throw std::runtime_error("Failed to open file: " + path); + } + fs.seekg(0, std::ios_base::end); + auto size = fs.tellg(); + fs.seekg(0); + std::string out; + out.resize(static_cast(size)); + fs.read(&out[0], static_cast(size)); + return out; +} + #define TEMPLATE_CHATML \ "{%- for message in messages -%}\n" \ " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \ @@ -120,6 +136,7 @@ const json message_tool { { "content", { {"result", 123}, }}, + { "tool_call_id", "123456789"}, }; const auto special_function_tool = json::parse(R"({ @@ -346,7 +363,8 @@ TEST(PolyfillTest, ToolSupported) { " \"role\": \"tool\",\n" " \"content\": {\n" " \"result\": 123\n" - " }\n" + " },\n" + " \"tool_call_id\": \"123456789\"\n" "}\n" "message: ", tmpl.apply(inputs)); @@ -363,9 +381,199 @@ TEST(PolyfillTest, ToolPolyfill) { " \"tool_response\": {\n" " \"content\": {\n" " \"result\": 123\n" - " }\n" + " },\n" + " \"tool_call_id\": \"123456789\"\n" " }\n" "}<|im_end|>\n" "<|im_start|>assistant\n", tmpl.apply(inputs)); -} \ No newline at end of file +} + +#ifndef _WIN32 +TEST(ToolTest, DeepSeekR1) { + chat_template tmpl(read_file("tests/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|tool▁outputs▁begin|><|tool▁output▁begin|>{'result': 123}<|tool▁output▁end|><|tool▁outputs▁end|>", + tmpl.apply(inputs)); +} + +TEST(ToolTest, CommandR7b) { + chat_template tmpl(read_file("tests/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble\n" + "You are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.\n" + "\n" + "Your information cutoff date is June 2024.\n" + "\n" + "You have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.\n" + "# Default Preamble\n" + "The following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.\n" + "- Your name is Command.\n" + "- You are a large language model built by Cohere.\n" + "- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.\n" + "- If the input is ambiguous, ask clarifying follow-up questions.\n" + "- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).\n" + "- Use LaTeX to generate mathematical notation for complex equations.\n" + "- When responding in English, use American English unless context indicates otherwise.\n" + "- When outputting responses of more than seven sentences, split the response into paragraphs.\n" + "- Prefer the active voice.\n" + "- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.\n" + "- Use gender-neutral pronouns for unspecified persons.\n" + "- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.\n" + "- Use the third person when asked to write a summary.\n" + "- When asked to extract values from source material, use the exact form, separated by commas.\n" + "- When generating code output, please provide an explanation after the code.\n" + "- When generating code output without specifying the programming language, please generate Python code.\n" + "- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n" + " {\n" + " \"tool_call_id\": \"\",\n" + " \"results\": {\n" + " \"0\": {\"result\": 123}\n" + " },\n" + " \"is_error\": null\n" + " }\n" + "]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", + tmpl.apply(inputs)); +} +#endif // NOT _WIN32 + +TEST(ToolTest, MistralNemo) { + chat_template tmpl(read_file("tests/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "[TOOL_RESULTS]{\"content\": {'result': 123}, \"call_id\": \"123456789\"}[/TOOL_RESULTS]", + tmpl.apply(inputs)); +} + +TEST(ToolTest, NousResearchHermes3) { + chat_template tmpl(read_file("tests/NousResearch-Hermes-3-Llama-3.1-70B-tool_use.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|im_start|>system\n" + "You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: Use the following pydantic model json schema for each tool call you will make: {\"properties\": {\"name\": {\"title\": \"Name\", \"type\": \"string\"}, \"arguments\": {\"title\": \"Arguments\", \"type\": \"object\"}}, \"required\": [\"name\", \"arguments\"], \"title\": \"FunctionCall\", \"type\": \"object\"}}\n" + "For each function call return a json object with function name and arguments within XML tags as follows:\n" + "\n" + "{\"name\": , \"arguments\": }\n" + "<|im_end|>\n" + "\n" + "{'result': 123}\n" + "<|im_end|><|im_start|>assistant\n", + tmpl.apply(inputs)); +} + +TEST(ToolTest, NousResearchHermes2) { + chat_template tmpl(read_file("tests/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|im_start|>system\n" + "You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: Use the following pydantic model json schema for each tool call you will make: {\"properties\": {\"name\": {\"title\": \"Name\", \"type\": \"string\"}, \"arguments\": {\"title\": \"Arguments\", \"type\": \"object\"}}, \"required\": [\"name\", \"arguments\"], \"title\": \"FunctionCall\", \"type\": \"object\"}}\n" + "For each function call return a json object with function name and arguments within XML tags as follows:\n" + "\n" + "{\"name\": , \"arguments\": }\n" + "<|im_end|>\n" + "\n" + "{'result': 123}\n" + "<|im_end|><|im_start|>assistant\n", + tmpl.apply(inputs)); +} + +TEST(ToolTest, Llama3_3) { + chat_template tmpl(read_file("tests/meta-llama-Llama-3.3-70B-Instruct.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|start_header_id|>system<|end_header_id|>\n" + "\n" + "Cutting Knowledge Date: December 2023\n" + "Today Date: 26 Jul 2024\n" + "\n" + "<|eot_id|><|start_header_id|>ipython<|end_header_id|>\n" + "\n" + "{\"result\": 123}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + "\n", + tmpl.apply(inputs)); +} + +TEST(ToolTest, MeetkaiFunctionary3_1) { + chat_template tmpl(read_file("tests/meetkai-functionary-medium-v3.1.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|start_header_id|>system<|end_header_id|>\n" + "\n" + "\n" + "Cutting Knowledge Date: December 2023\n" + "\n" + "<|eot_id|><|start_header_id|>ipython<|end_header_id|>\n" + "\n" + "{'result': 123}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + "\n", + tmpl.apply(inputs)); +} + +TEST(ToolTest, MeetkaiFunctionary3_2) { + chat_template tmpl(read_file("tests/meetkai-functionary-medium-v3.2.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|start_header_id|>system<|end_header_id|>\n" + "\n" + "You are capable of executing available function(s) if required.\n" + "Only execute function(s) when absolutely necessary.\n" + "Ask for the required input to:recipient==all\n" + "Use JSON for function arguments.\n" + "Respond in this format:\n" + ">>>${recipient}\n" + "${content}\n" + "Available functions:\n" + "// Supported function definitions that should be called when necessary.\n" + "namespace functions {\n" + "\n" + "} // namespace functions<|eot_id|><|start_header_id|>tool<|end_header_id|>\n" + "\n" + "{'result': 123}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + "\n" + ">>>", + tmpl.apply(inputs)); +} + +/* +https://github.com/google/minja/issues/7 +TEST(ToolTest, FirefunctionV2) { + chat_template tmpl(read_file("tests/fireworks-ai-llama-3-firefunction-v2.jinja"), "", ""); + + auto inputs = chat_template_inputs(); + inputs.messages = json::array({message_tool}); + + EXPECT_EQ( + "<|im_start|>tool\n" + "{\n" + " \"result\": 123\n" + "}\n" + "<|im_end|>", + tmpl.apply(inputs)); +} +*/ diff --git a/tests/test-supported-template.cpp b/tests/test-supported-template.cpp index 302ebbd..965375f 100644 --- a/tests/test-supported-template.cpp +++ b/tests/test-supported-template.cpp @@ -128,19 +128,21 @@ int main(int argc, char *argv[]) { struct minja::chat_template_inputs inputs; inputs.messages = ctx.at("messages"); - inputs.tools = ctx.contains("tools") ? ctx.at("tools") : json(); + ctx.erase("messages"); + + if (ctx.contains("tools")) { + inputs.tools = ctx.at("tools"); + ctx.erase("tools"); + } inputs.add_generation_prompt = ctx.at("add_generation_prompt"); + ctx.erase("add_generation_prompt"); std::istringstream ss(TEST_DATE); std::tm tm = {}; ss >> std::get_time(&tm, "%Y-%m-%d"); inputs.now = std::chrono::system_clock::from_time_t(std::mktime(&tm)); - if (ctx.contains("tools")) { - inputs.extra_context = json { - {"builtin_tools", json::array({"wolfram_alpha", "brave_search"})}, - }; - } + inputs.extra_context = ctx; std::string actual; try {