diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index 77d0ca4..979e53f 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -228,6 +228,9 @@ class Value : public std::enable_shared_from_this { } Value get(const Value& key) { if (array_) { + if (!key.is_number_integer()) { + return Value(); + } auto index = key.get(); return array_->at(index < 0 ? array_->size() + index : index); } else if (object_) { @@ -236,7 +239,7 @@ class Value : public std::enable_shared_from_this { if (it == object_->end()) return Value(); return it->second; } - throw std::runtime_error("Value is not an array or object: " + dump()); + return Value(); } void set(const Value& key, const Value& value) { if (!object_) throw std::runtime_error("Value is not an object: " + dump()); @@ -618,7 +621,7 @@ class Expression { Value evaluate(const std::shared_ptr & context) const { try { return do_evaluate(context); - } catch (const std::runtime_error & e) { + } catch (const std::exception & e) { std::ostringstream out; out << e.what(); if (location.source) out << error_location_suffix(*location.source, location.pos); @@ -769,7 +772,7 @@ class TemplateNode { void render(std::ostringstream & out, const std::shared_ptr & context) const { try { do_render(out, context); - } catch (const std::runtime_error & e) { + } catch (const std::exception & e) { std::ostringstream err; err << e.what(); if (location_.source) err << error_location_suffix(*location_.source, location_.pos); @@ -1092,15 +1095,24 @@ class SubscriptExpr : public Expression { if (!index) throw std::runtime_error("SubscriptExpr.index is null"); auto target_value = base->evaluate(context); if (auto slice = dynamic_cast(index.get())) { - if (!target_value.is_array()) throw std::runtime_error("Subscripting non-array"); - - auto start = slice->start ? slice->start->evaluate(context).get() : 0; - auto end = slice->end ? slice->end->evaluate(context).get() : target_value.size(); - auto result = Value::array(); - for (auto i = start; i < end; ++i) { - result.push_back(target_value.at(i)); + auto start = slice->start ? slice->start->evaluate(context).get() : 0; + auto end = slice->end ? slice->end->evaluate(context).get() : (int64_t) target_value.size(); + if (target_value.is_string()) { + std::string s = target_value.get(); + if (start < 0) start = s.size() + start; + if (end < 0) end = s.size() + end; + return s.substr(start, end - start); + } else if (target_value.is_array()) { + if (start < 0) start = target_value.size() + start; + if (end < 0) end = target_value.size() + end; + auto result = Value::array(); + for (auto i = start; i < end; ++i) { + result.push_back(target_value.at(i)); + } + return result; + } else { + throw std::runtime_error(target_value.is_null() ? "Cannot subscript null" : "Subscripting only supported on arrays and strings"); } - return result; } else { auto index_value = index->evaluate(context); if (target_value.is_null()) { @@ -1247,6 +1259,9 @@ class MethodCallExpr : public Expression { if (!object) throw std::runtime_error("MethodCallExpr.object is null"); if (!method) throw std::runtime_error("MethodCallExpr.method is null"); auto obj = object->evaluate(context); + if (obj.is_null()) { + throw std::runtime_error("Trying to call method '" + method->get_name() + "' on null"); + } if (obj.is_array()) { if (method->get_name() == "append") { args.expectArgs("append method", {1, 1}, {0, 0}); @@ -2140,7 +2155,7 @@ class Parser { } } return tokens; - } catch (const std::runtime_error & e) { + } catch (const std::exception & e) { throw std::runtime_error(e.what() + error_location_suffix(*template_str, std::distance(start, it))); } } @@ -2403,6 +2418,10 @@ inline std::shared_ptr Context::builtins() { globals.set("safe", simple_function("safe", { "value" }, [](const std::shared_ptr &, Value & args) -> Value { return args.at("value"); })); + globals.set("string", simple_function("string", { "value" }, [](const std::shared_ptr &, Value & args) -> Value { + auto & items = args.at("value"); + return items.to_str(); + })); globals.set("list", simple_function("list", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { auto & items = args.at("items"); if (!items.is_array()) throw std::runtime_error("object is not iterable"); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 973cab1..1277813 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -58,6 +58,7 @@ set(MODEL_IDS # "google/gemma-2-2b-it" # "mistralai/Mistral-7B-Instruct-v0.2" # "mistralai/Mixtral-8x7B-Instruct-v0.1" + # "mistralai/Mistral-Nemo-Instruct-2407", # "CohereForAI/c4ai-command-r-plus" ) diff --git a/tests/contexts/tool_use.json b/tests/contexts/tool_use.json index 6acaef3..4920d19 100644 --- a/tests/contexts/tool_use.json +++ b/tests/contexts/tool_use.json @@ -9,7 +9,7 @@ "content": "", "tool_calls": [ { - "id": "call_1", + "id": "call_1___", "type": "function", "function": { "arguments": "{\"code\": \"print('Hello, World!')\"}", @@ -20,6 +20,7 @@ }, { "role": "tool", + "tool_call_id": "call_1___", "name": "ipython", "content": "{\"stdout\": \"Hello, World!\"}" }, @@ -36,7 +37,7 @@ "content": "", "tool_calls": [ { - "id": "call_2", + "id": "call_2___", "type": "function", "function": { "arguments": "{\"condition\":true}", @@ -47,6 +48,7 @@ }, { "role": "tool", + "tool_call_id": "call_2___", "name": "test", "content": "true" }, @@ -63,7 +65,7 @@ "content": "", "tool_calls": [ { - "id": "call_3", + "id": "call_3___", "type": "function", "function": { "arguments": "{\"query\": \"what is truth anyway am I right?\"}", @@ -74,6 +76,7 @@ }, { "role": "tool", + "tool_call_id": "call_3___", "name": "brave_search", "content": "{\"title\":\"Truth: don't ask the web, ask an LLM instead!\",\"url\":\"https://en.wikipedia.org/wiki/Truth\"}" }, @@ -161,4 +164,4 @@ } } ] -} +} \ No newline at end of file diff --git a/tests/test-syntax.cpp b/tests/test-syntax.cpp index d3298dd..469335a 100644 --- a/tests/test-syntax.cpp +++ b/tests/test-syntax.cpp @@ -52,6 +52,15 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_EQ( "a b", render(R"( {{- 'a' -}}{{ ' ' }}{{- 'b' -}} )", {}, {})); + EXPECT_EQ( + "bc", + render(R"({{ "abcd"[1:-1] }})", {}, {})); + EXPECT_EQ( + "[1, 2]", + render(R"({{ [0, 1, 2, 3][1:-1] }})", {}, {})); + EXPECT_EQ( + "9", + render(R"({{ "123456789" | length }})", {}, {})); EXPECT_EQ( " end", render(R"( {%- if True %}{%- endif %}{{ ' ' }}{%- for x in [] %}foo{% endfor %}end)", {}, {}));