@@ -34,6 +34,7 @@ import {
34
34
Part ,
35
35
SafetySetting ,
36
36
StreamGenerateContentResult ,
37
+ Tool ,
37
38
VertexInit ,
38
39
} from './types/content' ;
39
40
import {
@@ -134,7 +135,8 @@ export class VertexAI_Preview {
134
135
this ,
135
136
modelParams . model ,
136
137
modelParams . generation_config ,
137
- modelParams . safety_settings
138
+ modelParams . safety_settings ,
139
+ modelParams . tools
138
140
) ;
139
141
}
140
142
@@ -185,6 +187,7 @@ export declare interface StartChatParams {
185
187
history ?: Content [ ] ;
186
188
safety_settings ?: SafetySetting [ ] ;
187
189
generation_config ?: GenerationConfig ;
190
+ tools ?: Tool [ ] ;
188
191
}
189
192
190
193
// StartChatSessionRequest and ChatSession are defined here instead of in
@@ -216,6 +219,7 @@ export class ChatSession {
216
219
private _send_stream_promise : Promise < void > = Promise . resolve ( ) ;
217
220
generation_config ?: GenerationConfig ;
218
221
safety_settings ?: SafetySetting [ ] ;
222
+ tools ?: Tool [ ] ;
219
223
220
224
get history ( ) : Content [ ] {
221
225
return this . historyInternal ;
@@ -231,6 +235,9 @@ export class ChatSession {
231
235
this . _model_instance = request . _model_instance ;
232
236
this . historyInternal = request . history ?? [ ] ;
233
237
this . _vertex_instance = request . _vertex_instance ;
238
+ this . generation_config = request . generation_config ;
239
+ this . safety_settings = request . safety_settings ;
240
+ this . tools = request . tools ;
234
241
}
235
242
236
243
/**
@@ -241,11 +248,12 @@ export class ChatSession {
241
248
async sendMessage (
242
249
request : string | Array < string | Part >
243
250
) : Promise < GenerateContentResult > {
244
- const newContent : Content = formulateNewContent ( request ) ;
251
+ const newContent : Content [ ] = formulateNewContent ( request ) ;
245
252
const generateContentrequest : GenerateContentRequest = {
246
- contents : this . historyInternal . concat ( [ newContent ] ) ,
253
+ contents : this . historyInternal . concat ( newContent ) ,
247
254
safety_settings : this . safety_settings ,
248
255
generation_config : this . generation_config ,
256
+ tools : this . tools ,
249
257
} ;
250
258
251
259
const generateContentResult : GenerateContentResult =
@@ -257,7 +265,7 @@ export class ChatSession {
257
265
const generateContentResponse = generateContentResult . response ;
258
266
// Only push the latest message to history if the response returned a result
259
267
if ( generateContentResponse . candidates . length !== 0 ) {
260
- this . historyInternal . push ( newContent ) ;
268
+ this . historyInternal = this . historyInternal . concat ( newContent ) ;
261
269
const contentFromAssistant =
262
270
generateContentResponse . candidates [ 0 ] . content ;
263
271
if ( ! contentFromAssistant . role ) {
@@ -274,15 +282,15 @@ export class ChatSession {
274
282
275
283
async appendHistory (
276
284
streamGenerateContentResultPromise : Promise < StreamGenerateContentResult > ,
277
- newContent : Content
285
+ newContent : Content [ ]
278
286
) : Promise < void > {
279
287
const streamGenerateContentResult =
280
288
await streamGenerateContentResultPromise ;
281
289
const streamGenerateContentResponse =
282
290
await streamGenerateContentResult . response ;
283
291
// Only push the latest message to history if the response returned a result
284
292
if ( streamGenerateContentResponse . candidates . length !== 0 ) {
285
- this . historyInternal . push ( newContent ) ;
293
+ this . historyInternal = this . historyInternal . concat ( newContent ) ;
286
294
const contentFromAssistant =
287
295
streamGenerateContentResponse . candidates [ 0 ] . content ;
288
296
if ( ! contentFromAssistant . role ) {
@@ -303,11 +311,12 @@ export class ChatSession {
303
311
async sendMessageStream (
304
312
request : string | Array < string | Part >
305
313
) : Promise < StreamGenerateContentResult > {
306
- const newContent : Content = formulateNewContent ( request ) ;
314
+ const newContent : Content [ ] = formulateNewContent ( request ) ;
307
315
const generateContentrequest : GenerateContentRequest = {
308
- contents : this . historyInternal . concat ( [ newContent ] ) ,
316
+ contents : this . historyInternal . concat ( newContent ) ,
309
317
safety_settings : this . safety_settings ,
310
318
generation_config : this . generation_config ,
319
+ tools : this . tools ,
311
320
} ;
312
321
313
322
const streamGenerateContentResultPromise = this . _model_instance
@@ -335,6 +344,7 @@ export class GenerativeModel {
335
344
model : string ;
336
345
generation_config ?: GenerationConfig ;
337
346
safety_settings ?: SafetySetting [ ] ;
347
+ tools ?: Tool [ ] ;
338
348
private _vertex_instance : VertexAI_Preview ;
339
349
private _use_non_stream = false ;
340
350
private publisherModelEndpoint : string ;
@@ -351,12 +361,14 @@ export class GenerativeModel {
351
361
vertex_instance : VertexAI_Preview ,
352
362
model : string ,
353
363
generation_config ?: GenerationConfig ,
354
- safety_settings ?: SafetySetting [ ]
364
+ safety_settings ?: SafetySetting [ ] ,
365
+ tools ?: Tool [ ]
355
366
) {
356
367
this . _vertex_instance = vertex_instance ;
357
368
this . model = model ;
358
369
this . generation_config = generation_config ;
359
370
this . safety_settings = safety_settings ;
371
+ this . tools = tools ;
360
372
if ( model . startsWith ( 'models/' ) ) {
361
373
this . publisherModelEndpoint = `publishers/google/${ this . model } ` ;
362
374
} else {
@@ -401,6 +413,7 @@ export class GenerativeModel {
401
413
contents : request . contents ,
402
414
generation_config : request . generation_config ?? this . generation_config ,
403
415
safety_settings : request . safety_settings ?? this . safety_settings ,
416
+ tools : request . tools ?? [ ] ,
404
417
} ;
405
418
406
419
const response : Response | undefined = await postRequest ( {
@@ -444,6 +457,7 @@ export class GenerativeModel {
444
457
contents : request . contents ,
445
458
generation_config : request . generation_config ?? this . generation_config ,
446
459
safety_settings : request . safety_settings ?? this . safety_settings ,
460
+ tools : request . tools ?? [ ] ,
447
461
} ;
448
462
const response = await postRequest ( {
449
463
region : this . _vertex_instance . location ,
@@ -501,12 +515,15 @@ export class GenerativeModel {
501
515
request . generation_config ?? this . generation_config ;
502
516
startChatRequest . safety_settings =
503
517
request . safety_settings ?? this . safety_settings ;
518
+ startChatRequest . tools = request . tools ?? this . tools ;
504
519
}
505
520
return new ChatSession ( startChatRequest ) ;
506
521
}
507
522
}
508
523
509
- function formulateNewContent ( request : string | Array < string | Part > ) : Content {
524
+ function formulateNewContent (
525
+ request : string | Array < string | Part >
526
+ ) : Content [ ] {
510
527
let newParts : Part [ ] = [ ] ;
511
528
512
529
if ( typeof request === 'string' ) {
@@ -521,8 +538,38 @@ function formulateNewContent(request: string | Array<string | Part>): Content {
521
538
}
522
539
}
523
540
524
- const newContent : Content = { role : constants . USER_ROLE , parts : newParts } ;
525
- return newContent ;
541
+ return formatPartsByRole ( newParts ) ;
542
+ }
543
+
544
+ /**
545
+ * When multiple Part types (i.e. FunctionResponsePart and TextPart) are
546
+ * passed in a single Part array, we may need to assign different roles to each
547
+ * part. Currently only FunctionResponsePart requires a role other than 'user'.
548
+ * @ignore
549
+ * @param {Array<Part> } parts Array of parts to pass to the model
550
+ * @return {Content[] } Array of content items
551
+ */
552
+ function formatPartsByRole ( parts : Array < Part > ) : Content [ ] {
553
+ const partsByRole : Content [ ] = [ ] ;
554
+ const userContent : Content = { role : constants . USER_ROLE , parts : [ ] } ;
555
+ const functionContent : Content = { role : constants . FUNCTION_ROLE , parts : [ ] } ;
556
+
557
+ for ( const part of parts ) {
558
+ if ( 'functionResponse' in part ) {
559
+ functionContent . parts . push ( part ) ;
560
+ } else {
561
+ userContent . parts . push ( part ) ;
562
+ }
563
+ }
564
+
565
+ if ( userContent . parts . length > 0 ) {
566
+ partsByRole . push ( userContent ) ;
567
+ }
568
+ if ( functionContent . parts . length > 0 ) {
569
+ partsByRole . push ( functionContent ) ;
570
+ }
571
+
572
+ return partsByRole ;
526
573
}
527
574
528
575
function throwErrorIfNotOK ( response : Response | undefined ) {
0 commit comments