Skip to content

Commit 11f217f

Browse files
committed
feat: gemini模型适配图片对话
1 parent 07b01bc commit 11f217f

File tree

4 files changed

+61
-10
lines changed

4 files changed

+61
-10
lines changed

internal/common/messages.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ func TextMessageCombiner[T any](
182182
Message map[string]string
183183
Buffer *bytes.Buffer
184184
Initial func() pkg.Keyv[interface{}]
185-
}{Previous: previous, Next: previous, Message: message, Buffer: buffer, Initial: func() pkg.Keyv[interface{}] {
185+
}{Previous: previous, Next: next, Message: message, Buffer: buffer, Initial: func() pkg.Keyv[interface{}] {
186186
if id, ok := message["id"]; ok {
187187
return sources[id]
188188
}
@@ -193,10 +193,7 @@ func TextMessageCombiner[T any](
193193
return nil, err
194194
}
195195

196-
if len(next) > 0 {
197-
newMessages = append(newMessages, nextMessages...)
198-
}
199-
196+
newMessages = append(newMessages, nextMessages...)
200197
previous = message["role"]
201198
}
202199
return

internal/plugin/llm/gemini/adapter.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,12 @@ func (API) Completion(ctx *gin.Context) {
7979
matchers = com.GetGinMatchers(ctx)
8080
)
8181

82-
newMessages, tokens := mergeMessages(completion.Messages)
82+
newMessages, tokens, err := mergeMessages(completion.Messages)
83+
if err != nil {
84+
response.Error(ctx, -1, err)
85+
return
86+
}
87+
8388
ctx.Set(ginTokens, tokens)
8489
r, err := build(ctx.Request.Context(), proxies, cookie, newMessages, completion)
8590
if err != nil {

internal/plugin/llm/gemini/fetch.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,9 @@ func build(ctx context.Context, proxies, token string, messages []map[string]int
187187
if res.StatusCode != http.StatusOK {
188188
h := res.Header
189189
if c := h.Get("content-type"); !strings.Contains(c, "text/html") {
190-
bts, e := io.ReadAll(res.Body)
190+
dataBytes, e := io.ReadAll(res.Body)
191191
if e == nil {
192-
logger.Errorf("%s", bts)
192+
logger.Errorf("%s", dataBytes)
193193
}
194194
}
195195
return nil, errors.New(res.Status)

internal/plugin/llm/gemini/message.go

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ func waitResponse(ctx *gin.Context, matchers []com.Matcher, partialResponse *htt
122122
}
123123
}
124124

125-
func mergeMessages(messages []pkg.Keyv[interface{}]) (newMessages []map[string]interface{}, tokens int) {
125+
func mergeMessages(messages []pkg.Keyv[interface{}]) (newMessages []map[string]interface{}, tokens int, err error) {
126126
// role类型转换
127127
condition := func(expr string) string {
128128
switch expr {
@@ -195,6 +195,55 @@ func mergeMessages(messages []pkg.Keyv[interface{}]) (newMessages []map[string]i
195195
return
196196
}
197197

198+
// 复合消息
199+
if _, ok := opts.Message["multi"]; ok && role == "user" {
200+
message := opts.Initial()
201+
values := message.GetSlice("content")
202+
if len(values) == 0 {
203+
return
204+
}
205+
206+
var multi []interface{}
207+
for _, value := range values {
208+
var keyv pkg.Keyv[interface{}]
209+
keyv, ok = value.(map[string]interface{})
210+
if !ok {
211+
continue
212+
}
213+
214+
if keyv.Is("type", "text") {
215+
multi = append(multi, map[string]interface{}{
216+
"text": keyv.GetString("text"),
217+
})
218+
}
219+
220+
if keyv.Is("type", "image_url") {
221+
o := keyv.GetKeyv("image_url")
222+
mime, data, e := com.LoadImageMeta(o.GetString("url"))
223+
if e != nil {
224+
err = e
225+
return
226+
}
227+
multi = append(multi, map[string]interface{}{
228+
"inlineData": map[string]interface{}{
229+
"mimeType": mime,
230+
"data": data,
231+
},
232+
})
233+
}
234+
}
235+
236+
if len(multi) == 0 {
237+
return
238+
}
239+
240+
result = append(result, map[string]interface{}{
241+
"role": condition("user"),
242+
"parts": multi,
243+
})
244+
return
245+
}
246+
198247
if role == "system" {
199248
result = append(result, map[string]interface{}{
200249
"role": "user",
@@ -228,6 +277,6 @@ func mergeMessages(messages []pkg.Keyv[interface{}]) (newMessages []map[string]i
228277
return
229278
}
230279

231-
newMessages, _ = com.TextMessageCombiner(messages, iterator)
280+
newMessages, err = com.TextMessageCombiner(messages, iterator)
232281
return
233282
}

0 commit comments

Comments
 (0)