Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions include/minja/chat-template.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,10 @@ class chat_template {
inputs.add_generation_prompt = false;
full = apply(inputs);
}

if (full.find(prefix) != 0) {
if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) {
prefix = prefix.substr(0, prefix.size() - eos_token_.size());
}
auto eos_pos_last = full.rfind(eos_token_);
if (eos_pos_last == prefix.size() - eos_token_.size() ||
(full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) {
full = full.substr(0, eos_pos_last);
}
if (full.find(prefix) != 0) {
fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
Expand Down Expand Up @@ -363,7 +362,7 @@ class chat_template {
if (polyfill_tools) {
adjusted_messages = add_system(inputs.messages,
"You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_));
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n"));
} else {
adjusted_messages = inputs.messages;
}
Expand Down
2 changes: 1 addition & 1 deletion scripts/fetch_templates_and_goldens.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def apply(self, context):
if has_tools and not caps.supports_tools:
add_system(context['messages'],
f"You can call any of the following tools to satisfy the user's requests: {json.dumps(context['tools'], indent=2)}" +
("\n\nExample tool call syntax:\n\n" + self.tool_call_example if self.tool_call_example is not None else ""))
("\n\nExample tool call syntax:\n\n" + self.tool_call_example + "\n\n" if self.tool_call_example is not None else ""))

for message in context['messages']:
if 'tool_calls' in message:
Expand Down
19 changes: 18 additions & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,25 @@ target_link_libraries(test-syntax PRIVATE
gtest_main
gmock
)

add_executable(test-polyfills test-polyfills.cpp)
target_compile_features(test-polyfills PUBLIC cxx_std_17)
if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
target_compile_definitions(test-polyfills PUBLIC _CRT_SECURE_NO_WARNINGS)
target_compile_options(gtest PRIVATE -Wno-language-extension-token)
endif()
target_link_libraries(test-polyfills PRIVATE
nlohmann_json::nlohmann_json
gtest_main
gmock
)
if (NOT CMAKE_CROSSCOMPILING)
gtest_discover_tests(test-syntax)
endif()

if (NOT CMAKE_CROSSCOMPILING)
gtest_discover_tests(test-syntax)
gtest_discover_tests(test-polyfills)
endif()

add_executable(test-capabilities test-capabilities.cpp)
Expand Down Expand Up @@ -54,7 +71,7 @@ set(MODEL_IDS
# minja implementation on the same template and context, and compare the output with the golden.
#
# For Gated models, you'll need to run `huggingface-cli login` (and be granted access) to download their template.

abacusai/Fewshot-Metamath-OrcaVicuna-Mistral
bofenghuang/vigogne-2-70b-chat
CohereForAI/c4ai-command-r-plus # Gated
Expand Down
Loading