16
16
*/
17
17
18
18
import {
19
- Citation ,
20
19
CitationMetadata ,
20
+ Content ,
21
21
CountTokensResponse ,
22
22
GenerateContentCandidate ,
23
23
GenerateContentResponse ,
@@ -26,6 +26,7 @@ import {
26
26
Part ,
27
27
StreamGenerateContentResult ,
28
28
} from '../types/content' ;
29
+ import { constants } from '../util' ;
29
30
import { ClientError , GoogleGenerativeAIError } from '../types/errors' ;
30
31
31
32
export async function throwErrorIfNotOK ( response : Response | undefined ) {
@@ -57,7 +58,7 @@ async function* generateResponseSequence(
57
58
if ( done ) {
58
59
break ;
59
60
}
60
- yield addCandidateFunctionCalls ( value ) ;
61
+ yield addMissingFields ( value ) ;
61
62
}
62
63
}
63
64
@@ -183,18 +184,28 @@ export function aggregateResponses(
183
184
) ;
184
185
}
185
186
186
- const aggregatedResponse : GenerateContentResponse = {
187
- candidates : [ ] ,
188
- promptFeedback : lastResponse . promptFeedback ,
189
- usageMetadata : lastResponse . usageMetadata ,
190
- } ;
187
+ const aggregatedResponse : GenerateContentResponse = { } ;
188
+
189
+ if ( lastResponse . promptFeedback ) {
190
+ aggregatedResponse . promptFeedback = lastResponse . promptFeedback ;
191
+ }
192
+ if ( lastResponse . usageMetadata ) {
193
+ aggregatedResponse . usageMetadata = lastResponse . usageMetadata ;
194
+ }
195
+
191
196
for ( const response of responses ) {
197
+ if ( ! response . candidates || response . candidates . length === 0 ) {
198
+ continue ;
199
+ }
192
200
for ( let i = 0 ; i < response . candidates . length ; i ++ ) {
201
+ if ( ! aggregatedResponse . candidates ) {
202
+ aggregatedResponse . candidates = [ ] ;
203
+ }
193
204
if ( ! aggregatedResponse . candidates [ i ] ) {
194
205
aggregatedResponse . candidates [ i ] = {
195
- index : response . candidates [ i ] . index ,
206
+ index : response . candidates [ i ] . index ?? i ,
196
207
content : {
197
- role : response . candidates [ i ] . content . role ,
208
+ role : response . candidates [ i ] . content . role ?? constants . MODEL_ROLE ,
198
209
parts : [ { text : '' } ] ,
199
210
} ,
200
211
} as GenerateContentCandidate ;
@@ -246,8 +257,6 @@ export function aggregateResponses(
246
257
}
247
258
}
248
259
}
249
- aggregatedResponse . promptFeedback =
250
- responses [ responses . length - 1 ] . promptFeedback ;
251
260
return aggregatedResponse ;
252
261
}
253
262
@@ -304,6 +313,33 @@ function aggregateGroundingMetadataForCandidate(
304
313
return groundingMetadataAggregated ;
305
314
}
306
315
316
+ function addMissingIndexAndRole (
317
+ response : GenerateContentResponse
318
+ ) : GenerateContentResponse {
319
+ const generateContentResponse = response as GenerateContentResponse ;
320
+ if (
321
+ generateContentResponse . candidates &&
322
+ generateContentResponse . candidates . length > 0
323
+ ) {
324
+ generateContentResponse . candidates . forEach ( ( candidate , index ) => {
325
+ if ( candidate . index === undefined ) {
326
+ generateContentResponse . candidates ! [ index ] . index = index ;
327
+ }
328
+
329
+ if ( candidate . content === undefined ) {
330
+ generateContentResponse . candidates ! [ index ] . content = { } as Content ;
331
+ }
332
+
333
+ if ( candidate . content . role === undefined ) {
334
+ generateContentResponse . candidates ! [ index ] . content . role =
335
+ constants . MODEL_ROLE ;
336
+ }
337
+ } ) ;
338
+ }
339
+
340
+ return generateContentResponse ;
341
+ }
342
+
307
343
function addCandidateFunctionCalls (
308
344
response : GenerateContentResponse
309
345
) : GenerateContentResponse {
@@ -328,6 +364,13 @@ function addCandidateFunctionCalls(
328
364
return response ;
329
365
}
330
366
367
+ function addMissingFields (
368
+ response : GenerateContentResponse
369
+ ) : GenerateContentResponse {
370
+ const generateContentResponse = addMissingIndexAndRole ( response ) ;
371
+ return addCandidateFunctionCalls ( generateContentResponse ) ;
372
+ }
373
+
331
374
/**
332
375
* Process model responses from generateContent
333
376
* @ignore
@@ -338,8 +381,9 @@ export async function processUnary(
338
381
if ( response !== undefined ) {
339
382
// ts-ignore
340
383
const responseJson = await response . json ( ) ;
384
+ const generateContentResponse = addMissingIndexAndRole ( responseJson ) ;
341
385
return Promise . resolve ( {
342
- response : addCandidateFunctionCalls ( responseJson ) ,
386
+ response : addCandidateFunctionCalls ( generateContentResponse ) ,
343
387
} ) ;
344
388
}
345
389
0 commit comments