@@ -16,7 +16,7 @@ class VertexAI extends BaseLLM {
16
16
declare geminiInstance : Gemini ;
17
17
18
18
static defaultOptions : Partial < LLMOptions > | undefined = {
19
- maxEmbeddingBatchSize : 5 ,
19
+ maxEmbeddingBatchSize : 250 ,
20
20
region : "us-central1" ,
21
21
} ;
22
22
@@ -35,6 +35,13 @@ class VertexAI extends BaseLLM {
35
35
}
36
36
37
37
constructor ( _options : LLMOptions ) {
38
+ if ( _options . region !== "us-central1" ) {
39
+ // Any region outside of us-central1 has a max batch size of 5.
40
+ _options . maxEmbeddingBatchSize = Math . min (
41
+ _options . maxEmbeddingBatchSize ?? 5 ,
42
+ 5 ,
43
+ ) ;
44
+ }
38
45
super ( _options ) ;
39
46
this . apiBase ??= VertexAI . getDefaultApiBaseFrom ( _options ) ;
40
47
this . vertexProvider =
@@ -143,97 +150,16 @@ class VertexAI extends BaseLLM {
143
150
`publishers/google/models/${ options . model } :streamGenerateContent` ,
144
151
this . apiBase ,
145
152
) ;
146
- // This feels hacky to repeat code from above function but was the quickest
147
- // way to ensure system message re-formatting isn't done if user has specified v1
148
- const isV1API = this . apiBase . includes ( "/v1/" ) ;
149
153
150
- const contents = messages
151
- . map ( ( msg ) => {
152
- if ( msg . role === "system" && ! isV1API ) {
153
- return null ; // Don't include system message in contents
154
- }
155
- if ( msg . role === "tool" ) {
156
- return null ;
157
- }
158
-
159
- return {
160
- role : msg . role === "assistant" ? "model" : "user" ,
161
- parts :
162
- typeof msg . content === "string"
163
- ? [ { text : msg . content } ]
164
- : msg . content . map ( this . geminiInstance . continuePartToGeminiPart ) ,
165
- } ;
166
- } )
167
- . filter ( ( c ) => c !== null ) ;
168
-
169
- const body = {
170
- ...this . geminiInstance . convertArgs ( options ) ,
171
- contents,
172
- // if this.systemMessage is defined, reformat it for Gemini API
173
- ...( this . systemMessage &&
174
- ! isV1API && {
175
- systemInstruction : { parts : [ { text : this . systemMessage } ] } ,
176
- } ) ,
177
- } ;
154
+ const body = this . geminiInstance . prepareBody ( messages , options , false ) ;
178
155
const response = await this . fetch ( apiURL , {
179
156
method : "POST" ,
180
157
body : JSON . stringify ( body ) ,
181
158
} ) ;
182
-
183
- let buffer = "" ;
184
- for await ( const chunk of streamResponse ( response ) ) {
185
- buffer += chunk ;
186
- if ( buffer . startsWith ( "[" ) ) {
187
- buffer = buffer . slice ( 1 ) ;
188
- }
189
- if ( buffer . endsWith ( "]" ) ) {
190
- buffer = buffer . slice ( 0 , - 1 ) ;
191
- }
192
- if ( buffer . startsWith ( "," ) ) {
193
- buffer = buffer . slice ( 1 ) ;
194
- }
195
-
196
- const parts = buffer . split ( "\n," ) ;
197
-
198
- let foundIncomplete = false ;
199
- for ( let i = 0 ; i < parts . length ; i ++ ) {
200
- const part = parts [ i ] ;
201
- let data ;
202
- try {
203
- data = JSON . parse ( part ) ;
204
- } catch ( e ) {
205
- foundIncomplete = true ;
206
- continue ; // yo!
207
- }
208
- if ( data . error ) {
209
- throw new Error ( data . error . message ) ;
210
- }
211
- // Check for existence of each level before accessing the final 'text' property
212
- if ( data ?. candidates ?. [ 0 ] ?. content ?. parts ?. [ 0 ] ?. text ) {
213
- // Incrementally stream the content to make it smoother
214
- const content = data . candidates [ 0 ] . content . parts [ 0 ] . text ;
215
- const words = content . split ( / ( \s + ) / ) ;
216
- const delaySeconds = Math . min ( 4.0 / ( words . length + 1 ) , 0.1 ) ;
217
- while ( words . length > 0 ) {
218
- const wordsToYield = Math . min ( 3 , words . length ) ;
219
- yield {
220
- role : "assistant" ,
221
- content : words . splice ( 0 , wordsToYield ) . join ( "" ) ,
222
- } ;
223
- await delay ( delaySeconds ) ;
224
- }
225
- } else {
226
- // Handle the case where the expected data structure is not found
227
- if ( data ?. candidates ?. [ 0 ] ?. finishReason !== "STOP" ) {
228
- console . warn ( "Unexpected response format:" , data ) ;
229
- }
230
- }
231
- }
232
- if ( foundIncomplete ) {
233
- buffer = parts [ parts . length - 1 ] ;
234
- } else {
235
- buffer = "" ;
236
- }
159
+ for await ( const message of this . geminiInstance . processGeminiResponse (
160
+ streamResponse ( response ) ,
161
+ ) ) {
162
+ yield message ;
237
163
}
238
164
}
239
165
@@ -337,7 +263,9 @@ class VertexAI extends BaseLLM {
337
263
} ) ;
338
264
339
265
for await ( const chunk of streamSse ( response ) ) {
340
- yield chunk . choices [ 0 ] . delta . content ;
266
+ if ( chunk . choices ?. [ 0 ] . delta ) {
267
+ yield chunk . choices [ 0 ] . delta . content ;
268
+ }
341
269
}
342
270
}
343
271
@@ -432,7 +360,9 @@ class VertexAI extends BaseLLM {
432
360
}
433
361
434
362
supportsFim ( ) : boolean {
435
- return [ "code-gecko" , "codestral-latest" ] . includes ( this . model ) ;
363
+ return (
364
+ this . model . includes ( "code-gecko" ) || this . model . includes ( "codestral" )
365
+ ) ;
436
366
}
437
367
438
368
protected async _embed ( chunks : string [ ] ) : Promise < number [ ] [ ] > {
0 commit comments