Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 81 additions & 31 deletions include/minja/minja.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1201,9 +1201,9 @@ class DictExpr : public Expression {

class SliceExpr : public Expression {
public:
std::shared_ptr<Expression> start, end;
SliceExpr(const Location & loc, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e)
: Expression(loc), start(std::move(s)), end(std::move(e)) {}
std::shared_ptr<Expression> start, end, step;
SliceExpr(const Location & loc, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e, std::shared_ptr<Expression> && st = nullptr)
: Expression(loc), start(std::move(s)), end(std::move(e)), step(std::move(st)) {}
Value do_evaluate(const std::shared_ptr<Context> &) const override {
throw std::runtime_error("SliceExpr not implemented");
}
Expand All @@ -1220,19 +1220,54 @@ class SubscriptExpr : public Expression {
if (!index) throw std::runtime_error("SubscriptExpr.index is null");
auto target_value = base->evaluate(context);
if (auto slice = dynamic_cast<SliceExpr*>(index.get())) {
auto start = slice->start ? slice->start->evaluate(context).get<int64_t>() : 0;
auto end = slice->end ? slice->end->evaluate(context).get<int64_t>() : (int64_t) target_value.size();
bool reverse = slice->step && slice->step->evaluate(context).get<int64_t>() == -1;
if (slice->step && !reverse) {
throw std::runtime_error("Slicing with step other than -1 is not supported");
}

int64_t start = slice->start ? slice->start->evaluate(context).get<int64_t>() : (reverse ? target_value.size() - 1 : 0);
int64_t end = slice->end ? slice->end->evaluate(context).get<int64_t>() : (reverse ? -1 : target_value.size());

size_t len = target_value.size();

if (slice->start && start < 0) {
start = (int64_t)len + start;
}
if (slice->end && end < 0) {
end = (int64_t)len + end;
}

if (target_value.is_string()) {
std::string s = target_value.get<std::string>();
if (start < 0) start = s.size() + start;
if (end < 0) end = s.size() + end;
return s.substr(start, end - start);
} else if (target_value.is_array()) {
if (start < 0) start = target_value.size() + start;
if (end < 0) end = target_value.size() + end;

std::string result_str;
if (reverse) {
for (int64_t i = start; i > end; --i) {
if (i >= 0 && i < (int64_t)len) {
result_str += s[i];
} else if (i < 0) {
break;
}
}
} else {
result_str = s.substr(start, end - start);
}
return result_str;

} else if (target_value.is_array()) {
auto result = Value::array();
for (auto i = start; i < end; ++i) {
result.push_back(target_value.at(i));
if (reverse) {
for (int64_t i = start; i > end; --i) {
if (i >= 0 && i < (int64_t)len) {
result.push_back(target_value.at(i));
} else if (i < 0) {
break;
}
}
} else {
for (auto i = start; i < end; ++i) {
result.push_back(target_value.at(i));
}
}
return result;
} else {
Expand Down Expand Up @@ -1306,6 +1341,8 @@ class BinaryOpExpr : public Expression {
if (name == "iterable") return l.is_iterable();
if (name == "sequence") return l.is_array();
if (name == "defined") return !l.is_null();
if (name == "true") return l.to_bool();
if (name == "false") return !l.to_bool();
throw std::runtime_error("Unknown type for 'is' operator: " + name);
};
auto value = eval();
Expand Down Expand Up @@ -1521,6 +1558,10 @@ class MethodCallExpr : public Expression {
vargs.expectArgs("endswith method", {1, 1}, {0, 0});
auto suffix = vargs.args[0].get<std::string>();
return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
} else if (method->get_name() == "startswith") {
vargs.expectArgs("startswith method", {1, 1}, {0, 0});
auto prefix = vargs.args[0].get<std::string>();
return prefix.length() <= str.length() && std::equal(prefix.begin(), prefix.end(), str.begin());
} else if (method->get_name() == "title") {
vargs.expectArgs("title method", {0, 0}, {0, 0});
auto res = str;
Expand Down Expand Up @@ -2083,28 +2124,37 @@ class Parser {

while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) {
if (!consumeToken("[").empty()) {
std::shared_ptr<Expression> index;
std::shared_ptr<Expression> index;
auto slice_loc = get_location();
std::shared_ptr<Expression> start, end, step;
bool c1 = false, c2 = false;

if (!peekSymbols({ ":" })) {
start = parseExpression();
}

if (!consumeToken(":").empty()) {
c1 = true;
if (!peekSymbols({ ":", "]" })) {
end = parseExpression();
}
if (!consumeToken(":").empty()) {
auto slice_end = parseExpression();
index = std::make_shared<SliceExpr>(slice_end->location, nullptr, std::move(slice_end));
} else {
auto slice_start = parseExpression();
if (!consumeToken(":").empty()) {
consumeSpaces();
if (peekSymbols({ "]" })) {
index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), nullptr);
} else {
auto slice_end = parseExpression();
index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), std::move(slice_end));
}
} else {
index = std::move(slice_start);
c2 = true;
if (!peekSymbols({ "]" })) {
step = parseExpression();
}
}
if (!index) throw std::runtime_error("Empty index in subscript");
if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
}

if ((c1 || c2) && (start || end || step)) {
index = std::make_shared<SliceExpr>(slice_loc, std::move(start), std::move(end), std::move(step));
} else {
index = std::move(start);
}
if (!index) throw std::runtime_error("Empty index in subscript");
if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");

value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
} else if (!consumeToken(".").empty()) {
auto identifier = parseIdentifier();
if (!identifier) throw std::runtime_error("Expected identifier in subscript");
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ set(MODEL_IDS
ValiantLabs/Llama3.1-8B-Enigma
xwen-team/Xwen-72B-Chat
xwen-team/Xwen-7B-Chat
Qwen/Qwen3-4B

# Broken, TODO:
# ai21labs/AI21-Jamba-1.5-Large # https://github.com/google/minja/issues/8
Expand Down
18 changes: 18 additions & 0 deletions tests/test-syntax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ TEST(SyntaxTest, SimpleCases) {
EXPECT_EQ(
"1",
render(R"({{ 1 | safe }})", {}, {}));
EXPECT_EQ(
"True,False",
render(R"({{ 'abc'.startswith('ab') }},{{ ''.startswith('a') }})", {}, {}));
EXPECT_EQ(
"True,False",
render(R"({{ 'abc'.endswith('bc') }},{{ ''.endswith('a') }})", {}, {}));
Expand Down Expand Up @@ -217,6 +220,18 @@ TEST(SyntaxTest, SimpleCases) {
EXPECT_EQ(
"False",
render(R"({% set foo = true %}{{ not foo is defined }})", {}, {}));
EXPECT_EQ(
"True",
render(R"({% set foo = true %}{{ foo is true }})", {}, {}));
EXPECT_EQ(
"False",
render(R"({% set foo = true %}{{ foo is false }})", {}, {}));
EXPECT_EQ(
"True",
render(R"({% set foo = false %}{{ foo is not true }})", {}, {}));
EXPECT_EQ(
"False",
render(R"({% set foo = false %}{{ foo is not false }})", {}, {}));
EXPECT_EQ(
R"({"a": "b"})",
render(R"({{ {"a": "b"} | tojson }})", {}, {}));
Expand Down Expand Up @@ -465,6 +480,9 @@ TEST(SyntaxTest, SimpleCases) {
EXPECT_EQ(
"[1, 2, 3][0, 1][1, 2]",
render("{% set x = [0, 1, 2, 3] %}{{ x[1:] }}{{ x[:2] }}{{ x[1:3] }}", {}, {}));
EXPECT_EQ(
"[3, 2, 1, 0][3, 2, 1][2, 1, 0][2, 1]",
render("{% set x = [0, 1, 2, 3] %}{{ x[::-1] }}{{ x[:0:-1] }}{{ x[2::-1] }}{{ x[2:0:-1] }}", {}, {}));
EXPECT_EQ(
"a",
render("{{ ' a ' | trim }}", {}, {}));
Expand Down