Skip to content

Commit 6007292

Browse files
bincooobincooo
bincooo
authored and
bincooo
committed
fix: cursor claude role error
1 parent 4f873c7 commit 6007292

File tree

4 files changed

+16
-4
lines changed

4 files changed

+16
-4
lines changed

relay/llm/cursor/adapter.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ func (api *api) Completion(ctx *gin.Context) (err error) {
123123
cookie = strings.Split(cookie, "%3A%3A")[1]
124124
}
125125

126-
buffer, err := convertRequest(completion)
126+
buffer, err := convertRequest(ctx, completion)
127127
if err != nil {
128128
return
129129
}

relay/llm/cursor/fetch.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package cursor
22

33
import (
4+
"chatgpt-adapter/core/gin/response"
45
"context"
56
"encoding/binary"
67
"fmt"
8+
"github.com/gin-gonic/gin"
79
"github.com/iocgo/sdk/stream"
810
"math/rand"
911
"net/http"
@@ -40,7 +42,12 @@ func fetch(ctx context.Context, proxied string, cookie string, buffer []byte) (r
4042
return
4143
}
4244

43-
func convertRequest(completion model.Completion) (buffer []byte, err error) {
45+
func convertRequest(ctx *gin.Context, completion model.Completion) (buffer []byte, err error) {
46+
specialized := ctx.GetBool("specialized")
47+
if specialized && response.IsClaude(ctx, completion.Model) {
48+
completion.Messages = completion.Messages[:1]
49+
completion.Messages[0].Set("role", "user")
50+
}
4451
messages := stream.Map(stream.OfSlice(completion.Messages), func(message model.Keyv[interface{}]) *ChatMessage_UserMessage {
4552
return &ChatMessage_UserMessage{
4653
MessageId: uuid.NewString(),

relay/llm/cursor/message.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ func echoMessages(ctx *gin.Context, completion model.Completion) {
208208

209209
func newScanner(body io.ReadCloser) (scanner *bufio.Scanner) {
210210
// 每个字节占8位
211-
// 00000011 第一个字节是占位符,应该是用来代表消息类型的 假定 0: 消息体/proto, 1: 系统提示词/gzip, 3: 错误标记/gzip
211+
// 00000011 第一个字节是占位符,应该是用来代表消息类型的 假定 0: 消息体/proto, 1: 系统提示词/gzip, 2、3: 错误标记/gzip
212212
// 00000000 00000000 00000010 11011000 4个字节描述包体大小
213213
scanner = bufio.NewScanner(body)
214214
var (
@@ -248,9 +248,14 @@ func newScanner(body io.ReadCloser) (scanner *bufio.Scanner) {
248248
return setup, []byte("event: error"), err
249249
}
250250

251+
if magic == 2 { // 内部异常信息
252+
return setup, []byte("event: error"), err
253+
}
254+
251255
if magic == 1 { // 系统提示词标记?
252256
return setup, []byte("event: system"), err
253257
}
258+
254259
// magic == 0
255260
return setup, []byte("event: message"), err
256261
}

relay/llm/cursor/toolcall.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func toolChoice(ctx *gin.Context, proxies, cookie string, completion model.Compl
2525
"content": message,
2626
},
2727
}
28-
messageBuffer, err := convertRequest(completion)
28+
messageBuffer, err := convertRequest(ctx, completion)
2929
if err != nil {
3030
return "", err
3131
}

0 commit comments

Comments
 (0)