Skip to content

Commit a98f4bc

Browse files
committed
add support for the deprecated field. Now it is correctly working for Dify.
1 parent de003e5 commit a98f4bc

File tree

3 files changed

+72
-26
lines changed

3 files changed

+72
-26
lines changed

examples/server/function-call.hpp

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,9 @@ std::string construct_json_tool_call_str(const json& tool_calls, nlohmann::order
226226
}
227227

228228

229+
230+
231+
229232
const std::vector<json> expand_messages(const json & body, json &tool_name_map) {
230233
std::string function_str = "";
231234
if (body.contains("tools") && !body["tools"].empty()) {
@@ -243,13 +246,11 @@ const std::vector<json> expand_messages(const json & body, json &tool_name_map)
243246
for (size_t i = 0; i < body["messages"].size(); ++i) {
244247
if (body["messages"][i]["role"] != "tool" and func_observation_map.size() > 0) {
245248
// insert the observation from the tool call before the next message
246-
std::string observation_str = "";
247-
std::vector<std::string> func_observation_array;
249+
json func_json_array = json::array();
248250
for (const auto& [key, value] : func_observation_map) {
249-
func_observation_array.push_back(value);
251+
func_json_array.push_back(value);
250252
}
251-
json func_json_array = func_observation_array;
252-
observation_str = std::string("start observation ") + func_json_array.dump() + std::string(" end observation");
253+
std::string observation_str = "start observation " + func_json_array.dump() + " end observation";
253254
json observation_call;
254255
observation_call["role"] = "user";
255256
observation_call["content"] = observation_str;
@@ -274,10 +275,15 @@ const std::vector<json> expand_messages(const json & body, json &tool_name_map)
274275
}
275276
}
276277
// else if (body["messages"][i]["role"] == "assistant" and (body["messages"][i]["content"].is_null() or body["messages"][i]["content"]=="") and !body["messages"][i]["tool_calls"].is_null() and !body["messages"][i]["tool_calls"].empty()){
277-
else if (body["messages"][i]["role"] == "assistant" and body["messages"][i].contains("tool_calls")){
278+
else if (body["messages"][i]["role"] == "assistant" and (body["messages"][i].contains("tool_calls") or body["messages"][i].contains("function_call"))){
278279
// convert OpenAI function call format to Rubra format
279-
// string tool_call_str = construct_python_tool_call_str(body["messages"][i]["tool_calls"], func_observation_map);
280-
std::string tool_call_str = construct_json_tool_call_str(body["messages"][i]["tool_calls"], func_observation_map);
280+
std::string tool_call_str;
281+
if (body["messages"][i].contains("tool_calls")) {
282+
tool_call_str = construct_json_tool_call_str(body["messages"][i]["tool_calls"], func_observation_map);
283+
}
284+
else {
285+
tool_call_str = std::string("starttoolcall") + body["messages"][i]["function_call"].dump() + std::string("endtoolcall");
286+
}
281287
json function_call;
282288
function_call["role"] = "assistant";
283289
function_call["content"] = tool_call_str;
@@ -293,20 +299,27 @@ const std::vector<json> expand_messages(const json & body, json &tool_name_map)
293299
}
294300

295301
}
302+
else if (body["messages"][i]["role"] == "function") {
303+
json func_json_array = json::array();
304+
func_json_array.push_back(body["messages"][i]["content"]);
305+
std::string observation_str = std::string("start observation ") + func_json_array.dump() + std::string(" end observation");
306+
json observation_call;
307+
observation_call["role"] = "user";
308+
observation_call["content"] = observation_str;
309+
temp_vec.push_back(observation_call);
310+
}
296311
else {
297312
temp_vec.push_back(body["messages"][i]);
298313
}
299314

300315
}
301316
if (func_observation_map.size() > 0) {
302317
// insert the observation from the tool call before the next message
303-
std::string observation_str = "";
304-
std::vector<std::string> func_observation_array;
318+
json func_json_array = json::array();
305319
for (const auto& [key, value] : func_observation_map) {
306-
func_observation_array.push_back(value);
320+
func_json_array.push_back(value);
307321
}
308-
json func_json_array = func_observation_array;
309-
observation_str = std::string("start observation ") + func_json_array.dump() + std::string(" end observation");
322+
std::string observation_str = "start observation " + func_json_array.dump() + " end observation";
310323
json observation_call;
311324
observation_call["role"] = "user";
312325
observation_call["content"] = observation_str;

examples/server/server.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3073,6 +3073,7 @@ int main(int argc, char ** argv) {
30733073
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
30743074
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
30753075

3076+
30763077
const int id_task = ctx_server.queue_tasks.get_new_id();
30773078

30783079
ctx_server.queue_results.add_waiting_task_id(id_task);
@@ -3091,14 +3092,13 @@ int main(int argc, char ** argv) {
30913092
}
30923093
ctx_server.queue_results.remove_waiting_task_id(id_task);
30933094
} else {
3094-
const auto chunked_content_provider = [id_task, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
3095+
const auto chunked_content_provider = [id_task, &ctx_server, completion_id, data](size_t, httplib::DataSink & sink) {
30953096
std::string last_str = "";
30963097
bool is_function_call = false;
30973098
bool checked_function_call = false;
30983099
json last_result_data;
3099-
3100-
auto process_and_send_data = [&](const json& data) {
3101-
std::vector<json> result_array = format_partial_response_oaicompat(data, completion_id);
3100+
auto process_and_send_data = [&](const json& res_data) {
3101+
std::vector<json> result_array = format_partial_response_oaicompat(res_data, completion_id);
31023102

31033103
for (const auto& item : result_array) {
31043104
if (!item.empty()) {
@@ -3116,7 +3116,9 @@ int main(int argc, char ** argv) {
31163116

31173117
while (true) {
31183118
server_task_result result = ctx_server.queue_results.recv(id_task);
3119-
3119+
if (data.contains("tool_field")) {
3120+
result.data["tool_field"] = data["tool_field"];
3121+
}
31203122
if (result.error) {
31213123
const std::string error_str = "error: " + result.data.dump(-1, ' ', false, json::error_handler_t::replace) + "\n\n";
31223124
LOG_VERBOSE("data stream", {{"to_send", error_str}});
@@ -3132,7 +3134,6 @@ int main(int argc, char ** argv) {
31323134
std::string str_to_check = last_str + content;
31333135
is_function_call = (str_to_check.find("starttool") != std::string::npos);
31343136
}
3135-
31363137
if (!is_function_call && !last_str.empty()) {
31373138
std::string temp_str = content;
31383139
result.data["content"] = last_str;

examples/server/utils.hpp

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,13 @@ static json oaicompat_completion_params_parse(
361361
llama_params["__oaicompat"] = true;
362362
json tool_name_map;
363363
const std::vector<json> expanded_messages = expand_messages(body, tool_name_map);
364+
llama_params["tool_field"] = "tool_calls";
365+
if (body.contains("tools") && !body["tools"].empty()) {
366+
llama_params["tool_field"] = "tool_calls";
367+
}
368+
else if (body.contains("functions") && !body["functions"].empty()) {
369+
llama_params["tool_field"] = "function_call";
370+
}
364371
llama_params["prompt"] = format_chat(model, chat_template, expanded_messages);
365372
llama_params["tool_name_map"] = tool_name_map;
366373

@@ -518,7 +525,6 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
518525
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
519526
return std::vector<json>({result});
520527
}
521-
522528
bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
523529
std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
524530

@@ -527,6 +533,7 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
527533
bool stopped_limit = json_value(result, "stopped_limit", false);
528534
std::string content = json_value(result, "content", std::string(""));
529535
std::vector<json> parsed_content = rubra_fc_json_tool_extractor(content);
536+
std::string tool_field = json_value(result, "tool_field", std::string("tool_calls"));
530537

531538
std::string finish_reason;
532539
if (stopped_word || stopped_eos) {
@@ -535,7 +542,6 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
535542
if (stopped_limit) {
536543
finish_reason = "length";
537544
}
538-
539545
std::time_t t = std::time(0);
540546

541547
json choices;
@@ -544,6 +550,7 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
544550
choices = json::array({json{{"finish_reason", finish_reason},
545551
{"index", 0},
546552
{"delta", json::object()}}});
553+
547554
} else {
548555
if (first) {
549556
if (content.empty()) {
@@ -592,10 +599,27 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
592599
};
593600
oai_format_tool_calls.push_back(tool_call);
594601
}
595-
choices = json::array({json{{"finish_reason", nullptr},
602+
if (tool_field == "tool_calls") {
603+
choices = json::array({json{{"finish_reason", nullptr},
596604
{"index", 0},
597-
{"delta", json{{"tool_calls", oai_format_tool_calls},
605+
{"delta", json{{tool_field, oai_format_tool_calls},
598606
{"role", "assistant"}}}}});
607+
}
608+
else {
609+
choices = json::array({json{{"finish_reason", nullptr},
610+
{"index", 0},
611+
{"delta", json{{tool_field, oai_format_tool_calls[0]["function"]},
612+
{"role", "assistant"}}}}});
613+
}
614+
615+
json second_ret = json{
616+
{"choices", choices},
617+
{"created", t},
618+
{"id", completion_id},
619+
{"model", modelname},
620+
{"object", "chat.completion.chunk"}};
621+
622+
return std::vector<json>({initial_ret, second_ret});
599623
}
600624

601625
}
@@ -632,10 +656,18 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
632656
};
633657
oai_format_tool_calls.push_back(tool_call);
634658
}
635-
choices = json::array({json{{"finish_reason", nullptr},
659+
if (tool_field == "tool_calls") {
660+
choices = json::array({json{{"finish_reason", nullptr},
661+
{"index", 0},
662+
{"delta", json{{tool_field, oai_format_tool_calls},
663+
{"role", "assistant"}}}}});
664+
}
665+
else {
666+
choices = json::array({json{{"finish_reason", nullptr},
636667
{"index", 0},
637-
{"delta", json{{"tool_calls", oai_format_tool_calls},
668+
{"delta", json{{tool_field, oai_format_tool_calls[0]["function"]},
638669
{"role", "assistant"}}}}});
670+
}
639671
}
640672

641673
}
@@ -657,7 +689,7 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
657689
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
658690
}});
659691
}
660-
692+
661693
return std::vector<json>({ret});
662694
}
663695

0 commit comments

Comments
 (0)