Skip to content

Commit 929895d

Browse files
committed
fix: google toolCall
1 parent 7b019f2 commit 929895d

File tree

5 files changed

+146
-19
lines changed

5 files changed

+146
-19
lines changed

internal/common/messages.go

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,51 @@ func MessageCombiner[T any](
1818
buffer := new(bytes.Buffer)
1919
msgs := make([]map[string]string, 0)
2020
for _, message := range messages {
21+
if message.Is("role", "assistant") && message.Has("tool_calls") {
22+
if buffer.Len() > 0 {
23+
msgs = append(msgs, map[string]string{
24+
"role": previous,
25+
"content": buffer.String(),
26+
})
27+
buffer.Reset()
28+
}
29+
30+
previous = message.GetString("role")
31+
toolCalls := message.GetSlice("tool_calls")
32+
if len(toolCalls) == 0 {
33+
continue
34+
}
35+
36+
var toolCall pkg.Keyv[interface{}] = toolCalls[0].(map[string]interface{})
37+
keyv := toolCall.GetKeyv("function")
38+
39+
msgs = append(msgs, map[string]string{
40+
"tool_calls": "yes",
41+
"role": previous,
42+
"name": keyv.GetString("name"),
43+
"content": keyv.GetString("arguments"),
44+
})
45+
continue
46+
}
47+
48+
if message.Is("role", "tool") || message.Is("role", "function") {
49+
if buffer.Len() > 0 {
50+
msgs = append(msgs, map[string]string{
51+
"role": previous,
52+
"content": buffer.String(),
53+
})
54+
buffer.Reset()
55+
}
56+
57+
previous = message.GetString("role")
58+
msgs = append(msgs, map[string]string{
59+
"role": previous,
60+
"name": message.GetString("name"),
61+
"content": message.GetString("content"),
62+
})
63+
continue
64+
}
65+
2166
str := strings.TrimSpace(message.GetString("content"))
2267
if str == "" {
2368
continue
@@ -33,7 +78,7 @@ func MessageCombiner[T any](
3378
continue
3479
}
3580

36-
if previous == message["role"] {
81+
if message.Is("role", previous) {
3782
buffer.WriteString(str)
3883
continue
3984
}

internal/plugin/llm/gemini/adapter.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"github.com/bincooo/emit.io"
1313
"github.com/gin-gonic/gin"
1414
"net/http"
15+
"net/url"
1516
"reflect"
1617
"strings"
1718
"sync"
@@ -104,6 +105,10 @@ func complete(ctx *gin.Context) {
104105
ctx.Set(ginTokens, tokens)
105106
r, err := build(ctx.Request.Context(), proxies, cookie, newMessages, completion)
106107
if err != nil {
108+
var urlError *url.Error
109+
if errors.As(err, &urlError) {
110+
urlError.URL = strings.ReplaceAll(urlError.URL, cookie, "AIzaSy***")
111+
}
107112
response.Error(ctx, -1, err)
108113
return
109114
}

internal/plugin/llm/gemini/fetch.go

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,27 +46,65 @@ func build(ctx context.Context, proxies, token string, messages []map[string]int
4646
completion.TopP = 0.95
4747
}
4848

49+
toStrings := func(slice []interface{}) (values []string) {
50+
for _, v := range slice {
51+
values = append(values, v.(string))
52+
}
53+
return
54+
}
55+
56+
condition := func(str string) string {
57+
switch str {
58+
case "string":
59+
return "STRING"
60+
case "boolean":
61+
return "BOOLEAN"
62+
case "number":
63+
return "NUMBER"
64+
default:
65+
if strings.HasPrefix(str, "array") {
66+
return "ARRAY"
67+
}
68+
return "OBJECT"
69+
}
70+
}
71+
72+
// fix: type 枚举必须符合google定义,否则报错400
73+
// https://ai.google.dev/api/rest/v1beta/Schema?hl=zh-cn#type
74+
var fix func(keyv pkg.Keyv[interface{}]) pkg.Keyv[interface{}]
75+
{
76+
fix = func(keyv pkg.Keyv[interface{}]) pkg.Keyv[interface{}] {
77+
if keyv.Has("type") {
78+
keyv.Set("type", condition(keyv.GetString("type")))
79+
}
80+
for k, _ := range keyv {
81+
child := keyv.GetKeyv(k)
82+
if child != nil {
83+
keyv.Set(k, fix(child))
84+
}
85+
}
86+
return keyv
87+
}
88+
}
89+
4990
// 参数基本与openai对齐
5091
_funcDecls := make([]funcDecl, 0)
5192
if toolsL := len(completion.Tools); toolsL > 0 {
5293
for _, v := range completion.Tools {
5394
kv := v.GetKeyv("function").GetKeyv("parameters")
54-
required, ok := kv.Get("required")
55-
if !ok {
56-
required = []string{}
57-
}
58-
95+
required := kv.GetSlice("required")
5996
_funcDecls = append(_funcDecls, funcDecl{
97+
// 必须为 a-z、A-Z、0-9,或包含下划线和短划线,长度上限为 63 个字符
6098
Name: strings.Replace(v.GetKeyv("function").GetString("name"), "-", "_", -1),
6199
Description: v.GetKeyv("function").GetString("description"),
62100
Params: struct {
63101
Properties map[string]interface{} `json:"properties"`
64102
Required []string `json:"required"`
65103
Type string `json:"type"`
66104
}{
67-
Properties: kv.GetKeyv("properties"),
68-
Required: required.([]string),
69-
Type: kv.GetString("function"),
105+
Properties: fix(kv.GetKeyv("properties")),
106+
Required: toStrings(required),
107+
Type: condition(kv.GetString("type")),
70108
},
71109
})
72110
}
@@ -120,7 +158,7 @@ func build(ctx context.Context, proxies, token string, messages []map[string]int
120158
// 函数调用
121159
payload["tools"] = []map[string]interface{}{
122160
{
123-
"function_declarations": _funcDecls,
161+
"functionDeclarations": _funcDecls,
124162
},
125163
}
126164
}
@@ -148,10 +186,10 @@ func build(ctx context.Context, proxies, token string, messages []map[string]int
148186

149187
if res.StatusCode != http.StatusOK {
150188
h := res.Header
151-
if c := h.Get("content-type"); strings.Contains(c, "application/json") {
189+
if c := h.Get("content-type"); !strings.Contains(c, "text/html") {
152190
bts, e := io.ReadAll(res.Body)
153191
if e == nil {
154-
return nil, fmt.Errorf("%s: %s", res.Status, bts)
192+
logger.Errorf("%s", bts)
155193
}
156194
}
157195
return nil, errors.New(res.Status)

internal/plugin/llm/gemini/message.go

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ func mergeMessages(messages []pkg.Keyv[interface{}]) (newMessages []map[string]i
171171
// role类型转换
172172
condition := func(expr string) string {
173173
switch expr {
174-
case "end":
174+
case "function", "tool", "end":
175175
return expr
176176
case "assistant":
177177
return "model"
@@ -185,18 +185,57 @@ func mergeMessages(messages []pkg.Keyv[interface{}]) (newMessages []map[string]i
185185
tokens += com.CalcTokens(message["content"])
186186
if condition(role) == condition(next) {
187187
// cache buffer
188-
if role == "function" {
189-
buffer.WriteString(fmt.Sprintf("这是系统内置tools工具的返回结果: (%s)\n\n##\n%s\n##", message["name"], message["content"]))
190-
return nil
191-
}
192188
buffer.WriteString(message["content"])
193189
return nil
194190
}
195191

196192
defer buffer.Reset()
197193
buffer.WriteString(fmt.Sprintf(message["content"]))
194+
var result []map[string]interface{}
195+
196+
if role == "tool" || role == "function" {
197+
var args interface{}
198+
if err := json.Unmarshal([]byte(message["content"]), &args); err != nil {
199+
logger.Error(err)
200+
return nil
201+
}
202+
203+
result = append(result, map[string]interface{}{
204+
"role": "user",
205+
"parts": []interface{}{
206+
map[string]interface{}{
207+
"functionResponse": map[string]interface{}{
208+
"name": message["name"],
209+
"response": args,
210+
},
211+
},
212+
},
213+
})
214+
return result
215+
}
216+
217+
if toolCalls, ok := message["tool_calls"]; ok && role == "assistant" && toolCalls == "yes" {
218+
var args interface{}
219+
if err := json.Unmarshal([]byte(message["content"]), &args); err != nil {
220+
logger.Error(err)
221+
return nil
222+
}
223+
224+
result = append(result, map[string]interface{}{
225+
"role": "assistant",
226+
"parts": []interface{}{
227+
map[string]interface{}{
228+
"functionCall": map[string]interface{}{
229+
"name": message["name"],
230+
"args": args,
231+
},
232+
},
233+
},
234+
})
235+
return result
236+
}
237+
198238
if role == "system" {
199-
var result []map[string]interface{}
200239
result = append(result, map[string]interface{}{
201240
"role": "user",
202241
"parts": []interface{}{

internal/plugin/toolcall.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ func toolCacheHash(completion pkg.ChatCompletion) (hash string) {
205205
return "-1"
206206
}
207207

208-
return common.HashString(hash)
208+
return common.HashString(completion.Model + hash)
209209
}
210210

211211
func buildTemplate(ctx *gin.Context, completion pkg.ChatCompletion, template string) (message string, err error) {

0 commit comments

Comments
 (0)