diff --git a/.gitignore b/.gitignore index 6295887..4049566 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ dist/ .DS_Store Testing/ .vscode/ +__pycache__/ \ No newline at end of file diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 91d71a1..75ba5d9 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -17,17 +17,26 @@ using json = nlohmann::ordered_json; namespace minja { +struct chat_template_caps { + bool supports_tools = false; + bool supports_tool_calls = false; + bool supports_tool_responses = false; + bool supports_system_role = false; + bool supports_parallel_tool_calls = false; + bool supports_tool_call_id = false; + // meta-llama/Llama-3.1-8B-Instruct expects arguments to be an object. + // Most other templates (and OpenAI's API) expect the arguments object to be stringified. + bool requires_object_arguments = false; + // CohereForAI/c4ai-command-r-plus simple variant + bool requires_non_null_content = false; + // MiniMaxAI/MiniMax-Text-01 special + bool requires_typed_content = false; +}; + class chat_template { - public: private: - bool supports_tools_ = true; - // Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object. - // Most other templates (and OpenAI's API) expect the arguments object to be stringified. - bool requires_object_arguments_ = false; - bool requires_typed_content_ = false; - bool supports_system_role_ = true; - bool supports_parallel_tool_calls_ = false; + chat_template_caps caps_; std::string source_; std::string bos_token_; std::string eos_token_; @@ -41,15 +50,16 @@ class chat_template { { try { auto prompt = apply(messages, tools, add_generation_prompt, extra_context, /* adjust_inputs= */ false); - // fprintf(stderr, "Prompt: %s\n", prompt.c_str()); + // fprintf(stderr, "try_raw_render: %s\n", prompt.c_str()); return prompt; } catch (const std::exception & e) { - // fprintf(stderr, "Error: %s\n", e.what()); + // fprintf(stderr, "try_raw_render error: %s\n", e.what()); return ""; } } public: + chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token) : source_(source), bos_token_(bos_token), eos_token_(eos_token) { @@ -58,69 +68,120 @@ class chat_template { /* .lstrip_blocks = */ true, /* .keep_trailing_newline = */ false, }); - supports_tools_ = source.find("tools") != std::string::npos; - auto renders_string_arguments = - try_raw_render({ - { - {"role", "user"}, - {"content", "Hey"} - }, - { - {"role", "assistant"}, - {"tool_calls", json::array({ - { - {"id", "call_1___"}, - {"type", "function"}, - {"function", { - {"arguments", "{\"code\": \"print('Hello, World!')\"}"}, - {"name", "ipython"}, + auto contains = [](const std::string & haystack, const std::string & needle) { + return haystack.find(needle) != std::string::npos; + }; + + const std::string user_needle = ""; + const std::string sys_needle = ""; + const json dummy_str_user_msg = {{"role", "user"}, {"content", user_needle}}; + const json dummy_typed_user_msg = {{"role", "user"}, {"content", json::array({{{"type", "text"}, {"text", user_needle}}})}}; + + caps_.requires_typed_content = + !contains(try_raw_render(json::array({dummy_str_user_msg}), {}, false), user_needle) + && contains(try_raw_render(json::array({dummy_typed_user_msg}), {}, false), user_needle); + + const auto dummy_user_msg = caps_.requires_typed_content + ? dummy_typed_user_msg + : dummy_str_user_msg; + const json needle_system_msg = { + {"role", "system"}, + {"content", caps_.requires_typed_content ? json::array({{{"type", "text"}, {"text", sys_needle}}}) : json(sys_needle)}, + }; + + caps_.supports_system_role = contains(try_raw_render({needle_system_msg, dummy_user_msg,}, {}, false), sys_needle); + + auto out = try_raw_render(json::array({ + dummy_user_msg + }), json::array({ + { + {"name", "some_tool"}, + {"type", "function"}, + {"function", { + {"name", "some_tool"}, + {"description", "Some tool."}, + {"parameters", { + {"type", "object"}, + {"properties", { + {"arg", { + {"type", "string"}, + {"description", "Some argument."}, }}, - }, - })}, - } - }, {}, false).find("{\"code\": \"print") != std::string::npos; - if (!renders_string_arguments) { - auto renders_object_arguments = - try_raw_render({ - { - {"role", "user"}, - {"content", "Hey"} - }, - { - {"role", "assistant"}, - {"tool_calls", json::array({ - { - {"id", "call_1___"}, - {"type", "function"}, - {"function", { - {"arguments", { - {"code", "print('Hello, World!')"}, - }}, - {"name", "ipython"}, - }}, - }, - })}, - } - }, {}, false).find("{\"code\": \"print") != std::string::npos; - requires_object_arguments_ = renders_object_arguments; - } - supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos; + }}, + {"required", json::array({ "arg" })}, + }}, + }}, + }, + }), false); + caps_.supports_tools = contains(out, "some_tool"); - supports_system_role_ = try_raw_render({ - {{"role", "system"}, {"content", ""}}, - {{"role", "user"}, {"content", "Hey"}} - }, {}, false).find("") != std::string::npos; + auto make_tool_calls_msg = [&](const json & tool_calls) { + return json { + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", tool_calls}, + }; + }; + auto make_tool_call = [](const std::string & tool_name, const json & arguments) { + return json { + {"id", "call_1___"}, + {"type", "function"}, + {"function", { + {"arguments", arguments}, + {"name", tool_name}, + }}, + }; + }; + const json dummy_args_obj {{"argument_needle", "print('Hello, World!')"}}; + + // Note: the arguments are rendered in both cases, but may be double-escaped, which we don't want. + out = try_raw_render(json::array({ + dummy_user_msg, + make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})), + }), {}, false); + auto tool_call_renders_str_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':"); + out = try_raw_render(json::array({ + dummy_user_msg, + make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})), + }), {}, false); + auto tool_call_renders_obj_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':"); + + caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments; + caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments; + auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false); + auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false); + caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle); + + if (caps_.supports_tool_calls) { + auto dummy_args = caps_.requires_object_arguments ? dummy_args_obj : json(dummy_args_obj.dump()); + auto tc1 = make_tool_call("test_tool1", dummy_args); + auto tc2 = make_tool_call("test_tool2", dummy_args); + auto out = try_raw_render(json::array({ + dummy_user_msg, + make_tool_calls_msg(json::array({tc1, tc2})), + }), {}, false); + caps_.supports_parallel_tool_calls = contains(out, "test_tool1") && contains(out, "test_tool2"); - requires_typed_content_ = try_raw_render({{{"role", "user"}, {"content", "Hey"}}}, {}, false).find("Hey") == std::string::npos - && try_raw_render({{{"role", "user"}, {"content", {{{"type", "text"}, {"text", "Hey"}}}}}}, {}, false).find("Hey") != std::string::npos; + out = try_raw_render(json::array({ + dummy_user_msg, + make_tool_calls_msg(json::array({tc1})), + { + {"role", "tool"}, + {"name", "test_tool1"}, + {"content", "Some response!"}, + {"tool_call_id", "call_911_"}, + } + }), {}, false); + caps_.supports_tool_responses = contains(out, "Some response!"); + caps_.supports_tool_call_id = contains(out, "call_911_"); + } } const std::string & source() const { return source_; } const std::string & bos_token() const { return bos_token_; } const std::string & eos_token() const { return eos_token_; } - bool supports_tools() const { return supports_tools_; } - bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; } + const chat_template_caps & original_caps() const { return caps_; } std::string apply( const nlohmann::ordered_json & messages, @@ -131,13 +192,19 @@ class chat_template { { json actual_messages; - // First, "fix" messages so they have a chance to be rendered correctly by the template - - if (adjust_inputs && (requires_object_arguments_ || !supports_system_role_ || !supports_tools_ || requires_typed_content_)) { + auto needs_adjustments = adjust_inputs && (false + || !caps_.supports_system_role + || !caps_.supports_tools + || !caps_.supports_tool_responses + || !caps_.supports_tool_calls + || caps_.requires_object_arguments + || caps_.requires_typed_content + ); + if (needs_adjustments) { actual_messages = json::array(); auto add_message = [&](const json & msg) { - if (requires_typed_content_ && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) { + if (caps_.requires_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) { actual_messages.push_back({ {"role", msg.at("role")}, {"content", {{ @@ -160,7 +227,9 @@ class chat_template { pending_system.clear(); } }; - for (const auto & message_ : messages) { + auto needs_tools_in_system = !tools.is_null() && tools.size() > 0 && !caps_.supports_tools; + + for (const auto & message_ : needs_tools_in_system ? add_system(messages, "Available tools: " + tools.dump(2)) : messages) { auto message = message_; if (!message.contains("role") || !message.contains("content")) { throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump()); @@ -168,21 +237,22 @@ class chat_template { std::string role = message.at("role"); if (message.contains("tool_calls")) { - if (requires_object_arguments_ || !supports_tools_) { + if (caps_.requires_object_arguments || !caps_.supports_tool_calls) { for (auto & tool_call : message.at("tool_calls")) { if (tool_call["type"] == "function") { auto & function = tool_call.at("function"); - std::string arguments = function.at("arguments"); - try { - function["arguments"] = json::parse(arguments); - } catch (const std::exception & ecvt) { - fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what()); - function["arguments"] = arguments; + auto & arguments = function.at("arguments"); + if (arguments.is_string()) { + try { + arguments = json::parse(arguments.get()); + } catch (const std::exception & ecvt) { + fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what()); + } } } } } - if (!supports_tools_) { + if (!caps_.supports_tool_calls) { auto content = message.at("content"); auto tool_calls = json::array(); for (const auto & tool_call : message.at("tool_calls")) { @@ -209,7 +279,7 @@ class chat_template { message.erase("tool_calls"); } } - if (!supports_tools_ && role == "tool") { + if (!caps_.supports_tool_responses && role == "tool") { message["role"] = "user"; auto obj = json { {"tool_response", { @@ -224,7 +294,7 @@ class chat_template { message.erase("name"); } - if (!message["content"].is_null() && !supports_system_role_) { + if (!message["content"].is_null() && !caps_.supports_system_role) { std::string content = message.at("content"); if (role == "system") { if (!pending_system.empty()) pending_system += "\n"; @@ -243,7 +313,9 @@ class chat_template { } add_message(message); } - flush_sys(); + if (!caps_.supports_system_role) { + flush_sys(); + } } else { actual_messages = messages; } @@ -266,7 +338,28 @@ class chat_template { } } - return template_root_->render(context); + auto ret = template_root_->render(context); + // fprintf(stderr, "actual_messages: %s\n", actual_messages.dump(2).c_str()); + // fprintf(stderr, "apply: %s\n\n", ret.c_str()); + return ret; + } + + static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) { + json messages_with_system = messages; + + if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") { + std::string existing_system = messages_with_system.at(0).at("content"); + messages_with_system[0] = json { + {"role", "system"}, + {"content", existing_system + "\n" + system_prompt}, + }; + } else { + messages_with_system.insert(messages_with_system.begin(), json { + {"role", "system"}, + {"content", system_prompt}, + }); + } + return messages_with_system; } }; diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index dd0ae6c..604e613 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -1270,11 +1270,6 @@ class BinaryOpExpr : public Expression { } auto r = right->evaluate(context); - if (op != Op::Eq && op != Op::Ne) { - if (r.is_null() || (l.is_null() && (op != Op::In && op != Op::NotIn))) { - throw std::runtime_error("unsupported operand type(s)"); - } - } switch (op) { case Op::StrConcat: return l.to_str() + r.to_str(); case Op::Add: return l + r; @@ -2152,11 +2147,11 @@ class Parser { } std::runtime_error unexpected(const TemplateToken & token) const { - return std::runtime_error("Encountered unknown tag '" + TemplateToken::typeToString(token.type) + "'" + return std::runtime_error("Unexpected " + TemplateToken::typeToString(token.type) + error_location_suffix(*template_str, token.location.pos)); } std::runtime_error unterminated(const TemplateToken & token) const { - return std::runtime_error("Unexpected end of template. Jinja was looking for the following tags: '" + TemplateToken::typeToString(token.type) + "'" + return std::runtime_error("Unterminated " + TemplateToken::typeToString(token.type) + error_location_suffix(*template_str, token.location.pos)); } diff --git a/requirements.txt b/requirements.txt index 2ec508b..f27dfab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,4 @@ +aiofiles +aiohttp huggingface_hub jinja2 diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/analyze_capabilities.py b/scripts/analyze_capabilities.py new file mode 100755 index 0000000..25a7f61 --- /dev/null +++ b/scripts/analyze_capabilities.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +import json +import os +from pathlib import Path +from typing import Dict, List + + +def generate_markdown_table(files_data: List[tuple[str, Dict]]) -> str: + """Generate a markdown table from the capabilities data.""" + if not files_data: + return "No capability files found." + + all_caps = set() + for _, data in files_data: + all_caps.update(data.keys()) + all_caps = sorted(all_caps) + + lines = [ + "| Model | " + " | ".join(c.replace('_', ' ') for c in all_caps) + " |", + "|" + "|".join("-" * (len(cap) + 2) for cap in ["Model"] + list(all_caps)) + "|", + ] + + # Sort data by most supports and least requires + def sort_key(item): + model, data = item + supports_count = sum(1 for k, v in data.items() + if k.startswith("supports_") and str(v).lower() == "true") + requires_count = sum(1 for k, v in data.items() + if k.startswith("requires_") and str(v).lower() == "true") + return (-supports_count, requires_count) # negative for descending supports + + for model, data in sorted(files_data, key=sort_key): + model_name = os.path.basename(model).replace(".caps.json", "") + row = [model_name] + for cap in all_caps: + raw_value = str(data.get(cap, "N/A")).lower() + if raw_value == "true": + if cap.startswith("supports_"): + value = "✅" + elif cap.startswith("requires_"): + value = "⚠️" + else: + value = raw_value + elif raw_value == "false": + value = "" + else: + value = raw_value + row.append(value) + lines.append("| " + " | ".join(row) + " |") + + return "\n".join(lines) + +def main(): + script_dir = Path(__file__).parent + build_dir = script_dir.parent / "build" + + files_data = [ + (str(f), json.loads(f.read_text())) + for f in list((build_dir / "tests").rglob("*.caps.json")) + ] + + markdown = generate_markdown_table(files_data) + + (build_dir / "capabilities.md").write_text(markdown) + + print(markdown) + +if __name__ == "__main__": + main() diff --git a/scripts/fetch_templates_and_goldens.py b/scripts/fetch_templates_and_goldens.py index 3813ff6..619b539 100644 --- a/scripts/fetch_templates_and_goldens.py +++ b/scripts/fetch_templates_and_goldens.py @@ -15,19 +15,24 @@ Example: pip install -r requirements.txt - python scripts/fetch_templates_and_goldens.py ./test_files tests/contexts/*.json mistralai/Mistral-Large-Instruct-2407 meetkai/functionary-medium-v3.1.jinja microsoft/Phi-3-medium-4k-instruct Qwen/Qwen2-7B-Instruct + python scripts/fetch_templates_and_goldens.py ./test_files tests/contexts/*.json CohereForAI/c4ai-command-r-plus mistralai/Mistral-Large-Instruct-2407 meetkai/functionary-medium-v3.1.jinja microsoft/Phi-3-medium-4k-instruct Qwen/Qwen2-7B-Instruct ''' +from dataclasses import dataclass import logging import datetime import os import sys -from huggingface_hub import hf_hub_download +import asyncio +import aiofiles +from huggingface_hub import AsyncInferenceClient +from huggingface_hub.utils import build_hf_headers import json import jinja2 import jinja2.ext import re import argparse +import aiohttp import shutil logging.basicConfig(level=logging.INFO, format='%(message)s') @@ -55,8 +60,160 @@ def join_cmake_path(parent, child): ''' return '/'.join(x.replace(r'\\', '/') for x in (parent, child)) -def handle_chat_template(output_folder, model_id, variant, template_src, context_files): +def add_system(messages, system_prompt): + if len(messages) > 0 and messages[0]["role"] == "system": + existing_system = messages[0]["content"] + messages[0] = { + "role": "system", + "content": existing_system + "\n" + system_prompt, + } + else: + messages.insert(0, { + "role": "system", + "content": system_prompt, + }) + +# data class +@dataclass +class TemplateCaps: + supports_tools: bool = False + supports_tool_calls: bool = False + supports_tool_responses: bool = False + supports_system_role: bool = False + supports_parallel_tool_calls: bool = False + supports_tool_call_id: bool = False + requires_object_arguments: bool = False + requires_non_null_content: bool = False + requires_typed_content: bool = False + + def to_json(self): + return json.dumps({ + "supports_tools": self.supports_tools, + "supports_tool_calls": self.supports_tool_calls, + "supports_tool_responses": self.supports_tool_responses, + "supports_system_role": self.supports_system_role, + "supports_parallel_tool_calls": self.supports_parallel_tool_calls, + "supports_tool_call_id": self.supports_tool_call_id, + "requires_object_arguments": self.requires_object_arguments, + # "requires_non_null_content": self.requires_non_null_content, + "requires_typed_content": self.requires_typed_content, + }, indent=2) + +def detect_caps(template_file, template): + basic_extra_context = { + "bos_token": "<|startoftext|>", + "eos_token": "<|endoftext|>", + } + def try_raw_render(messages, *, tools=[], add_generation_prompt=False, extra_context={}, expect_strings=[]): + try: + out = template.render(messages=messages, tools=tools, add_generation_prompt=add_generation_prompt, **basic_extra_context, **extra_context) + # print(out, file=sys.stderr) + return out + except BaseException as e: + # print(f"{template_file}: Error rendering template with messages {messages}: {e}", file=sys.stderr, flush=True) + return "" + + caps = TemplateCaps() + + user_needle = "" + sys_needle = "" + dummy_str_user_msg = {"role": "user", "content": user_needle } + dummy_typed_user_msg = {"role": "user", "content": [{"type": "text", "text": user_needle}]} + + caps.requires_typed_content = \ + (user_needle not in try_raw_render([dummy_str_user_msg])) \ + and (user_needle in try_raw_render([dummy_typed_user_msg])) + dummy_user_msg = dummy_typed_user_msg if caps.requires_typed_content else dummy_str_user_msg + + needle_system_msg = {"role": "system", "content": [{"type": "text", "text": sys_needle}] if caps.requires_typed_content else sys_needle} + + caps.supports_system_role = sys_needle in try_raw_render([needle_system_msg, dummy_user_msg]) + + out = try_raw_render([dummy_user_msg], tools=[{ + "name": "some_tool", + "type": "function", + "function": { + "name": "some_tool", + "description": "Some tool", + "parameters": { + "type": "object", + "properties": { + "arg": { + "type": "string", + "description": "Some arg", + }, + }, + "required": ["arg"], + }, + }, + }]) + caps.supports_tools = "some_tool" in out + + def make_tool_calls_msg(tool_calls, content=None): + return { + "role": "assistant", + "content": content, + "tool_calls": tool_calls, + } + def make_tool_call(tool_name, arguments): + return { + "id": "call_1___", + "type": "function", + "function": { + "arguments": arguments, + "name": tool_name, + } + } + + dummy_args_obj = {"argument_needle": "print('Hello, World!')"} + + out = try_raw_render([ + dummy_user_msg, + make_tool_calls_msg([make_tool_call("ipython", json.dumps(dummy_args_obj))]), + ]) + tool_call_renders_str_arguments = '"argument_needle":' in out or "'argument_needle':" in out + out = try_raw_render([ + dummy_user_msg, + make_tool_calls_msg([make_tool_call("ipython", dummy_args_obj)]), + ]) + tool_call_renders_obj_arguments = '"argument_needle":' in out or "'argument_needle':" in out + + caps.supports_tool_calls = tool_call_renders_str_arguments or tool_call_renders_obj_arguments + caps.requires_object_arguments = not tool_call_renders_str_arguments and tool_call_renders_obj_arguments + + empty_out = try_raw_render([dummy_user_msg, {"role": "assistant", "content": ''}]) + none_out = try_raw_render([dummy_user_msg, {"role": "assistant", "content": None}]) + caps.requires_non_null_content = \ + (user_needle in try_raw_render([dummy_user_msg, {"role": "assistant", "content": ''}])) \ + and (user_needle not in try_raw_render([dummy_user_msg, {"role": "assistant", "content": None}])) + + if caps.supports_tool_calls: + dummy_args = dummy_args_obj if caps.requires_object_arguments else json.dumps(dummy_args_obj) + tc1 = make_tool_call("test_tool1", dummy_args) + tc2 = make_tool_call("test_tool2", dummy_args) + out = try_raw_render([ + dummy_user_msg, + make_tool_calls_msg([tc1, tc2]), + ]) + caps.supports_parallel_tool_calls = "test_tool1" in out and "test_tool2" in out + + out = try_raw_render([ + dummy_user_msg, + make_tool_calls_msg([tc1]), + { + "role": "tool", + "name": "test_tool1", + "content": "Some response!", + "tool_call_id": "call_911_", + } + ]) + caps.supports_tool_responses = "Some response!" in out + caps.supports_tool_call_id = "call_911_" in out + + return caps + +async def handle_chat_template(output_folder, model_id, variant, template_src, context_files): if '{% generation %}' in template_src: print('Removing {% generation %} blocks from template', file=sys.stderr) template_src = template_src.replace('{% generation %}', '').replace('{% endgeneration %}', '') @@ -64,13 +221,10 @@ def handle_chat_template(output_folder, model_id, variant, template_src, context model_name = model_id.replace("/", "-") base_name = f'{model_name}-{variant}' if variant else model_name template_file = join_cmake_path(output_folder, f'{base_name}.jinja') + caps_file = join_cmake_path(output_folder, f'{base_name}.caps.json') - with open(template_file, 'w') as f: - f.write(template_src) - - if not context_files: - print(f"{template_file} n/a {template_file}") - return + async with aiofiles.open(template_file, 'w') as f: + await f.write(template_src) env = jinja2.Environment( trim_blocks=True, @@ -84,81 +238,39 @@ def handle_chat_template(output_folder, model_id, variant, template_src, context env.globals['raise_exception'] = raise_exception env.globals['strftime_now'] = strftime_now - template_handles_tools = 'tools' in template_src - supports_code_interpreter = 'code_interpreter' in template_src + caps = detect_caps(template_file, template) + if not context_files: + print(f"{template_file} {caps_file} n/a {template_file}") + return - def renders(messages, *, tools=[], add_generation_prompt=False, extra_context={}, expect_strings=[]): - try: - prompt = template.render(messages=messages, tools=tools, add_generation_prompt=add_generation_prompt, **extra_context) - for str in expect_strings: - if str not in prompt: - # print(f"Expected string not found: {str}\nin prompt:\n{prompt}", file=sys.stderr, flush=True) - return False - return True - except Exception as e: - # print(f"Error rendering template with messages {messages}: {e}", file=sys.stderr, flush=True) - return False - - basic_extra_context = { - "bos_token": "<|startoftext|>", - "eos_token": "<|endoftext|>", - } - renders_string_arguments = renders([ - {"role": "user", "content": "Hey"}, - {"role": "assistant", "tool_calls": [{ - "id": "call_1___", - "type": "function", - "function": { - "arguments": "{\"code\": \"print('Hello, World!')\"}", - "name": "ipython" - } - }]} - ], extra_context=basic_extra_context, expect_strings=[r'{"code": "print']) - renders_object_arguments = renders([ - {"role": "user", "content": "Hey"}, - {"role": "assistant", "tool_calls": [{ - "id": "call_1___", - "type": "function", - "function": { - "arguments": {"code": "print('Hello, World!')"}, - "name": "ipython" - } - }]} - ], extra_context=basic_extra_context, expect_strings=[r'{"code": "print']) - requires_object_arguments = not renders_string_arguments and renders_object_arguments - - supports_system_role = renders([ - {"role": "system", "content": "System Needle"}, - {"role": "user", "content": "Hey"} - ], extra_context=basic_extra_context, expect_strings=["System Needle"]) - - requires_typed_content = \ - not renders([{"role": "user", "content": "Hey"}], extra_context=basic_extra_context, expect_strings=["Hey"]) \ - and renders([{"role": "user", "content": [{"type": "text", "text": "Hey"}]}], extra_context=basic_extra_context, expect_strings=["Hey"]) + async with aiofiles.open(caps_file, 'w') as f: + await f.write(caps.to_json()) for context_file in context_files: context_name = os.path.basename(context_file).replace(".json", "") - with open(context_file, 'r') as f: - context = json.load(f) + async with aiofiles.open(context_file, 'r') as f: + context = json.loads(await f.read()) - if not template_handles_tools and 'tools' in context: + has_tools = 'tools' in context + needs_tools_in_system = has_tools and not caps.supports_tools + + if not caps.supports_tool_calls and has_tools: print(f'Skipping {context_name} test as tools seem unsupported by template {template_file}', file=sys.stderr) continue - - if not supports_code_interpreter and 'tools' in context and any(t['type'] == 'code_interpreter' for t in context['tools']): - print(f'Skipping {context_name} test as code_interpreter seems unsupported by template {template_file}', file=sys.stderr) - continue - - if not supports_system_role and any(m['role'] == 'system' for m in context['messages']): + + if not caps.supports_system_role and (any(m['role'] == 'system' for m in context['messages']) or needs_tools_in_system): continue output_file = join_cmake_path(output_folder, f'{base_name}-{context_name}.txt') - if requires_object_arguments: - for message in context['messages']: - if 'tool_calls' in message: - for tool_call in message['tool_calls']: + if needs_tools_in_system: + add_system(context['messages'], f"Available tools: {json.dumps(context['tools'], indent=2)}") + + for message in context['messages']: + if 'tool_calls' in message: + for tool_call in message['tool_calls']: + if caps.requires_object_arguments: if tool_call.get('type') == 'function': arguments = tool_call['function']['arguments'] try: @@ -166,8 +278,31 @@ def renders(messages, *, tools=[], add_generation_prompt=False, extra_context={} except: pass tool_call['function']['arguments'] = arguments - - if requires_typed_content: + if not caps.supports_tool_calls: + message['content'] = json.dumps({ + "tool_calls": [ + { + "name": tc['function']['name'], + "arguments": json.loads(tc['function']['arguments']), + "id": tc.get('id'), + } + for tc in message['tool_calls'] + ], + "content": None if message.get('content', '') == '' else message['content'], + }, indent=2) + del message['tool_calls'] + if message.get('role') == 'tool' and not caps.supports_tool_responses: + message['role'] = 'user' + message['content'] = json.dumps({ + "tool_response": { + "tool": message['name'], + "content": message['content'], + "tool_call_id": message.get('tool_call_id'), + } + }, indent=2) + del message['name'] + + if caps.requires_typed_content: for message in context['messages']: if 'content' in message and isinstance(message['content'], str): message['content'] = [{"type": "text", "text": message['content']}] @@ -175,7 +310,7 @@ def renders(messages, *, tools=[], add_generation_prompt=False, extra_context={} try: output = template.render(**context) except Exception as e1: - for message in context["messages"]: + for message in context['messages']: if message.get("content") is None: message["content"] = "" @@ -185,14 +320,47 @@ def renders(messages, *, tools=[], add_generation_prompt=False, extra_context={} logger.info(f" ERROR: {e2} (after first error: {e1})") output = f"ERROR: {e2}" - with open(output_file, 'w') as f: - f.write(output) + async with aiofiles.open(output_file, 'w') as f: + await f.write(output) - # Output the line of arguments for the C++ test binary - print(f"{template_file} {context_file} {output_file}") + print(f"{template_file} {caps_file} {context_file} {output_file}") +async def async_hf_download(repo_id: str, filename: str) -> str: + headers = build_hf_headers() + url = f"https://huggingface.co/{repo_id}/raw/main/{filename}" + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=headers) as response: + response.raise_for_status() + return await response.text() -def main(): +async def process_model(output_folder: str, model_id: str, context_files: list): + try: + config_str = await async_hf_download(model_id, "tokenizer_config.json") + + try: + config = json.loads(config_str) + except json.JSONDecodeError: + config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str)) + + assert 'chat_template' in config, 'No "chat_template" entry in tokenizer_config.json!' + chat_template = config['chat_template'] + if isinstance(chat_template, str): + await handle_chat_template(output_folder, model_id, None, chat_template, context_files) + else: + await asyncio.gather(*[ + handle_chat_template(output_folder, model_id, ct['name'], ct['template'], context_files) + for ct in chat_template + ]) + except Exception as e: + logger.error(f"Error processing model {model_id}: {e}") + await handle_chat_template(output_folder, model_id, None, str(e), []) + +async def async_copy_file(src: str, dst: str): + async with aiofiles.open(src, 'rb') as fsrc: + async with aiofiles.open(dst, 'wb') as fdst: + await fdst.write(await fsrc.read()) + +async def main(): parser = argparse.ArgumentParser(description="Generate chat templates and output test arguments.") parser.add_argument("output_folder", help="Folder to store all output files") parser.add_argument("json_context_files_or_model_ids", nargs="+", help="List of context JSON files or HuggingFace model IDs") @@ -210,31 +378,17 @@ def main(): if not os.path.isdir(output_folder): os.makedirs(output_folder) - # Copy context files to the output folder - for context_file in context_files: - shutil.copy(context_file, output_folder) - - for model_id in model_ids: - try: - with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f: - config_str = f.read() - - try: - config = json.loads(config_str) - except json.JSONDecodeError: - config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str)) - - assert 'chat_template' in config, 'No "chat_template" entry in tokenizer_config.json!' - chat_template = config['chat_template'] - if isinstance(chat_template, str): - handle_chat_template(output_folder, model_id, None, chat_template, context_files) - else: - for ct in chat_template: - handle_chat_template(output_folder, model_id, ct['name'], ct['template'], context_files) - except Exception as e: - logger.error(f"Error processing model {model_id}: {e}") - handle_chat_template(output_folder, model_id, None, str(e), []) + # Copy context files to the output folder asynchronously + await asyncio.gather(*[ + async_copy_file(context_file, os.path.join(output_folder, os.path.basename(context_file))) + for context_file in context_files + ]) + # Process models concurrently + await asyncio.gather(*[ + process_model(output_folder, model_id, context_files) + for model_id in model_ids + ]) if __name__ == '__main__': - main() + asyncio.run(main()) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 8791f17..9ccc942 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -21,6 +21,20 @@ if (NOT CMAKE_CROSSCOMPILING) gtest_discover_tests(test-syntax) endif() +add_executable(test-capabilities test-capabilities.cpp) +target_compile_features(test-capabilities PUBLIC cxx_std_17) +if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + target_compile_definitions(test-capabilities PUBLIC _CRT_SECURE_NO_WARNINGS) + target_compile_options(gtest PRIVATE -Wno-language-extension-token) +endif() +target_link_libraries(test-capabilities PRIVATE + nlohmann_json::nlohmann_json + gtest_main + gmock +) +add_test(NAME test-capabilities COMMAND test-capabilities) +set_tests_properties(test-capabilities PROPERTIES WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) + add_test(NAME test-syntax-jinja2 COMMAND test-syntax) set_tests_properties(test-syntax-jinja2 PROPERTIES ENVIRONMENT "USE_JINJA2=1;PYTHON_EXECUTABLE=${Python_EXECUTABLE};PYTHONPATH=${CMAKE_SOURCE_DIR}") @@ -42,16 +56,15 @@ set(MODEL_IDS bofenghuang/vigogne-2-70b-chat CohereForAI/c4ai-command-r-plus # Gated databricks/dbrx-instruct # Gated - deepseek-ai/deepseek-coder-33b-instruct - deepseek-ai/DeepSeek-Coder-V2-Instruct - deepseek-ai/DeepSeek-V2.5 google/gemma-2-2b-it # Gated google/gemma-7b-it # Gated MiniMaxAI/MiniMax-Text-01 indischepartij/MiniCPM-3B-OpenHermes-2.5-v2 mattshumer/Reflection-Llama-3.1-70B meetkai/functionary-medium-v3.2 + meta-llama/Llama-3.1-8B-Instruct # Gated meta-llama/Llama-3.2-3B-Instruct # Gated + meta-llama/Llama-3.3-70B-Instruct # Gated meta-llama/Meta-Llama-3.1-8B-Instruct # Gated microsoft/Phi-3-medium-4k-instruct microsoft/Phi-3-mini-4k-instruct @@ -80,16 +93,23 @@ set(MODEL_IDS TheBloke/FusionNet_34Bx2_MoE-AWQ # Broken, TODO: - # meetkai/functionary-medium-v3.1 # jinja2 expectation is computed w/ wrong escapes + # meetkai/functionary-medium-v3.1 # jinja2 expectation is computed w/ wrong escapes # fireworks-ai/llama-3-firefunction-v2 # https://github.com/google/minja/issues/7 - # ai21labs/AI21-Jamba-1.5-Large # https://github.com/google/minja/issues/8 - - # Can't find template(s), TODO: - # apple/OpenELM-1_1B-Instruct - # dreamgen/WizardLM-2-7B - # xai-org/grok-1 + # ai21labs/AI21-Jamba-1.5-Large # https://github.com/google/minja/issues/8 ) +if(NOT WIN32) + list(APPEND MODEL_IDS + # Needs investigation + deepseek-ai/deepseek-coder-33b-instruct + deepseek-ai/DeepSeek-Coder-V2-Instruct + deepseek-ai/DeepSeek-V2.5 + deepseek-ai/DeepSeek-R1-Distill-Llama-8B + deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + deepseek-ai/DeepSeek-R1-Distill-Qwen-32B + ) +endif() + # Create one test case for each {template, context} combination file(GLOB CONTEXT_FILES "${CMAKE_SOURCE_DIR}/tests/contexts/*.json") execute_process( diff --git a/tests/contexts/tool_use_code_interpreter.json b/tests/contexts/tool_use_code_interpreter.json deleted file mode 100644 index ba6f159..0000000 --- a/tests/contexts/tool_use_code_interpreter.json +++ /dev/null @@ -1,43 +0,0 @@ -{ - "messages": [ - { - "role": "user", - "content": "Print a hello world message with python." - }, - { - "role": "assistant", - "content": "", - "tool_calls": [ - { - "id": "call_1___", - "type": "function", - "function": { - "arguments": "print('Hello, World!')", - "name": "python" - } - } - ] - }, - { - "role": "tool", - "tool_call_id": "call_1___", - "name": "python", - "content": "{\"stdout\": \"Hello, World!\"}" - } - ], - "add_generation_prompt": true, - "bos_token": "<|startoftext|>", - "eos_token": "<|endoftext|>", - "builtin_tools": [ - "wolfram_alpha", - "brave_search", - "code_interpreter" - ], - "cutting_knowledge_date": "2023-04-01", - "todays_date": "2024-09-03", - "tools": [ - { - "type": "code_interpreter" - } - ] -} \ No newline at end of file diff --git a/tests/test-capabilities.cpp b/tests/test-capabilities.cpp new file mode 100644 index 0000000..225581f --- /dev/null +++ b/tests/test-capabilities.cpp @@ -0,0 +1,235 @@ +/* + Copyright 2024 Google LLC + + Use of this source code is governed by an MIT-style + license that can be found in the LICENSE file or at + https://opensource.org/licenses/MIT. +*/ +// SPDX-License-Identifier: MIT +#include "chat-template.hpp" + +#include +#include + +#include +#include +#include +#include +#include +#include + +#undef NDEBUG +#include + +using json = nlohmann::ordered_json; + +static std::string read_file(const std::string &path) +{ + std::ifstream fs(path, std::ios_base::binary); + if (!fs.is_open()) + { + throw std::runtime_error("Failed to open file: " + path); + } + fs.seekg(0, std::ios_base::end); + auto size = fs.tellg(); + fs.seekg(0); + std::string out; + out.resize(static_cast(size)); + fs.read(&out[0], static_cast(size)); + return out; +} + +static minja::chat_template_caps get_caps(const std::string &path) +{ + auto caps = minja::chat_template(read_file(path), "", "").original_caps(); + + auto print = [](const std::string &name, bool value) { + std::cout << " " << (value ? "EXPECT_TRUE" : "EXPECT_FALSE") << "(caps." << name << ");" << std::endl; + }; + auto test_info = ::testing::UnitTest::GetInstance()->current_test_info(); + + std::cout << "TEST(" << test_info->test_suite_name() << ", " << test_info->name() << ") {" << std::endl; + std::cout << " auto caps = get_caps(\"" << path << "\");" << std::endl; + print("supports_system_role", caps.supports_system_role); + print("supports_tools", caps.supports_tools); + print("supports_tool_calls", caps.supports_tool_calls); + print("supports_tool_responses", caps.supports_tool_responses); + print("supports_parallel_tool_calls", caps.supports_parallel_tool_calls); + print("requires_object_arguments", caps.requires_object_arguments); + // print("requires_non_null_content", caps.requires_non_null_content); + print("requires_typed_content", caps.requires_typed_content); + std::cout << "}\n" << std::endl; + + return caps; +} + +TEST(CapabilitiesTest, Gemma7b) { + auto caps = get_caps("tests/google-gemma-7b-it.jinja"); + EXPECT_FALSE(caps.supports_system_role); + EXPECT_FALSE(caps.supports_tools); + EXPECT_FALSE(caps.supports_tool_calls); + EXPECT_FALSE(caps.supports_tool_responses); + EXPECT_FALSE(caps.supports_parallel_tool_calls); + EXPECT_FALSE(caps.requires_object_arguments); + // EXPECT_TRUE(caps.requires_non_null_content); + EXPECT_FALSE(caps.requires_typed_content); +} + +#ifndef _WIN32 +TEST(CapabilitiesTest, DeepSeekR1Distill) +{ + auto caps = get_caps("tests/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja"); + EXPECT_TRUE(caps.supports_system_role); + EXPECT_FALSE(caps.supports_tools); + EXPECT_TRUE(caps.supports_tool_calls); + EXPECT_TRUE(caps.supports_tool_responses); + EXPECT_TRUE(caps.supports_parallel_tool_calls); + EXPECT_FALSE(caps.requires_object_arguments); + // EXPECT_FALSE(caps.requires_non_null_content); + EXPECT_FALSE(caps.requires_typed_content); +} +#endif + +TEST(CapabilitiesTest, FunctionaryMediumV3_2) { + auto caps = get_caps("tests/meetkai-functionary-medium-v3.2.jinja"); + EXPECT_TRUE(caps.supports_system_role); + EXPECT_TRUE(caps.supports_tools); + EXPECT_TRUE(caps.supports_tool_calls); + EXPECT_TRUE(caps.supports_tool_responses); + EXPECT_TRUE(caps.supports_parallel_tool_calls); + EXPECT_FALSE(caps.requires_object_arguments); + // EXPECT_FALSE(caps.requires_non_null_content); + EXPECT_FALSE(caps.requires_typed_content); +} + +TEST(CapabilitiesTest, MetaLlama3_1_8BInstruct) { + auto caps = get_caps("tests/meta-llama-Llama-3.1-8B-Instruct.jinja"); + EXPECT_TRUE(caps.supports_system_role); + EXPECT_TRUE(caps.supports_tools); + EXPECT_TRUE(caps.supports_tool_calls); + EXPECT_TRUE(caps.supports_tool_responses); + EXPECT_FALSE(caps.supports_parallel_tool_calls); + EXPECT_TRUE(caps.requires_object_arguments); + // EXPECT_TRUE(caps.requires_non_null_content); + EXPECT_FALSE(caps.requires_typed_content); +} + +TEST(CapabilitiesTest, MetaLlama3_2_3BInstruct) { + auto caps = get_caps("tests/meta-llama-Llama-3.2-3B-Instruct.jinja"); + EXPECT_TRUE(caps.supports_system_role); + EXPECT_TRUE(caps.supports_tools); + EXPECT_TRUE(caps.supports_tool_calls); + EXPECT_TRUE(caps.supports_tool_responses); + EXPECT_FALSE(caps.supports_parallel_tool_calls); + EXPECT_TRUE(caps.requires_object_arguments); + // EXPECT_TRUE(caps.requires_non_null_content); + EXPECT_FALSE(caps.requires_typed_content); +} + +TEST(CapabilitiesTest, MetaLlama3_3_70BInstruct) { + auto caps = get_caps("tests/meta-llama-Llama-3.3-70B-Instruct.jinja"); + EXPECT_TRUE(caps.supports_system_role); + EXPECT_TRUE(caps.supports_tools); + EXPECT_TRUE(caps.supports_tool_calls); + EXPECT_TRUE(caps.supports_tool_responses); + EXPECT_FALSE(caps.supports_parallel_tool_calls); + EXPECT_TRUE(caps.requires_object_arguments); + // EXPECT_TRUE(caps.requires_non_null_content); + EXPECT_FALSE(caps.requires_typed_content); +} + +TEST(CapabilitiesTest, MiniMaxAIText01) { + auto caps = get_caps("tests/MiniMaxAI-MiniMax-Text-01.jinja"); + EXPECT_TRUE(caps.supports_system_role); + EXPECT_FALSE(caps.supports_tools); + EXPECT_FALSE(caps.supports_tool_calls); + EXPECT_FALSE(caps.supports_tool_responses); + EXPECT_FALSE(caps.supports_parallel_tool_calls); + EXPECT_FALSE(caps.requires_object_arguments); + // EXPECT_FALSE(caps.requires_non_null_content); + EXPECT_TRUE(caps.requires_typed_content); +} + +TEST(CapabilitiesTest, Mistral7BInstruct) { + auto caps = get_caps("tests/mistralai-Mistral-7B-Instruct-v0.2.jinja"); + EXPECT_TRUE(caps.supports_system_role); + EXPECT_FALSE(caps.supports_tools); + EXPECT_FALSE(caps.supports_tool_calls); + EXPECT_FALSE(caps.supports_tool_responses); + EXPECT_FALSE(caps.supports_parallel_tool_calls); + EXPECT_FALSE(caps.requires_object_arguments); + // EXPECT_TRUE(caps.requires_non_null_content); + EXPECT_FALSE(caps.requires_typed_content); +} + +TEST(CapabilitiesTest, MistralNemoInstruct) { + auto caps = get_caps("tests/mistralai-Mistral-Nemo-Instruct-2407.jinja"); + EXPECT_TRUE(caps.supports_system_role); + EXPECT_TRUE(caps.supports_tools); + EXPECT_TRUE(caps.supports_tool_calls); + EXPECT_TRUE(caps.supports_tool_responses); + EXPECT_TRUE(caps.supports_parallel_tool_calls); + EXPECT_TRUE(caps.requires_object_arguments); + // EXPECT_TRUE(caps.requires_non_null_content); + EXPECT_FALSE(caps.requires_typed_content); +} + +TEST(CapabilitiesTest, NousResearchHermes3Llama3_1_70BToolUse) { + auto caps = get_caps("tests/NousResearch-Hermes-3-Llama-3.1-70B-tool_use.jinja"); + EXPECT_TRUE(caps.supports_system_role); + EXPECT_TRUE(caps.supports_tools); + EXPECT_TRUE(caps.supports_tool_calls); + EXPECT_TRUE(caps.supports_tool_responses); + EXPECT_TRUE(caps.supports_parallel_tool_calls); + EXPECT_FALSE(caps.requires_object_arguments); + // EXPECT_TRUE(caps.requires_non_null_content); + EXPECT_FALSE(caps.requires_typed_content); +} + +TEST(CapabilitiesTest, NousResearchHermes2ProLlama3_8BToolUse) { + auto caps = get_caps("tests/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"); + EXPECT_TRUE(caps.supports_system_role); + EXPECT_TRUE(caps.supports_tools); + EXPECT_TRUE(caps.supports_tool_calls); + EXPECT_TRUE(caps.supports_tool_responses); + EXPECT_TRUE(caps.supports_parallel_tool_calls); + EXPECT_FALSE(caps.requires_object_arguments); + // EXPECT_TRUE(caps.requires_non_null_content); + EXPECT_FALSE(caps.requires_typed_content); +} + +TEST(CapabilitiesTest, CommandRPlusDefault) { + auto caps = get_caps("tests/CohereForAI-c4ai-command-r-plus-default.jinja"); + EXPECT_TRUE(caps.supports_system_role); + EXPECT_FALSE(caps.supports_tools); + EXPECT_FALSE(caps.supports_tool_calls); + EXPECT_FALSE(caps.supports_tool_responses); + EXPECT_FALSE(caps.supports_parallel_tool_calls); + EXPECT_FALSE(caps.requires_object_arguments); + // EXPECT_TRUE(caps.requires_non_null_content); + EXPECT_FALSE(caps.requires_typed_content); +} + +TEST(CapabilitiesTest, CommandRPlusRag) { + auto caps = get_caps("tests/CohereForAI-c4ai-command-r-plus-rag.jinja"); + EXPECT_TRUE(caps.supports_system_role); + EXPECT_FALSE(caps.supports_tools); + EXPECT_FALSE(caps.supports_tool_calls); + EXPECT_FALSE(caps.supports_tool_responses); + EXPECT_FALSE(caps.supports_parallel_tool_calls); + EXPECT_FALSE(caps.requires_object_arguments); + // EXPECT_TRUE(caps.requires_non_null_content); + EXPECT_FALSE(caps.requires_typed_content); +} + +TEST(CapabilitiesTest, CommandRPlusToolUse) { + auto caps = get_caps("tests/CohereForAI-c4ai-command-r-plus-tool_use.jinja"); + EXPECT_TRUE(caps.supports_system_role); + EXPECT_TRUE(caps.supports_tools); + EXPECT_TRUE(caps.supports_tool_calls); + EXPECT_TRUE(caps.supports_tool_responses); + EXPECT_TRUE(caps.supports_parallel_tool_calls); + EXPECT_TRUE(caps.requires_object_arguments); + // EXPECT_TRUE(caps.requires_non_null_content); + EXPECT_FALSE(caps.requires_typed_content); +} diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index b0bb9d4..6f8bcb6 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -23,8 +23,19 @@ using json = nlohmann::ordered_json; template static void assert_equals(const T &expected, const T &actual){ if (expected != actual) { - std::cerr << "Expected: " << expected << std::endl; - std::cerr << "Actual: " << actual << std::endl; + std::cerr << "Expected: " << expected << "\n\n"; + std::cerr << "Actual: " << actual << "\n\n"; + auto i_divergence = std::min(expected.size(), actual.size()); + for (size_t i = 0; i < i_divergence; i++) { + if (expected[i] != actual[i]) { + i_divergence = i; + break; + } + } + std::cerr << "Divergence at index " << i_divergence << "\n\n"; + std::cerr << "Expected suffix: " << expected.substr(i_divergence) << "\n\n"; + std::cerr << "Actual suffix: " << actual.substr(i_divergence) << "\n\n"; + std::cerr << std::flush; throw std::runtime_error("Test failed"); } @@ -44,10 +55,26 @@ static std::string read_file(const std::string &path) { return out; } +#ifndef _WIN32 +static json caps_to_json(const minja::chat_template_caps &caps) { + return { + {"supports_tools", caps.supports_tools}, + {"supports_tool_calls", caps.supports_tool_calls}, + {"supports_tool_responses", caps.supports_tool_responses}, + {"supports_system_role", caps.supports_system_role}, + {"supports_parallel_tool_calls", caps.supports_parallel_tool_calls}, + {"supports_tool_call_id", caps.supports_tool_call_id}, + {"requires_object_arguments", caps.requires_object_arguments}, + // {"requires_non_null_content", caps.requires_non_null_content}, + {"requires_typed_content", caps.requires_typed_content}, + }; +} +#endif + int main(int argc, char *argv[]) { - if (argc != 4) + if (argc != 5) { - std::cerr << "Usage: " << argv[0] << " " << std::endl; + std::cerr << "Usage: " << argv[0] << " " << std::endl; for (int i = 0; i < argc; i++) { std::cerr << "argv[" << i << "] = " << argv[i] << std::endl; @@ -57,11 +84,12 @@ int main(int argc, char *argv[]) { try { std::string tmpl_file = argv[1]; - std::string ctx_file = argv[2]; - std::string golden_file = argv[3]; - + std::string caps_file = argv[2]; + std::string ctx_file = argv[3]; + std::string golden_file = argv[4]; + auto tmpl_str = read_file(tmpl_file); - + if (ctx_file == "n/a") { std::cout << "# Skipping template: " << tmpl_file << "\n" << tmpl_str << std::endl; @@ -69,6 +97,7 @@ int main(int argc, char *argv[]) { } std::cout << "# Testing template: " << tmpl_file << std::endl + << "# With caps: " << caps_file << std::endl << "# With context: " << ctx_file << std::endl << "# Against golden file: " << golden_file << std::endl << std::flush; @@ -104,6 +133,15 @@ int main(int argc, char *argv[]) { } assert_equals(expected, actual); + + // Some unresolved CRLF issues again with the goldens on Windows. +#ifndef _WIN32 + // Checks that the Python & C++ capability detection codes are in sync. + auto expected_caps = minja::normalize_newlines(read_file(caps_file)); + auto caps = caps_to_json(tmpl.original_caps()).dump(2); + assert_equals(expected_caps, caps); +#endif + std::cout << "Test passed successfully." << std::endl; return 0; } catch (const std::exception &e) { diff --git a/tests/test-syntax.cpp b/tests/test-syntax.cpp index db1787d..ebe5e19 100644 --- a/tests/test-syntax.cpp +++ b/tests/test-syntax.cpp @@ -73,9 +73,6 @@ TEST(SyntaxTest, SimpleCases) { auto ThrowsWithSubstr = [](const std::string & expected_substr) { return testing::Throws(Property(&std::runtime_error::what, testing::HasSubstr(expected_substr))); }; - // EXPECT_EQ( - // "\r\nhey\r\nho!", - // render("\r\n{{ 'hey\r\nho!' }}\r\n", {}, {})); EXPECT_EQ( " b", render(R"( {% set _ = 1 %} {% set _ = 2 %}b)", {}, lstrip_trim_blocks)); @@ -452,6 +449,11 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_EQ( "a", render("{{ ' a ' | trim }}", {}, {})); + if (!getenv("USE_JINJA2")) { + EXPECT_EQ( + "", + render(R"({{ None | trim }})", {}, {})); + } EXPECT_EQ( "[0, 1, 2][4, 5, 6][0, 2, 4, 6, 8]", render("{{ range(3) | list }}{{ range(4, 7) | list }}{{ range(0, 10, 2) | list }}", {}, {})); @@ -484,33 +486,28 @@ TEST(SyntaxTest, SimpleCases) { "", render("{% if 1 %}{% elif 1 %}{% else %}{% endif %}", {}, {})); - EXPECT_THAT([]() { render(R"({{ 'a' + None }})", {}, {}); }, testing::Throws()); - EXPECT_THAT([]() { render(R"({{ None + 'b' }})", {}, {}); }, testing::Throws()); - EXPECT_THAT([]() { render(R"({{ 'a' in None }})", {}, {}); }, testing::Throws()); - EXPECT_EQ( - "False,True,False", - render(R"({{ None in [] }},{{ None == None }},{{ None != None }})", {}, {})); if (!getenv("USE_JINJA2")) { // TODO: capture stderr from jinja2 and test these. + EXPECT_THAT([]() { render("{%- set _ = [].pop() -%}", {}, {}); }, ThrowsWithSubstr("pop from empty list")); EXPECT_THAT([]() { render("{%- set _ = {}.pop() -%}", {}, {}); }, ThrowsWithSubstr("pop")); EXPECT_THAT([]() { render("{%- set _ = {}.pop('foooo') -%}", {}, {}); }, ThrowsWithSubstr("foooo")); - EXPECT_THAT([]() { render("{% else %}", {}, {}); }, ThrowsWithSubstr("Encountered unknown tag 'else'")); + EXPECT_THAT([]() { render("{% else %}", {}, {}); }, ThrowsWithSubstr("Unexpected else")); - EXPECT_THAT([]() { render("{% else %}", {}, {}); }, ThrowsWithSubstr("Encountered unknown tag 'else'")); - EXPECT_THAT([]() { render("{% endif %}", {}, {}); }, ThrowsWithSubstr("Encountered unknown tag 'endif'")); - EXPECT_THAT([]() { render("{% elif 1 %}", {}, {}); }, ThrowsWithSubstr("Encountered unknown tag 'elif'")); - EXPECT_THAT([]() { render("{% endfor %}", {}, {}); }, ThrowsWithSubstr("Encountered unknown tag 'endfor'")); - EXPECT_THAT([]() { render("{% endfilter %}", {}, {}); }, ThrowsWithSubstr("Encountered unknown tag 'endfilter'")); + EXPECT_THAT([]() { render("{% else %}", {}, {}); }, ThrowsWithSubstr("Unexpected else")); + EXPECT_THAT([]() { render("{% endif %}", {}, {}); }, ThrowsWithSubstr("Unexpected endif")); + EXPECT_THAT([]() { render("{% elif 1 %}", {}, {}); }, ThrowsWithSubstr("Unexpected elif")); + EXPECT_THAT([]() { render("{% endfor %}", {}, {}); }, ThrowsWithSubstr("Unexpected endfor")); + EXPECT_THAT([]() { render("{% endfilter %}", {}, {}); }, ThrowsWithSubstr("Unexpected endfilter")); - EXPECT_THAT([]() { render("{% if 1 %}", {}, {}); }, ThrowsWithSubstr("Unexpected end of template. Jinja was looking for the following tags: 'if'")); - EXPECT_THAT([]() { render("{% for x in 1 %}", {}, {}); }, ThrowsWithSubstr("Unexpected end of template. Jinja was looking for the following tags: 'for'")); - EXPECT_THAT([]() { render("{% generation %}", {}, {}); }, ThrowsWithSubstr("Unexpected end of template. Jinja was looking for the following tags: 'generation'")); - EXPECT_THAT([]() { render("{% if 1 %}{% else %}", {}, {}); }, ThrowsWithSubstr("Unexpected end of template. Jinja was looking for the following tags: 'if'")); - EXPECT_THAT([]() { render("{% if 1 %}{% else %}{% elif 1 %}{% endif %}", {}, {}); }, ThrowsWithSubstr("Unexpected end of template. Jinja was looking for the following tags: 'if'")); - EXPECT_THAT([]() { render("{% filter trim %}", {}, {}); }, ThrowsWithSubstr("Unexpected end of template. Jinja was looking for the following tags: 'filter'")); + EXPECT_THAT([]() { render("{% if 1 %}", {}, {}); }, ThrowsWithSubstr("Unterminated if")); + EXPECT_THAT([]() { render("{% for x in 1 %}", {}, {}); }, ThrowsWithSubstr("Unterminated for")); + EXPECT_THAT([]() { render("{% generation %}", {}, {}); }, ThrowsWithSubstr("Unterminated generation")); + EXPECT_THAT([]() { render("{% if 1 %}{% else %}", {}, {}); }, ThrowsWithSubstr("Unterminated if")); + EXPECT_THAT([]() { render("{% if 1 %}{% else %}{% elif 1 %}{% endif %}", {}, {}); }, ThrowsWithSubstr("Unterminated if")); + EXPECT_THAT([]() { render("{% filter trim %}", {}, {}); }, ThrowsWithSubstr("Unterminated filter")); } EXPECT_EQ(