From 15d49a404e20cbb0881e7bc54f91caafcc1226fb Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 26 Jan 2025 20:30:59 +0000 Subject: [PATCH 1/3] Support select & rejectattr --- include/minja/minja.hpp | 113 +++++++++++++++++++++------------------- tests/test-syntax.cpp | 6 +++ 2 files changed, 66 insertions(+), 53 deletions(-) diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index 80bdd4b..604e613 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -2648,31 +2648,34 @@ inline std::shared_ptr Context::builtins() { return filter.call(context, actual_args); }); }; - // https://jinja.palletsprojects.com/en/3.0.x/templates/#jinja-filters.reject - globals.set("reject", Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { - args.expectArgs("reject", {2, (std::numeric_limits::max)()}, {0, 0}); - auto & items = args.args[0]; - auto filter_fn = context->get(args.args[1]); - if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); + auto select_or_reject = [make_filter](bool is_select) { + return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { + args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits::max)()}, {0, 0}); + auto & items = args.args[0]; + auto filter_fn = context->get(args.args[1]); + if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); - auto filter_args = Value::array(); - for (size_t i = 2, n = args.args.size(); i < n; i++) { - filter_args.push_back(args.args[i]); - } - auto filter = make_filter(filter_fn, filter_args); + auto filter_args = Value::array(); + for (size_t i = 2, n = args.args.size(); i < n; i++) { + filter_args.push_back(args.args[i]); + } + auto filter = make_filter(filter_fn, filter_args); - auto res = Value::array(); - for (size_t i = 0, n = items.size(); i < n; i++) { - auto & item = items.at(i); - ArgumentsValue filter_args; - filter_args.args.emplace_back(item); - auto pred_res = filter.call(context, filter_args); - if (!pred_res.to_bool()) { - res.push_back(item); + auto res = Value::array(); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto & item = items.at(i); + ArgumentsValue filter_args; + filter_args.args.emplace_back(item); + auto pred_res = filter.call(context, filter_args); + if (pred_res.to_bool() == (is_select ? true : false)) { + res.push_back(item); + } } - } - return res; - })); + return res; + }); + }; + globals.set("select", select_or_reject(/* is_select= */ true)); + globals.set("reject", select_or_reject(/* is_select= */ false)); globals.set("map", Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { auto res = Value::array(); if (args.args.size() == 1 && @@ -2720,41 +2723,45 @@ inline std::shared_ptr Context::builtins() { if (!text.empty() && text.back() == '\n') out += "\n"; return out; })); - globals.set("selectattr", Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { - args.expectArgs("selectattr", {2, (std::numeric_limits::max)()}, {0, 0}); - auto & items = args.args[0]; - if (items.is_null()) - return Value::array(); - auto attr_name = args.args[1].get(); - - bool has_test = false; - Value test_fn; - ArgumentsValue test_args {{Value()}, {}}; - if (args.args.size() >= 3) { - has_test = true; - test_fn = context->get(args.args[2]); - if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump()); - for (size_t i = 3, n = args.args.size(); i < n; i++) { - test_args.args.emplace_back(args.args[i]); + auto select_or_reject_attr = [](bool is_select) { + return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { + args.expectArgs(is_select ? "selectattr" : "rejectattr", {2, (std::numeric_limits::max)()}, {0, 0}); + auto & items = args.args[0]; + if (items.is_null()) + return Value::array(); + auto attr_name = args.args[1].get(); + + bool has_test = false; + Value test_fn; + ArgumentsValue test_args {{Value()}, {}}; + if (args.args.size() >= 3) { + has_test = true; + test_fn = context->get(args.args[2]); + if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump()); + for (size_t i = 3, n = args.args.size(); i < n; i++) { + test_args.args.emplace_back(args.args[i]); + } + test_args.kwargs = args.kwargs; } - test_args.kwargs = args.kwargs; - } - auto res = Value::array(); - for (size_t i = 0, n = items.size(); i < n; i++) { - auto & item = items.at(i); - auto attr = item.get(attr_name); - if (has_test) { - test_args.args[0] = attr; - if (test_fn.call(context, test_args).to_bool()) { - res.push_back(item); + auto res = Value::array(); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto & item = items.at(i); + auto attr = item.get(attr_name); + if (has_test) { + test_args.args[0] = attr; + if (test_fn.call(context, test_args).to_bool() == (is_select ? true : false)) { + res.push_back(item); + } + } else { + res.push_back(attr); } - } else { - res.push_back(attr); } - } - return res; - })); + return res; + }); + }; + globals.set("selectattr", select_or_reject_attr(/* is_select= */ true)); + globals.set("rejectattr", select_or_reject_attr(/* is_select= */ false)); globals.set("range", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { std::vector startEndStep(3); std::vector param_set(3); diff --git a/tests/test-syntax.cpp b/tests/test-syntax.cpp index 21e5ea4..54088b8 100644 --- a/tests/test-syntax.cpp +++ b/tests/test-syntax.cpp @@ -183,6 +183,9 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_EQ( R"([{'a': 1}])", render(R"({{ [{"a": 1}, {"a": 2}, {}] | selectattr("a", "equalto", 1) | list }})", {}, {})); + EXPECT_EQ( + R"([{'a': 2}, {}])", + render(R"({{ [{"a": 1}, {"a": 2}, {}] | rejectattr("a", "equalto", 1) | list }})", {}, {})); EXPECT_EQ( "[1, 2]", render(R"({{ [{"a": 1}, {"a": 2}] | map(attribute="a") | list }})", {}, {})); @@ -251,6 +254,9 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_EQ( "Tools: 1, 3...", render("{{ 'Tools: ' + [1, 2, 3] | reject('equalto', 2) | join(', ') + '...' }}", {}, {})); + EXPECT_EQ( + "Tools: 2...", + render("{{ 'Tools: ' + [1, 2, 3] | select('equalto', 2) | join(', ') + '...' }}", {}, {})); EXPECT_EQ( "1, 2, 3", render("{{ [1, 2, 3] | join(', ') }}", {}, {})); From ac93c241bba08a73acb8b2485d091edc8e720e8c Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 26 Jan 2025 20:51:52 +0000 Subject: [PATCH 2/3] test code_interpreter when supported explicitly by template --- scripts/fetch_templates_and_goldens.py | 6 +++ tests/contexts/tool_use_code_interpreter.json | 43 +++++++++++++++++++ 2 files changed, 49 insertions(+) create mode 100644 tests/contexts/tool_use_code_interpreter.json diff --git a/scripts/fetch_templates_and_goldens.py b/scripts/fetch_templates_and_goldens.py index 5a8348c..da07aae 100644 --- a/scripts/fetch_templates_and_goldens.py +++ b/scripts/fetch_templates_and_goldens.py @@ -85,6 +85,7 @@ def handle_chat_template(output_folder, model_id, variant, template_src, context env.globals['strftime_now'] = strftime_now template_handles_tools = 'tools' in template_src + supports_code_interpreter = 'code_interpreter' in template_src def renders(messages, *, tools=[], add_generation_prompt=False, extra_context={}, expect_strings=[]): @@ -142,6 +143,11 @@ def renders(messages, *, tools=[], add_generation_prompt=False, extra_context={} context = json.load(f) if not template_handles_tools and 'tools' in context: + print(f'Skipping {context_name} test as tools seem unsupported by template {template_file}', file=sys.stderr) + continue + + if not supports_code_interpreter and 'tools' in context and any(t['type'] == 'code_interpreter' for t in context['tools']): + print(f'Skipping {context_name} test as code_interpreter seems unsupported by template {template_file}', file=sys.stderr) continue if not supports_system_role and any(m['role'] == 'system' for m in context['messages']): diff --git a/tests/contexts/tool_use_code_interpreter.json b/tests/contexts/tool_use_code_interpreter.json new file mode 100644 index 0000000..ba6f159 --- /dev/null +++ b/tests/contexts/tool_use_code_interpreter.json @@ -0,0 +1,43 @@ +{ + "messages": [ + { + "role": "user", + "content": "Print a hello world message with python." + }, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_1___", + "type": "function", + "function": { + "arguments": "print('Hello, World!')", + "name": "python" + } + } + ] + }, + { + "role": "tool", + "tool_call_id": "call_1___", + "name": "python", + "content": "{\"stdout\": \"Hello, World!\"}" + } + ], + "add_generation_prompt": true, + "bos_token": "<|startoftext|>", + "eos_token": "<|endoftext|>", + "builtin_tools": [ + "wolfram_alpha", + "brave_search", + "code_interpreter" + ], + "cutting_knowledge_date": "2023-04-01", + "todays_date": "2024-09-03", + "tools": [ + { + "type": "code_interpreter" + } + ] +} \ No newline at end of file From 97c5fed63f5b5218e2c4fa268c5c16df234dbcfa Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 26 Jan 2025 20:58:37 +0000 Subject: [PATCH 3/3] Best effort deserialize of tools arguments when object required (to accommodate python raw string) --- include/minja/chat-template.hpp | 9 +++++++-- scripts/fetch_templates_and_goldens.py | 6 +++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index a89eb55..91d71a1 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -59,7 +59,7 @@ class chat_template { /* .keep_trailing_newline = */ false, }); supports_tools_ = source.find("tools") != std::string::npos; - + auto renders_string_arguments = try_raw_render({ { @@ -173,7 +173,12 @@ class chat_template { if (tool_call["type"] == "function") { auto & function = tool_call.at("function"); std::string arguments = function.at("arguments"); - function["arguments"] = json::parse(arguments); + try { + function["arguments"] = json::parse(arguments); + } catch (const std::exception & ecvt) { + fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what()); + function["arguments"] = arguments; + } } } } diff --git a/scripts/fetch_templates_and_goldens.py b/scripts/fetch_templates_and_goldens.py index da07aae..3813ff6 100644 --- a/scripts/fetch_templates_and_goldens.py +++ b/scripts/fetch_templates_and_goldens.py @@ -161,7 +161,11 @@ def renders(messages, *, tools=[], add_generation_prompt=False, extra_context={} for tool_call in message['tool_calls']: if tool_call.get('type') == 'function': arguments = tool_call['function']['arguments'] - tool_call['function']['arguments'] = json.loads(arguments) + try: + arguments = json.loads(arguments) + except: + pass + tool_call['function']['arguments'] = arguments if requires_typed_content: for message in context['messages']: