Skip to content

Commit a37e593

Browse files
committed
feat: 1.修改google模型名称,添加flash模型;2.添加tool增强标签,用于工具选择默认
1 parent 6c3c228 commit a37e593

File tree

7 files changed

+126
-66
lines changed

7 files changed

+126
-66
lines changed

flags.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,4 +275,14 @@ flag: histories
275275
"model": "coze",
276276
"stream": false
277277
}
278+
```
279+
280+
#### tools 工具 开启 默认选中模式,作用是让工具选择在不匹配时默认选择一个,仅支持无参工具
281+
```text
282+
flag: tool
283+
284+
attribute:
285+
id: (string) 指定tool_function里的name值,默认-1
286+
287+
<tool id="xxx">
278288
```

internal/agent/com.go

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,33 +13,44 @@ output: {{$value.content}}
1313
{{end -}}
1414
{{end}}
1515
16-
<Instruction>
17-
你是一个智能机器人,除了可以回答用户问题外,你还掌握工具的使用能力。有时候,你可以依赖工具的运行结果,来更准确的回答用户。
16+
17+
你是一个智能机器人,你专注于选择工具的给用户使用的能力。有时候,你可以依赖工具的运行结果,来更准确的回答用户。
1818
1919
工具使用了 JSON Schema 的格式声明,其中 toolId 是工具的 description 是工具的描述,parameters 是工具的参数,包括参数的类型和描述,required 是必填参数的列表。
2020
2121
请你根据工具描述,决定回答问题或是使用工具。在完成任务过程中,USER代表用户的输入,TOOL_RESPONSE代表工具运行结果。ASSISTANT 代表你的输出。
22+
{{- if eq .toolDef "-1" }}
2223
你的每次输出都必须以0,1开头,代表是否需要调用工具:
23-
0: 不使用工具,直接回答内容
24+
0: 不使用工具。
2425
1: 使用工具,返回工具调用的参数。
25-
26+
{{- else }}
27+
你的本次输必须以1开头,代表是否需要调用工具:
28+
0: 不使用工具。
29+
1: 使用工具,返回工具调用的参数。
30+
{{- end }}
2631
例如:
2732
28-
USER: 你好呀
29-
ANSWER: 0: 你好,有什么可以帮助你的么?
30-
USER: 今天杭州的天气如何
31-
ANSWER: 1: {"toolId":"testToolId",arguments:{"city": "杭州"}}
33+
USER: 你好呀 <|end|>
34+
{{- if eq .toolDef "-1" }}
35+
ANSWER: 0: <|end|>
36+
{{- else }}
37+
ANSWER: 1: {"toolId":"{{.toolDef}}",arguments:{}} <|end|>
38+
{{- end }}
39+
USER: 今天杭州的天气如何 <|end|>
40+
ANSWER: 1: {"toolId":"testToolId",arguments:{"city": "杭州"}} <|end|>
3241
TOOL_RESPONSE: """
3342
晴天......
3443
"""
35-
ANSWER: 0: 今天杭州是晴天。
36-
USER: 今天杭州的天气适合去哪里玩?
37-
ANSWER: 1: {"toolId":"testToolId2",arguments:{"query": "杭州 天气 去哪里玩"}}
44+
USER: 今天杭州的天气适合去哪里玩? <|end|>
45+
ANSWER: 1: {"toolId":"testToolId2",arguments:{"query": "杭州 天气 去哪里玩"}} <|end|>
3846
TOOL_RESPONSE: """
3947
晴天. 西湖、灵隐寺、千岛湖……
4048
"""
41-
ANSWER: 0: 今天杭州是晴天,适合去西湖、灵隐寺、千岛湖等地玩。
42-
</Instruction>
49+
{{- if eq .toolDef "-1" }}
50+
ANSWER: 0: <|end|>
51+
{{- else }}
52+
ANSWER: 1: {"toolId":"{{.toolDef}}",arguments:{}} <|end|>
53+
{{- end }}
4354
4455
现在,我们开始吧!下面是你本次可以使用的工具:
4556
@@ -58,13 +69,13 @@ ANSWER: 0: 今天杭州是晴天,适合去西湖、灵隐寺、千岛湖等地
5869
"type": "{{$v.type}}",
5970
"description": "{{$v.description}}"
6071
}
61-
{{end -}}
72+
{{- end }}
6273
}
6374
},
64-
"required": {{$value.function.parameters.required}}
75+
"required": [{{join $value.function.parameters.required ", " }}]
6576
},
66-
{{end -}}
67-
{{end -}}
77+
{{- end -}}
78+
{{- end}}
6879
]
6980
"""
7081

internal/common/parser.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,7 @@ func xmlFlagsToContents(ctx *gin.Context, messages []pkg.Keyv[interface{}]) (han
542542
"pad", // bing中使用的标记:填充引导对话,尝试避免道歉
543543
"notebook", // notebook模式
544544
"histories",
545+
"tool",
545546
})
546547
)
547548

@@ -671,6 +672,20 @@ func xmlFlagsToContents(ctx *gin.Context, messages []pkg.Keyv[interface{}]) (han
671672
}
672673
continue
673674
}
675+
676+
if node.t == XML_TYPE_X && node.tag == "tool" {
677+
id := "-1"
678+
if e, ok := node.attr["id"]; ok {
679+
if o, k := e.(string); k {
680+
id = o
681+
}
682+
}
683+
clean(content[node.index:node.end])
684+
if id != "-1" {
685+
ctx.Set("tool", id)
686+
}
687+
continue
688+
}
674689
}
675690
}
676691

internal/middle/gemini/adapter.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import (
2020
)
2121

2222
const MODEL = "gemini"
23-
const GOOGLE_BASE = "https://generativelanguage.googleapis.com/%s?alt=sse&key=%s"
2423
const login = "http://127.0.0.1:8081/v1/login"
2524

2625
var (
@@ -55,7 +54,7 @@ type API struct {
5554

5655
func (API) Match(_ *gin.Context, model string) bool {
5756
switch model {
58-
case "gemini-1.0", "gemini-1.5":
57+
case "gemini-1.0-pro-latest", "gemini-1.5-pro-latest", "gemini-1.5-flash-latest":
5958
return true
6059
default:
6160
return false
@@ -65,12 +64,17 @@ func (API) Match(_ *gin.Context, model string) bool {
6564
func (API) Models() []middle.Model {
6665
return []middle.Model{
6766
{
68-
Id: "gemini-1.0",
67+
Id: "gemini-1.0-pro-latest",
6968
Object: "model",
7069
Created: 1686935002,
7170
By: "gemini-adapter",
7271
}, {
73-
Id: "gemini-1.5",
72+
Id: "gemini-1.5-pro-latest",
73+
Object: "model",
74+
Created: 1686935002,
75+
By: "gemini-adapter",
76+
}, {
77+
Id: "gemini-1.5-flash-latest",
7478
Object: "model",
7579
Created: 1686935002,
7680
By: "gemini-adapter",

internal/middle/gemini/fetch.go

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import (
1414
"strings"
1515
)
1616

17+
const GOOGLE_BASE_FORMAT = "https://generativelanguage.googleapis.com/v1beta/models/%s:streamGenerateContent?alt=sse&key=%s"
18+
1719
type funcDecl struct {
1820
Name string `json:"name"`
1921
Description string `json:"description"`
@@ -25,35 +27,29 @@ type funcDecl struct {
2527
}
2628

2729
// 构建请求,返回响应
28-
func build(ctx context.Context, proxies, token string, messages []map[string]interface{}, req pkg.ChatCompletion) (*http.Response, error) {
29-
var (
30-
burl = fmt.Sprintf(GOOGLE_BASE, "v1beta/models/gemini-1.0-pro-latest:streamGenerateContent", token)
31-
)
32-
33-
if req.Model == "gemini-1.5" {
34-
burl = fmt.Sprintf(GOOGLE_BASE, "v1beta/models/gemini-1.5-pro-latest:streamGenerateContent", token)
35-
}
30+
func build(ctx context.Context, proxies, token string, messages []map[string]interface{}, completion pkg.ChatCompletion) (*http.Response, error) {
31+
gURL := fmt.Sprintf(GOOGLE_BASE_FORMAT, completion.Model, token)
3632

37-
if req.Temperature < 0.1 {
38-
req.Temperature = 1
33+
if completion.Temperature < 0.1 {
34+
completion.Temperature = 1
3935
}
4036

41-
if req.MaxTokens == 0 {
42-
req.MaxTokens = 2048
37+
if completion.MaxTokens == 0 {
38+
completion.MaxTokens = 2048
4339
}
4440

45-
if req.TopK == 0 {
46-
req.TopK = 100
41+
if completion.TopK == 0 {
42+
completion.TopK = 100
4743
}
4844

49-
if req.TopP == 0 {
50-
req.TopP = 0.95
45+
if completion.TopP == 0 {
46+
completion.TopP = 0.95
5147
}
5248

5349
// 参数基本与openai对齐
5450
_funcDecls := make([]funcDecl, 0)
55-
if toolsL := len(req.Tools); toolsL > 0 {
56-
for _, v := range req.Tools {
51+
if toolsL := len(completion.Tools); toolsL > 0 {
52+
for _, v := range completion.Tools {
5753
kv := v.GetKeyv("function").GetKeyv("parameters")
5854
required, ok := kv.Get("required")
5955
if !ok {
@@ -93,10 +89,10 @@ func build(ctx context.Context, proxies, token string, messages []map[string]int
9389
payload := map[string]any{
9490
"contents": messages, // [ { role: user, parts: [ { text: 'xxx' } ] } ]
9591
"generationConfig": map[string]any{
96-
"topK": req.TopK,
97-
"topP": req.TopP,
98-
"temperature": req.Temperature, // 0.8
99-
"maxOutputTokens": req.MaxTokens,
92+
"topK": completion.TopK,
93+
"topP": completion.TopP,
94+
"temperature": completion.Temperature, // 0.8
95+
"maxOutputTokens": completion.MaxTokens,
10096
"stopSequences": []string{},
10197
},
10298
// 安全级别
@@ -137,7 +133,7 @@ func build(ctx context.Context, proxies, token string, messages []map[string]int
137133
res, err := emit.ClientBuilder().
138134
Proxies(proxies).
139135
Context(ctx).
140-
POST(burl).
136+
POST(gURL).
141137
JHeader().
142138
Bytes(marshal).
143139
Do()

internal/middle/lmsys/fetch.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ func fetchCookies(ctx context.Context, proxies string) (cookies string) {
316316
DoS(http.StatusOK)
317317
if err != nil {
318318
var e emit.Error
319+
logrus.Errorf("retry[%d]: %v", index, err)
319320
if errors.As(err, &e) && e.Code == 429 {
320321
return
321322
}

internal/middle/toolcall.go

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package middle
22

33
import (
44
"encoding/json"
5+
"fmt"
56
"github.com/bincooo/chatgpt-adapter/v2/internal/agent"
67
"github.com/bincooo/chatgpt-adapter/v2/internal/common"
78
"github.com/bincooo/chatgpt-adapter/v2/internal/vars"
@@ -12,31 +13,56 @@ import (
1213
"time"
1314
)
1415

15-
func buildTemplate(tools []pkg.Keyv[interface{}], messages []pkg.Keyv[interface{}], template string, max int) (message string, err error) {
16-
pMessages := messages
16+
func buildTemplate(ctx *gin.Context, completion pkg.ChatCompletion, template string, max int) (message string, err error) {
17+
toolDef := ctx.GetString("tool")
18+
if toolDef == "" {
19+
toolDef = "-1"
20+
}
21+
22+
pMessages := completion.Messages
1723
content := "continue"
18-
if messageL := len(messages); messageL > 0 && messages[messageL-1]["role"] == "user" {
19-
content = messages[messageL-1].GetString("content")
24+
if messageL := len(pMessages); messageL > 0 && pMessages[messageL-1]["role"] == "user" {
25+
content = pMessages[messageL-1].GetString("content")
2026
if max == 0 {
2127
pMessages = make([]pkg.Keyv[interface{}], 0)
2228
} else if max > 0 && messageL > max {
23-
pMessages = messages[messageL-max : messageL-1]
29+
pMessages = pMessages[messageL-max : messageL-1]
2430
} else {
25-
pMessages = messages[:messageL-1]
31+
pMessages = pMessages[:messageL-1]
2632
}
2733
}
2834

29-
for _, t := range tools {
30-
if !t.GetKeyv("function").Has("id") {
31-
t.GetKeyv("function").Set("id", common.RandStr(5))
35+
for _, t := range completion.Tools {
36+
id := common.RandStr(5)
37+
fn := t.GetKeyv("function")
38+
if !fn.Has("id") {
39+
t.GetKeyv("function").Set("id", id)
40+
} else {
41+
id = fn.GetString("id")
42+
}
43+
44+
if toolDef != "-1" && fn.Has("name") {
45+
if toolDef == fn.GetString("name") {
46+
toolDef = id
47+
}
3248
}
3349
}
3450

3551
parser := templateBuilder().
36-
Vars("tools", tools).
52+
Vars("toolDef", toolDef).
53+
Vars("tools", completion.Tools).
3754
Vars("pMessages", pMessages).
3855
Vars("content", content).
39-
Do()
56+
Func("join", func(slice []interface{}, sep string) string {
57+
if len(slice) == 0 {
58+
return ""
59+
}
60+
var result []string
61+
for _, v := range slice {
62+
result = append(result, fmt.Sprintf("\"%v\"", v))
63+
}
64+
return strings.Join(result, sep)
65+
}).Do()
4066
return parser(template)
4167
}
4268

@@ -45,11 +71,8 @@ func buildTemplate(tools []pkg.Keyv[interface{}], messages []pkg.Keyv[interface{
4571
// return:
4672
// bool > 是否执行了工具
4773
// error > 执行异常
48-
func CompleteToolCalls(ctx *gin.Context, req pkg.ChatCompletion, callback func(message string) (string, error)) (bool, error) {
49-
message, err := buildTemplate(
50-
req.Tools,
51-
req.Messages,
52-
agent.ToolCall, 5)
74+
func CompleteToolCalls(ctx *gin.Context, completion pkg.ChatCompletion, callback func(message string) (string, error)) (bool, error) {
75+
message, err := buildTemplate(ctx, completion, agent.ToolCall, 5)
5376
if err != nil {
5477
return false, err
5578
}
@@ -64,10 +87,10 @@ func CompleteToolCalls(ctx *gin.Context, req pkg.ChatCompletion, callback func(m
6487
ctx.Set(vars.GinCompletionUsage, common.CalcUsageTokens(content, previousTokens))
6588

6689
// 解析参数
67-
return parseToToolCall(ctx, content, req), nil
90+
return parseToToolCall(ctx, content, completion), nil
6891
}
6992

70-
func parseToToolCall(ctx *gin.Context, content string, req pkg.ChatCompletion) bool {
93+
func parseToToolCall(ctx *gin.Context, content string, completion pkg.ChatCompletion) bool {
7194
j := ""
7295
created := time.Now().Unix()
7396
slice := strings.Split(content, "TOOL_RESPONSE")
@@ -87,7 +110,7 @@ func parseToToolCall(ctx *gin.Context, content string, req pkg.ChatCompletion) b
87110
}
88111

89112
name := ""
90-
for _, t := range req.Tools {
113+
for _, t := range completion.Tools {
91114
if strings.Contains(j, t.GetKeyv("function").GetString("id")) {
92115
name = t.GetKeyv("function").GetString("name")
93116
break
@@ -113,11 +136,11 @@ func parseToToolCall(ctx *gin.Context, content string, req pkg.ChatCompletion) b
113136
}
114137
bytes, _ := json.Marshal(obj)
115138

116-
if req.Stream {
117-
SSEToolCallResponse(ctx, req.Model, name, string(bytes), created)
139+
if completion.Stream {
140+
SSEToolCallResponse(ctx, completion.Model, name, string(bytes), created)
118141
return true
119142
} else {
120-
ToolCallResponse(ctx, req.Model, name, string(bytes))
143+
ToolCallResponse(ctx, completion.Model, name, string(bytes))
121144
return true
122145
}
123146
}

0 commit comments

Comments
 (0)