From 6b9a34b68eea0fdd9d632cc0d4879ea1e4819622 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sat, 1 Feb 2025 11:39:12 +0000 Subject: [PATCH 01/20] TMP inputs refactor --- include/minja/chat-template.hpp | 51 ++++++++++++++++++++++++--------- tests/test-chat-template.cpp | 19 +++++++----- 2 files changed, 50 insertions(+), 20 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 58e119a..8e05652 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -33,6 +33,17 @@ struct chat_template_caps { bool requires_typed_content = false; }; +struct chat_template_inputs { + nlohmann::ordered_json messages; + nlohmann::ordered_json tools; + bool add_generation_prompt; + nlohmann::ordered_json extra_context; + // Epoch time in milliseconds. + uint64_t now; + // Timezone offset in minutes. + int64_t timezone_offset; +}; + class chat_template { private: @@ -49,7 +60,15 @@ class chat_template { const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const { try { - auto prompt = apply(messages, tools, add_generation_prompt, extra_context, /* adjust_inputs= */ false); + chat_template_inputs inputs { + messages, + tools, + add_generation_prompt, + extra_context, + /* now= */ 0, + /* timezone_offset= */ 0, + }; + auto prompt = apply(inputs, /* adjust_inputs= */ false); // fprintf(stderr, "try_raw_render: %s\n", prompt.c_str()); return prompt; } catch (const std::exception & e) { @@ -184,10 +203,11 @@ class chat_template { const chat_template_caps & original_caps() const { return caps_; } std::string apply( - const nlohmann::ordered_json & messages, - const nlohmann::ordered_json & tools, - bool add_generation_prompt, - const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(), + const chat_template_inputs & inputs, + // const nlohmann::ordered_json & messages, + // const nlohmann::ordered_json & tools, + // bool add_generation_prompt, + // const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(), bool adjust_inputs = true) const { json actual_messages; @@ -227,9 +247,9 @@ class chat_template { pending_system.clear(); } }; - auto needs_tools_in_system = !tools.is_null() && tools.size() > 0 && !caps_.supports_tools; + auto needs_tools_in_system = !inputs.tools.is_null() && inputs.tools.size() > 0 && !caps_.supports_tools; - for (const auto & message_ : needs_tools_in_system ? add_system(messages, "Available tools: " + tools.dump(2)) : messages) { + for (const auto & message_ : needs_tools_in_system ? add_system(inputs.messages, "Available tools: " + inputs.tools.dump(2)) : inputs.messages) { auto message = message_; if (!message.contains("role") || !message.contains("content")) { throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump()); @@ -319,22 +339,27 @@ class chat_template { flush_sys(); } } else { - actual_messages = messages; + actual_messages = inputs.messages; } auto context = minja::Context::make(json({ {"messages", actual_messages}, - {"add_generation_prompt", add_generation_prompt}, + {"add_generation_prompt", inputs.add_generation_prompt}, {"bos_token", bos_token_}, {"eos_token", eos_token_}, + // {"strftime_now", Value::callable([=](const std::shared_ptr & context, minja::ArgumentsValue & args) { + // args.expectArgs("strftime_now", {1, 1}, {0, 0}); + // auto format = args.args[0].get(); + // return Value(std::to_string(inputs.now)); + // })}, })); - if (!tools.is_null()) { - auto tools_val = minja::Value(tools); + if (!inputs.tools.is_null()) { + auto tools_val = minja::Value(inputs.tools); context->set("tools", tools_val); } - if (!extra_context.is_null()) { - for (auto & kv : extra_context.items()) { + if (!inputs.extra_context.is_null()) { + for (auto & kv : inputs.extra_context.items()) { minja::Value val(kv.value()); context->set(kv.key(), val); } diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 6f8bcb6..f6110cd 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -118,15 +118,20 @@ int main(int argc, char *argv[]) { return 1; } + struct minja::chat_template_inputs inputs; + inputs.messages = ctx.at("messages"); + inputs.tools = ctx.contains("tools") ? ctx.at("tools") : json(); + inputs.add_generation_prompt = ctx.at("add_generation_prompt"); + if (ctx.contains("tools")) { + inputs.extra_context = json { + {"builtin_tools", { + {"wolfram_alpha", "brave_search"} + }}, + }; + } std::string actual; try { - actual = tmpl.apply( - ctx.at("messages"), - ctx.contains("tools") ? ctx.at("tools") : json(), - ctx.at("add_generation_prompt"), - ctx.contains("tools") ? json{ - {"builtin_tools", {"wolfram_alpha", "brave_search"}}} - : json()); + actual = tmpl.apply(inputs); } catch (const std::exception &e) { std::cerr << "Error applying template: " << e.what() << std::endl; return 1; From 4d6481998eb3a4d620fa6eae4498e55faed6af16 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 14:19:11 +0000 Subject: [PATCH 02/20] Tool support backfil: provide automated tool call example --- include/minja/chat-template.hpp | 50 ++++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 58e119a..1900950 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -41,6 +41,7 @@ class chat_template { std::string bos_token_; std::string eos_token_; std::shared_ptr template_root_; + std::string tool_call_example_; std::string try_raw_render( const nlohmann::ordered_json & messages, @@ -176,6 +177,43 @@ class chat_template { caps_.supports_tool_responses = contains(out, "Some response!"); caps_.supports_tool_call_id = contains(out, "call_911_"); } + + if (!caps_.supports_tools) { + const json user_msg { + {"role", "user"}, + {"content", "Hey"}, + }; + const json tool_call_msg { + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + { + // TODO: detect if requires numerical id or fixed length == 6 like Nemo + {"id", "call_1___"}, + {"type", "function"}, + {"function", { + {"name", "tool_name"}, + {"arguments", (json { + {"arg1", "some_value"}, + }).dump()}, + }}, + }, + })}, + }; + const json tools; + auto prefix = apply(json::array({user_msg}), tools, /* add_generation_prompt= */ true); + auto full = apply(json::array({user_msg, tool_call_msg}), tools, /* add_generation_prompt= */ false); + if (full.find(prefix) != 0) { + if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) { + prefix = prefix.substr(0, prefix.size() - eos_token_.size()); + } else { + throw std::runtime_error("prefix not found at start of full: " + prefix + " vs " + full); + } + } else { + + } + tool_call_example_ = full.substr(prefix.size()); + } } const std::string & source() const { return source_; } @@ -229,7 +267,17 @@ class chat_template { }; auto needs_tools_in_system = !tools.is_null() && tools.size() > 0 && !caps_.supports_tools; - for (const auto & message_ : needs_tools_in_system ? add_system(messages, "Available tools: " + tools.dump(2)) : messages) { + json adjusted_messages; + if (needs_tools_in_system) { + adjusted_messages = add_system(messages, + "\n\n" + "You can call any of the following tools to satisfy the user's requests: " + tools.dump(2) + "\n\n" + "Example tool call syntax:\n\n" + tool_call_example_ + "\n\n"); + } else { + adjusted_messages = messages; + } + + for (const auto & message_ : adjusted_messages) { auto message = message_; if (!message.contains("role") || !message.contains("content")) { throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump()); From 0026057b987c5b4bb02ffda4480c2a3e94e775e4 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 20:39:30 +0000 Subject: [PATCH 03/20] Switch naming towards polyfills (and skip tool-related logic if there's no tools) --- include/minja/chat-template.hpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 1900950..862cfeb 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -50,7 +50,7 @@ class chat_template { const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const { try { - auto prompt = apply(messages, tools, add_generation_prompt, extra_context, /* adjust_inputs= */ false); + auto prompt = apply(messages, tools, add_generation_prompt, extra_context, /* apply_polyfills= */ false); // fprintf(stderr, "try_raw_render: %s\n", prompt.c_str()); return prompt; } catch (const std::exception & e) { @@ -226,19 +226,21 @@ class chat_template { const nlohmann::ordered_json & tools, bool add_generation_prompt, const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(), - bool adjust_inputs = true) const + bool apply_polyfills = true) const { json actual_messages; - auto needs_adjustments = adjust_inputs && (false + auto needs_polyfills = apply_polyfills && (false || !caps_.supports_system_role - || !caps_.supports_tools - || !caps_.supports_tool_responses - || !caps_.supports_tool_calls - || caps_.requires_object_arguments + || (!tools.is_null() && (false + || !caps_.supports_tools + || !caps_.supports_tool_responses + || !caps_.supports_tool_calls + || caps_.requires_object_arguments + )) || caps_.requires_typed_content ); - if (needs_adjustments) { + if (needs_polyfills) { actual_messages = json::array(); auto add_message = [&](const json & msg) { From 23f4378a560223db2f0cf1399046859a4d39b480 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 20:40:34 +0000 Subject: [PATCH 04/20] rename test-chat-template -> test-supported-template --- tests/CMakeLists.txt | 8 ++++---- ...template.cpp => test-supported-template.cpp} | 17 ++++++++++++++++- 2 files changed, 20 insertions(+), 5 deletions(-) rename tests/{test-chat-template.cpp => test-supported-template.cpp} (89%) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 96c368a..0d1c40e 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -39,9 +39,9 @@ add_test(NAME test-syntax-jinja2 COMMAND test-syntax) set_tests_properties(test-syntax-jinja2 PROPERTIES ENVIRONMENT "USE_JINJA2=1;PYTHON_EXECUTABLE=${Python_EXECUTABLE};PYTHONPATH=${CMAKE_SOURCE_DIR}") -add_executable(test-chat-template test-chat-template.cpp) -target_compile_features(test-chat-template PUBLIC cxx_std_17) -target_link_libraries(test-chat-template PRIVATE nlohmann_json::nlohmann_json) +add_executable(test-supported-template test-supported-template.cpp) +target_compile_features(test-supported-template PUBLIC cxx_std_17) +target_link_libraries(test-supported-template PRIVATE nlohmann_json::nlohmann_json) set(MODEL_IDS # List of model IDs to test the chat template of. @@ -133,7 +133,7 @@ foreach(test_case ${CHAT_TEMPLATE_TEST_CASES}) separate_arguments(test_args UNIX_COMMAND "${test_case}") list(GET test_args -1 last_arg) string(REGEX REPLACE "^[^ ]+/([^ /\\]+)\\.[^.]+$" "\\1" test_name "${last_arg}") - add_test(NAME ${test_name} COMMAND $ ${test_args}) + add_test(NAME ${test_name} COMMAND $ ${test_args}) set_tests_properties(${test_name} PROPERTIES SKIP_RETURN_CODE 127) endforeach() diff --git a/tests/test-chat-template.cpp b/tests/test-supported-template.cpp similarity index 89% rename from tests/test-chat-template.cpp rename to tests/test-supported-template.cpp index 6f8bcb6..f23cd8c 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-supported-template.cpp @@ -55,6 +55,14 @@ static std::string read_file(const std::string &path) { return out; } +static void write_file(const std::string &path, const std::string &content) { + std::ofstream fs(path, std::ios_base::binary); + if (!fs.is_open()) { + throw std::runtime_error("Failed to open file: " + path); + } + fs.write(content.data(), content.size()); +} + #ifndef _WIN32 static json caps_to_json(const minja::chat_template_caps &caps) { return { @@ -132,7 +140,14 @@ int main(int argc, char *argv[]) { return 1; } - assert_equals(expected, actual); + if (expected != actual) { + if (getenv("WRITE_GOLDENS")) { + write_file(golden_file, actual); + std::cerr << "Updated golden file: " << golden_file << std::endl; + } else { + assert_equals(expected, actual); + } + } // Some unresolved CRLF issues again with the goldens on Windows. #ifndef _WIN32 From 51c5ac5e738f66219dee3eb1517c0431989b802d Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 20:51:32 +0000 Subject: [PATCH 05/20] refactor fetch_templates_and_goldens --- include/minja/chat-template.hpp | 97 ++++--- scripts/fetch_templates_and_goldens.py | 366 +++++++++++++------------ 2 files changed, 248 insertions(+), 215 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 862cfeb..eee28cb 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -178,42 +178,42 @@ class chat_template { caps_.supports_tool_call_id = contains(out, "call_911_"); } - if (!caps_.supports_tools) { - const json user_msg { - {"role", "user"}, - {"content", "Hey"}, - }; - const json tool_call_msg { - {"role", "assistant"}, - {"content", nullptr}, - {"tool_calls", json::array({ - { - // TODO: detect if requires numerical id or fixed length == 6 like Nemo - {"id", "call_1___"}, - {"type", "function"}, - {"function", { - {"name", "tool_name"}, - {"arguments", (json { - {"arg1", "some_value"}, - }).dump()}, - }}, - }, - })}, - }; - const json tools; - auto prefix = apply(json::array({user_msg}), tools, /* add_generation_prompt= */ true); - auto full = apply(json::array({user_msg, tool_call_msg}), tools, /* add_generation_prompt= */ false); - if (full.find(prefix) != 0) { - if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) { - prefix = prefix.substr(0, prefix.size() - eos_token_.size()); - } else { - throw std::runtime_error("prefix not found at start of full: " + prefix + " vs " + full); - } - } else { - - } - tool_call_example_ = full.substr(prefix.size()); - } + // if (!caps_.supports_tools) { + // const json user_msg { + // {"role", "user"}, + // {"content", "Hey"}, + // }; + // const json tool_call_msg { + // {"role", "assistant"}, + // {"content", nullptr}, + // {"tool_calls", json::array({ + // { + // // TODO: detect if requires numerical id or fixed length == 6 like Nemo + // {"id", "call_1___"}, + // {"type", "function"}, + // {"function", { + // {"name", "tool_name"}, + // {"arguments", (json { + // {"arg1", "some_value"}, + // }).dump()}, + // }}, + // }, + // })}, + // }; + // const json tools; + // auto prefix = apply(json::array({user_msg}), tools, /* add_generation_prompt= */ true); + // auto full = apply(json::array({user_msg, tool_call_msg}), tools, /* add_generation_prompt= */ false); + // if (full.find(prefix) != 0) { + // if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) { + // prefix = prefix.substr(0, prefix.size() - eos_token_.size()); + // } else { + // throw std::runtime_error("prefix not found at start of full: " + prefix + " vs " + full); + // } + // } else { + + // } + // tool_call_example_ = full.substr(prefix.size()); + // } } const std::string & source() const { return source_; } @@ -232,13 +232,19 @@ class chat_template { auto needs_polyfills = apply_polyfills && (false || !caps_.supports_system_role - || (!tools.is_null() && (false - || !caps_.supports_tools - || !caps_.supports_tool_responses - || !caps_.supports_tool_calls - || caps_.requires_object_arguments - )) + || !caps_.supports_tools + || !caps_.supports_tool_responses + || !caps_.supports_tool_calls + || caps_.requires_object_arguments || caps_.requires_typed_content + // || !caps_.supports_system_role + // || (!tools.is_null() && (false + // || !caps_.supports_tools + // || !caps_.supports_tool_responses + // || !caps_.supports_tool_calls + // || caps_.requires_object_arguments + // )) + // || caps_.requires_typed_content ); if (needs_polyfills) { actual_messages = json::array(); @@ -272,9 +278,10 @@ class chat_template { json adjusted_messages; if (needs_tools_in_system) { adjusted_messages = add_system(messages, - "\n\n" - "You can call any of the following tools to satisfy the user's requests: " + tools.dump(2) + "\n\n" - "Example tool call syntax:\n\n" + tool_call_example_ + "\n\n"); + "Available tools: " + tools.dump(2)); + // "\n\n" + // "You can call any of the following tools to satisfy the user's requests: " + tools.dump(2) + "\n\n" + // "Example tool call syntax:\n\n" + tool_call_example_ + "\n\n"); } else { adjusted_messages = messages; } diff --git a/scripts/fetch_templates_and_goldens.py b/scripts/fetch_templates_and_goldens.py index 619b539..3251b10 100644 --- a/scripts/fetch_templates_and_goldens.py +++ b/scripts/fetch_templates_and_goldens.py @@ -99,120 +99,208 @@ def to_json(self): # "requires_non_null_content": self.requires_non_null_content, "requires_typed_content": self.requires_typed_content, }, indent=2) - -def detect_caps(template_file, template): - basic_extra_context = { - "bos_token": "<|startoftext|>", - "eos_token": "<|endoftext|>", - } - def try_raw_render(messages, *, tools=[], add_generation_prompt=False, extra_context={}, expect_strings=[]): + + +class chat_template: + + def try_raw_render(self, messages, *, tools=[], add_generation_prompt=False, extra_context={}, expect_strings=[]): + basic_extra_context = { + "bos_token": "<|startoftext|>", + "eos_token": "<|endoftext|>", + } + try: - out = template.render(messages=messages, tools=tools, add_generation_prompt=add_generation_prompt, **basic_extra_context, **extra_context) + out = self.template.render(messages=messages, tools=tools, add_generation_prompt=add_generation_prompt, **basic_extra_context, **extra_context) # print(out, file=sys.stderr) return out except BaseException as e: # print(f"{template_file}: Error rendering template with messages {messages}: {e}", file=sys.stderr, flush=True) return "" - - caps = TemplateCaps() - - user_needle = "" - sys_needle = "" - dummy_str_user_msg = {"role": "user", "content": user_needle } - dummy_typed_user_msg = {"role": "user", "content": [{"type": "text", "text": user_needle}]} - - caps.requires_typed_content = \ - (user_needle not in try_raw_render([dummy_str_user_msg])) \ - and (user_needle in try_raw_render([dummy_typed_user_msg])) - dummy_user_msg = dummy_typed_user_msg if caps.requires_typed_content else dummy_str_user_msg - - needle_system_msg = {"role": "system", "content": [{"type": "text", "text": sys_needle}] if caps.requires_typed_content else sys_needle} - - caps.supports_system_role = sys_needle in try_raw_render([needle_system_msg, dummy_user_msg]) - - out = try_raw_render([dummy_user_msg], tools=[{ - "name": "some_tool", - "type": "function", - "function": { + + def __init__(self, template, env=None): + if not env: + env = jinja2.Environment( + trim_blocks=True, + lstrip_blocks=True, + extensions=[jinja2.ext.loopcontrols] + ) + self.env = env + self.template = env.from_string(template) + + caps = TemplateCaps() + + user_needle = "" + sys_needle = "" + dummy_str_user_msg = {"role": "user", "content": user_needle } + dummy_typed_user_msg = {"role": "user", "content": [{"type": "text", "text": user_needle}]} + + caps.requires_typed_content = \ + (user_needle not in self.try_raw_render([dummy_str_user_msg])) \ + and (user_needle in self.try_raw_render([dummy_typed_user_msg])) + dummy_user_msg = dummy_typed_user_msg if caps.requires_typed_content else dummy_str_user_msg + + needle_system_msg = {"role": "system", "content": [{"type": "text", "text": sys_needle}] if caps.requires_typed_content else sys_needle} + + caps.supports_system_role = sys_needle in self.try_raw_render([needle_system_msg, dummy_user_msg]) + + out = self.try_raw_render([dummy_user_msg], tools=[{ "name": "some_tool", - "description": "Some tool", - "parameters": { - "type": "object", - "properties": { - "arg": { - "type": "string", - "description": "Some arg", + "type": "function", + "function": { + "name": "some_tool", + "description": "Some tool", + "parameters": { + "type": "object", + "properties": { + "arg": { + "type": "string", + "description": "Some arg", + }, }, + "required": ["arg"], }, - "required": ["arg"], }, - }, - }]) - caps.supports_tools = "some_tool" in out - - def make_tool_calls_msg(tool_calls, content=None): - return { - "role": "assistant", - "content": content, - "tool_calls": tool_calls, - } - def make_tool_call(tool_name, arguments): - return { - "id": "call_1___", - "type": "function", - "function": { - "arguments": arguments, - "name": tool_name, + }]) + caps.supports_tools = "some_tool" in out + + def make_tool_calls_msg(tool_calls, content=None): + return { + "role": "assistant", + "content": content, + "tool_calls": tool_calls, + } + def make_tool_call(tool_name, arguments): + return { + "id": "call_1___", + "type": "function", + "function": { + "arguments": arguments, + "name": tool_name, + } } - } - - dummy_args_obj = {"argument_needle": "print('Hello, World!')"} - - out = try_raw_render([ - dummy_user_msg, - make_tool_calls_msg([make_tool_call("ipython", json.dumps(dummy_args_obj))]), - ]) - tool_call_renders_str_arguments = '"argument_needle":' in out or "'argument_needle':" in out - out = try_raw_render([ - dummy_user_msg, - make_tool_calls_msg([make_tool_call("ipython", dummy_args_obj)]), - ]) - tool_call_renders_obj_arguments = '"argument_needle":' in out or "'argument_needle':" in out - caps.supports_tool_calls = tool_call_renders_str_arguments or tool_call_renders_obj_arguments - caps.requires_object_arguments = not tool_call_renders_str_arguments and tool_call_renders_obj_arguments - - empty_out = try_raw_render([dummy_user_msg, {"role": "assistant", "content": ''}]) - none_out = try_raw_render([dummy_user_msg, {"role": "assistant", "content": None}]) - caps.requires_non_null_content = \ - (user_needle in try_raw_render([dummy_user_msg, {"role": "assistant", "content": ''}])) \ - and (user_needle not in try_raw_render([dummy_user_msg, {"role": "assistant", "content": None}])) - - if caps.supports_tool_calls: - dummy_args = dummy_args_obj if caps.requires_object_arguments else json.dumps(dummy_args_obj) - tc1 = make_tool_call("test_tool1", dummy_args) - tc2 = make_tool_call("test_tool2", dummy_args) - out = try_raw_render([ + dummy_args_obj = {"argument_needle": "print('Hello, World!')"} + + out = self.try_raw_render([ dummy_user_msg, - make_tool_calls_msg([tc1, tc2]), + make_tool_calls_msg([make_tool_call("ipython", json.dumps(dummy_args_obj))]), ]) - caps.supports_parallel_tool_calls = "test_tool1" in out and "test_tool2" in out - - out = try_raw_render([ + tool_call_renders_str_arguments = '"argument_needle":' in out or "'argument_needle':" in out + out = self.try_raw_render([ dummy_user_msg, - make_tool_calls_msg([tc1]), - { - "role": "tool", - "name": "test_tool1", - "content": "Some response!", - "tool_call_id": "call_911_", - } + make_tool_calls_msg([make_tool_call("ipython", dummy_args_obj)]), ]) - caps.supports_tool_responses = "Some response!" in out - caps.supports_tool_call_id = "call_911_" in out - - return caps - + tool_call_renders_obj_arguments = '"argument_needle":' in out or "'argument_needle':" in out + + caps.supports_tool_calls = tool_call_renders_str_arguments or tool_call_renders_obj_arguments + caps.requires_object_arguments = not tool_call_renders_str_arguments and tool_call_renders_obj_arguments + + caps.requires_non_null_content = \ + (user_needle in self.try_raw_render([dummy_user_msg, {"role": "assistant", "content": ''}])) \ + and (user_needle not in self.try_raw_render([dummy_user_msg, {"role": "assistant", "content": None}])) + + if caps.supports_tool_calls: + dummy_args = dummy_args_obj if caps.requires_object_arguments else json.dumps(dummy_args_obj) + tc1 = make_tool_call("test_tool1", dummy_args) + tc2 = make_tool_call("test_tool2", dummy_args) + out = self.try_raw_render([ + dummy_user_msg, + make_tool_calls_msg([tc1, tc2]), + ]) + caps.supports_parallel_tool_calls = "test_tool1" in out and "test_tool2" in out + + out = self.try_raw_render([ + dummy_user_msg, + make_tool_calls_msg([tc1]), + { + "role": "tool", + "name": "test_tool1", + "content": "Some response!", + "tool_call_id": "call_911_", + } + ]) + caps.supports_tool_responses = "Some response!" in out + caps.supports_tool_call_id = "call_911_" in out + + self.original_caps = caps + + def needs_polyfills(self, context): + has_tools = context.get('tools') is not None + caps = self.original_caps + return not caps.supports_system_role \ + or (has_tools is not None and (False \ + or not caps.supports_tools \ + or not caps.supports_tool_responses \ + or not caps.supports_tool_calls \ + or caps.requires_object_arguments \ + )) \ + or caps.requires_typed_content + + def apply(self, context): + + caps = self.original_caps + has_tools = 'tools' in context + + if self.needs_polyfills(context): + if has_tools and not caps.supports_tools: + add_system(context['messages'], f"Available tools: {json.dumps(context['tools'], indent=2)}") + + for message in context['messages']: + if 'tool_calls' in message: + for tool_call in message['tool_calls']: + if caps.requires_object_arguments: + if tool_call.get('type') == 'function': + arguments = tool_call['function']['arguments'] + try: + arguments = json.loads(arguments) + except: + pass + tool_call['function']['arguments'] = arguments + if not caps.supports_tool_calls: + message['content'] = json.dumps({ + "tool_calls": [ + { + "name": tc['function']['name'], + "arguments": json.loads(tc['function']['arguments']), + "id": tc.get('id'), + } + for tc in message['tool_calls'] + ], + "content": None if message.get('content', '') == '' else message['content'], + }, indent=2) + del message['tool_calls'] + if message.get('role') == 'tool' and not caps.supports_tool_responses: + message['role'] = 'user' + message['content'] = json.dumps({ + "tool_response": { + "tool": message['name'], + "content": message['content'], + "tool_call_id": message.get('tool_call_id'), + } + }, indent=2) + del message['name'] + + if caps.requires_typed_content: + for message in context['messages']: + if 'content' in message and isinstance(message['content'], str): + message['content'] = [{"type": "text", "text": message['content']}] + + try: + return self.template.render(**context) + except Exception as e1: + for message in context['messages']: + if message.get("content") is None: + message["content"] = "" + + try: + return self.template.render(**context) + except Exception as e2: + logger.info(f" ERROR: {e2} (after first error: {e1})") + return f"ERROR: {e2}" + + + + async def handle_chat_template(output_folder, model_id, variant, template_src, context_files): if '{% generation %}' in template_src: print('Removing {% generation %} blocks from template', file=sys.stderr) @@ -221,24 +309,19 @@ async def handle_chat_template(output_folder, model_id, variant, template_src, c model_name = model_id.replace("/", "-") base_name = f'{model_name}-{variant}' if variant else model_name template_file = join_cmake_path(output_folder, f'{base_name}.jinja') + caps_file = join_cmake_path(output_folder, f'{base_name}.caps.json') async with aiofiles.open(template_file, 'w') as f: await f.write(template_src) - env = jinja2.Environment( - trim_blocks=True, - lstrip_blocks=True, - extensions=[jinja2.ext.loopcontrols] - ) - template = env.from_string(template_src) - - env.filters['safe'] = lambda x: x - env.filters['tojson'] = tojson - env.globals['raise_exception'] = raise_exception - env.globals['strftime_now'] = strftime_now + template = chat_template(template_src) + template.env.filters['safe'] = lambda x: x + template.env.filters['tojson'] = tojson + template.env.globals['raise_exception'] = raise_exception + template.env.globals['strftime_now'] = strftime_now - caps = detect_caps(template_file, template) + caps = template.original_caps if not context_files: print(f"{template_file} {caps_file} n/a {template_file}") @@ -252,74 +335,17 @@ async def handle_chat_template(output_folder, model_id, variant, template_src, c async with aiofiles.open(context_file, 'r') as f: context = json.loads(await f.read()) - has_tools = 'tools' in context - needs_tools_in_system = has_tools and not caps.supports_tools - - if not caps.supports_tool_calls and has_tools: + if not caps.supports_tool_calls and context.get('tools') is not None: print(f'Skipping {context_name} test as tools seem unsupported by template {template_file}', file=sys.stderr) continue + needs_tools_in_system = len(context.get('tools', [])) > 0 and not caps.supports_tools if not caps.supports_system_role and (any(m['role'] == 'system' for m in context['messages']) or needs_tools_in_system): continue output_file = join_cmake_path(output_folder, f'{base_name}-{context_name}.txt') - if needs_tools_in_system: - add_system(context['messages'], f"Available tools: {json.dumps(context['tools'], indent=2)}") - - for message in context['messages']: - if 'tool_calls' in message: - for tool_call in message['tool_calls']: - if caps.requires_object_arguments: - if tool_call.get('type') == 'function': - arguments = tool_call['function']['arguments'] - try: - arguments = json.loads(arguments) - except: - pass - tool_call['function']['arguments'] = arguments - if not caps.supports_tool_calls: - message['content'] = json.dumps({ - "tool_calls": [ - { - "name": tc['function']['name'], - "arguments": json.loads(tc['function']['arguments']), - "id": tc.get('id'), - } - for tc in message['tool_calls'] - ], - "content": None if message.get('content', '') == '' else message['content'], - }, indent=2) - del message['tool_calls'] - if message.get('role') == 'tool' and not caps.supports_tool_responses: - message['role'] = 'user' - message['content'] = json.dumps({ - "tool_response": { - "tool": message['name'], - "content": message['content'], - "tool_call_id": message.get('tool_call_id'), - } - }, indent=2) - del message['name'] - - if caps.requires_typed_content: - for message in context['messages']: - if 'content' in message and isinstance(message['content'], str): - message['content'] = [{"type": "text", "text": message['content']}] - - try: - output = template.render(**context) - except Exception as e1: - for message in context['messages']: - if message.get("content") is None: - message["content"] = "" - - try: - output = template.render(**context) - except Exception as e2: - logger.info(f" ERROR: {e2} (after first error: {e1})") - output = f"ERROR: {e2}" - + output = template.apply(context) async with aiofiles.open(output_file, 'w') as f: await f.write(output) From 876fecb35b765b25031ba94da682fbe06cc991bf Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 20:59:51 +0000 Subject: [PATCH 06/20] print test command in format pastable to vscode debugger --- tests/test-supported-template.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test-supported-template.cpp b/tests/test-supported-template.cpp index f23cd8c..e35eeb5 100644 --- a/tests/test-supported-template.cpp +++ b/tests/test-supported-template.cpp @@ -104,10 +104,8 @@ int main(int argc, char *argv[]) { return 127; } - std::cout << "# Testing template: " << tmpl_file << std::endl - << "# With caps: " << caps_file << std::endl - << "# With context: " << ctx_file << std::endl - << "# Against golden file: " << golden_file << std::endl + std::cout << "# Testing template:\n" + << "# ./build/bin/test-supported-template " << json::array({tmpl_file, caps_file, ctx_file, golden_file}).dump() << std::endl << std::flush; auto ctx = json::parse(read_file(ctx_file)); From 53fabddc8a3ac7b4a316961b51032908035fa637 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 21:19:05 +0000 Subject: [PATCH 07/20] Update chat-template.hpp --- include/minja/chat-template.hpp | 74 +++++++++++++++++---------------- 1 file changed, 39 insertions(+), 35 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index eee28cb..e6c8420 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -178,41 +178,45 @@ class chat_template { caps_.supports_tool_call_id = contains(out, "call_911_"); } - // if (!caps_.supports_tools) { - // const json user_msg { - // {"role", "user"}, - // {"content", "Hey"}, - // }; - // const json tool_call_msg { - // {"role", "assistant"}, - // {"content", nullptr}, - // {"tool_calls", json::array({ - // { - // // TODO: detect if requires numerical id or fixed length == 6 like Nemo - // {"id", "call_1___"}, - // {"type", "function"}, - // {"function", { - // {"name", "tool_name"}, - // {"arguments", (json { - // {"arg1", "some_value"}, - // }).dump()}, - // }}, - // }, - // })}, - // }; - // const json tools; - // auto prefix = apply(json::array({user_msg}), tools, /* add_generation_prompt= */ true); - // auto full = apply(json::array({user_msg, tool_call_msg}), tools, /* add_generation_prompt= */ false); - // if (full.find(prefix) != 0) { - // if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) { - // prefix = prefix.substr(0, prefix.size() - eos_token_.size()); - // } else { - // throw std::runtime_error("prefix not found at start of full: " + prefix + " vs " + full); - // } - // } else { - - // } - // tool_call_example_ = full.substr(prefix.size()); + // try { + if (!caps_.supports_tools) { + const json user_msg { + {"role", "user"}, + {"content", "Hey"}, + }; + const json tool_call_msg { + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + { + // TODO: detect if requires numerical id or fixed length == 6 like Nemo + {"id", "call_1___"}, + {"type", "function"}, + {"function", { + {"name", "tool_name"}, + {"arguments", (json { + {"arg1", "some_value"}, + }).dump()}, + }}, + }, + })}, + }; + const json tools; + auto prefix = apply(json::array({user_msg}), tools, /* add_generation_prompt= */ true); + auto full = apply(json::array({user_msg, tool_call_msg}), tools, /* add_generation_prompt= */ false); + if (full.find(prefix) != 0) { + if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) { + prefix = prefix.substr(0, prefix.size() - eos_token_.size()); + // } else { + // throw std::runtime_error("# prefix not found at start of prefix:\n" + prefix + "\n# vs full:\n" + full + "\n#"); + } + } else { + + } + tool_call_example_ = full.substr(prefix.size()); + } + // } catch (const std::exception & e) { + // fprintf(stderr, "Failed to generate tool call example: %s\n", e.what()); // } } From fb7a3b31c23ec3c122839c1a17bccf2a65ded36c Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 21:41:44 +0000 Subject: [PATCH 08/20] fix compilation --- examples/chat-template.cpp | 22 ++++++++++++---------- include/minja/chat-template.hpp | 2 +- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/examples/chat-template.cpp b/examples/chat-template.cpp index 8161b2b..d1838e7 100644 --- a/examples/chat-template.cpp +++ b/examples/chat-template.cpp @@ -19,14 +19,16 @@ int main() { /* bos_token= */ "<|start|>", /* eos_token= */ "<|end|>" ); - std::cout << tmpl.apply( - json::parse(R"([ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there"} - ])"), - json::parse(R"([ - {"type": "function", "function": {"name": "google_search", "arguments": {"query": "2+2"}}} - ])"), - /* add_generation_prompt= */ true, - /* extra_context= */ {}) << std::endl; + + minja::chat_template_inputs inputs; + inputs.messages = json::parse(R"([ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"} + ])"); + inputs.add_generation_prompt = true; + inputs.tools = json::parse(R"([ + {"type": "function", "function": {"name": "google_search", "arguments": {"query": "2+2"}}} + ])"); + + std::cout << tmpl.apply(inputs) << std::endl; } diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index b1f8ee3..5e901f1 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -231,7 +231,7 @@ class chat_template { inputs.messages = json::array({user_msg}); inputs.add_generation_prompt = true; auto prefix = apply(inputs); - + inputs.messages.push_back(tool_call_msg); inputs.add_generation_prompt = false; auto full = apply(inputs); From 354e77a9da87caec5a15c7294c200c8cc536aa6b Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 21:58:35 +0000 Subject: [PATCH 09/20] implement strftime_now + respect new opts --- include/minja/chat-template.hpp | 40 ++++++++++++++++++++----------- tests/test-supported-template.cpp | 8 +++++++ 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 5e901f1..8989fa3 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -204,7 +204,8 @@ class chat_template { caps_.supports_tool_call_id = contains(out, "call_911_"); } - // try { +#if 0 + try { if (!caps_.supports_tools) { const json user_msg { {"role", "user"}, @@ -247,9 +248,10 @@ class chat_template { } tool_call_example_ = full.substr(prefix.size()); } - // } catch (const std::exception & e) { - // fprintf(stderr, "Failed to generate tool call example: %s\n", e.what()); - // } + } catch (const std::exception & e) { + fprintf(stderr, "Failed to generate tool call example: %s\n", e.what()); + } +#endif } const std::string & source() const { return source_; } @@ -415,18 +417,28 @@ class chat_template { auto context = minja::Context::make(json({ {"messages", actual_messages}, {"add_generation_prompt", inputs.add_generation_prompt}, - {"bos_token", bos_token_}, - {"eos_token", eos_token_}, - // {"strftime_now", Value::callable([=](const std::shared_ptr & context, minja::ArgumentsValue & args) { - // args.expectArgs("strftime_now", {1, 1}, {0, 0}); - // auto format = args.args[0].get(); - // return Value(std::to_string(inputs.now)); - // })}, })); - + if (opts.use_bos_token) { + context->set("bos_token", bos_token_); + } + if (opts.use_eos_token) { + context->set("eos_token", eos_token_); + } + if (opts.define_strftime_now) { + auto now = inputs.now; + context->set("strftime_now", Value::callable([now](const std::shared_ptr &, minja::ArgumentsValue & args) { + args.expectArgs("strftime_now", {1, 1}, {0, 0}); + auto format = args.args[0].get(); + + auto time = std::chrono::system_clock::to_time_t(now); + auto local_time = *std::localtime(&time); + std::ostringstream ss; + ss << std::put_time(&local_time, format.c_str()); + return ss.str(); + })); + } if (!inputs.tools.is_null()) { - auto tools_val = minja::Value(inputs.tools); - context->set("tools", tools_val); + context->set("tools", minja::Value(inputs.tools)); } if (!inputs.extra_context.is_null()) { for (auto & kv : inputs.extra_context.items()) { diff --git a/tests/test-supported-template.cpp b/tests/test-supported-template.cpp index 71625d6..a7f69a3 100644 --- a/tests/test-supported-template.cpp +++ b/tests/test-supported-template.cpp @@ -18,6 +18,8 @@ #undef NDEBUG #include +#define TEST_DATE (getenv("TEST_DATE") ? getenv("TEST_DATE") : "2024-07-26") + using json = nlohmann::ordered_json; template @@ -128,6 +130,12 @@ int main(int argc, char *argv[]) { inputs.messages = ctx.at("messages"); inputs.tools = ctx.contains("tools") ? ctx.at("tools") : json(); inputs.add_generation_prompt = ctx.at("add_generation_prompt"); + + std::istringstream ss(TEST_DATE); + std::tm tm = {}; + ss >> std::get_time(&tm, "%Y-%m-%d"); + inputs.now = std::chrono::system_clock::from_time_t(std::mktime(&tm)); + if (ctx.contains("tools")) { inputs.extra_context = json { {"builtin_tools", { From a6a5cba00eea62fe3385a8dd1deeb4fc88f92fa5 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 22:17:54 +0000 Subject: [PATCH 10/20] more defensive against non arrays in reject/accept + fix test builtin_tools --- include/minja/minja.hpp | 5 +++++ tests/test-supported-template.cpp | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index 0eadae8..069358e 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -2695,6 +2695,10 @@ inline std::shared_ptr Context::builtins() { return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits::max)()}, {0, 0}); auto & items = args.args[0]; + if (items.is_null()) + return Value::array(); + if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); + auto filter_fn = context->get(args.args[1]); if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); @@ -2772,6 +2776,7 @@ inline std::shared_ptr Context::builtins() { auto & items = args.args[0]; if (items.is_null()) return Value::array(); + if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); auto attr_name = args.args[1].get(); bool has_test = false; diff --git a/tests/test-supported-template.cpp b/tests/test-supported-template.cpp index a7f69a3..6cbfb7c 100644 --- a/tests/test-supported-template.cpp +++ b/tests/test-supported-template.cpp @@ -139,7 +139,7 @@ int main(int argc, char *argv[]) { if (ctx.contains("tools")) { inputs.extra_context = json { {"builtin_tools", { - {"wolfram_alpha", "brave_search"} + json::array({"wolfram_alpha", "brave_search"}) }}, }; } From f74b40d43ce6215dac33d6d341936d0917b43c6e Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 22:25:56 +0000 Subject: [PATCH 11/20] rename test script --- README.md | 4 ++-- scripts/{run_fuzzing_mode.sh => fuzzing_tests.sh} | 0 scripts/{run_tests.sh => tests.sh} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename scripts/{run_fuzzing_mode.sh => fuzzing_tests.sh} (100%) rename scripts/{run_tests.sh => tests.sh} (100%) diff --git a/README.md b/README.md index 85a3089..d45c45e 100644 --- a/README.md +++ b/README.md @@ -157,7 +157,7 @@ Main limitations (non-exhaustive list): huggingface-cli login ``` -- Build & run tests (shorthand: `./scripts/run_tests.sh`): +- Build & run tests (shorthand: `./scripts/tests.sh`): ```bash rm -fR build && \ @@ -195,7 +195,7 @@ Main limitations (non-exhaustive list): - Build in [fuzzing mode](https://github.com/google/fuzztest/blob/main/doc/quickstart-cmake.md#fuzzing-mode) & run all fuzzing tests (optionally, set a higher `TIMEOUT` as env var): ```bash - ./scripts/run_fuzzing_mode.sh + ./scripts/fuzzing_tests.sh ``` - If your model's template doesn't run fine, please consider the following before [opening a bug](https://github.com/googlestaging/minja/issues/new): diff --git a/scripts/run_fuzzing_mode.sh b/scripts/fuzzing_tests.sh similarity index 100% rename from scripts/run_fuzzing_mode.sh rename to scripts/fuzzing_tests.sh diff --git a/scripts/run_tests.sh b/scripts/tests.sh similarity index 100% rename from scripts/run_tests.sh rename to scripts/tests.sh From a10a911edc4bc58312660e9e9e14147576971f0c Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 22:26:14 +0000 Subject: [PATCH 12/20] more type defensiveness --- include/minja/chat-template.hpp | 3 +-- include/minja/minja.hpp | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 8989fa3..eed6946 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -442,8 +442,7 @@ class chat_template { } if (!inputs.extra_context.is_null()) { for (auto & kv : inputs.extra_context.items()) { - minja::Value val(kv.value()); - context->set(kv.key(), val); + context->set(kv.key(), minja::Value(kv.value())); } } diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index 069358e..c304b5c 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -2615,6 +2615,7 @@ inline std::shared_ptr Context::builtins() { })); globals.set("join", simple_function("join", { "items", "d" }, [](const std::shared_ptr &, Value & args) { auto do_join = [](Value & items, const std::string & sep) { + if (!items.is_array()) throw std::runtime_error("object is not iterable: " + items.dump()); std::ostringstream oss; auto first = true; for (size_t i = 0, n = items.size(); i < n; ++i) { From 968ea9f8a3fead0efd3637f164df15468d15c300 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 22:28:29 +0000 Subject: [PATCH 13/20] fix json typo --- tests/test-supported-template.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test-supported-template.cpp b/tests/test-supported-template.cpp index 6cbfb7c..a1cf1db 100644 --- a/tests/test-supported-template.cpp +++ b/tests/test-supported-template.cpp @@ -138,9 +138,7 @@ int main(int argc, char *argv[]) { if (ctx.contains("tools")) { inputs.extra_context = json { - {"builtin_tools", { - json::array({"wolfram_alpha", "brave_search"}) - }}, + {"builtin_tools", json::array({"wolfram_alpha", "brave_search"})}, }; } std::string actual; From 16664ae4691e210a1b5055948db089577905953a Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 22:41:56 +0000 Subject: [PATCH 14/20] Update tools prompt --- include/minja/chat-template.hpp | 55 ++++++++++++-------------- scripts/fetch_templates_and_goldens.py | 4 +- 2 files changed, 28 insertions(+), 31 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index eed6946..5483f77 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -204,7 +204,6 @@ class chat_template { caps_.supports_tool_call_id = contains(out, "call_911_"); } -#if 0 try { if (!caps_.supports_tools) { const json user_msg { @@ -228,30 +227,33 @@ class chat_template { }, })}, }; - chat_template_inputs inputs; - inputs.messages = json::array({user_msg}); - inputs.add_generation_prompt = true; - auto prefix = apply(inputs); - - inputs.messages.push_back(tool_call_msg); - inputs.add_generation_prompt = false; - auto full = apply(inputs); + std::string prefix, full; + { + chat_template_inputs inputs; + inputs.messages = json::array({user_msg}); + inputs.add_generation_prompt = true; + prefix = apply(inputs); + } + { + chat_template_inputs inputs; + inputs.messages = json::array({user_msg, tool_call_msg}); + inputs.add_generation_prompt = false; + full = apply(inputs); + } if (full.find(prefix) != 0) { if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) { prefix = prefix.substr(0, prefix.size() - eos_token_.size()); - // } else { - // throw std::runtime_error("# prefix not found at start of prefix:\n" + prefix + "\n# vs full:\n" + full + "\n#"); } - } else { - + } + if (full.find(prefix) != 0) { + fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n"); } tool_call_example_ = full.substr(prefix.size()); } } catch (const std::exception & e) { fprintf(stderr, "Failed to generate tool call example: %s\n", e.what()); } -#endif } const std::string & source() const { return source_; } @@ -265,21 +267,16 @@ class chat_template { { json actual_messages; + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto needs_polyfills = opts.apply_polyfills && (false || !caps_.supports_system_role - || !caps_.supports_tools - || !caps_.supports_tool_responses - || !caps_.supports_tool_calls - || caps_.requires_object_arguments + || (has_tools && (false + || !caps_.supports_tools + || !caps_.supports_tool_responses + || !caps_.supports_tool_calls + || caps_.requires_object_arguments + )) || caps_.requires_typed_content - // || !caps_.supports_system_role - // || (!tools.is_null() && (false - // || !caps_.supports_tools - // || !caps_.supports_tool_responses - // || !caps_.supports_tool_calls - // || caps_.requires_object_arguments - // )) - // || caps_.requires_typed_content ); if (needs_polyfills) { actual_messages = json::array(); @@ -313,9 +310,9 @@ class chat_template { json adjusted_messages; if (needs_tools_in_system) { adjusted_messages = add_system(inputs.messages, - "Available tools: " + inputs.tools.dump(2)); // "\n\n" - // "You can call any of the following tools to satisfy the user's requests: " + tools.dump(2) + "\n\n" + "You can call any of the following tools to satisfy the user's requests: " + inputs.tools.dump(2)); + // "\n\n" // "Example tool call syntax:\n\n" + tool_call_example_ + "\n\n"); } else { adjusted_messages = inputs.messages; @@ -459,7 +456,7 @@ class chat_template { std::string existing_system = messages_with_system.at(0).at("content"); messages_with_system[0] = json { {"role", "system"}, - {"content", existing_system + "\n" + system_prompt}, + {"content", existing_system + "\n\n" + system_prompt}, }; } else { messages_with_system.insert(messages_with_system.begin(), json { diff --git a/scripts/fetch_templates_and_goldens.py b/scripts/fetch_templates_and_goldens.py index 3251b10..a139d1e 100644 --- a/scripts/fetch_templates_and_goldens.py +++ b/scripts/fetch_templates_and_goldens.py @@ -66,7 +66,7 @@ def add_system(messages, system_prompt): existing_system = messages[0]["content"] messages[0] = { "role": "system", - "content": existing_system + "\n" + system_prompt, + "content": existing_system + "\n\n" + system_prompt, } else: messages.insert(0, { @@ -243,7 +243,7 @@ def apply(self, context): if self.needs_polyfills(context): if has_tools and not caps.supports_tools: - add_system(context['messages'], f"Available tools: {json.dumps(context['tools'], indent=2)}") + add_system(context['messages'], f"You can call any of the following tools to satisfy the user's requests: {json.dumps(context['tools'], indent=2)}") for message in context['messages']: if 'tool_calls' in message: From 01bc7f3fed148942eb36a317f09df1b5577fce08 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 23:01:24 +0000 Subject: [PATCH 15/20] expose all polyfills as options --- include/minja/chat-template.hpp | 73 +++++++++++++++++++++---------- tests/test-supported-template.cpp | 7 ++- 2 files changed, 55 insertions(+), 25 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 5483f77..46b8db6 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -46,6 +46,14 @@ struct chat_template_options { bool use_bos_token = true; bool use_eos_token = true; bool define_strftime_now = true; + + bool polyfill_tools = true; + bool polyfill_tool_call_examples = true; + bool polyfill_tool_calls = true; + bool polyfill_tool_responses = true; + bool polyfill_system_role = true; + bool polyfill_object_arguments = true; + bool polyfill_typed_content = true; }; class chat_template { @@ -268,21 +276,43 @@ class chat_template { json actual_messages; auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + auto has_tool_calls = false; + auto has_tool_responses = false; + auto has_string_content = false; + for (const auto & message : inputs.messages) { + if (!message["tool_calls"].is_null()) { + has_tool_calls = true; + } + if (message["role"] == "tool") { + has_tool_responses = true; + } + if (message["content"].is_string()) { + has_string_content = true; + } + } + + auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role; + auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools; + auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples; + auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls; + auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses; + auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments; + auto polyfill_typed_content = opts.polyfill_typed_content && has_string_content && caps_.requires_typed_content; + auto needs_polyfills = opts.apply_polyfills && (false - || !caps_.supports_system_role - || (has_tools && (false - || !caps_.supports_tools - || !caps_.supports_tool_responses - || !caps_.supports_tool_calls - || caps_.requires_object_arguments - )) - || caps_.requires_typed_content + || polyfill_system_role + || polyfill_tools + || polyfill_tool_calls + || polyfill_tool_responses + || polyfill_object_arguments + || polyfill_typed_content ); + if (needs_polyfills) { actual_messages = json::array(); auto add_message = [&](const json & msg) { - if (caps_.requires_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) { + if (polyfill_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) { actual_messages.push_back({ {"role", msg.at("role")}, {"content", {{ @@ -305,15 +335,12 @@ class chat_template { pending_system.clear(); } }; - auto needs_tools_in_system = !inputs.tools.is_null() && inputs.tools.size() > 0 && !caps_.supports_tools; json adjusted_messages; - if (needs_tools_in_system) { + if (polyfill_tools) { adjusted_messages = add_system(inputs.messages, - // "\n\n" - "You can call any of the following tools to satisfy the user's requests: " + inputs.tools.dump(2)); - // "\n\n" - // "Example tool call syntax:\n\n" + tool_call_example_ + "\n\n"); + "You can call any of the following tools to satisfy the user's requests: " + inputs.tools.dump(2) + + (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_)); } else { adjusted_messages = inputs.messages; } @@ -326,7 +353,7 @@ class chat_template { std::string role = message.at("role"); if (message.contains("tool_calls")) { - if (caps_.requires_object_arguments || !caps_.supports_tool_calls) { + if (polyfill_object_arguments || polyfill_tool_calls) { for (auto & tool_call : message.at("tool_calls")) { if (tool_call["type"] == "function") { auto & function = tool_call.at("function"); @@ -341,7 +368,7 @@ class chat_template { } } } - if (!caps_.supports_tool_calls) { + if (polyfill_tool_calls) { auto content = message.at("content"); auto tool_calls = json::array(); for (const auto & tool_call : message.at("tool_calls")) { @@ -368,7 +395,7 @@ class chat_template { message.erase("tool_calls"); } } - if (!caps_.supports_tool_responses && role == "tool") { + if (polyfill_tool_responses && role == "tool") { message["role"] = "user"; auto obj = json { {"tool_response", { @@ -385,7 +412,7 @@ class chat_template { message.erase("name"); } - if (!message["content"].is_null() && !caps_.supports_system_role) { + if (!message["content"].is_null() && polyfill_system_role) { std::string content = message.at("content"); if (role == "system") { if (!pending_system.empty()) pending_system += "\n"; @@ -404,9 +431,7 @@ class chat_template { } add_message(message); } - if (!caps_.supports_system_role) { - flush_sys(); - } + flush_sys(); } else { actual_messages = inputs.messages; } @@ -426,13 +451,13 @@ class chat_template { context->set("strftime_now", Value::callable([now](const std::shared_ptr &, minja::ArgumentsValue & args) { args.expectArgs("strftime_now", {1, 1}, {0, 0}); auto format = args.args[0].get(); - + auto time = std::chrono::system_clock::to_time_t(now); auto local_time = *std::localtime(&time); std::ostringstream ss; ss << std::put_time(&local_time, format.c_str()); return ss.str(); - })); + })); } if (!inputs.tools.is_null()) { context->set("tools", minja::Value(inputs.tools)); diff --git a/tests/test-supported-template.cpp b/tests/test-supported-template.cpp index a1cf1db..96d7cfa 100644 --- a/tests/test-supported-template.cpp +++ b/tests/test-supported-template.cpp @@ -141,9 +141,14 @@ int main(int argc, char *argv[]) { {"builtin_tools", json::array({"wolfram_alpha", "brave_search"})}, }; } + + minja::chat_template_options opts; + // TODO: implement logic for examples in python + opts.polyfill_tool_call_examples = false; + std::string actual; try { - actual = tmpl.apply(inputs); + actual = tmpl.apply(inputs, opts); } catch (const std::exception &e) { std::cerr << "Error applying template: " << e.what() << std::endl; return 1; From a1608667d1316f0726dbae1807ae40a260d1a9e7 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 23:30:22 +0000 Subject: [PATCH 16/20] Fully align tool call examples backfill test python logic w/ c++ --- include/minja/chat-template.hpp | 9 +-- scripts/fetch_templates_and_goldens.py | 92 +++++++++++++++++++------- tests/test-supported-template.cpp | 6 +- 3 files changed, 74 insertions(+), 33 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 46b8db6..dfd46d7 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -218,6 +218,9 @@ class chat_template { {"role", "user"}, {"content", "Hey"}, }; + const json args { + {"arg1", "some_value"}, + }; const json tool_call_msg { {"role", "assistant"}, {"content", nullptr}, @@ -228,9 +231,7 @@ class chat_template { {"type", "function"}, {"function", { {"name", "tool_name"}, - {"arguments", (json { - {"arg1", "some_value"}, - }).dump()}, + {"arguments", (caps_.requires_object_arguments ? args : json(minja::Value(args).dump(-1, /* to_json= */ true)))}, }}, }, })}, @@ -339,7 +340,7 @@ class chat_template { json adjusted_messages; if (polyfill_tools) { adjusted_messages = add_system(inputs.messages, - "You can call any of the following tools to satisfy the user's requests: " + inputs.tools.dump(2) + + "You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) + (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_)); } else { adjusted_messages = inputs.messages; diff --git a/scripts/fetch_templates_and_goldens.py b/scripts/fetch_templates_and_goldens.py index a139d1e..66238bc 100644 --- a/scripts/fetch_templates_and_goldens.py +++ b/scripts/fetch_templates_and_goldens.py @@ -86,7 +86,7 @@ class TemplateCaps: requires_object_arguments: bool = False requires_non_null_content: bool = False requires_typed_content: bool = False - + def to_json(self): return json.dumps({ "supports_tools": self.supports_tools, @@ -108,7 +108,7 @@ def try_raw_render(self, messages, *, tools=[], add_generation_prompt=False, ext "bos_token": "<|startoftext|>", "eos_token": "<|endoftext|>", } - + try: out = self.template.render(messages=messages, tools=tools, add_generation_prompt=add_generation_prompt, **basic_extra_context, **extra_context) # print(out, file=sys.stderr) @@ -117,7 +117,7 @@ def try_raw_render(self, messages, *, tools=[], add_generation_prompt=False, ext # print(f"{template_file}: Error rendering template with messages {messages}: {e}", file=sys.stderr, flush=True) return "" - def __init__(self, template, env=None): + def __init__(self, template, known_eos_tokens, env=None): if not env: env = jinja2.Environment( trim_blocks=True, @@ -128,21 +128,21 @@ def __init__(self, template, env=None): self.template = env.from_string(template) caps = TemplateCaps() - + user_needle = "" sys_needle = "" dummy_str_user_msg = {"role": "user", "content": user_needle } dummy_typed_user_msg = {"role": "user", "content": [{"type": "text", "text": user_needle}]} - + caps.requires_typed_content = \ (user_needle not in self.try_raw_render([dummy_str_user_msg])) \ and (user_needle in self.try_raw_render([dummy_typed_user_msg])) dummy_user_msg = dummy_typed_user_msg if caps.requires_typed_content else dummy_str_user_msg - + needle_system_msg = {"role": "system", "content": [{"type": "text", "text": sys_needle}] if caps.requires_typed_content else sys_needle} - + caps.supports_system_role = sys_needle in self.try_raw_render([needle_system_msg, dummy_user_msg]) - + out = self.try_raw_render([dummy_user_msg], tools=[{ "name": "some_tool", "type": "function", @@ -162,7 +162,7 @@ def __init__(self, template, env=None): }, }]) caps.supports_tools = "some_tool" in out - + def make_tool_calls_msg(tool_calls, content=None): return { "role": "assistant", @@ -178,7 +178,7 @@ def make_tool_call(tool_name, arguments): "name": tool_name, } } - + dummy_args_obj = {"argument_needle": "print('Hello, World!')"} out = self.try_raw_render([ @@ -191,10 +191,10 @@ def make_tool_call(tool_name, arguments): make_tool_calls_msg([make_tool_call("ipython", dummy_args_obj)]), ]) tool_call_renders_obj_arguments = '"argument_needle":' in out or "'argument_needle':" in out - + caps.supports_tool_calls = tool_call_renders_str_arguments or tool_call_renders_obj_arguments caps.requires_object_arguments = not tool_call_renders_str_arguments and tool_call_renders_obj_arguments - + caps.requires_non_null_content = \ (user_needle in self.try_raw_render([dummy_user_msg, {"role": "assistant", "content": ''}])) \ and (user_needle not in self.try_raw_render([dummy_user_msg, {"role": "assistant", "content": None}])) @@ -208,7 +208,7 @@ def make_tool_call(tool_name, arguments): make_tool_calls_msg([tc1, tc2]), ]) caps.supports_parallel_tool_calls = "test_tool1" in out and "test_tool2" in out - + out = self.try_raw_render([ dummy_user_msg, make_tool_calls_msg([tc1]), @@ -221,9 +221,42 @@ def make_tool_call(tool_name, arguments): ]) caps.supports_tool_responses = "Some response!" in out caps.supports_tool_call_id = "call_911_" in out - + + self.tool_call_example = None + try: + if not caps.supports_tools: + user_msg = {"role": "user", "content": "Hey"} + args = {"arg1": "some_value"} + tool_call_msg = { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_1___", + "type": "function", + "function": { + "name": "tool_name", + "arguments": args if caps.requires_object_arguments else json.dumps(args), + }, + }, + ], + } + prefix = self.try_raw_render([user_msg], add_generation_prompt=True) + full = self.try_raw_render([user_msg, tool_call_msg], add_generation_prompt=False) + if not full.startswith(prefix): + for known_eos_token in known_eos_tokens: + prefix = prefix.rstrip() + if prefix.endswith(known_eos_token): + prefix = prefix[:-len(known_eos_token)] + break + if not full.startswith(prefix): + print("Failed to infer a tool call example (possible template bug)", file=sys.stderr) + self.tool_call_example = full[len(prefix):] + except Exception as e: + print(f"Failed to generate tool call example: {e}", file=sys.stderr) + self.original_caps = caps - + def needs_polyfills(self, context): has_tools = context.get('tools') is not None caps = self.original_caps @@ -237,13 +270,15 @@ def needs_polyfills(self, context): or caps.requires_typed_content def apply(self, context): - + caps = self.original_caps has_tools = 'tools' in context if self.needs_polyfills(context): if has_tools and not caps.supports_tools: - add_system(context['messages'], f"You can call any of the following tools to satisfy the user's requests: {json.dumps(context['tools'], indent=2)}") + add_system(context['messages'], + f"You can call any of the following tools to satisfy the user's requests: {json.dumps(context['tools'], indent=2)}" + + ("\n\nExample tool call syntax:\n\n" + self.tool_call_example if self.tool_call_example is not None else "")) for message in context['messages']: if 'tool_calls' in message: @@ -299,7 +334,7 @@ def apply(self, context): return f"ERROR: {e2}" - + async def handle_chat_template(output_folder, model_id, variant, template_src, context_files): if '{% generation %}' in template_src: @@ -315,21 +350,30 @@ async def handle_chat_template(output_folder, model_id, variant, template_src, c async with aiofiles.open(template_file, 'w') as f: await f.write(template_src) - template = chat_template(template_src) + known_eos_tokens = [ + "<|END_OF_TURN_TOKEN|>", + "", + "", + "<|im_end|>", + "<|eom_id|>", + "<|eot_id|>", + "<|end▁of▁sentence|>", + ] + + template = chat_template(template_src, known_eos_tokens) template.env.filters['safe'] = lambda x: x template.env.filters['tojson'] = tojson template.env.globals['raise_exception'] = raise_exception template.env.globals['strftime_now'] = strftime_now - caps = template.original_caps - + if not context_files: print(f"{template_file} {caps_file} n/a {template_file}") return async with aiofiles.open(caps_file, 'w') as f: await f.write(caps.to_json()) - + for context_file in context_files: context_name = os.path.basename(context_file).replace(".json", "") async with aiofiles.open(context_file, 'r') as f: @@ -338,7 +382,7 @@ async def handle_chat_template(output_folder, model_id, variant, template_src, c if not caps.supports_tool_calls and context.get('tools') is not None: print(f'Skipping {context_name} test as tools seem unsupported by template {template_file}', file=sys.stderr) continue - + needs_tools_in_system = len(context.get('tools', [])) > 0 and not caps.supports_tools if not caps.supports_system_role and (any(m['role'] == 'system' for m in context['messages']) or needs_tools_in_system): continue @@ -362,7 +406,7 @@ async def async_hf_download(repo_id: str, filename: str) -> str: async def process_model(output_folder: str, model_id: str, context_files: list): try: config_str = await async_hf_download(model_id, "tokenizer_config.json") - + try: config = json.loads(config_str) except json.JSONDecodeError: diff --git a/tests/test-supported-template.cpp b/tests/test-supported-template.cpp index 96d7cfa..302ebbd 100644 --- a/tests/test-supported-template.cpp +++ b/tests/test-supported-template.cpp @@ -142,13 +142,9 @@ int main(int argc, char *argv[]) { }; } - minja::chat_template_options opts; - // TODO: implement logic for examples in python - opts.polyfill_tool_call_examples = false; - std::string actual; try { - actual = tmpl.apply(inputs, opts); + actual = tmpl.apply(inputs); } catch (const std::exception &e) { std::cerr << "Error applying template: " << e.what() << std::endl; return 1; From e5afc512dd8fd87f19ab204ac7f1b8f7cfab21bf Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 23:44:10 +0000 Subject: [PATCH 17/20] Add / deprecate old chat_template::apply overload --- include/minja/chat-template.hpp | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index dfd46d7..2c3d96c 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -270,6 +270,28 @@ class chat_template { const std::string & eos_token() const { return eos_token_; } const chat_template_caps & original_caps() const { return caps_; } + // Deprecated, please use the form with chat_template_inputs and chat_template_options + std::string apply( + const nlohmann::ordered_json & messages, + const nlohmann::ordered_json & tools, + bool add_generation_prompt, + const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(), + bool apply_polyfills = true) + { + fprintf(stderr, "[%s] Deprecated!\n", __func__); + chat_template_inputs inputs; + inputs.messages = messages; + inputs.tools = tools; + inputs.add_generation_prompt = add_generation_prompt; + inputs.extra_context = extra_context; + inputs.now = std::chrono::system_clock::now(); + + chat_template_options opts; + opts.apply_polyfills = apply_polyfills; + + return apply(inputs, opts); + } + std::string apply( const chat_template_inputs & inputs, const chat_template_options & opts = chat_template_options()) const From f9969534fce007440e0b7ccf58673091a79cc657 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 3 Feb 2025 23:58:35 +0000 Subject: [PATCH 18/20] fix crash --- include/minja/chat-template.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 2c3d96c..69ee4e8 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -303,13 +303,13 @@ class chat_template { auto has_tool_responses = false; auto has_string_content = false; for (const auto & message : inputs.messages) { - if (!message["tool_calls"].is_null()) { + if (message.contains("tool_calls") && !message["tool_calls"].is_null()) { has_tool_calls = true; } - if (message["role"] == "tool") { + if (message.contains("role") && message["role"] == "tool") { has_tool_responses = true; } - if (message["content"].is_string()) { + if (message.contains("content") && message["content"].is_string()) { has_string_content = true; } } From 7e7730e22c5fa6cc7f9d9eb2da2a43e6fb264e45 Mon Sep 17 00:00:00 2001 From: ochafik Date: Tue, 4 Feb 2025 00:40:41 +0000 Subject: [PATCH 19/20] mute another win arm64 deprecation --- tests/CMakeLists.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 0d1c40e..6a34dd4 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -41,6 +41,9 @@ set_tests_properties(test-syntax-jinja2 PROPERTIES ENVIRONMENT "USE_JINJA2=1;PYT add_executable(test-supported-template test-supported-template.cpp) target_compile_features(test-supported-template PUBLIC cxx_std_17) +if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + target_compile_definitions(test-supported-template PUBLIC _CRT_SECURE_NO_WARNINGS) +endif() target_link_libraries(test-supported-template PRIVATE nlohmann_json::nlohmann_json) set(MODEL_IDS From e083fc1e2c937f42839f1bde92a5924b014253b4 Mon Sep 17 00:00:00 2001 From: ochafik Date: Tue, 4 Feb 2025 01:07:04 +0000 Subject: [PATCH 20/20] mute another win arm64 deprecation --- examples/CMakeLists.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 750bdea..7d75913 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -12,5 +12,8 @@ foreach(example add_executable(${example} ${example}.cpp) target_compile_features(${example} PUBLIC cxx_std_17) target_link_libraries(${example} PRIVATE nlohmann_json::nlohmann_json) + if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + target_compile_definitions(${example} PUBLIC _CRT_SECURE_NO_WARNINGS) + endif() endforeach()