diff --git a/README.md b/README.md index 85a3089..d45c45e 100644 --- a/README.md +++ b/README.md @@ -157,7 +157,7 @@ Main limitations (non-exhaustive list): huggingface-cli login ``` -- Build & run tests (shorthand: `./scripts/run_tests.sh`): +- Build & run tests (shorthand: `./scripts/tests.sh`): ```bash rm -fR build && \ @@ -195,7 +195,7 @@ Main limitations (non-exhaustive list): - Build in [fuzzing mode](https://github.com/google/fuzztest/blob/main/doc/quickstart-cmake.md#fuzzing-mode) & run all fuzzing tests (optionally, set a higher `TIMEOUT` as env var): ```bash - ./scripts/run_fuzzing_mode.sh + ./scripts/fuzzing_tests.sh ``` - If your model's template doesn't run fine, please consider the following before [opening a bug](https://github.com/googlestaging/minja/issues/new): diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 750bdea..7d75913 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -12,5 +12,8 @@ foreach(example add_executable(${example} ${example}.cpp) target_compile_features(${example} PUBLIC cxx_std_17) target_link_libraries(${example} PRIVATE nlohmann_json::nlohmann_json) + if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + target_compile_definitions(${example} PUBLIC _CRT_SECURE_NO_WARNINGS) + endif() endforeach() diff --git a/examples/chat-template.cpp b/examples/chat-template.cpp index 8161b2b..d1838e7 100644 --- a/examples/chat-template.cpp +++ b/examples/chat-template.cpp @@ -19,14 +19,16 @@ int main() { /* bos_token= */ "<|start|>", /* eos_token= */ "<|end|>" ); - std::cout << tmpl.apply( - json::parse(R"([ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there"} - ])"), - json::parse(R"([ - {"type": "function", "function": {"name": "google_search", "arguments": {"query": "2+2"}}} - ])"), - /* add_generation_prompt= */ true, - /* extra_context= */ {}) << std::endl; + + minja::chat_template_inputs inputs; + inputs.messages = json::parse(R"([ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"} + ])"); + inputs.add_generation_prompt = true; + inputs.tools = json::parse(R"([ + {"type": "function", "function": {"name": "google_search", "arguments": {"query": "2+2"}}} + ])"); + + std::cout << tmpl.apply(inputs) << std::endl; } diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 58e119a..69ee4e8 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -33,6 +33,29 @@ struct chat_template_caps { bool requires_typed_content = false; }; +struct chat_template_inputs { + nlohmann::ordered_json messages; + nlohmann::ordered_json tools; + bool add_generation_prompt = true; + nlohmann::ordered_json extra_context; + std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); +}; + +struct chat_template_options { + bool apply_polyfills = true; + bool use_bos_token = true; + bool use_eos_token = true; + bool define_strftime_now = true; + + bool polyfill_tools = true; + bool polyfill_tool_call_examples = true; + bool polyfill_tool_calls = true; + bool polyfill_tool_responses = true; + bool polyfill_system_role = true; + bool polyfill_object_arguments = true; + bool polyfill_typed_content = true; +}; + class chat_template { private: @@ -41,6 +64,7 @@ class chat_template { std::string bos_token_; std::string eos_token_; std::shared_ptr template_root_; + std::string tool_call_example_; std::string try_raw_render( const nlohmann::ordered_json & messages, @@ -49,7 +73,18 @@ class chat_template { const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const { try { - auto prompt = apply(messages, tools, add_generation_prompt, extra_context, /* adjust_inputs= */ false); + chat_template_inputs inputs; + inputs.messages = messages; + inputs.tools = tools; + inputs.add_generation_prompt = add_generation_prompt; + inputs.extra_context = extra_context; + // Use fixed date for tests + inputs.now = std::chrono::system_clock::from_time_t(0); + + chat_template_options opts; + opts.apply_polyfills = false; + + auto prompt = apply(inputs, opts); // fprintf(stderr, "try_raw_render: %s\n", prompt.c_str()); return prompt; } catch (const std::exception & e) { @@ -176,6 +211,58 @@ class chat_template { caps_.supports_tool_responses = contains(out, "Some response!"); caps_.supports_tool_call_id = contains(out, "call_911_"); } + + try { + if (!caps_.supports_tools) { + const json user_msg { + {"role", "user"}, + {"content", "Hey"}, + }; + const json args { + {"arg1", "some_value"}, + }; + const json tool_call_msg { + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + { + // TODO: detect if requires numerical id or fixed length == 6 like Nemo + {"id", "call_1___"}, + {"type", "function"}, + {"function", { + {"name", "tool_name"}, + {"arguments", (caps_.requires_object_arguments ? args : json(minja::Value(args).dump(-1, /* to_json= */ true)))}, + }}, + }, + })}, + }; + std::string prefix, full; + { + chat_template_inputs inputs; + inputs.messages = json::array({user_msg}); + inputs.add_generation_prompt = true; + prefix = apply(inputs); + } + { + chat_template_inputs inputs; + inputs.messages = json::array({user_msg, tool_call_msg}); + inputs.add_generation_prompt = false; + full = apply(inputs); + } + + if (full.find(prefix) != 0) { + if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) { + prefix = prefix.substr(0, prefix.size() - eos_token_.size()); + } + } + if (full.find(prefix) != 0) { + fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n"); + } + tool_call_example_ = full.substr(prefix.size()); + } + } catch (const std::exception & e) { + fprintf(stderr, "Failed to generate tool call example: %s\n", e.what()); + } } const std::string & source() const { return source_; } @@ -183,28 +270,72 @@ class chat_template { const std::string & eos_token() const { return eos_token_; } const chat_template_caps & original_caps() const { return caps_; } + // Deprecated, please use the form with chat_template_inputs and chat_template_options std::string apply( const nlohmann::ordered_json & messages, const nlohmann::ordered_json & tools, bool add_generation_prompt, const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(), - bool adjust_inputs = true) const + bool apply_polyfills = true) + { + fprintf(stderr, "[%s] Deprecated!\n", __func__); + chat_template_inputs inputs; + inputs.messages = messages; + inputs.tools = tools; + inputs.add_generation_prompt = add_generation_prompt; + inputs.extra_context = extra_context; + inputs.now = std::chrono::system_clock::now(); + + chat_template_options opts; + opts.apply_polyfills = apply_polyfills; + + return apply(inputs, opts); + } + + std::string apply( + const chat_template_inputs & inputs, + const chat_template_options & opts = chat_template_options()) const { json actual_messages; - auto needs_adjustments = adjust_inputs && (false - || !caps_.supports_system_role - || !caps_.supports_tools - || !caps_.supports_tool_responses - || !caps_.supports_tool_calls - || caps_.requires_object_arguments - || caps_.requires_typed_content + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + auto has_tool_calls = false; + auto has_tool_responses = false; + auto has_string_content = false; + for (const auto & message : inputs.messages) { + if (message.contains("tool_calls") && !message["tool_calls"].is_null()) { + has_tool_calls = true; + } + if (message.contains("role") && message["role"] == "tool") { + has_tool_responses = true; + } + if (message.contains("content") && message["content"].is_string()) { + has_string_content = true; + } + } + + auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role; + auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools; + auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples; + auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls; + auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses; + auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments; + auto polyfill_typed_content = opts.polyfill_typed_content && has_string_content && caps_.requires_typed_content; + + auto needs_polyfills = opts.apply_polyfills && (false + || polyfill_system_role + || polyfill_tools + || polyfill_tool_calls + || polyfill_tool_responses + || polyfill_object_arguments + || polyfill_typed_content ); - if (needs_adjustments) { + + if (needs_polyfills) { actual_messages = json::array(); auto add_message = [&](const json & msg) { - if (caps_.requires_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) { + if (polyfill_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) { actual_messages.push_back({ {"role", msg.at("role")}, {"content", {{ @@ -227,9 +358,17 @@ class chat_template { pending_system.clear(); } }; - auto needs_tools_in_system = !tools.is_null() && tools.size() > 0 && !caps_.supports_tools; - for (const auto & message_ : needs_tools_in_system ? add_system(messages, "Available tools: " + tools.dump(2)) : messages) { + json adjusted_messages; + if (polyfill_tools) { + adjusted_messages = add_system(inputs.messages, + "You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) + + (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_)); + } else { + adjusted_messages = inputs.messages; + } + + for (const auto & message_ : adjusted_messages) { auto message = message_; if (!message.contains("role") || !message.contains("content")) { throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump()); @@ -237,7 +376,7 @@ class chat_template { std::string role = message.at("role"); if (message.contains("tool_calls")) { - if (caps_.requires_object_arguments || !caps_.supports_tool_calls) { + if (polyfill_object_arguments || polyfill_tool_calls) { for (auto & tool_call : message.at("tool_calls")) { if (tool_call["type"] == "function") { auto & function = tool_call.at("function"); @@ -252,7 +391,7 @@ class chat_template { } } } - if (!caps_.supports_tool_calls) { + if (polyfill_tool_calls) { auto content = message.at("content"); auto tool_calls = json::array(); for (const auto & tool_call : message.at("tool_calls")) { @@ -279,7 +418,7 @@ class chat_template { message.erase("tool_calls"); } } - if (!caps_.supports_tool_responses && role == "tool") { + if (polyfill_tool_responses && role == "tool") { message["role"] = "user"; auto obj = json { {"tool_response", { @@ -296,7 +435,7 @@ class chat_template { message.erase("name"); } - if (!message["content"].is_null() && !caps_.supports_system_role) { + if (!message["content"].is_null() && polyfill_system_role) { std::string content = message.at("content"); if (role == "system") { if (!pending_system.empty()) pending_system += "\n"; @@ -315,28 +454,40 @@ class chat_template { } add_message(message); } - if (!caps_.supports_system_role) { - flush_sys(); - } + flush_sys(); } else { - actual_messages = messages; + actual_messages = inputs.messages; } auto context = minja::Context::make(json({ {"messages", actual_messages}, - {"add_generation_prompt", add_generation_prompt}, - {"bos_token", bos_token_}, - {"eos_token", eos_token_}, + {"add_generation_prompt", inputs.add_generation_prompt}, })); - - if (!tools.is_null()) { - auto tools_val = minja::Value(tools); - context->set("tools", tools_val); + if (opts.use_bos_token) { + context->set("bos_token", bos_token_); + } + if (opts.use_eos_token) { + context->set("eos_token", eos_token_); + } + if (opts.define_strftime_now) { + auto now = inputs.now; + context->set("strftime_now", Value::callable([now](const std::shared_ptr &, minja::ArgumentsValue & args) { + args.expectArgs("strftime_now", {1, 1}, {0, 0}); + auto format = args.args[0].get(); + + auto time = std::chrono::system_clock::to_time_t(now); + auto local_time = *std::localtime(&time); + std::ostringstream ss; + ss << std::put_time(&local_time, format.c_str()); + return ss.str(); + })); + } + if (!inputs.tools.is_null()) { + context->set("tools", minja::Value(inputs.tools)); } - if (!extra_context.is_null()) { - for (auto & kv : extra_context.items()) { - minja::Value val(kv.value()); - context->set(kv.key(), val); + if (!inputs.extra_context.is_null()) { + for (auto & kv : inputs.extra_context.items()) { + context->set(kv.key(), minja::Value(kv.value())); } } @@ -353,7 +504,7 @@ class chat_template { std::string existing_system = messages_with_system.at(0).at("content"); messages_with_system[0] = json { {"role", "system"}, - {"content", existing_system + "\n" + system_prompt}, + {"content", existing_system + "\n\n" + system_prompt}, }; } else { messages_with_system.insert(messages_with_system.begin(), json { diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index 0eadae8..c304b5c 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -2615,6 +2615,7 @@ inline std::shared_ptr Context::builtins() { })); globals.set("join", simple_function("join", { "items", "d" }, [](const std::shared_ptr &, Value & args) { auto do_join = [](Value & items, const std::string & sep) { + if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); std::ostringstream oss; auto first = true; for (size_t i = 0, n = items.size(); i < n; ++i) { @@ -2695,6 +2696,10 @@ inline std::shared_ptr Context::builtins() { return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits::max)()}, {0, 0}); auto & items = args.args[0]; + if (items.is_null()) + return Value::array(); + if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); + auto filter_fn = context->get(args.args[1]); if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); @@ -2772,6 +2777,7 @@ inline std::shared_ptr Context::builtins() { auto & items = args.args[0]; if (items.is_null()) return Value::array(); + if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); auto attr_name = args.args[1].get(); bool has_test = false; diff --git a/scripts/fetch_templates_and_goldens.py b/scripts/fetch_templates_and_goldens.py index 619b539..66238bc 100644 --- a/scripts/fetch_templates_and_goldens.py +++ b/scripts/fetch_templates_and_goldens.py @@ -66,7 +66,7 @@ def add_system(messages, system_prompt): existing_system = messages[0]["content"] messages[0] = { "role": "system", - "content": existing_system + "\n" + system_prompt, + "content": existing_system + "\n\n" + system_prompt, } else: messages.insert(0, { @@ -86,7 +86,7 @@ class TemplateCaps: requires_object_arguments: bool = False requires_non_null_content: bool = False requires_typed_content: bool = False - + def to_json(self): return json.dumps({ "supports_tools": self.supports_tools, @@ -99,120 +99,243 @@ def to_json(self): # "requires_non_null_content": self.requires_non_null_content, "requires_typed_content": self.requires_typed_content, }, indent=2) - -def detect_caps(template_file, template): - basic_extra_context = { - "bos_token": "<|startoftext|>", - "eos_token": "<|endoftext|>", - } - def try_raw_render(messages, *, tools=[], add_generation_prompt=False, extra_context={}, expect_strings=[]): + + +class chat_template: + + def try_raw_render(self, messages, *, tools=[], add_generation_prompt=False, extra_context={}, expect_strings=[]): + basic_extra_context = { + "bos_token": "<|startoftext|>", + "eos_token": "<|endoftext|>", + } + try: - out = template.render(messages=messages, tools=tools, add_generation_prompt=add_generation_prompt, **basic_extra_context, **extra_context) + out = self.template.render(messages=messages, tools=tools, add_generation_prompt=add_generation_prompt, **basic_extra_context, **extra_context) # print(out, file=sys.stderr) return out except BaseException as e: # print(f"{template_file}: Error rendering template with messages {messages}: {e}", file=sys.stderr, flush=True) return "" - - caps = TemplateCaps() - - user_needle = "" - sys_needle = "" - dummy_str_user_msg = {"role": "user", "content": user_needle } - dummy_typed_user_msg = {"role": "user", "content": [{"type": "text", "text": user_needle}]} - - caps.requires_typed_content = \ - (user_needle not in try_raw_render([dummy_str_user_msg])) \ - and (user_needle in try_raw_render([dummy_typed_user_msg])) - dummy_user_msg = dummy_typed_user_msg if caps.requires_typed_content else dummy_str_user_msg - - needle_system_msg = {"role": "system", "content": [{"type": "text", "text": sys_needle}] if caps.requires_typed_content else sys_needle} - - caps.supports_system_role = sys_needle in try_raw_render([needle_system_msg, dummy_user_msg]) - - out = try_raw_render([dummy_user_msg], tools=[{ - "name": "some_tool", - "type": "function", - "function": { + + def __init__(self, template, known_eos_tokens, env=None): + if not env: + env = jinja2.Environment( + trim_blocks=True, + lstrip_blocks=True, + extensions=[jinja2.ext.loopcontrols] + ) + self.env = env + self.template = env.from_string(template) + + caps = TemplateCaps() + + user_needle = "" + sys_needle = "" + dummy_str_user_msg = {"role": "user", "content": user_needle } + dummy_typed_user_msg = {"role": "user", "content": [{"type": "text", "text": user_needle}]} + + caps.requires_typed_content = \ + (user_needle not in self.try_raw_render([dummy_str_user_msg])) \ + and (user_needle in self.try_raw_render([dummy_typed_user_msg])) + dummy_user_msg = dummy_typed_user_msg if caps.requires_typed_content else dummy_str_user_msg + + needle_system_msg = {"role": "system", "content": [{"type": "text", "text": sys_needle}] if caps.requires_typed_content else sys_needle} + + caps.supports_system_role = sys_needle in self.try_raw_render([needle_system_msg, dummy_user_msg]) + + out = self.try_raw_render([dummy_user_msg], tools=[{ "name": "some_tool", - "description": "Some tool", - "parameters": { - "type": "object", - "properties": { - "arg": { - "type": "string", - "description": "Some arg", + "type": "function", + "function": { + "name": "some_tool", + "description": "Some tool", + "parameters": { + "type": "object", + "properties": { + "arg": { + "type": "string", + "description": "Some arg", + }, }, + "required": ["arg"], }, - "required": ["arg"], }, - }, - }]) - caps.supports_tools = "some_tool" in out - - def make_tool_calls_msg(tool_calls, content=None): - return { - "role": "assistant", - "content": content, - "tool_calls": tool_calls, - } - def make_tool_call(tool_name, arguments): - return { - "id": "call_1___", - "type": "function", - "function": { - "arguments": arguments, - "name": tool_name, + }]) + caps.supports_tools = "some_tool" in out + + def make_tool_calls_msg(tool_calls, content=None): + return { + "role": "assistant", + "content": content, + "tool_calls": tool_calls, + } + def make_tool_call(tool_name, arguments): + return { + "id": "call_1___", + "type": "function", + "function": { + "arguments": arguments, + "name": tool_name, + } } - } - - dummy_args_obj = {"argument_needle": "print('Hello, World!')"} - out = try_raw_render([ - dummy_user_msg, - make_tool_calls_msg([make_tool_call("ipython", json.dumps(dummy_args_obj))]), - ]) - tool_call_renders_str_arguments = '"argument_needle":' in out or "'argument_needle':" in out - out = try_raw_render([ - dummy_user_msg, - make_tool_calls_msg([make_tool_call("ipython", dummy_args_obj)]), - ]) - tool_call_renders_obj_arguments = '"argument_needle":' in out or "'argument_needle':" in out - - caps.supports_tool_calls = tool_call_renders_str_arguments or tool_call_renders_obj_arguments - caps.requires_object_arguments = not tool_call_renders_str_arguments and tool_call_renders_obj_arguments - - empty_out = try_raw_render([dummy_user_msg, {"role": "assistant", "content": ''}]) - none_out = try_raw_render([dummy_user_msg, {"role": "assistant", "content": None}]) - caps.requires_non_null_content = \ - (user_needle in try_raw_render([dummy_user_msg, {"role": "assistant", "content": ''}])) \ - and (user_needle not in try_raw_render([dummy_user_msg, {"role": "assistant", "content": None}])) - - if caps.supports_tool_calls: - dummy_args = dummy_args_obj if caps.requires_object_arguments else json.dumps(dummy_args_obj) - tc1 = make_tool_call("test_tool1", dummy_args) - tc2 = make_tool_call("test_tool2", dummy_args) - out = try_raw_render([ + dummy_args_obj = {"argument_needle": "print('Hello, World!')"} + + out = self.try_raw_render([ dummy_user_msg, - make_tool_calls_msg([tc1, tc2]), + make_tool_calls_msg([make_tool_call("ipython", json.dumps(dummy_args_obj))]), ]) - caps.supports_parallel_tool_calls = "test_tool1" in out and "test_tool2" in out - - out = try_raw_render([ + tool_call_renders_str_arguments = '"argument_needle":' in out or "'argument_needle':" in out + out = self.try_raw_render([ dummy_user_msg, - make_tool_calls_msg([tc1]), - { - "role": "tool", - "name": "test_tool1", - "content": "Some response!", - "tool_call_id": "call_911_", - } + make_tool_calls_msg([make_tool_call("ipython", dummy_args_obj)]), ]) - caps.supports_tool_responses = "Some response!" in out - caps.supports_tool_call_id = "call_911_" in out - - return caps - + tool_call_renders_obj_arguments = '"argument_needle":' in out or "'argument_needle':" in out + + caps.supports_tool_calls = tool_call_renders_str_arguments or tool_call_renders_obj_arguments + caps.requires_object_arguments = not tool_call_renders_str_arguments and tool_call_renders_obj_arguments + + caps.requires_non_null_content = \ + (user_needle in self.try_raw_render([dummy_user_msg, {"role": "assistant", "content": ''}])) \ + and (user_needle not in self.try_raw_render([dummy_user_msg, {"role": "assistant", "content": None}])) + + if caps.supports_tool_calls: + dummy_args = dummy_args_obj if caps.requires_object_arguments else json.dumps(dummy_args_obj) + tc1 = make_tool_call("test_tool1", dummy_args) + tc2 = make_tool_call("test_tool2", dummy_args) + out = self.try_raw_render([ + dummy_user_msg, + make_tool_calls_msg([tc1, tc2]), + ]) + caps.supports_parallel_tool_calls = "test_tool1" in out and "test_tool2" in out + + out = self.try_raw_render([ + dummy_user_msg, + make_tool_calls_msg([tc1]), + { + "role": "tool", + "name": "test_tool1", + "content": "Some response!", + "tool_call_id": "call_911_", + } + ]) + caps.supports_tool_responses = "Some response!" in out + caps.supports_tool_call_id = "call_911_" in out + + self.tool_call_example = None + try: + if not caps.supports_tools: + user_msg = {"role": "user", "content": "Hey"} + args = {"arg1": "some_value"} + tool_call_msg = { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_1___", + "type": "function", + "function": { + "name": "tool_name", + "arguments": args if caps.requires_object_arguments else json.dumps(args), + }, + }, + ], + } + prefix = self.try_raw_render([user_msg], add_generation_prompt=True) + full = self.try_raw_render([user_msg, tool_call_msg], add_generation_prompt=False) + if not full.startswith(prefix): + for known_eos_token in known_eos_tokens: + prefix = prefix.rstrip() + if prefix.endswith(known_eos_token): + prefix = prefix[:-len(known_eos_token)] + break + if not full.startswith(prefix): + print("Failed to infer a tool call example (possible template bug)", file=sys.stderr) + self.tool_call_example = full[len(prefix):] + except Exception as e: + print(f"Failed to generate tool call example: {e}", file=sys.stderr) + + self.original_caps = caps + + def needs_polyfills(self, context): + has_tools = context.get('tools') is not None + caps = self.original_caps + return not caps.supports_system_role \ + or (has_tools is not None and (False \ + or not caps.supports_tools \ + or not caps.supports_tool_responses \ + or not caps.supports_tool_calls \ + or caps.requires_object_arguments \ + )) \ + or caps.requires_typed_content + + def apply(self, context): + + caps = self.original_caps + has_tools = 'tools' in context + + if self.needs_polyfills(context): + if has_tools and not caps.supports_tools: + add_system(context['messages'], + f"You can call any of the following tools to satisfy the user's requests: {json.dumps(context['tools'], indent=2)}" + + ("\n\nExample tool call syntax:\n\n" + self.tool_call_example if self.tool_call_example is not None else "")) + + for message in context['messages']: + if 'tool_calls' in message: + for tool_call in message['tool_calls']: + if caps.requires_object_arguments: + if tool_call.get('type') == 'function': + arguments = tool_call['function']['arguments'] + try: + arguments = json.loads(arguments) + except: + pass + tool_call['function']['arguments'] = arguments + if not caps.supports_tool_calls: + message['content'] = json.dumps({ + "tool_calls": [ + { + "name": tc['function']['name'], + "arguments": json.loads(tc['function']['arguments']), + "id": tc.get('id'), + } + for tc in message['tool_calls'] + ], + "content": None if message.get('content', '') == '' else message['content'], + }, indent=2) + del message['tool_calls'] + if message.get('role') == 'tool' and not caps.supports_tool_responses: + message['role'] = 'user' + message['content'] = json.dumps({ + "tool_response": { + "tool": message['name'], + "content": message['content'], + "tool_call_id": message.get('tool_call_id'), + } + }, indent=2) + del message['name'] + + if caps.requires_typed_content: + for message in context['messages']: + if 'content' in message and isinstance(message['content'], str): + message['content'] = [{"type": "text", "text": message['content']}] + + try: + return self.template.render(**context) + except Exception as e1: + for message in context['messages']: + if message.get("content") is None: + message["content"] = "" + + try: + return self.template.render(**context) + except Exception as e2: + logger.info(f" ERROR: {e2} (after first error: {e1})") + return f"ERROR: {e2}" + + + + async def handle_chat_template(output_folder, model_id, variant, template_src, context_files): if '{% generation %}' in template_src: print('Removing {% generation %} blocks from template', file=sys.stderr) @@ -221,105 +344,52 @@ async def handle_chat_template(output_folder, model_id, variant, template_src, c model_name = model_id.replace("/", "-") base_name = f'{model_name}-{variant}' if variant else model_name template_file = join_cmake_path(output_folder, f'{base_name}.jinja') + caps_file = join_cmake_path(output_folder, f'{base_name}.caps.json') async with aiofiles.open(template_file, 'w') as f: await f.write(template_src) - env = jinja2.Environment( - trim_blocks=True, - lstrip_blocks=True, - extensions=[jinja2.ext.loopcontrols] - ) - template = env.from_string(template_src) - - env.filters['safe'] = lambda x: x - env.filters['tojson'] = tojson - env.globals['raise_exception'] = raise_exception - env.globals['strftime_now'] = strftime_now - - caps = detect_caps(template_file, template) - + known_eos_tokens = [ + "<|END_OF_TURN_TOKEN|>", + "", + "", + "<|im_end|>", + "<|eom_id|>", + "<|eot_id|>", + "<|end▁of▁sentence|>", + ] + + template = chat_template(template_src, known_eos_tokens) + template.env.filters['safe'] = lambda x: x + template.env.filters['tojson'] = tojson + template.env.globals['raise_exception'] = raise_exception + template.env.globals['strftime_now'] = strftime_now + caps = template.original_caps + if not context_files: print(f"{template_file} {caps_file} n/a {template_file}") return async with aiofiles.open(caps_file, 'w') as f: await f.write(caps.to_json()) - + for context_file in context_files: context_name = os.path.basename(context_file).replace(".json", "") async with aiofiles.open(context_file, 'r') as f: context = json.loads(await f.read()) - has_tools = 'tools' in context - needs_tools_in_system = has_tools and not caps.supports_tools - - if not caps.supports_tool_calls and has_tools: + if not caps.supports_tool_calls and context.get('tools') is not None: print(f'Skipping {context_name} test as tools seem unsupported by template {template_file}', file=sys.stderr) continue - + + needs_tools_in_system = len(context.get('tools', [])) > 0 and not caps.supports_tools if not caps.supports_system_role and (any(m['role'] == 'system' for m in context['messages']) or needs_tools_in_system): continue output_file = join_cmake_path(output_folder, f'{base_name}-{context_name}.txt') - if needs_tools_in_system: - add_system(context['messages'], f"Available tools: {json.dumps(context['tools'], indent=2)}") - - for message in context['messages']: - if 'tool_calls' in message: - for tool_call in message['tool_calls']: - if caps.requires_object_arguments: - if tool_call.get('type') == 'function': - arguments = tool_call['function']['arguments'] - try: - arguments = json.loads(arguments) - except: - pass - tool_call['function']['arguments'] = arguments - if not caps.supports_tool_calls: - message['content'] = json.dumps({ - "tool_calls": [ - { - "name": tc['function']['name'], - "arguments": json.loads(tc['function']['arguments']), - "id": tc.get('id'), - } - for tc in message['tool_calls'] - ], - "content": None if message.get('content', '') == '' else message['content'], - }, indent=2) - del message['tool_calls'] - if message.get('role') == 'tool' and not caps.supports_tool_responses: - message['role'] = 'user' - message['content'] = json.dumps({ - "tool_response": { - "tool": message['name'], - "content": message['content'], - "tool_call_id": message.get('tool_call_id'), - } - }, indent=2) - del message['name'] - - if caps.requires_typed_content: - for message in context['messages']: - if 'content' in message and isinstance(message['content'], str): - message['content'] = [{"type": "text", "text": message['content']}] - - try: - output = template.render(**context) - except Exception as e1: - for message in context['messages']: - if message.get("content") is None: - message["content"] = "" - - try: - output = template.render(**context) - except Exception as e2: - logger.info(f" ERROR: {e2} (after first error: {e1})") - output = f"ERROR: {e2}" - + output = template.apply(context) async with aiofiles.open(output_file, 'w') as f: await f.write(output) @@ -336,7 +406,7 @@ async def async_hf_download(repo_id: str, filename: str) -> str: async def process_model(output_folder: str, model_id: str, context_files: list): try: config_str = await async_hf_download(model_id, "tokenizer_config.json") - + try: config = json.loads(config_str) except json.JSONDecodeError: diff --git a/scripts/run_fuzzing_mode.sh b/scripts/fuzzing_tests.sh similarity index 100% rename from scripts/run_fuzzing_mode.sh rename to scripts/fuzzing_tests.sh diff --git a/scripts/run_tests.sh b/scripts/tests.sh similarity index 100% rename from scripts/run_tests.sh rename to scripts/tests.sh diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 96c368a..6a34dd4 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -39,9 +39,12 @@ add_test(NAME test-syntax-jinja2 COMMAND test-syntax) set_tests_properties(test-syntax-jinja2 PROPERTIES ENVIRONMENT "USE_JINJA2=1;PYTHON_EXECUTABLE=${Python_EXECUTABLE};PYTHONPATH=${CMAKE_SOURCE_DIR}") -add_executable(test-chat-template test-chat-template.cpp) -target_compile_features(test-chat-template PUBLIC cxx_std_17) -target_link_libraries(test-chat-template PRIVATE nlohmann_json::nlohmann_json) +add_executable(test-supported-template test-supported-template.cpp) +target_compile_features(test-supported-template PUBLIC cxx_std_17) +if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + target_compile_definitions(test-supported-template PUBLIC _CRT_SECURE_NO_WARNINGS) +endif() +target_link_libraries(test-supported-template PRIVATE nlohmann_json::nlohmann_json) set(MODEL_IDS # List of model IDs to test the chat template of. @@ -133,7 +136,7 @@ foreach(test_case ${CHAT_TEMPLATE_TEST_CASES}) separate_arguments(test_args UNIX_COMMAND "${test_case}") list(GET test_args -1 last_arg) string(REGEX REPLACE "^[^ ]+/([^ /\\]+)\\.[^.]+$" "\\1" test_name "${last_arg}") - add_test(NAME ${test_name} COMMAND $ ${test_args}) + add_test(NAME ${test_name} COMMAND $ ${test_args}) set_tests_properties(${test_name} PROPERTIES SKIP_RETURN_CODE 127) endforeach() diff --git a/tests/test-chat-template.cpp b/tests/test-supported-template.cpp similarity index 74% rename from tests/test-chat-template.cpp rename to tests/test-supported-template.cpp index 6f8bcb6..302ebbd 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-supported-template.cpp @@ -18,6 +18,8 @@ #undef NDEBUG #include +#define TEST_DATE (getenv("TEST_DATE") ? getenv("TEST_DATE") : "2024-07-26") + using json = nlohmann::ordered_json; template @@ -55,6 +57,14 @@ static std::string read_file(const std::string &path) { return out; } +static void write_file(const std::string &path, const std::string &content) { + std::ofstream fs(path, std::ios_base::binary); + if (!fs.is_open()) { + throw std::runtime_error("Failed to open file: " + path); + } + fs.write(content.data(), content.size()); +} + #ifndef _WIN32 static json caps_to_json(const minja::chat_template_caps &caps) { return { @@ -96,10 +106,8 @@ int main(int argc, char *argv[]) { return 127; } - std::cout << "# Testing template: " << tmpl_file << std::endl - << "# With caps: " << caps_file << std::endl - << "# With context: " << ctx_file << std::endl - << "# Against golden file: " << golden_file << std::endl + std::cout << "# Testing template:\n" + << "# ./build/bin/test-supported-template " << json::array({tmpl_file, caps_file, ctx_file, golden_file}).dump() << std::endl << std::flush; auto ctx = json::parse(read_file(ctx_file)); @@ -118,21 +126,38 @@ int main(int argc, char *argv[]) { return 1; } + struct minja::chat_template_inputs inputs; + inputs.messages = ctx.at("messages"); + inputs.tools = ctx.contains("tools") ? ctx.at("tools") : json(); + inputs.add_generation_prompt = ctx.at("add_generation_prompt"); + + std::istringstream ss(TEST_DATE); + std::tm tm = {}; + ss >> std::get_time(&tm, "%Y-%m-%d"); + inputs.now = std::chrono::system_clock::from_time_t(std::mktime(&tm)); + + if (ctx.contains("tools")) { + inputs.extra_context = json { + {"builtin_tools", json::array({"wolfram_alpha", "brave_search"})}, + }; + } + std::string actual; try { - actual = tmpl.apply( - ctx.at("messages"), - ctx.contains("tools") ? ctx.at("tools") : json(), - ctx.at("add_generation_prompt"), - ctx.contains("tools") ? json{ - {"builtin_tools", {"wolfram_alpha", "brave_search"}}} - : json()); + actual = tmpl.apply(inputs); } catch (const std::exception &e) { std::cerr << "Error applying template: " << e.what() << std::endl; return 1; } - assert_equals(expected, actual); + if (expected != actual) { + if (getenv("WRITE_GOLDENS")) { + write_file(golden_file, actual); + std::cerr << "Updated golden file: " << golden_file << std::endl; + } else { + assert_equals(expected, actual); + } + } // Some unresolved CRLF issues again with the goldens on Windows. #ifndef _WIN32