Skip to content

Commit bdab535

Browse files
committed
feat: Add Embedding API (gemini and custom)
1 parent d8eed13 commit bdab535

File tree

10 files changed

+266
-7
lines changed

10 files changed

+266
-7
lines changed

cmd/command.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@ package main
22

33
import (
44
"chatgpt-adapter/internal/common"
5-
"chatgpt-adapter/internal/gin.handler"
5+
handler "chatgpt-adapter/internal/gin.handler"
66
"chatgpt-adapter/internal/vars"
77
"chatgpt-adapter/logger"
88
"chatgpt-adapter/pkg"
99
"fmt"
10+
1011
"github.com/sirupsen/logrus"
1112
"github.com/spf13/cobra"
1213
)

internal/common/gin.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ import (
44
"chatgpt-adapter/internal/vars"
55
"chatgpt-adapter/pkg"
66
"context"
7-
"github.com/gin-gonic/gin"
87
"time"
8+
9+
"github.com/gin-gonic/gin"
910
)
1011

1112
func GinDebugger(ctx *gin.Context) bool {
@@ -21,6 +22,11 @@ func GetGinCompletion(ctx *gin.Context) (value pkg.ChatCompletion) {
2122
return
2223
}
2324

25+
func GetGinEmbedding(ctx *gin.Context) (value pkg.EmbedRequest) {
26+
value, _ = GetGinValue[pkg.EmbedRequest](ctx, vars.GinEmbedding)
27+
return
28+
}
29+
2430
func GetGinGeneration(ctx *gin.Context) (value pkg.ChatGeneration) {
2531
value, _ = GetGinValue[pkg.ChatGeneration](ctx, vars.GinGeneration)
2632
return

internal/gin.handler/basic.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,16 @@ import (
1414
"encoding/hex"
1515
"encoding/json"
1616
"fmt"
17-
"github.com/gin-gonic/gin"
18-
"github.com/google/uuid"
1917
"io"
2018
"net/http"
2119
"net/http/httputil"
2220
"os"
2321
"slices"
2422
"strconv"
2523
"strings"
24+
25+
"github.com/gin-gonic/gin"
26+
"github.com/google/uuid"
2627
)
2728

2829
func Bind(port int, version, proxies string) {
@@ -43,6 +44,7 @@ func Bind(port int, version, proxies string) {
4344
route.POST("/v1/chat/completions", completions)
4445
route.POST("/v1/object/completions", completions)
4546
route.POST("/proxies/v1/chat/completions", completions)
47+
route.POST("/v1/embeddings", embedding)
4648
route.POST("v1/images/generations", generations)
4749
route.POST("v1/object/generations", generations)
4850
route.POST("proxies/v1/images/generations", generations)

internal/gin.handler/embedding.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package handler
2+
3+
import (
4+
"chatgpt-adapter/internal/gin.handler/response"
5+
"chatgpt-adapter/internal/vars"
6+
"chatgpt-adapter/logger"
7+
"chatgpt-adapter/pkg"
8+
"fmt"
9+
10+
"github.com/gin-gonic/gin"
11+
)
12+
13+
func embedding(ctx *gin.Context) {
14+
15+
var embedding pkg.EmbedRequest
16+
if err := ctx.BindJSON(&embedding); err != nil {
17+
logger.Error(err)
18+
response.Error(ctx, -1, err)
19+
return
20+
}
21+
_ = ctx.Request.Body.Close()
22+
ctx.Set(vars.GinEmbedding, embedding)
23+
24+
if !GlobalExtension.Match(ctx, embedding.Model) {
25+
response.Error(ctx, -1, fmt.Sprintf("model '%s' is not not yet supported", embedding.Model))
26+
return
27+
}
28+
29+
GlobalExtension.Embedding(ctx)
30+
}

internal/plugin/adapter.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"chatgpt-adapter/logger"
88
"chatgpt-adapter/pkg"
99
"fmt"
10+
1011
"github.com/bincooo/emit.io"
1112
"github.com/gin-gonic/gin"
1213
socketio "github.com/zishang520/socket.io/socket"
@@ -71,6 +72,7 @@ type Adapter interface {
7172
Models() []Model
7273
Completion(ctx *gin.Context)
7374
Generation(ctx *gin.Context)
75+
Embedding(ctx *gin.Context)
7476
}
7577

7678
type BaseAdapter struct {
@@ -90,6 +92,8 @@ func (BaseAdapter) Completion(*gin.Context) {
9092
func (BaseAdapter) Generation(*gin.Context) {
9193
}
9294

95+
func (BaseAdapter) Embedding(*gin.Context) {}
96+
9397
func (adapter ExtensionAdapter) Match(ctx *gin.Context, model string) bool {
9498
for _, extension := range adapter.Extensions {
9599
if extension.Match(ctx, model) {
@@ -117,6 +121,17 @@ func (adapter ExtensionAdapter) Completion(ctx *gin.Context) {
117121
response.Error(ctx, -1, fmt.Sprintf("model '%s' is not not yet supported", completion.Model))
118122
}
119123

124+
func (adapter ExtensionAdapter) Embedding(ctx *gin.Context) {
125+
embedding := common.GetGinEmbedding(ctx)
126+
for _, extension := range adapter.Extensions {
127+
if extension.Match(ctx, embedding.Model) {
128+
extension.Embedding(ctx)
129+
return
130+
}
131+
}
132+
response.Error(ctx, -1, fmt.Sprintf("model '%s' is not not yet supported", embedding.Model))
133+
}
134+
120135
func (adapter ExtensionAdapter) Messages(ctx *gin.Context) {
121136
completion := common.GetGinCompletion(ctx)
122137
for _, extension := range adapter.Extensions {

internal/plugin/llm/gemini/adapter.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@ import (
88
"chatgpt-adapter/logger"
99
"encoding/json"
1010
"errors"
11-
"github.com/gin-gonic/gin"
1211
"net/url"
1312
"strings"
13+
14+
"github.com/gin-gonic/gin"
1415
)
1516

1617
const MODEL = "gemini"
@@ -38,7 +39,7 @@ type API struct {
3839

3940
func (API) Match(_ *gin.Context, model string) bool {
4041
switch model {
41-
case "gemini-1.0-pro-latest", "gemini-1.5-pro-latest", "gemini-1.5-flash-latest":
42+
case "gemini-1.0-pro-latest", "gemini-1.5-pro-latest", "gemini-1.5-flash-latest", "models/text-embedding-004":
4243
return true
4344
default:
4445
return false
@@ -62,6 +63,11 @@ func (API) Models() []plugin.Model {
6263
Object: "model",
6364
Created: 1686935002,
6465
By: "gemini-adapter",
66+
}, {
67+
Id: "models/text-embedding-004",
68+
Object: "model",
69+
Created: 1686935002,
70+
By: "gemini-adapter",
6571
},
6672
}
6773
}
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
package gemini
2+
3+
import (
4+
"chatgpt-adapter/internal/common"
5+
"chatgpt-adapter/internal/plugin"
6+
"chatgpt-adapter/pkg"
7+
"encoding/json"
8+
"io"
9+
"net/http"
10+
11+
"github.com/bincooo/emit.io"
12+
"github.com/gin-gonic/gin"
13+
"github.com/pkg/errors"
14+
)
15+
16+
func ConvertOpenAIRequestToGemini(openAIReq *pkg.EmbedRequest, model string) (*GeminiEmbedBatchReq, error) {
17+
if openAIReq.EncodingFormat != "" && openAIReq.EncodingFormat != "float" {
18+
return nil, errors.New("unsupported encoding format")
19+
}
20+
reqs := make([]GeminiEmbedReq, 0)
21+
switch v := openAIReq.Input.(type) {
22+
case string:
23+
reqs = append(reqs, GeminiEmbedReq{
24+
Model: model,
25+
Content: GeminiContent{
26+
Parts: []GeminiContPart{{Text: v}},
27+
},
28+
})
29+
case []interface{}:
30+
for _, text := range v {
31+
if t, ok := text.(string); ok {
32+
reqs = append(reqs, GeminiEmbedReq{
33+
Model: model,
34+
Content: GeminiContent{
35+
Parts: []GeminiContPart{{Text: t}},
36+
},
37+
})
38+
} else {
39+
return nil, errors.Errorf("unsupported input type: %T", t)
40+
}
41+
}
42+
default:
43+
return nil, errors.Errorf("unsupported input type: %T", v)
44+
}
45+
46+
return &GeminiEmbedBatchReq{Requests: reqs}, nil
47+
}
48+
49+
func ConvertGeminiResponseToOpenAI(geminiResp *GeminiResp, model string) *EmbedResponse {
50+
openAIResp := &EmbedResponse{
51+
Object: "list",
52+
Model: model,
53+
}
54+
55+
for i, geminiResp := range geminiResp.Embeddings {
56+
openAIResp.Data = append(openAIResp.Data, &EmbedResponseData{
57+
Object: "embedding",
58+
Embedding: geminiResp.Values,
59+
Index: i,
60+
})
61+
}
62+
63+
openAIResp.Usage = &Usage{
64+
PromptTokens: 0,
65+
TotalTokens: 0,
66+
}
67+
68+
return openAIResp
69+
}
70+
71+
type GeminiEmbedBatchReq struct {
72+
Requests []GeminiEmbedReq `json:"requests"`
73+
}
74+
75+
type GeminiEmbedReq struct {
76+
Model string `json:"model"`
77+
Content GeminiContent `json:"content"`
78+
}
79+
80+
type GeminiContent struct {
81+
Parts []GeminiContPart `json:"parts"`
82+
}
83+
84+
type GeminiContPart struct {
85+
Text string `json:"text"`
86+
}
87+
88+
type EmbedResponseData struct {
89+
Object string `json:"object"`
90+
Embedding []float32 `json:"embedding"`
91+
Index int `json:"index"`
92+
}
93+
94+
type Usage struct {
95+
PromptTokens int `json:"prompt_tokens"`
96+
TotalTokens int `json:"total_tokens"`
97+
}
98+
99+
type EmbedResponse struct {
100+
Object string `json:"object"`
101+
Data []*EmbedResponseData `json:"data"`
102+
Model string `json:"model"`
103+
Usage *Usage `json:"usage"`
104+
}
105+
106+
type GeminiResp struct {
107+
Embeddings []GeminiEmbedding `json:"embeddings"`
108+
}
109+
110+
type GeminiEmbedding struct {
111+
Values []float32 `json:"values"`
112+
}
113+
114+
func (API) Embedding(ctx *gin.Context) {
115+
116+
openAIReq := common.GetGinEmbedding(ctx)
117+
var (
118+
token = ctx.GetString("token")
119+
proxies = ctx.GetString("proxies")
120+
)
121+
122+
geminiReq, err := ConvertOpenAIRequestToGemini(&openAIReq, openAIReq.Model)
123+
if err != nil {
124+
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Bad Request"})
125+
return
126+
}
127+
url := "https://generativelanguage.googleapis.com/v1beta/" +
128+
openAIReq.Model + ":batchEmbedContents?key=" + token
129+
resp, err := emit.ClientBuilder(plugin.HTTPClient).
130+
Proxies(proxies).
131+
Context(common.GetGinContext(ctx)).
132+
POST(url).
133+
JHeader().
134+
Body(geminiReq).DoC(emit.Status(http.StatusOK))
135+
136+
if err != nil {
137+
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
138+
return
139+
}
140+
respBytes, err := io.ReadAll(resp.Body)
141+
if err != nil {
142+
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
143+
return
144+
}
145+
var geminiResp GeminiResp
146+
json.Unmarshal(respBytes, &geminiResp)
147+
openAIResp := ConvertGeminiResponseToOpenAI(&geminiResp, openAIReq.Model)
148+
149+
ctx.JSON(http.StatusOK, openAIResp)
150+
}

internal/plugin/llm/v1/adapter.go

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,13 @@ import (
55
"chatgpt-adapter/internal/gin.handler/response"
66
"chatgpt-adapter/internal/plugin"
77
"chatgpt-adapter/logger"
8-
"github.com/gin-gonic/gin"
8+
"chatgpt-adapter/pkg"
9+
"io"
10+
"net/http"
911
"strings"
12+
13+
"github.com/bincooo/emit.io"
14+
"github.com/gin-gonic/gin"
1015
)
1116

1217
var (
@@ -68,3 +73,38 @@ label:
6873
response.Error(ctx, -1, "EMPTY RESPONSE")
6974
}
7075
}
76+
77+
func (API) Embedding(ctx *gin.Context) {
78+
embedding := common.GetGinEmbedding(ctx)
79+
embedding.Model = embedding.Model[7:]
80+
var (
81+
token = ctx.GetString("token")
82+
proxies = ctx.GetString("proxies")
83+
baseUrl = pkg.Config.GetString("custom-llm.baseUrl")
84+
useProxy = pkg.Config.GetBool("custom-llm.useProxy")
85+
)
86+
if !useProxy {
87+
proxies = ""
88+
}
89+
resp, err := emit.ClientBuilder(plugin.HTTPClient).
90+
Proxies(proxies).
91+
Context(common.GetGinContext(ctx)).
92+
POST(baseUrl+"/v1/embeddings").
93+
Header("Authorization", "Bearer "+token).
94+
JHeader().
95+
Body(embedding).DoC(emit.Status(http.StatusOK))
96+
if err != nil {
97+
ctx.JSON(http.StatusBadGateway, gin.H{
98+
"error": "can't send request to upstream",
99+
})
100+
}
101+
ctx.Header("Content-Type", "application/json; charset=utf-8")
102+
content, err := io.ReadAll(resp.Body)
103+
if err != nil {
104+
ctx.JSON(http.StatusBadGateway, gin.H{
105+
"error": "can't read from upstream",
106+
})
107+
}
108+
ctx.Writer.Write(content)
109+
ctx.Writer.Flush()
110+
}

internal/vars/com.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ const (
1111

1212
GinCompletion = "__completion__"
1313
GinGeneration = "__generation__"
14+
GinEmbedding = "__embedding__"
1415
GinMatchers = "__matchers__"
1516
GinCompletionUsage = "__completion-usage__"
1617
GinDebugger = "__debug__"

pkg/model.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ type ChatGeneration struct {
2828
Quality string `json:"quality"`
2929
}
3030

31+
type EmbedRequest struct {
32+
Input interface{} `json:"input"`
33+
Model string `json:"model"`
34+
EncodingFormat string `json:"encoding_format,omitempty"`
35+
Dimensions int `json:"dimensions,omitempty"`
36+
User string `json:"user,omitempty"`
37+
}
38+
3139
type Keyv[V any] map[string]V
3240

3341
type ChatResponse struct {

0 commit comments

Comments
 (0)