Skip to content

Commit 89568a6

Browse files
happy-qiaocopybara-github
authored andcommitted
feat: Support functionCalls property in GenerationContentCandidate interface for non streaming mode
PiperOrigin-RevId: 613447415
1 parent f366687 commit 89568a6

File tree

4 files changed

+200
-9
lines changed

4 files changed

+200
-9
lines changed

src/functions/post_fetch_processing.ts

+25-2
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717

1818
import {
1919
CitationSource,
20+
Content,
2021
CountTokensResponse,
2122
GenerateContentCandidate,
2223
GenerateContentResponse,
2324
GenerateContentResult,
25+
Part,
2426
StreamGenerateContentResult,
2527
} from '../types/content';
2628
import {ClientError, GoogleGenerativeAIError} from '../types/errors';
@@ -54,7 +56,7 @@ async function* generateResponseSequence(
5456
if (done) {
5557
break;
5658
}
57-
yield value;
59+
yield addCandidateFunctionCalls(value);
5860
}
5961
}
6062

@@ -229,6 +231,27 @@ function aggregateResponses(
229231
return aggregatedResponse;
230232
}
231233

234+
function addCandidateFunctionCalls(
235+
response: GenerateContentResponse
236+
): GenerateContentResponse {
237+
for (const candidate of response.candidates) {
238+
if (
239+
!candidate.content ||
240+
!candidate.content.parts ||
241+
candidate.content.parts.length === 0
242+
) {
243+
continue;
244+
}
245+
const functionCalls = candidate.content.parts
246+
.filter((part: Part) => !!part.functionCall)
247+
.map((part: Part) => part.functionCall!);
248+
if (functionCalls.length > 0) {
249+
candidate.functionCalls = functionCalls;
250+
}
251+
}
252+
return response;
253+
}
254+
232255
/**
233256
* Process model responses from generateContent
234257
* @ignore
@@ -240,7 +263,7 @@ export async function processNonStream(
240263
// ts-ignore
241264
const responseJson = await response.json();
242265
return Promise.resolve({
243-
response: responseJson,
266+
response: addCandidateFunctionCalls(responseJson),
244267
});
245268
}
246269

src/functions/test/functions_test.ts

+57-3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import {
3131
StreamGenerateContentResult,
3232
Tool,
3333
} from '../../types';
34+
import {FunctionCall} from '../../types/content';
3435
import {constants} from '../../util';
3536
import {countTokens} from '../count_tokens';
3637
import {generateContent, generateContentStream} from '../generate_content';
@@ -117,6 +118,18 @@ const TEST_MODEL_RESPONSE = {
117118
candidates: TEST_CANDIDATES,
118119
usageMetadata: {promptTokenCount: 0, candidatesTokenCount: 0},
119120
};
121+
const TEST_CANDIDATE_WITH_INVALID_DATA = [
122+
{
123+
index: 1,
124+
content: {
125+
role: constants.MODEL_ROLE,
126+
parts: [],
127+
},
128+
},
129+
];
130+
const TEST_MODEL_RESPONSE_WITH_INVALID_DATA = {
131+
candidates: TEST_CANDIDATE_WITH_INVALID_DATA,
132+
};
120133
const TEST_FUNCTION_CALL_RESPONSE = {
121134
functionCall: {
122135
name: 'get_current_weather',
@@ -137,6 +150,7 @@ const TEST_CANDIDATES_WITH_FUNCTION_CALL = [
137150
finishReason: FinishReason.STOP,
138151
finishMessage: '',
139152
safetyRatings: TEST_SAFETY_RATINGS,
153+
functionCalls: [TEST_FUNCTION_CALL_RESPONSE.functionCall],
140154
},
141155
];
142156
const TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL = {
@@ -588,17 +602,57 @@ describe('generateContent', () => {
588602
response: TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL,
589603
};
590604
fetchSpy.and.resolveTo(
591-
buildFetchResponse(TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL)
605+
new Response(
606+
JSON.stringify(TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL),
607+
fetchResponseObj
608+
)
592609
);
593-
const resp = await generateContent(
610+
611+
const actualResult = await generateContent(
594612
TEST_LOCATION,
595613
TEST_PROJECT,
596614
TEST_PUBLISHER_MODEL_ENDPOINT,
597615
TEST_TOKEN_PROMISE,
598616
req,
599617
TEST_API_ENDPOINT
600618
);
601-
expect(resp).toEqual(expectedResult);
619+
expect(actualResult).toEqual(expectedResult);
620+
expect(actualResult.response.candidates[0].functionCalls).toHaveSize(1);
621+
expect(actualResult.response.candidates[0].functionCalls).toEqual([
622+
expectedResult.response.candidates[0].content.parts[0].functionCall!,
623+
]);
624+
});
625+
626+
it('returns a empty FunctionCall list when response contains invalid data', async () => {
627+
const req: GenerateContentRequest = {
628+
contents: [
629+
{
630+
role: 'user',
631+
parts: [{text: 'What is the weater like in Boston?'}],
632+
},
633+
],
634+
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION,
635+
};
636+
const expectedResult: GenerateContentResult = {
637+
response: TEST_MODEL_RESPONSE_WITH_INVALID_DATA,
638+
};
639+
fetchSpy.and.resolveTo(
640+
new Response(
641+
JSON.stringify(TEST_MODEL_RESPONSE_WITH_INVALID_DATA),
642+
fetchResponseObj
643+
)
644+
);
645+
646+
const actualResult = await generateContent(
647+
TEST_LOCATION,
648+
TEST_PROJECT,
649+
TEST_PUBLISHER_MODEL_ENDPOINT,
650+
TEST_TOKEN_PROMISE,
651+
req,
652+
TEST_API_ENDPOINT
653+
);
654+
expect(actualResult).toEqual(expectedResult);
655+
expect(actualResult.response.candidates[0].functionCalls).not.toBeDefined();
602656
});
603657

604658
it('throws ClientError when functionResponse is not immedidately following functionCall case1', async () => {

src/types/content.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -485,8 +485,8 @@ export declare interface GenerateContentCandidate {
485485
citationMetadata?: CitationMetadata;
486486
/** Optional. {@link GroundingMetadata}. */
487487
groundingMetadata?: GroundingMetadata;
488-
/** Optional. {@link FunctionResponse}. */
489-
functionCall?: FunctionCall;
488+
/* Optional. Array of {@link FunctionCall}. */
489+
functionCalls?: FunctionCall[];
490490
}
491491

492492
/**

system_test/end_to_end_sample_test.ts

+116-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import {
2020
ClientError,
2121
FunctionDeclarationsTool,
2222
GoogleSearchRetrievalTool,
23+
Part,
2324
TextPart,
2425
VertexAI,
2526
} from '../src';
@@ -361,7 +362,7 @@ describe('generateContentStream', () => {
361362
);
362363
});
363364

364-
it('should return a FunctionCall or text when passed a FunctionDeclaration or FunctionResponse', async () => {
365+
it('should return a text when passed a FunctionDeclaration or FunctionResponse', async () => {
365366
const request = {
366367
contents: [
367368
{role: 'user', parts: [{text: 'What is the weather in Boston?'}]},
@@ -381,7 +382,7 @@ describe('generateContentStream', () => {
381382
);
382383
}
383384
});
384-
it('in preview should return a FunctionCall or text when passed a FunctionDeclaration or FunctionResponse', async () => {
385+
it('in preview should return a text when passed a FunctionDeclaration or FunctionResponse', async () => {
385386
const request = {
386387
contents: [
387388
{role: 'user', parts: [{text: 'What is the weather in Boston?'}]},
@@ -401,6 +402,46 @@ describe('generateContentStream', () => {
401402
);
402403
}
403404
});
405+
it('should return a FunctionCall when passed a FunctionDeclaration', async () => {
406+
const request = {
407+
contents: [
408+
{role: 'user', parts: [{text: 'What is the weather in Boston?'}]},
409+
],
410+
tools: TOOLS_WITH_FUNCTION_DECLARATION,
411+
};
412+
const streamingResp =
413+
await generativeTextModel.generateContentStream(request);
414+
for await (const item of streamingResp.stream) {
415+
expect(item.candidates[0]).toBeTruthy(
416+
`sys test failure on generateContentStream, for item ${item}`
417+
);
418+
const functionCalls = item.candidates[0].content.parts
419+
.filter(part => !!part.functionCall)
420+
.map(part => part.functionCall!);
421+
expect(functionCalls).toHaveSize(1);
422+
expect(item.candidates[0].functionCalls!).toEqual(functionCalls!);
423+
}
424+
});
425+
it('in preview should return a FunctionCall when passed a FunctionDeclaration', async () => {
426+
const request = {
427+
contents: [
428+
{role: 'user', parts: [{text: 'What is the weather in Boston?'}]},
429+
],
430+
tools: TOOLS_WITH_FUNCTION_DECLARATION,
431+
};
432+
const streamingResp =
433+
await generativeTextModelPreview.generateContentStream(request);
434+
for await (const item of streamingResp.stream) {
435+
expect(item.candidates[0]).toBeTruthy(
436+
`sys test failure on generateContentStream in preview, for item ${item}`
437+
);
438+
const functionCalls = item.candidates[0].content.parts
439+
.filter(part => !!part.functionCall)
440+
.map(part => part.functionCall!);
441+
expect(functionCalls).toHaveSize(1);
442+
expect(item.candidates[0].functionCalls!).toEqual(functionCalls!);
443+
}
444+
});
404445
});
405446

406447
describe('generateContent', () => {
@@ -488,6 +529,79 @@ describe('generateContent', () => {
488529
expect(!!groundingMetadata.webSearchQueries).toBeTruthy();
489530
}
490531
});
532+
it('should return a text when passed a FunctionDeclaration or FunctionResponse', async () => {
533+
const request = {
534+
contents: [
535+
{role: 'user', parts: [{text: 'What is the weather in Boston?'}]},
536+
{role: 'model', parts: FUNCTION_CALL},
537+
{role: 'function', parts: FUNCTION_RESPONSE_PART},
538+
],
539+
tools: TOOLS_WITH_FUNCTION_DECLARATION,
540+
};
541+
const resp = await generativeTextModel.generateContent(request);
542+
543+
expect(resp.response.candidates[0]).toBeTruthy(
544+
`sys test failure on generateContentStream, for resp ${resp}`
545+
);
546+
expect(
547+
resp.response.candidates[0].content.parts[0].text?.toLowerCase()
548+
).toContain(WEATHER_FORECAST);
549+
});
550+
it('in preview should return a text when passed a FunctionDeclaration or FunctionResponse', async () => {
551+
const request = {
552+
contents: [
553+
{role: 'user', parts: [{text: 'What is the weather in Boston?'}]},
554+
{role: 'model', parts: FUNCTION_CALL},
555+
{role: 'function', parts: FUNCTION_RESPONSE_PART},
556+
],
557+
tools: TOOLS_WITH_FUNCTION_DECLARATION,
558+
};
559+
const resp = await generativeTextModelPreview.generateContent(request);
560+
expect(resp.response.candidates[0]).toBeTruthy(
561+
`sys test failure on generateContentStream in preview, for resp ${resp}`
562+
);
563+
const functionCalls = resp.response.candidates[0].content.parts
564+
.filter((part: Part) => !!part.functionCall)
565+
.map((part: Part) => part.functionCall!);
566+
expect(
567+
resp.response.candidates[0].content.parts[0].text?.toLowerCase()
568+
).toContain(WEATHER_FORECAST);
569+
});
570+
it('should return a FunctionCall when passed a FunctionDeclaration', async () => {
571+
const request = {
572+
contents: [
573+
{role: 'user', parts: [{text: 'What is the weather in Boston?'}]},
574+
],
575+
tools: TOOLS_WITH_FUNCTION_DECLARATION,
576+
};
577+
const resp = await generativeTextModel.generateContent(request);
578+
579+
expect(resp.response.candidates[0]).toBeTruthy(
580+
`sys test failure on generateContentStream, for resp ${resp}`
581+
);
582+
const functionCalls = resp.response.candidates[0].content.parts
583+
.filter((part: Part) => !!part.functionCall)
584+
.map((part: Part) => part.functionCall!);
585+
expect(functionCalls).toHaveSize(1);
586+
expect(resp.response.candidates[0].functionCalls!).toEqual(functionCalls!);
587+
});
588+
it('in preview should return a FunctionCall when passed a FunctionDeclaration', async () => {
589+
const request = {
590+
contents: [
591+
{role: 'user', parts: [{text: 'What is the weather in Boston?'}]},
592+
],
593+
tools: TOOLS_WITH_FUNCTION_DECLARATION,
594+
};
595+
const resp = await generativeTextModelPreview.generateContent(request);
596+
expect(resp.response.candidates[0]).toBeTruthy(
597+
`sys test failure on generateContentStream in preview, for resp ${resp}`
598+
);
599+
const functionCalls = resp.response.candidates[0].content.parts
600+
.filter((part: Part) => !!part.functionCall)
601+
.map((part: Part) => part.functionCall!);
602+
expect(functionCalls).toHaveSize(1);
603+
expect(resp.response.candidates[0].functionCalls!).toEqual(functionCalls!);
604+
});
491605
});
492606

493607
describe('sendMessage', () => {

0 commit comments

Comments
 (0)