Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ option(MINJA_FUZZTEST_ENABLED "minja: fuzztests enabled"
option(MINJA_FUZZTEST_FUZZING_MODE "minja: run fuzztests (if enabled) in fuzzing mode" OFF)
option(MINJA_USE_VENV "minja: use Python venv for build" MINJA_USE_VENV_DEFAULT)

# Note: tests require C++14 because google/fuzztest depends on a version of gtest that requires it
# (and we don't want to use an older version of fuzztest)
# Examples are built w/ C++11 to check the compatibility of the library.
set(CMAKE_CXX_STANDARD 17)

set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded$<$<CONFIG:Debug>:Debug>DLL")
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ It is **not general purpose**: it includes just what’s needed for actual chat
- See `MODEL_IDS` in [tests/CMakeLists.txt](./tests/CMakeLists.txt) for the list of models currently supported
- Easy to integrate to/with projects such as [llama.cpp](https://github.com/ggerganov/llama.cpp) or [gemma.cpp](https://github.com/google/gemma.cpp):
- Header-only
- C++11
- C++17
- Only depend on [nlohmann::json](https://github.com/nlohmann/json) (no Boost)
- Keep codebase small (currently 2.5k LoC) and easy to understand
- *Decent* performance compared to Python.
Expand Down
1 change: 0 additions & 1 deletion examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ foreach(example
raw
)
add_executable(${example} ${example}.cpp)
set_target_properties(${example} PROPERTIES CXX_STANDARD 11)
target_link_libraries(${example} PRIVATE nlohmann_json::nlohmann_json)

endforeach()
103 changes: 46 additions & 57 deletions include/minja/minja.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,6 @@

using json = nlohmann::ordered_json;

/* Backport make_unique from C++14. */
template <class T, class... Args>
typename std::unique_ptr<T> nonstd_make_unique(Args &&...args) {
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
}

namespace minja {

class Context;
Expand All @@ -51,8 +45,8 @@ class Value : public std::enable_shared_from_this<Value> {
}

Value get_named(const std::string & name) {
for (const auto & p : kwargs) {
if (p.first == name) return p.second;
for (const auto & [key, value] : kwargs) {
if (key == name) return value;
}
return Value();
}
Expand Down Expand Up @@ -489,13 +483,11 @@ inline json Value::get<json>() const {
}
if (object_) {
json res = json::object();
for (const auto& item : *object_) {
const auto & key = item.first;
auto json_value = item.second.get<json>();
for (const auto& [key, value] : *object_) {
if (key.is_string()) {
res[key.get<std::string>()] = json_value;
res[key.get<std::string>()] = value.get<json>();
} else if (key.is_primitive()) {
res[key.dump()] = json_value;
res[key.dump()] = value.get<json>();
} else {
throw std::runtime_error("Invalid key type for conversion to JSON: " + key.dump());
}
Expand Down Expand Up @@ -610,8 +602,8 @@ class Expression {
for (const auto& arg : this->args) {
vargs.args.push_back(arg->evaluate(context));
}
for (const auto& arg : this->kwargs) {
vargs.kwargs.push_back({arg.first, arg.second->evaluate(context)});
for (const auto& [name, value] : this->kwargs) {
vargs.kwargs.push_back({name, value->evaluate(context)});
}
return vargs;
}
Expand Down Expand Up @@ -974,13 +966,11 @@ class MacroNode : public TemplateNode {
auto & param_name = params[i].first;
call_context->set(param_name, arg);
}
for (size_t i = 0, n = args.kwargs.size(); i < n; i++) {
auto & arg = args.kwargs[i];
auto & arg_name = arg.first;
for (auto & [arg_name, value] : args.kwargs) {
auto it = named_param_positions.find(arg_name);
if (it == named_param_positions.end()) throw std::runtime_error("Unknown parameter name for macro " + name->get_name() + ": " + arg_name);

call_context->set(arg_name, arg.second);
call_context->set(arg_name, value);
param_set[it->second] = true;
}
// Set default values for parameters that were not passed
Expand Down Expand Up @@ -1106,10 +1096,10 @@ class DictExpr : public Expression {
: Expression(location), elements(std::move(e)) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
auto result = Value::object();
for (const auto& e : elements) {
if (!e.first) throw std::runtime_error("Dict key is null");
if (!e.second) throw std::runtime_error("Dict value is null");
result.set(e.first->evaluate(context), e.second->evaluate(context));
for (const auto& [key, value] : elements) {
if (!key) throw std::runtime_error("Dict key is null");
if (!value) throw std::runtime_error("Dict value is null");
result.set(key->evaluate(context), value->evaluate(context));
}
return result;
}
Expand Down Expand Up @@ -1462,7 +1452,7 @@ class Parser {
escape = true;
} else if (*it == quote) {
++it;
return nonstd_make_unique<std::string>(std::move(result));
return std::make_unique<std::string>(std::move(result));
} else {
result += *it;
}
Expand Down Expand Up @@ -1609,8 +1599,8 @@ class Parser {
}

auto location = get_location();
auto if_expr = parseIfExpression();
return std::make_shared<IfExpr>(location, std::move(if_expr.first), std::move(left), std::move(if_expr.second));
auto [condition, else_expr] = parseIfExpression();
return std::make_shared<IfExpr>(location, std::move(condition), std::move(left), std::move(else_expr));
}

Location get_location() const {
Expand All @@ -1627,7 +1617,7 @@ class Parser {
else_expr = parseExpression();
if (!else_expr) throw std::runtime_error("Expected 'else' expression");
}
return std::make_pair(std::move(condition), std::move(else_expr));
return std::pair(std::move(condition), std::move(else_expr));
}

std::shared_ptr<Expression> parseLogicalOr() {
Expand Down Expand Up @@ -2012,7 +2002,7 @@ class Parser {
if (consumeToken(":").empty()) throw std::runtime_error("Expected colon betweek key & value in dictionary");
auto value = parseExpression();
if (!value) throw std::runtime_error("Expected value in dictionary");
elements.emplace_back(std::make_pair(std::move(key), std::move(value)));
elements.emplace_back(std::pair(std::move(key), std::move(value)));
};

parseKeyValuePair();
Expand Down Expand Up @@ -2087,7 +2077,7 @@ class Parser {
auto pre_space = parsePreSpace(group[1]);
auto content = group[2];
auto post_space = parsePostSpace(group[3]);
tokens.push_back(nonstd_make_unique<CommentTemplateToken>(location, pre_space, post_space, content));
tokens.push_back(std::make_unique<CommentTemplateToken>(location, pre_space, post_space, content));
} else if (!(group = consumeTokenGroups(expr_open_regex, SpaceHandling::Keep)).empty()) {
auto pre_space = parsePreSpace(group[1]);
auto expr = parseExpression();
Expand All @@ -2097,7 +2087,7 @@ class Parser {
}

auto post_space = parsePostSpace(group[1]);
tokens.push_back(nonstd_make_unique<ExpressionTemplateToken>(location, pre_space, post_space, std::move(expr)));
tokens.push_back(std::make_unique<ExpressionTemplateToken>(location, pre_space, post_space, std::move(expr)));
} else if (!(group = consumeTokenGroups(block_open_regex, SpaceHandling::Keep)).empty()) {
auto pre_space = parsePreSpace(group[1]);

Expand All @@ -2115,19 +2105,19 @@ class Parser {
if (!condition) throw std::runtime_error("Expected condition in if block");

auto post_space = parseBlockClose();
tokens.push_back(nonstd_make_unique<IfTemplateToken>(location, pre_space, post_space, std::move(condition)));
tokens.push_back(std::make_unique<IfTemplateToken>(location, pre_space, post_space, std::move(condition)));
} else if (keyword == "elif") {
auto condition = parseExpression();
if (!condition) throw std::runtime_error("Expected condition in elif block");

auto post_space = parseBlockClose();
tokens.push_back(nonstd_make_unique<ElifTemplateToken>(location, pre_space, post_space, std::move(condition)));
tokens.push_back(std::make_unique<ElifTemplateToken>(location, pre_space, post_space, std::move(condition)));
} else if (keyword == "else") {
auto post_space = parseBlockClose();
tokens.push_back(nonstd_make_unique<ElseTemplateToken>(location, pre_space, post_space));
tokens.push_back(std::make_unique<ElseTemplateToken>(location, pre_space, post_space));
} else if (keyword == "endif") {
auto post_space = parseBlockClose();
tokens.push_back(nonstd_make_unique<EndIfTemplateToken>(location, pre_space, post_space));
tokens.push_back(std::make_unique<EndIfTemplateToken>(location, pre_space, post_space));
} else if (keyword == "for") {
static std::regex recursive_tok(R"(recursive\b)");
static std::regex if_tok(R"(if\b)");
Expand All @@ -2145,10 +2135,10 @@ class Parser {
auto recursive = !consumeToken(recursive_tok).empty();

auto post_space = parseBlockClose();
tokens.push_back(nonstd_make_unique<ForTemplateToken>(location, pre_space, post_space, std::move(varnames), std::move(iterable), std::move(condition), recursive));
tokens.push_back(std::make_unique<ForTemplateToken>(location, pre_space, post_space, std::move(varnames), std::move(iterable), std::move(condition), recursive));
} else if (keyword == "endfor") {
auto post_space = parseBlockClose();
tokens.push_back(nonstd_make_unique<EndForTemplateToken>(location, pre_space, post_space));
tokens.push_back(std::make_unique<EndForTemplateToken>(location, pre_space, post_space));
} else if (keyword == "set") {
static std::regex namespaced_var_regex(R"((\w+)[\s\n\r]*\.[\s\n\r]*(\w+))");

Expand All @@ -2172,34 +2162,34 @@ class Parser {
}
}
auto post_space = parseBlockClose();
tokens.push_back(nonstd_make_unique<SetTemplateToken>(location, pre_space, post_space, ns, var_names, std::move(value)));
tokens.push_back(std::make_unique<SetTemplateToken>(location, pre_space, post_space, ns, var_names, std::move(value)));
} else if (keyword == "endset") {
auto post_space = parseBlockClose();
tokens.push_back(nonstd_make_unique<EndSetTemplateToken>(location, pre_space, post_space));
tokens.push_back(std::make_unique<EndSetTemplateToken>(location, pre_space, post_space));
} else if (keyword == "macro") {
auto macroname = parseIdentifier();
if (!macroname) throw std::runtime_error("Expected macro name in macro block");
auto params = parseParameters();

auto post_space = parseBlockClose();
tokens.push_back(nonstd_make_unique<MacroTemplateToken>(location, pre_space, post_space, std::move(macroname), std::move(params)));
tokens.push_back(std::make_unique<MacroTemplateToken>(location, pre_space, post_space, std::move(macroname), std::move(params)));
} else if (keyword == "endmacro") {
auto post_space = parseBlockClose();
tokens.push_back(nonstd_make_unique<EndMacroTemplateToken>(location, pre_space, post_space));
tokens.push_back(std::make_unique<EndMacroTemplateToken>(location, pre_space, post_space));
} else if (keyword == "filter") {
auto filter = parseExpression();
if (!filter) throw std::runtime_error("Expected expression in filter block");

auto post_space = parseBlockClose();
tokens.push_back(nonstd_make_unique<FilterTemplateToken>(location, pre_space, post_space, std::move(filter)));
tokens.push_back(std::make_unique<FilterTemplateToken>(location, pre_space, post_space, std::move(filter)));
} else if (keyword == "endfilter") {
auto post_space = parseBlockClose();
tokens.push_back(nonstd_make_unique<EndFilterTemplateToken>(location, pre_space, post_space));
tokens.push_back(std::make_unique<EndFilterTemplateToken>(location, pre_space, post_space));
} else {
throw std::runtime_error("Unexpected block: " + keyword);
}
} else if (!(text = consumeToken(text_regex, SpaceHandling::Keep)).empty()) {
tokens.push_back(nonstd_make_unique<TextTemplateToken>(location, SpaceHandling::Keep, SpaceHandling::Keep, text));
tokens.push_back(std::make_unique<TextTemplateToken>(location, SpaceHandling::Keep, SpaceHandling::Keep, text));
} else {
if (it != end) throw std::runtime_error("Unexpected character");
}
Expand Down Expand Up @@ -2352,14 +2342,13 @@ static Value simple_function(const std::string & fn_name, const std::vector<std:
throw std::runtime_error("Too many positional params for " + fn_name);
}
}
for (size_t i = 0, n = args.kwargs.size(); i < n; i++) {
auto & arg = args.kwargs[i];
auto named_pos_it = named_positions.find(arg.first);
for (auto & [name, value] : args.kwargs) {
auto named_pos_it = named_positions.find(name);
if (named_pos_it == named_positions.end()) {
throw std::runtime_error("Unknown argument " + arg.first + " for function " + fn_name);
throw std::runtime_error("Unknown argument " + name + " for function " + fn_name);
}
provided_args[named_pos_it->second] = true;
args_obj.set(arg.first, arg.second);
args_obj.set(name, value);
}
return fn(context, args_obj);
});
Expand Down Expand Up @@ -2481,8 +2470,8 @@ inline std::shared_ptr<Context> Context::builtins() {
globals.set("namespace", Value::callable([=](const std::shared_ptr<Context> &, Value::Arguments & args) {
auto ns = Value::object();
args.expectArgs("namespace", {0, 0}, {0, std::numeric_limits<size_t>::max()});
for (auto & arg : args.kwargs) {
ns.set(arg.first, arg.second);
for (auto & [name, value] : args.kwargs) {
ns.set(name, value);
}
return ns;
}));
Expand Down Expand Up @@ -2652,17 +2641,17 @@ inline std::shared_ptr<Context> Context::builtins() {
param_set[i] = true;
}
}
for (auto & arg : args.kwargs) {
for (auto & [name, value] : args.kwargs) {
size_t i;
if (arg.first == "start") i = 0;
else if (arg.first == "end") i = 1;
else if (arg.first == "step") i = 2;
else throw std::runtime_error("Unknown argument " + arg.first + " for function range");
if (name == "start") i = 0;
else if (name == "end") i = 1;
else if (name == "step") i = 2;
else throw std::runtime_error("Unknown argument " + name + " for function range");

if (param_set[i]) {
throw std::runtime_error("Duplicate argument " + arg.first + " for function range");
throw std::runtime_error("Duplicate argument " + name + " for function range");
}
startEndStep[i] = arg.second.get<int64_t>();
startEndStep[i] = value.get<int64_t>();
param_set[i] = true;
}
if (!param_set[1]) {
Expand Down
1 change: 0 additions & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ if (MINJA_FUZZTEST_ENABLED)
fuzztest_setup_fuzzing_flags()
endif()
add_executable(test-fuzz test-fuzz.cpp)
set_target_properties(test-fuzz PROPERTIES CXX_STANDARD 17)
target_include_directories(test-fuzz PRIVATE ${fuzztest_BINARY_DIR})
target_link_libraries(test-fuzz PRIVATE nlohmann_json::nlohmann_json)
link_fuzztest(test-fuzz)
Expand Down
Loading