Skip to content

Commit 7a366ab

Browse files
yyyu-googlecopybara-github
authored andcommitted
fix: correct GenerateContentCandidate interface and GenerateContentResponse interface
PiperOrigin-RevId: 617875952
1 parent 0d3754a commit 7a366ab

File tree

8 files changed

+390
-213
lines changed

8 files changed

+390
-213
lines changed

src/functions/post_fetch_processing.ts

+56-12
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
*/
1717

1818
import {
19-
Citation,
2019
CitationMetadata,
20+
Content,
2121
CountTokensResponse,
2222
GenerateContentCandidate,
2323
GenerateContentResponse,
@@ -26,6 +26,7 @@ import {
2626
Part,
2727
StreamGenerateContentResult,
2828
} from '../types/content';
29+
import {constants} from '../util';
2930
import {ClientError, GoogleGenerativeAIError} from '../types/errors';
3031

3132
export async function throwErrorIfNotOK(response: Response | undefined) {
@@ -57,7 +58,7 @@ async function* generateResponseSequence(
5758
if (done) {
5859
break;
5960
}
60-
yield addCandidateFunctionCalls(value);
61+
yield addMissingFields(value);
6162
}
6263
}
6364

@@ -183,18 +184,28 @@ export function aggregateResponses(
183184
);
184185
}
185186

186-
const aggregatedResponse: GenerateContentResponse = {
187-
candidates: [],
188-
promptFeedback: lastResponse.promptFeedback,
189-
usageMetadata: lastResponse.usageMetadata,
190-
};
187+
const aggregatedResponse: GenerateContentResponse = {};
188+
189+
if (lastResponse.promptFeedback) {
190+
aggregatedResponse.promptFeedback = lastResponse.promptFeedback;
191+
}
192+
if (lastResponse.usageMetadata) {
193+
aggregatedResponse.usageMetadata = lastResponse.usageMetadata;
194+
}
195+
191196
for (const response of responses) {
197+
if (!response.candidates || response.candidates.length === 0) {
198+
continue;
199+
}
192200
for (let i = 0; i < response.candidates.length; i++) {
201+
if (!aggregatedResponse.candidates) {
202+
aggregatedResponse.candidates = [];
203+
}
193204
if (!aggregatedResponse.candidates[i]) {
194205
aggregatedResponse.candidates[i] = {
195-
index: response.candidates[i].index,
206+
index: response.candidates[i].index ?? i,
196207
content: {
197-
role: response.candidates[i].content.role,
208+
role: response.candidates[i].content.role ?? constants.MODEL_ROLE,
198209
parts: [{text: ''}],
199210
},
200211
} as GenerateContentCandidate;
@@ -246,8 +257,6 @@ export function aggregateResponses(
246257
}
247258
}
248259
}
249-
aggregatedResponse.promptFeedback =
250-
responses[responses.length - 1].promptFeedback;
251260
return aggregatedResponse;
252261
}
253262

@@ -304,6 +313,33 @@ function aggregateGroundingMetadataForCandidate(
304313
return groundingMetadataAggregated;
305314
}
306315

316+
function addMissingIndexAndRole(
317+
response: GenerateContentResponse
318+
): GenerateContentResponse {
319+
const generateContentResponse = response as GenerateContentResponse;
320+
if (
321+
generateContentResponse.candidates &&
322+
generateContentResponse.candidates.length > 0
323+
) {
324+
generateContentResponse.candidates.forEach((candidate, index) => {
325+
if (candidate.index === undefined) {
326+
generateContentResponse.candidates![index].index = index;
327+
}
328+
329+
if (candidate.content === undefined) {
330+
generateContentResponse.candidates![index].content = {} as Content;
331+
}
332+
333+
if (candidate.content.role === undefined) {
334+
generateContentResponse.candidates![index].content.role =
335+
constants.MODEL_ROLE;
336+
}
337+
});
338+
}
339+
340+
return generateContentResponse;
341+
}
342+
307343
function addCandidateFunctionCalls(
308344
response: GenerateContentResponse
309345
): GenerateContentResponse {
@@ -328,6 +364,13 @@ function addCandidateFunctionCalls(
328364
return response;
329365
}
330366

367+
function addMissingFields(
368+
response: GenerateContentResponse
369+
): GenerateContentResponse {
370+
const generateContentResponse = addMissingIndexAndRole(response);
371+
return addCandidateFunctionCalls(generateContentResponse);
372+
}
373+
331374
/**
332375
* Process model responses from generateContent
333376
* @ignore
@@ -338,8 +381,9 @@ export async function processUnary(
338381
if (response !== undefined) {
339382
// ts-ignore
340383
const responseJson = await response.json();
384+
const generateContentResponse = addMissingIndexAndRole(responseJson);
341385
return Promise.resolve({
342-
response: addCandidateFunctionCalls(responseJson),
386+
response: addCandidateFunctionCalls(generateContentResponse),
343387
});
344388
}
345389

src/functions/test/functions_test.ts

+7-5
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ describe('generateContent', () => {
585585
TEST_API_ENDPOINT
586586
);
587587
expect(
588-
resp.response.candidates[0].citationMetadata?.citations.length
588+
resp.response.candidates![0].citationMetadata?.citations.length
589589
).toEqual(
590590
TEST_MODEL_RESPONSE.candidates[0].citationMetadata.citations.length
591591
);
@@ -617,9 +617,9 @@ describe('generateContent', () => {
617617
TEST_API_ENDPOINT
618618
);
619619
expect(actualResult).toEqual(expectedResult);
620-
expect(actualResult.response.candidates[0].functionCalls).toHaveSize(1);
621-
expect(actualResult.response.candidates[0].functionCalls).toEqual([
622-
expectedResult.response.candidates[0].content.parts[0].functionCall!,
620+
expect(actualResult.response.candidates![0].functionCalls).toHaveSize(1);
621+
expect(actualResult.response.candidates![0].functionCalls).toEqual([
622+
expectedResult.response.candidates![0].content.parts[0].functionCall!,
623623
]);
624624
});
625625

@@ -652,7 +652,9 @@ describe('generateContent', () => {
652652
TEST_API_ENDPOINT
653653
);
654654
expect(actualResult).toEqual(expectedResult);
655-
expect(actualResult.response.candidates[0].functionCalls).not.toBeDefined();
655+
expect(
656+
actualResult.response.candidates![0].functionCalls
657+
).not.toBeDefined();
656658
});
657659

658660
it('returns empty candidates when response is empty', async () => {

src/functions/test/post_fetch_processing_test.ts

+50
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@
1818
import {
1919
AGGREGATED_RESPONSE_STREAM_RESPONSE_CHUNKS_1,
2020
AGGREGATED_RESPONSE_STREAM_RESPONSE_CHUNKS_2,
21+
AGGREGATED_RESPONSE_STREAM_RESPONSE_CHUNKS_3,
2122
COUNT_TOKENS_RESPONSE_1,
2223
STREAM_RESPONSE_CHUNKS_1,
2324
STREAM_RESPONSE_CHUNKS_2,
25+
STREAM_RESPONSE_CHUNKS_3,
2426
UNARY_RESPONSE_1,
27+
UNARY_RESPONSE_MISSING_ROLE_INDEX,
2528
} from './test_data';
2629
import * as PostFetchFunctions from '../post_fetch_processing';
2730
import {aggregateResponses} from '../post_fetch_processing';
@@ -60,6 +63,18 @@ describe('aggregateResponses', () => {
6063
JSON.stringify(AGGREGATED_RESPONSE_STREAM_RESPONSE_CHUNKS_2)
6164
);
6265
});
66+
67+
it('missing candidates, should return {}', () => {
68+
expect(aggregateResponses([{}, {}])).toEqual({});
69+
});
70+
71+
it('missing role and index, should add role and index', () => {
72+
const actualResult = aggregateResponses(STREAM_RESPONSE_CHUNKS_3);
73+
74+
expect(JSON.stringify(actualResult)).toEqual(
75+
JSON.stringify(AGGREGATED_RESPONSE_STREAM_RESPONSE_CHUNKS_3)
76+
);
77+
});
6378
});
6479

6580
describe('processUnary', () => {
@@ -80,6 +95,41 @@ describe('processUnary', () => {
8095

8196
expect(actualResponse).toEqual(UNARY_RESPONSE_1);
8297
});
98+
99+
it('response missing role and index, should add role and index', async () => {
100+
const fetchResult = new Response(
101+
JSON.stringify(UNARY_RESPONSE_MISSING_ROLE_INDEX),
102+
fetchResponseObj
103+
);
104+
const expectedResult = UNARY_RESPONSE_1;
105+
spyOn(global, 'fetch').and.resolveTo(fetchResult);
106+
const actualResult = await generateContent(
107+
LOCATION,
108+
PROJECT,
109+
PUBLISHER_MODEL_ENDPOINT,
110+
TOKEN,
111+
GENERATE_CONTENT_REQUEST
112+
);
113+
const actualResponse = actualResult.response;
114+
115+
expect(actualResponse).toEqual(expectedResult);
116+
});
117+
118+
it('candidate undefined, should return empty response', async () => {
119+
spyOn(global, 'fetch').and.resolveTo(
120+
new Response(JSON.stringify({}), fetchResponseObj)
121+
);
122+
const actualResult = await generateContent(
123+
LOCATION,
124+
PROJECT,
125+
PUBLISHER_MODEL_ENDPOINT,
126+
TOKEN,
127+
GENERATE_CONTENT_REQUEST
128+
);
129+
const actualResponse = actualResult.response;
130+
131+
expect(actualResponse).toEqual({});
132+
});
83133
});
84134

85135
describe('processStream', () => {

0 commit comments

Comments
 (0)