@@ -361,6 +361,13 @@ static json oaicompat_completion_params_parse(
361
361
llama_params[" __oaicompat" ] = true ;
362
362
json tool_name_map;
363
363
const std::vector<json> expanded_messages = expand_messages (body, tool_name_map);
364
+ llama_params[" tool_field" ] = " tool_calls" ;
365
+ if (body.contains (" tools" ) && !body[" tools" ].empty ()) {
366
+ llama_params[" tool_field" ] = " tool_calls" ;
367
+ }
368
+ else if (body.contains (" functions" ) && !body[" functions" ].empty ()) {
369
+ llama_params[" tool_field" ] = " function_call" ;
370
+ }
364
371
llama_params[" prompt" ] = format_chat (model, chat_template, expanded_messages);
365
372
llama_params[" tool_name_map" ] = tool_name_map;
366
373
@@ -518,7 +525,6 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
518
525
if (!result.contains (" model" ) || !result.contains (" oaicompat_token_ctr" )) {
519
526
return std::vector<json>({result});
520
527
}
521
-
522
528
bool first = json_value (result, " oaicompat_token_ctr" , 0 ) == 0 ;
523
529
std::string modelname = json_value (result, " model" , std::string (DEFAULT_OAICOMPAT_MODEL));
524
530
@@ -527,6 +533,7 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
527
533
bool stopped_limit = json_value (result, " stopped_limit" , false );
528
534
std::string content = json_value (result, " content" , std::string (" " ));
529
535
std::vector<json> parsed_content = rubra_fc_json_tool_extractor (content);
536
+ std::string tool_field = json_value (result, " tool_field" , std::string (" tool_calls" ));
530
537
531
538
std::string finish_reason;
532
539
if (stopped_word || stopped_eos) {
@@ -535,7 +542,6 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
535
542
if (stopped_limit) {
536
543
finish_reason = " length" ;
537
544
}
538
-
539
545
std::time_t t = std::time (0 );
540
546
541
547
json choices;
@@ -544,6 +550,7 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
544
550
choices = json::array ({json{{" finish_reason" , finish_reason},
545
551
{" index" , 0 },
546
552
{" delta" , json::object ()}}});
553
+
547
554
} else {
548
555
if (first) {
549
556
if (content.empty ()) {
@@ -592,10 +599,27 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
592
599
};
593
600
oai_format_tool_calls.push_back (tool_call);
594
601
}
595
- choices = json::array ({json{{" finish_reason" , nullptr },
602
+ if (tool_field == " tool_calls" ) {
603
+ choices = json::array ({json{{" finish_reason" , nullptr },
596
604
{" index" , 0 },
597
- {" delta" , json{{" tool_calls " , oai_format_tool_calls},
605
+ {" delta" , json{{tool_field , oai_format_tool_calls},
598
606
{" role" , " assistant" }}}}});
607
+ }
608
+ else {
609
+ choices = json::array ({json{{" finish_reason" , nullptr },
610
+ {" index" , 0 },
611
+ {" delta" , json{{tool_field, oai_format_tool_calls[0 ][" function" ]},
612
+ {" role" , " assistant" }}}}});
613
+ }
614
+
615
+ json second_ret = json{
616
+ {" choices" , choices},
617
+ {" created" , t},
618
+ {" id" , completion_id},
619
+ {" model" , modelname},
620
+ {" object" , " chat.completion.chunk" }};
621
+
622
+ return std::vector<json>({initial_ret, second_ret});
599
623
}
600
624
601
625
}
@@ -632,10 +656,18 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
632
656
};
633
657
oai_format_tool_calls.push_back (tool_call);
634
658
}
635
- choices = json::array ({json{{" finish_reason" , nullptr },
659
+ if (tool_field == " tool_calls" ) {
660
+ choices = json::array ({json{{" finish_reason" , nullptr },
661
+ {" index" , 0 },
662
+ {" delta" , json{{tool_field, oai_format_tool_calls},
663
+ {" role" , " assistant" }}}}});
664
+ }
665
+ else {
666
+ choices = json::array ({json{{" finish_reason" , nullptr },
636
667
{" index" , 0 },
637
- {" delta" , json{{" tool_calls " , oai_format_tool_calls},
668
+ {" delta" , json{{tool_field , oai_format_tool_calls[ 0 ][ " function " ] },
638
669
{" role" , " assistant" }}}}});
670
+ }
639
671
}
640
672
641
673
}
@@ -657,7 +689,7 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
657
689
{" total_tokens" , num_tokens_predicted + num_prompt_tokens}
658
690
}});
659
691
}
660
-
692
+
661
693
return std::vector<json>({ret});
662
694
}
663
695
0 commit comments