Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
# ubuntu-22.04,
ubuntu-latest,
# windows-2019,
# windows-latest,
windows-latest,
]
type: [
Release,
Expand Down Expand Up @@ -65,5 +65,4 @@ jobs:
run: cmake --build ${{github.workspace}}/build --config ${{ matrix.type }} --parallel

- name: Test
working-directory: ${{github.workspace}}/build
run: ctest --test-dir tests --output-on-failure --verbose -C ${{ matrix.type }}
run: ctest --test-dir build --output-on-failure --verbose -C ${{ matrix.type }}
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ foreach(example
raw
)
add_executable(${example} ${example}.cpp)
target_compile_features(${example} PUBLIC cxx_std_17)
target_link_libraries(${example} PRIVATE nlohmann_json::nlohmann_json)

endforeach()
88 changes: 60 additions & 28 deletions include/minja/minja.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
#include <unordered_set>
#include <json.hpp>

#ifdef _WIN32
#define ENDL "\r\n"
#else
#define ENDL "\n"
#endif

using json = nlohmann::ordered_json;

namespace minja {
Expand All @@ -32,6 +38,15 @@ struct Options {

struct ArgumentsValue;

static std::string normalize_newlines(const std::string & s) {
#ifdef _WIN32
static const std::regex nl_regex("\r\n");
return std::regex_replace(s, nl_regex, "\n");
#else
return s;
#endif
}

/* Values that behave roughly like in Python. */
class Value : public std::enable_shared_from_this<Value> {
public:
Expand Down Expand Up @@ -76,7 +91,7 @@ class Value : public std::enable_shared_from_this<Value> {
void dump(std::ostringstream & out, int indent = -1, int level = 0, bool to_json = false) const {
auto print_indent = [&](int level) {
if (indent > 0) {
out << "\n";
out << ENDL;
for (int i = 0, n = level * indent; i < n; ++i) out << ' ';
}
};
Expand Down Expand Up @@ -547,11 +562,11 @@ static std::string error_location_suffix(const std::string & source, size_t pos)
auto max_line = std::count(start, end, '\n') + 1;
auto col = pos - std::string(start, it).rfind('\n');
std::ostringstream out;
out << " at row " << line << ", column " << col << ":\n";
if (line > 1) out << get_line(line - 1) << "\n";
out << get_line(line) << "\n";
out << std::string(col - 1, ' ') << "^" << "\n";
if (line < max_line) out << get_line(line + 1) << "\n";
out << " at row " << line << ", column " << col << ":" ENDL;
if (line > 1) out << get_line(line - 1) << ENDL;
out << get_line(line) << ENDL;
out << std::string(col - 1, ' ') << "^" << ENDL;
if (line < max_line) out << get_line(line + 1) << ENDL;

return out.str();
}
Expand Down Expand Up @@ -786,7 +801,7 @@ class TemplateNode {
std::string render(const std::shared_ptr<Context> & context) const {
std::ostringstream out;
render(out, context);
return out.str();
return normalize_newlines(out.str());
}
};

Expand Down Expand Up @@ -1214,8 +1229,8 @@ class BinaryOpExpr : public Expression {
if (!l.to_bool()) return Value(false);
return right->evaluate(context).to_bool();
} else if (op == Op::Or) {
if (l.to_bool()) return Value(true);
return right->evaluate(context).to_bool();
if (l.to_bool()) return l;
return right->evaluate(context);
}

auto r = right->evaluate(context);
Expand Down Expand Up @@ -1292,6 +1307,10 @@ struct ArgumentsExpression {
static std::string strip(const std::string & s) {
static std::regex trailing_spaces_regex("^\\s+|\\s+$");
return std::regex_replace(s, trailing_spaces_regex, "");
// auto start = s.find_first_not_of(" \t\n\r");
// if (start == std::string::npos) return "";
// auto end = s.find_last_not_of(" \t\n\r");
// return s.substr(start, end - start + 1);
}

static std::string html_escape(const std::string & s) {
Expand All @@ -1302,7 +1321,7 @@ static std::string html_escape(const std::string & s) {
case '&': result += "&amp;"; break;
case '<': result += "&lt;"; break;
case '>': result += "&gt;"; break;
case '"': result += "&quot;"; break;
case '"': result += "&#34;"; break;
case '\'': result += "&apos;"; break;
default: result += c; break;
}
Expand Down Expand Up @@ -2101,13 +2120,14 @@ class Parser {
static std::regex expr_open_regex(R"(\{\{([-~])?)");
static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)");
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|set|endset|block|endblock|macro|endmacro|filter|endfilter)\b)");
static std::regex text_regex(R"([\s\S\n\r]*?($|(?=\{\{|\{%|\{#)))");
static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})");
static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})");

TemplateTokenVector tokens;
std::vector<std::string> group;
std::string text;
std::smatch match;

try {
while (it != end) {
Expand Down Expand Up @@ -2228,10 +2248,15 @@ class Parser {
} else {
throw std::runtime_error("Unexpected block: " + keyword);
}
} else if (!(text = consumeToken(text_regex, SpaceHandling::Keep)).empty()) {
} else if (std::regex_search(it, end, match, non_text_open_regex)) {
auto text_end = it + match.position();
text = std::string(it, text_end);
it = text_end;
tokens.push_back(std::make_unique<TextTemplateToken>(location, SpaceHandling::Keep, SpaceHandling::Keep, text));
} else {
if (it != end) throw std::runtime_error("Unexpected character");
text = std::string(it, end);
it = end;
tokens.push_back(std::make_unique<TextTemplateToken>(location, SpaceHandling::Keep, SpaceHandling::Keep, text));
}
}
return tokens;
Expand Down Expand Up @@ -2280,24 +2305,31 @@ class Parser {
SpaceHandling post_space = it != end ? (*it)->pre_space : SpaceHandling::Keep;

auto text = text_token->text;
if (pre_space == SpaceHandling::Strip) {
static std::regex leading_space_regex(R"(^(\s|\r|\n)+)");
text = std::regex_replace(text, leading_space_regex, "");
} else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast<ExpressionTemplateToken*>((*(it - 2)).get())) {
static std::regex leading_line(R"(^[ \t]*\r?\n)");
text = std::regex_replace(text, leading_line, "");
}
if (post_space == SpaceHandling::Strip) {
static std::regex trailing_space_regex(R"((\s|\r|\n)+$)");
text = std::regex_replace(text, trailing_space_regex, "");
} else if (options.lstrip_blocks && it != end) {
static std::regex trailing_last_line_space_regex(R"((\r?\n)[ \t]*$)");
text = std::regex_replace(text, trailing_last_line_space_regex, "$1");
auto i = text.size();
while (i > 0 && (text[i - 1] == ' ' || text[i - 1] == '\t')) i--;
if ((i == 0 && (it - 1) == begin) || (i > 0 && text[i - 1] == '\n')) {
text.resize(i);
}
}
if (pre_space == SpaceHandling::Strip) {
static std::regex leading_space_regex(R"(^(\s|\r|\n)+)");
text = std::regex_replace(text, leading_space_regex, "");
} else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast<ExpressionTemplateToken*>((*(it - 2)).get())) {
if (text.length() > 0 && text[0] == '\n') {
text.erase(0, 1);
}
}

if (it == end && !options.keep_trailing_newline) {
static std::regex r(R"(\r?\n$)");
text = std::regex_replace(text, r, ""); // Strip one trailing newline
auto i = text.size();
if (i > 0 && text[i - 1] == '\n') {
i--;
if (i > 0 && text[i - 1] == '\r') i--;
text.resize(i);
}
}
children.emplace_back(std::make_shared<TextNode>(token->location, text));
} else if (auto expr_token = dynamic_cast<ExpressionTemplateToken*>(token.get())) {
Expand Down Expand Up @@ -2357,7 +2389,7 @@ class Parser {
public:

static std::shared_ptr<TemplateNode> parse(const std::string& template_str, const Options & options) {
Parser parser(std::make_shared<std::string>(template_str), options);
Parser parser(std::make_shared<std::string>(normalize_newlines(template_str)), options);
auto tokens = parser.tokenize();
TemplateTokenIterator begin = tokens.begin();
auto it = begin;
Expand Down Expand Up @@ -2627,11 +2659,11 @@ inline std::shared_ptr<Context> Context::builtins() {
while (std::getline(iss, line, '\n')) {
auto needs_indent = !is_first || first;
if (is_first) is_first = false;
else out += "\n";
else out += ENDL;
if (needs_indent) out += indent;
out += line;
}
if (!text.empty() && text.back() == '\n') out += "\n";
if (!text.empty() && text.back() == '\n') out += ENDL;
return out;
}));
globals.set("selectattr", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
Expand Down
21 changes: 21 additions & 0 deletions scripts/render.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2024 Google LLC
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
#
# SPDX-License-Identifier: MIT
import sys
import json
from jinja2 import Environment
import jinja2.ext
from pathlib import Path

input_file, output_file = sys.argv[1:3]
data = json.loads(Path(input_file).read_text())
# print(json.dumps(data, indent=2), file=sys.stderr)

env = Environment(**data['options'], extensions=[jinja2.ext.loopcontrols])
tmpl = env.from_string(data['template'])
output = tmpl.render(data['bindings'])
Path(output_file).write_text(output)
17 changes: 15 additions & 2 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,20 @@
# SPDX-License-Identifier: MIT

add_executable(test-syntax test-syntax.cpp)
target_compile_features(test-syntax PUBLIC cxx_std_17)
target_link_libraries(test-syntax PRIVATE
nlohmann_json::nlohmann_json
gtest_main
gmock
)
gtest_discover_tests(test-syntax)

add_test(NAME test-syntax-jinja2 COMMAND test-syntax)
set_tests_properties(test-syntax-jinja2 PROPERTIES ENVIRONMENT "USE_JINJA2=1;PYTHON_EXECUTABLE=${Python_EXECUTABLE};PYTHONPATH=${CMAKE_SOURCE_DIR}")


add_executable(test-chat-template test-chat-template.cpp)
target_compile_features(test-chat-template PUBLIC cxx_std_17)
target_link_libraries(test-chat-template PRIVATE nlohmann_json::nlohmann_json)

set(MODEL_IDS
Expand Down Expand Up @@ -68,15 +74,21 @@ set(MODEL_IDS
TheBloke/FusionNet_34Bx2_MoE-AWQ

# Broken, TODO:
# fireworks-ai/llama-3-firefunction-v2
# fireworks-ai/llama-3-firefunction-v2 # https://github.com/google/minja/issues/7
# ai21labs/AI21-Jamba-1.5-Large # https://github.com/google/minja/issues/8

# Can't find template(s), TODO:
# ai21labs/Jamba-v0.1
# apple/OpenELM-1_1B-Instruct
# dreamgen/WizardLM-2-7B
# xai-org/grok-1
)

if(WIN32)
list(REMOVE_ITEM MODEL_IDS
bofenghuang/vigogne-2-70b-chat
)
endif()

# Create one test case for each {template, context} combination
file(GLOB CONTEXT_FILES "${CMAKE_SOURCE_DIR}/tests/contexts/*.json")
execute_process(
Expand Down Expand Up @@ -109,6 +121,7 @@ if (MINJA_FUZZTEST_ENABLED)
fuzztest_setup_fuzzing_flags()
endif()
add_executable(test-fuzz test-fuzz.cpp)
target_compile_features(test-fuzz PUBLIC cxx_std_17)
target_include_directories(test-fuzz PRIVATE ${fuzztest_BINARY_DIR})
target_link_libraries(test-fuzz PRIVATE nlohmann_json::nlohmann_json)
link_fuzztest(test-fuzz)
Expand Down
2 changes: 1 addition & 1 deletion tests/test-chat-template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ static std::string read_file(const std::string &path) {
std::string out;
out.resize(static_cast<size_t>(size));
fs.read(&out[0], static_cast<std::streamsize>(size));
return out;
return minja::normalize_newlines(out);
}

int main(int argc, char *argv[]) {
Expand Down
Loading
Loading