Skip to content

Commit c1699c9

Browse files
authored
Merge pull request continuedev#4016 from Lash-L/vertexai_fixes
Fix VertexAI bugs and add function calling
2 parents 47fd482 + cef6296 commit c1699c9

File tree

6 files changed

+72
-120
lines changed

6 files changed

+72
-120
lines changed

core/llm/llms/Gemini.ts

+38-20
Original file line numberDiff line numberDiff line change
@@ -134,21 +134,11 @@ class Gemini extends BaseLLM {
134134
};
135135
}
136136

137-
private async *streamChatGemini(
137+
public prepareBody(
138138
messages: ChatMessage[],
139-
signal: AbortSignal,
140139
options: CompletionOptions,
141-
): AsyncGenerator<ChatMessage> {
142-
const apiURL = new URL(
143-
`models/${options.model}:streamGenerateContent?key=${this.apiKey}`,
144-
this.apiBase,
145-
);
146-
// This feels hacky to repeat code from above function but was the quickest
147-
// way to ensure system message re-formatting isn't done if user has specified v1
148-
const apiBase = this.apiBase || Gemini.defaultOptions.apiBase!; // Determine if it's a v1 API call based on apiBase
149-
const isV1API = apiBase.includes("/v1/");
150-
151-
// Convert chat messages to contents
140+
isV1API: boolean,
141+
): GeminiChatRequestBody {
152142
const body: GeminiChatRequestBody = {
153143
contents: messages
154144
.filter((msg) => !(msg.role === "system" && isV1API))
@@ -303,15 +293,14 @@ class Gemini extends BaseLLM {
303293
}
304294
}
305295
}
296+
return body;
297+
}
306298

307-
const response = await this.fetch(apiURL, {
308-
method: "POST",
309-
body: JSON.stringify(body),
310-
signal,
311-
});
312-
299+
public async *processGeminiResponse(
300+
stream: AsyncIterable<string>,
301+
): AsyncGenerator<ChatMessage> {
313302
let buffer = "";
314-
for await (const chunk of streamResponse(response)) {
303+
for await (const chunk of stream) {
315304
buffer += chunk;
316305
if (buffer.startsWith("[")) {
317306
buffer = buffer.slice(1);
@@ -425,6 +414,35 @@ class Gemini extends BaseLLM {
425414
}
426415
}
427416
}
417+
418+
private async *streamChatGemini(
419+
messages: ChatMessage[],
420+
signal: AbortSignal,
421+
options: CompletionOptions,
422+
): AsyncGenerator<ChatMessage> {
423+
const apiURL = new URL(
424+
`models/${options.model}:streamGenerateContent?key=${this.apiKey}`,
425+
this.apiBase,
426+
);
427+
// This feels hacky to repeat code from above function but was the quickest
428+
// way to ensure system message re-formatting isn't done if user has specified v1
429+
const apiBase = this.apiBase || Gemini.defaultOptions.apiBase!; // Determine if it's a v1 API call based on apiBase
430+
const isV1API = apiBase.includes("/v1/");
431+
432+
// Convert chat messages to contents
433+
const body = this.prepareBody(messages, options, isV1API);
434+
435+
const response = await this.fetch(apiURL, {
436+
method: "POST",
437+
body: JSON.stringify(body),
438+
signal,
439+
});
440+
for await (const message of this.processGeminiResponse(
441+
streamResponse(response),
442+
)) {
443+
yield message;
444+
}
445+
}
428446
private async *streamChatBison(
429447
messages: ChatMessage[],
430448
signal: AbortSignal,

core/llm/llms/VertexAI.ts

+19-89
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class VertexAI extends BaseLLM {
1616
declare geminiInstance: Gemini;
1717

1818
static defaultOptions: Partial<LLMOptions> | undefined = {
19-
maxEmbeddingBatchSize: 5,
19+
maxEmbeddingBatchSize: 250,
2020
region: "us-central1",
2121
};
2222

@@ -35,6 +35,13 @@ class VertexAI extends BaseLLM {
3535
}
3636

3737
constructor(_options: LLMOptions) {
38+
if (_options.region !== "us-central1") {
39+
// Any region outside of us-central1 has a max batch size of 5.
40+
_options.maxEmbeddingBatchSize = Math.min(
41+
_options.maxEmbeddingBatchSize ?? 5,
42+
5,
43+
);
44+
}
3845
super(_options);
3946
this.apiBase ??= VertexAI.getDefaultApiBaseFrom(_options);
4047
this.vertexProvider =
@@ -143,97 +150,16 @@ class VertexAI extends BaseLLM {
143150
`publishers/google/models/${options.model}:streamGenerateContent`,
144151
this.apiBase,
145152
);
146-
// This feels hacky to repeat code from above function but was the quickest
147-
// way to ensure system message re-formatting isn't done if user has specified v1
148-
const isV1API = this.apiBase.includes("/v1/");
149153

150-
const contents = messages
151-
.map((msg) => {
152-
if (msg.role === "system" && !isV1API) {
153-
return null; // Don't include system message in contents
154-
}
155-
if (msg.role === "tool") {
156-
return null;
157-
}
158-
159-
return {
160-
role: msg.role === "assistant" ? "model" : "user",
161-
parts:
162-
typeof msg.content === "string"
163-
? [{ text: msg.content }]
164-
: msg.content.map(this.geminiInstance.continuePartToGeminiPart),
165-
};
166-
})
167-
.filter((c) => c !== null);
168-
169-
const body = {
170-
...this.geminiInstance.convertArgs(options),
171-
contents,
172-
// if this.systemMessage is defined, reformat it for Gemini API
173-
...(this.systemMessage &&
174-
!isV1API && {
175-
systemInstruction: { parts: [{ text: this.systemMessage }] },
176-
}),
177-
};
154+
const body = this.geminiInstance.prepareBody(messages, options, false);
178155
const response = await this.fetch(apiURL, {
179156
method: "POST",
180157
body: JSON.stringify(body),
181158
});
182-
183-
let buffer = "";
184-
for await (const chunk of streamResponse(response)) {
185-
buffer += chunk;
186-
if (buffer.startsWith("[")) {
187-
buffer = buffer.slice(1);
188-
}
189-
if (buffer.endsWith("]")) {
190-
buffer = buffer.slice(0, -1);
191-
}
192-
if (buffer.startsWith(",")) {
193-
buffer = buffer.slice(1);
194-
}
195-
196-
const parts = buffer.split("\n,");
197-
198-
let foundIncomplete = false;
199-
for (let i = 0; i < parts.length; i++) {
200-
const part = parts[i];
201-
let data;
202-
try {
203-
data = JSON.parse(part);
204-
} catch (e) {
205-
foundIncomplete = true;
206-
continue; // yo!
207-
}
208-
if (data.error) {
209-
throw new Error(data.error.message);
210-
}
211-
// Check for existence of each level before accessing the final 'text' property
212-
if (data?.candidates?.[0]?.content?.parts?.[0]?.text) {
213-
// Incrementally stream the content to make it smoother
214-
const content = data.candidates[0].content.parts[0].text;
215-
const words = content.split(/(\s+)/);
216-
const delaySeconds = Math.min(4.0 / (words.length + 1), 0.1);
217-
while (words.length > 0) {
218-
const wordsToYield = Math.min(3, words.length);
219-
yield {
220-
role: "assistant",
221-
content: words.splice(0, wordsToYield).join(""),
222-
};
223-
await delay(delaySeconds);
224-
}
225-
} else {
226-
// Handle the case where the expected data structure is not found
227-
if (data?.candidates?.[0]?.finishReason !== "STOP") {
228-
console.warn("Unexpected response format:", data);
229-
}
230-
}
231-
}
232-
if (foundIncomplete) {
233-
buffer = parts[parts.length - 1];
234-
} else {
235-
buffer = "";
236-
}
159+
for await (const message of this.geminiInstance.processGeminiResponse(
160+
streamResponse(response),
161+
)) {
162+
yield message;
237163
}
238164
}
239165

@@ -337,7 +263,9 @@ class VertexAI extends BaseLLM {
337263
});
338264

339265
for await (const chunk of streamSse(response)) {
340-
yield chunk.choices[0].delta.content;
266+
if (chunk.choices?.[0].delta) {
267+
yield chunk.choices[0].delta.content;
268+
}
341269
}
342270
}
343271

@@ -432,7 +360,9 @@ class VertexAI extends BaseLLM {
432360
}
433361

434362
supportsFim(): boolean {
435-
return ["code-gecko", "codestral-latest"].includes(this.model);
363+
return (
364+
this.model.includes("code-gecko") || this.model.includes("codestral")
365+
);
436366
}
437367

438368
protected async _embed(chunks: string[]): Promise<number[][]> {

core/llm/toolSupport.ts

+4
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ export const PROVIDER_TOOL_SUPPORT: Record<
5757
// All gemini models support function calling
5858
return model.toLowerCase().includes("gemini");
5959
},
60+
vertexai: (model) => {
61+
// All gemini models except flash 2.0 lite support function calling
62+
return model.toLowerCase().includes("gemini") && !model.toLowerCase().includes("lite");;
63+
},
6064
bedrock: (model) => {
6165
// For Bedrock, only support Claude Sonnet models with versions 3.5/3-5 and 3.7/3-7
6266
if (

docs/i18n/zh-CN/docusaurus-plugin-content-docs/current/customize/model-providers/top-level/vertexai.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,13 @@ We recommend configuring **Codestral** or **code-gecko** as your autocomplete mo
4545

4646
## Embeddings model
4747

48-
We recommend configuring **text-embedding-004** as your embeddings model.
48+
We recommend configuring **text-embedding-005** as your embeddings model.
4949

5050
```json title="config.json"
5151
{
5252
"embeddingsProvider": {
5353
"provider": "vertexai",
54-
"model": "text-embedding-004",
54+
"model": "text-embedding-005",
5555
"projectId": "[PROJECT_ID]",
5656
"region": "us-central1"
5757
}

extensions/vscode/config_schema.json

+3-3
Original file line numberDiff line numberDiff line change
@@ -2855,11 +2855,11 @@
28552855
"description": "The name of your VertexAI project"
28562856
},
28572857
"region": {
2858-
"description": "The region your VertexAI model is hosted in - typically central1",
2859-
"default": "central1"
2858+
"description": "The region your VertexAI model is hosted in - typically us-central1",
2859+
"default": "us-central1"
28602860
},
28612861
"model": {
2862-
"default": "text-embedding-004"
2862+
"default": "text-embedding-005"
28632863
}
28642864
},
28652865
"required": ["projectId", "model", "region"]

packages/llm-info/src/providers/vertexai.ts

+6-6
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,17 @@ export const Gemini: ModelProvider = {
3030
},
3131
// embed
3232
{
33-
model: "text-embedding-004",
33+
model: "text-embedding-005",
3434
displayName: "Vertex Text Embedding",
3535
recommendedFor: ["embed"],
3636
},
3737
//autocomplete
3838
{
39-
model: "code-gecko",
40-
displayName: "VertexAI Code Gecko",
41-
recommendedFor: ["autocomplete"],
42-
maxCompletionTokens: 64,
43-
}
39+
model: "code-gecko",
40+
displayName: "VertexAI Code Gecko",
41+
recommendedFor: ["autocomplete"],
42+
maxCompletionTokens: 64,
43+
},
4444
],
4545
id: "gemini",
4646
displayName: "Gemini",

0 commit comments

Comments
 (0)