Skip to content

Commit aa48e37

Browse files
ochafikochafikCISC
authored
server: inject date_string in llama 3.x template + fix date for firefunction v2 (ggml-org#12802)
* Inject date_string in llama 3.x + fix for functionary v2 ggml-org#12729 * move/fix detection of functionary v3.1 before llama 3.x, fix & test their non-tool mode Co-authored-by: Sigbjørn Skjæret <[email protected]> * generate more tokens in test_completion_with_required_tool_tiny_fast to avoid truncation --------- Co-authored-by: ochafik <[email protected]> Co-authored-by: Sigbjørn Skjæret <[email protected]>
1 parent e3a9421 commit aa48e37

File tree

5 files changed

+180
-107
lines changed

5 files changed

+180
-107
lines changed

common/chat.cpp

+125-105
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@
66

77
#include <optional>
88

9+
static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) {
10+
auto time = std::chrono::system_clock::to_time_t(now);
11+
auto local_time = *std::localtime(&time);
12+
std::ostringstream ss;
13+
ss << std::put_time(&local_time, format.c_str());
14+
auto res = ss.str();
15+
return res;
16+
}
17+
918
typedef minja::chat_template common_chat_template;
1019

1120
struct common_chat_templates {
@@ -24,6 +33,7 @@ struct templates_params {
2433
std::string grammar;
2534
bool add_generation_prompt = true;
2635
bool extract_reasoning = true;
36+
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
2737
};
2838

2939
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
@@ -939,78 +949,83 @@ static void expect_tool_parameters(const std::string & name, const json & parame
939949
}
940950
}
941951

942-
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
952+
static common_chat_params common_chat_params_init_llama_3_x(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
943953
auto builtin_tools = json::array();
944954
common_chat_params data;
945-
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
946-
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
947-
std::vector<std::string> tool_rules;
955+
if (!inputs.tools.is_null()) {
956+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
957+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
958+
std::vector<std::string> tool_rules;
948959

949-
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
950-
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
951-
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
952-
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
953-
expect_tool_parameters(name, parameters, {"query"});
954-
} else if (name == "python" || name == "code_interpreter") {
955-
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
956-
expect_tool_parameters(name, parameters, {"code"});
957-
} else {
958-
return false;
959-
}
960+
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
961+
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
962+
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
963+
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
964+
expect_tool_parameters(name, parameters, {"query"});
965+
} else if (name == "python" || name == "code_interpreter") {
966+
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
967+
expect_tool_parameters(name, parameters, {"code"});
968+
} else {
969+
return false;
970+
}
960971

961-
std::vector<std::string> kvs;
962-
for (const auto & [key, value] : parameters.at("properties").items()) {
963-
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
964-
}
972+
std::vector<std::string> kvs;
973+
for (const auto & [key, value] : parameters.at("properties").items()) {
974+
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
975+
}
965976

966-
tool_rules.push_back(
967-
builder.add_rule(
968-
name + "-call",
969-
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
970-
builtin_tools.push_back(name);
977+
tool_rules.push_back(
978+
builder.add_rule(
979+
name + "-call",
980+
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
981+
builtin_tools.push_back(name);
971982

972-
return true;
973-
};
983+
return true;
984+
};
974985

975-
foreach_function(inputs.tools, [&](const json & tool) {
976-
const auto & function = tool.at("function");
977-
std::string name = function.at("name");
978-
auto parameters = function.at("parameters");
979-
builder.resolve_refs(parameters);
986+
foreach_function(inputs.tools, [&](const json & tool) {
987+
const auto & function = tool.at("function");
988+
std::string name = function.at("name");
989+
auto parameters = function.at("parameters");
990+
builder.resolve_refs(parameters);
980991

981-
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
982-
if (allow_python_tag_builtin_tools) {
983-
handle_builtin_tool(name, parameters);
992+
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
993+
if (allow_python_tag_builtin_tools) {
994+
handle_builtin_tool(name, parameters);
995+
}
996+
tool_rules.push_back(
997+
builder.add_rule(
998+
name + "-call",
999+
"\"{\" space "
1000+
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
1001+
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
1002+
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
1003+
"\"}\" space"));
1004+
});
1005+
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
1006+
data.grammar_triggers.push_back({
1007+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
1008+
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
1009+
});
1010+
if (!builtin_tools.empty()) {
1011+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
1012+
data.preserved_tokens.push_back("<|python_tag|>");
9841013
}
985-
tool_rules.push_back(
986-
builder.add_rule(
987-
name + "-call",
988-
"\"{\" space "
989-
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
990-
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
991-
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
992-
"\"}\" space"));
1014+
// Allow a few empty lines on top of the usual constrained json schema space rule.
1015+
builder.add_rule("root", string_join(tool_rules, " | "));
1016+
data.additional_stops.push_back("<|eom_id|>");
9931017
});
994-
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
995-
data.grammar_triggers.push_back({
996-
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
997-
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
998-
});
999-
if (!builtin_tools.empty()) {
1000-
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
1001-
data.preserved_tokens.push_back("<|python_tag|>");
1002-
}
1003-
// Allow a few empty lines on top of the usual constrained json schema space rule.
1004-
builder.add_rule("root", string_join(tool_rules, " | "));
1005-
});
1006-
data.additional_stops.push_back("<|eom_id|>");
1018+
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
1019+
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
1020+
: COMMON_CHAT_FORMAT_LLAMA_3_X;
1021+
} else {
1022+
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
1023+
}
10071024
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
1025+
{"date_string", format_time(inputs.now, "%d %b %Y")},
10081026
{"tools_in_user_message", false},
10091027
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
10101028
});
1011-
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
1012-
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
1013-
: COMMON_CHAT_FORMAT_LLAMA_3_X;
10141029
return data;
10151030
}
10161031
static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
@@ -1150,7 +1165,7 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
11501165
LOG_DBG("%s\n", __func__);
11511166
common_chat_params data;
11521167
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
1153-
{"datetime", "Jan 29 2025 13:00:00 GMT"},
1168+
{"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
11541169
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
11551170
});
11561171
if (inputs.tools.is_array() && !inputs.tools.empty()) {
@@ -1285,55 +1300,59 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
12851300
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
12861301
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
12871302
common_chat_params data;
1288-
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
1289-
std::string python_code_argument_name;
1290-
auto has_raw_python = false;
12911303

1292-
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1293-
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1294-
std::vector<std::string> tool_rules;
1295-
foreach_function(inputs.tools, [&](const json & tool) {
1296-
const auto & function = tool.at("function");
1297-
const auto & parameters = function.at("parameters");
1298-
std::string name = function.at("name");
1299-
if (name == "python" || name == "ipython") {
1300-
if (!parameters.contains("type")) {
1301-
throw std::runtime_error("Missing type in python tool");
1302-
}
1303-
has_raw_python = true;
1304-
const auto & type = parameters.at("type");
1305-
if (type == "object") {
1306-
auto properties = parameters.at("properties");
1307-
for (auto it = properties.begin(); it != properties.end(); ++it) {
1308-
if (it.value().at("type") == "string") {
1309-
if (!python_code_argument_name.empty()) {
1310-
throw std::runtime_error("Multiple string arguments found in python tool");
1304+
if (!inputs.tools.is_null()) {
1305+
std::string python_code_argument_name;
1306+
auto has_raw_python = false;
1307+
1308+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1309+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1310+
std::vector<std::string> tool_rules;
1311+
foreach_function(inputs.tools, [&](const json & tool) {
1312+
const auto & function = tool.at("function");
1313+
const auto & parameters = function.at("parameters");
1314+
std::string name = function.at("name");
1315+
if (name == "python" || name == "ipython") {
1316+
if (!parameters.contains("type")) {
1317+
throw std::runtime_error("Missing type in python tool");
1318+
}
1319+
has_raw_python = true;
1320+
const auto & type = parameters.at("type");
1321+
if (type == "object") {
1322+
auto properties = parameters.at("properties");
1323+
for (auto it = properties.begin(); it != properties.end(); ++it) {
1324+
if (it.value().at("type") == "string") {
1325+
if (!python_code_argument_name.empty()) {
1326+
throw std::runtime_error("Multiple string arguments found in python tool");
1327+
}
1328+
python_code_argument_name = it.key();
13111329
}
1312-
python_code_argument_name = it.key();
13131330
}
1331+
if (python_code_argument_name.empty()) {
1332+
throw std::runtime_error("No string argument found in python tool");
1333+
}
1334+
} else if (type != "string") {
1335+
throw std::runtime_error("Invalid type in python tool: " + type.dump());
13141336
}
1315-
if (python_code_argument_name.empty()) {
1316-
throw std::runtime_error("No string argument found in python tool");
1317-
}
1318-
} else if (type != "string") {
1319-
throw std::runtime_error("Invalid type in python tool: " + type.dump());
13201337
}
1338+
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
1339+
});
1340+
if (has_raw_python) {
1341+
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
1342+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
1343+
data.preserved_tokens.push_back("<|python_tag|>");
13211344
}
1322-
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
1345+
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
1346+
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
1347+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
13231348
});
1324-
if (has_raw_python) {
1325-
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
1326-
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
1327-
data.preserved_tokens.push_back("<|python_tag|>");
1328-
}
1329-
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
1330-
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
1331-
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
1332-
});
1349+
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
1350+
} else {
1351+
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
1352+
}
13331353

13341354
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
13351355
// TODO: if (has_raw_python)
1336-
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
13371356
return data;
13381357
}
13391358
static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
@@ -1593,6 +1612,7 @@ static common_chat_params common_chat_templates_apply_jinja(
15931612
params.extract_reasoning = inputs.extract_reasoning;
15941613
params.tool_choice = inputs.tool_choice;
15951614
params.grammar = inputs.grammar;
1615+
params.now = inputs.now;
15961616
if (!inputs.json_schema.empty()) {
15971617
params.json_schema = json::parse(inputs.json_schema);
15981618
}
@@ -1644,21 +1664,21 @@ static common_chat_params common_chat_templates_apply_jinja(
16441664
return common_chat_params_init_firefunction_v2(tmpl, params);
16451665
}
16461666

1647-
// Plain handler (no tools)
1648-
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
1649-
return common_chat_params_init_without_tools(tmpl, params);
1650-
}
1651-
16521667
// Functionary v3.1 (w/ tools)
16531668
if (src.find("<|start_header_id|>") != std::string::npos
16541669
&& src.find("<function=") != std::string::npos) {
16551670
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
16561671
}
16571672

1658-
// Llama 3.1, 3.2, 3.3 (w/ tools)
1673+
// Llama 3.1, 3.2, 3.3 (also requires date_string so using it even w/o tools)
16591674
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
16601675
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
1661-
return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
1676+
return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
1677+
}
1678+
1679+
// Plain handler (no tools)
1680+
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
1681+
return common_chat_params_init_without_tools(tmpl, params);
16621682
}
16631683

16641684
// Mistral Nemo (w/ tools)

common/chat.h

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#pragma once
44

55
#include "common.h"
6+
#include <chrono>
67
#include <string>
78
#include <vector>
89

@@ -71,6 +72,7 @@ struct common_chat_templates_inputs {
7172
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
7273
bool parallel_tool_calls = false;
7374
bool extract_reasoning = true;
75+
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
7476
};
7577

7678
struct common_chat_params {

tests/test-chat.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,9 @@ static void test_template_output_parsers() {
832832
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
833833
common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
834834
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
835-
common_chat_templates_apply(tmpls.get(), inputs_tools).format);
835+
common_chat_templates_apply(tmpls.get(), inputs_tools).format);
836+
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
837+
common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
836838

837839
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
838840
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,

0 commit comments

Comments
 (0)