Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 28 additions & 9 deletions include/minja/chat-template.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,22 @@ class chat_template {
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
bool requires_object_arguments_ = false;
bool requires_typed_content_ = false;
bool supports_system_role_ = true;
bool supports_parallel_tool_calls_ = false;
std::string source_;
std::string bos_token_;
std::string eos_token_;
std::shared_ptr<minja::TemplateNode> template_root_;

std::string try_render(
std::string try_raw_render(
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
bool add_generation_prompt,
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
{
try {
auto prompt = apply(messages, tools, add_generation_prompt, extra_context);
auto prompt = apply(messages, tools, add_generation_prompt, extra_context, /* adjust_inputs= */ false);
// fprintf(stderr, "Prompt: %s\n", prompt.c_str());
return prompt;
} catch (const std::exception & e) {
Expand All @@ -60,7 +61,7 @@ class chat_template {
supports_tools_ = source.find("tools") != std::string::npos;

auto renders_string_arguments =
try_render({
try_raw_render({
{
{"role", "user"},
{"content", "Hey"}
Expand All @@ -81,7 +82,7 @@ class chat_template {
}, {}, false).find("{\"code\": \"print") != std::string::npos;
if (!renders_string_arguments) {
auto renders_object_arguments =
try_render({
try_raw_render({
{
{"role", "user"},
{"content", "Hey"}
Expand All @@ -106,10 +107,13 @@ class chat_template {
}
supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos;

supports_system_role_ = try_render({
supports_system_role_ = try_raw_render({
{{"role", "system"}, {"content", "<System Needle>"}},
{{"role", "user"}, {"content", "Hey"}}
}, {}, false).find("<System Needle>") != std::string::npos;

requires_typed_content_ = try_raw_render({{{"role", "user"}, {"content", "Hey"}}}, {}, false).find("Hey") == std::string::npos
&& try_raw_render({{{"role", "user"}, {"content", {{{"type", "text"}, {"text", "Hey"}}}}}}, {}, false).find("Hey") != std::string::npos;
}

const std::string & source() const { return source_; }
Expand All @@ -122,19 +126,34 @@ class chat_template {
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
bool add_generation_prompt,
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(),
bool adjust_inputs = true) const
{
json actual_messages;

// First, "fix" messages so they have a chance to be rendered correctly by the template

if (requires_object_arguments_ || !supports_system_role_ || !supports_tools_) {
if (adjust_inputs && (requires_object_arguments_ || !supports_system_role_ || !supports_tools_ || requires_typed_content_)) {
actual_messages = json::array();

auto add_message = [&](const json & msg) {
if (requires_typed_content_ && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) {
actual_messages.push_back({
{"role", msg.at("role")},
{"content", {{
{"type", "text"},
{"text", msg.at("content")},
}}},
});
} else {
actual_messages.push_back(msg);
}
};

std::string pending_system;
auto flush_sys = [&]() {
if (!pending_system.empty()) {
actual_messages.push_back({
add_message({
{"role", "user"},
{"content", pending_system},
});
Expand Down Expand Up @@ -217,7 +236,7 @@ class chat_template {
}
}
}
actual_messages.push_back(message);
add_message(message);
}
flush_sys();
} else {
Expand Down
28 changes: 26 additions & 2 deletions include/minja/minja.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline };

class TemplateToken {
public:
enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter };
enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter };

static std::string typeToString(Type t) {
switch (t) {
Expand All @@ -712,6 +712,8 @@ class TemplateToken {
case Type::EndMacro: return "endmacro";
case Type::Filter: return "filter";
case Type::EndFilter: return "endfilter";
case Type::Generation: return "generation";
case Type::EndGeneration: return "endgeneration";
}
return "Unknown";
}
Expand Down Expand Up @@ -788,6 +790,14 @@ struct EndForTemplateToken : public TemplateToken {
EndForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFor, location, pre, post) {}
};

struct GenerationTemplateToken : public TemplateToken {
GenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Generation, location, pre, post) {}
};

struct EndGenerationTemplateToken : public TemplateToken {
EndGenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndGeneration, location, pre, post) {}
};

struct SetTemplateToken : public TemplateToken {
std::string ns;
std::vector<std::string> var_names;
Expand Down Expand Up @@ -2149,7 +2159,7 @@ class Parser {
static std::regex comment_tok(R"(\{#([-~]?)(.*?)([-~]?)#\})");
static std::regex expr_open_regex(R"(\{\{([-~])?)");
static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)");
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|set|endset|block|endblock|macro|endmacro|filter|endfilter)\b)");
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter)\b)");
static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})");
static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})");
Expand Down Expand Up @@ -2229,6 +2239,12 @@ class Parser {
} else if (keyword == "endfor") {
auto post_space = parseBlockClose();
tokens.push_back(std::make_unique<EndForTemplateToken>(location, pre_space, post_space));
} else if (keyword == "generation") {
auto post_space = parseBlockClose();
tokens.push_back(std::make_unique<GenerationTemplateToken>(location, pre_space, post_space));
} else if (keyword == "endgeneration") {
auto post_space = parseBlockClose();
tokens.push_back(std::make_unique<EndGenerationTemplateToken>(location, pre_space, post_space));
} else if (keyword == "set") {
static std::regex namespaced_var_regex(R"((\w+)[\s\n\r]*\.[\s\n\r]*(\w+))");

Expand Down Expand Up @@ -2330,6 +2346,13 @@ class Parser {
throw unterminated(**start);
}
children.emplace_back(std::make_shared<ForNode>(token->location, std::move(for_token->var_names), std::move(for_token->iterable), std::move(for_token->condition), std::move(body), for_token->recursive, std::move(else_body)));
} else if (dynamic_cast<GenerationTemplateToken*>(token.get())) {
auto body = parseTemplate(begin, it, end);
if (it == end || (*(it++))->type != TemplateToken::Type::EndGeneration) {
throw unterminated(**start);
}
// Treat as a no-op, as our scope is templates for inference, not training (`{% generation %}` wraps generated tokens for masking).
children.emplace_back(std::move(body));
} else if (auto text_token = dynamic_cast<TextTemplateToken*>(token.get())) {
SpaceHandling pre_space = (it - 1) != begin ? (*(it - 2))->post_space : SpaceHandling::Keep;
SpaceHandling post_space = it != end ? (*it)->pre_space : SpaceHandling::Keep;
Expand Down Expand Up @@ -2397,6 +2420,7 @@ class Parser {
|| dynamic_cast<EndFilterTemplateToken*>(token.get())
|| dynamic_cast<EndIfTemplateToken*>(token.get())
|| dynamic_cast<ElseTemplateToken*>(token.get())
|| dynamic_cast<EndGenerationTemplateToken*>(token.get())
|| dynamic_cast<ElifTemplateToken*>(token.get())) {
it--; // unconsume the token
break; // exit the loop
Expand Down
14 changes: 14 additions & 0 deletions scripts/fetch_templates_and_goldens.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def join_cmake_path(parent, child):
return '/'.join(x.replace(r'\\', '/') for x in (parent, child))

def handle_chat_template(output_folder, model_id, variant, template_src, context_files):

if '{% generation %}' in template_src:
print('Removing {% generation %} blocks from template', file=sys.stderr)
template_src = template_src.replace('{% generation %}', '').replace('{% endgeneration %}', '')

model_name = model_id.replace("/", "-")
base_name = f'{model_name}-{variant}' if variant else model_name
template_file = join_cmake_path(output_folder, f'{base_name}.jinja')
Expand Down Expand Up @@ -126,6 +131,10 @@ def renders(messages, *, tools=[], add_generation_prompt=False, extra_context={}
{"role": "system", "content": "System Needle"},
{"role": "user", "content": "Hey"}
], extra_context=basic_extra_context, expect_strings=["System Needle"])

requires_typed_content = \
not renders([{"role": "user", "content": "Hey"}], extra_context=basic_extra_context, expect_strings=["Hey"]) \
and renders([{"role": "user", "content": [{"type": "text", "text": "Hey"}]}], extra_context=basic_extra_context, expect_strings=["Hey"])

for context_file in context_files:
context_name = os.path.basename(context_file).replace(".json", "")
Expand All @@ -148,6 +157,11 @@ def renders(messages, *, tools=[], add_generation_prompt=False, extra_context={}
arguments = tool_call['function']['arguments']
tool_call['function']['arguments'] = json.loads(arguments)

if requires_typed_content:
for message in context['messages']:
if 'content' in message and isinstance(message['content'], str):
message['content'] = [{"type": "text", "text": message['content']}]

try:
output = template.render(**context)
except Exception as e1:
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ set(MODEL_IDS
deepseek-ai/DeepSeek-V2.5
google/gemma-2-2b-it # Gated
google/gemma-7b-it # Gated
MiniMaxAI/MiniMax-Text-01
indischepartij/MiniCPM-3B-OpenHermes-2.5-v2
mattshumer/Reflection-Llama-3.1-70B
meetkai/functionary-medium-v3.2
Expand Down
4 changes: 4 additions & 0 deletions tests/test-syntax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,9 @@ TEST(SyntaxTest, SimpleCases) {
EXPECT_EQ(
"[]",
render(R"({{ None | items | list | tojson }})", {}, {}));
EXPECT_EQ(
"Foo",
render(R"({% generation %}Foo{% endgeneration %})", {}, {}));
}
EXPECT_EQ(
"[[1, 2]]",
Expand Down Expand Up @@ -493,6 +496,7 @@ TEST(SyntaxTest, SimpleCases) {

EXPECT_THAT([]() { render("{% if 1 %}", {}, {}); }, ThrowsWithSubstr("Unterminated if"));
EXPECT_THAT([]() { render("{% for x in 1 %}", {}, {}); }, ThrowsWithSubstr("Unterminated for"));
EXPECT_THAT([]() { render("{% generation %}", {}, {}); }, ThrowsWithSubstr("Unterminated generation"));
EXPECT_THAT([]() { render("{% if 1 %}{% else %}", {}, {}); }, ThrowsWithSubstr("Unterminated if"));
EXPECT_THAT([]() { render("{% if 1 %}{% else %}{% elif 1 %}{% endif %}", {}, {}); }, ThrowsWithSubstr("Unterminated if"));
EXPECT_THAT([]() { render("{% filter trim %}", {}, {}); }, ThrowsWithSubstr("Unterminated filter"));
Expand Down
Loading