Skip to content

Commit d32755e

Browse files
yyyu-googlecopybara-github
authored andcommitted
feat: include grounding metadata to stream aggregated response.
PiperOrigin-RevId: 615967064
1 parent b7b79fa commit d32755e

File tree

2 files changed

+177
-16
lines changed

2 files changed

+177
-16
lines changed

src/functions/post_fetch_processing.ts

+40
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import {
2222
GenerateContentCandidate,
2323
GenerateContentResponse,
2424
GenerateContentResult,
25+
GroundingMetadata,
2526
Part,
2627
StreamGenerateContentResult,
2728
} from '../types/content';
@@ -235,13 +236,52 @@ function aggregateResponses(
235236
}
236237
}
237238
}
239+
const groundingMetadataAggregated: GroundingMetadata | undefined =
240+
aggregateGroundingMetadataForCandidate(
241+
response.candidates[i],
242+
aggregatedResponse.candidates[i]
243+
);
244+
if (groundingMetadataAggregated) {
245+
aggregatedResponse.candidates[i].groundingMetadata =
246+
groundingMetadataAggregated;
247+
}
238248
}
239249
}
240250
aggregatedResponse.promptFeedback =
241251
responses[responses.length - 1].promptFeedback;
242252
return aggregatedResponse;
243253
}
244254

255+
function aggregateGroundingMetadataForCandidate(
256+
candidateChunk: GenerateContentCandidate,
257+
aggregatedCandidate: GenerateContentCandidate
258+
): GroundingMetadata | undefined {
259+
if (!candidateChunk.groundingMetadata) {
260+
return;
261+
}
262+
const emptyGroundingMetadata: GroundingMetadata = {
263+
webSearchQueries: [],
264+
groundingAttributions: [],
265+
};
266+
const groundingMetadataAggregated: GroundingMetadata =
267+
aggregatedCandidate.groundingMetadata ?? emptyGroundingMetadata;
268+
const groundingMetadataChunk: GroundingMetadata =
269+
candidateChunk.groundingMetadata!;
270+
if (groundingMetadataChunk.webSearchQueries) {
271+
groundingMetadataAggregated.webSearchQueries =
272+
groundingMetadataAggregated.webSearchQueries!.concat(
273+
groundingMetadataChunk.webSearchQueries
274+
);
275+
}
276+
if (groundingMetadataChunk.groundingAttributions) {
277+
groundingMetadataAggregated.groundingAttributions =
278+
groundingMetadataAggregated.groundingAttributions!.concat(
279+
groundingMetadataChunk.groundingAttributions
280+
);
281+
}
282+
return groundingMetadataAggregated;
283+
}
284+
245285
function addCandidateFunctionCalls(
246286
response: GenerateContentResponse
247287
): GenerateContentResponse {

system_test/end_to_end_sample_test.ts

+137-16
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,15 @@ const vertexAI = new VertexAI({
113113
location: LOCATION,
114114
});
115115

116+
const TEXT_MODEL_NAME = 'gemini-1.0-pro';
116117
const generativeTextModel = vertexAI.getGenerativeModel({
117-
model: 'gemini-1.0-pro',
118+
model: TEXT_MODEL_NAME,
118119
generationConfig: {
119120
maxOutputTokens: 256,
120121
},
121122
});
122123
const generativeTextModelPreview = vertexAI.preview.getGenerativeModel({
123-
model: 'gemini-1.0-pro',
124+
model: TEXT_MODEL_NAME,
124125
generationConfig: {
125126
maxOutputTokens: 256,
126127
},
@@ -138,12 +139,6 @@ const generativeTextModelWithPrefixPreview =
138139
maxOutputTokens: 256,
139140
},
140141
});
141-
const textModelNoOutputLimit = vertexAI.getGenerativeModel({
142-
model: 'gemini-1.0-pro',
143-
});
144-
const textModelNoOutputLimitPreview = vertexAI.preview.getGenerativeModel({
145-
model: 'gemini-1.0-pro',
146-
});
147142
const generativeVisionModel = vertexAI.getGenerativeModel({
148143
model: 'gemini-1.0-pro-vision',
149144
});
@@ -442,6 +437,70 @@ describe('generateContentStream', () => {
442437
expect(item.candidates[0].functionCalls!).toEqual(functionCalls!);
443438
}
444439
});
440+
it('should return grounding metadata when passed GoogleSearchRetriever in getGenerativeModel', async () => {
441+
const generativeTextModel = vertexAI.getGenerativeModel({
442+
model: TEXT_MODEL_NAME,
443+
tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL,
444+
});
445+
const result = await generativeTextModel.generateContentStream({
446+
contents: [{role: 'user', parts: [{text: 'Why is the sky blue?'}]}],
447+
});
448+
const response = await result.response;
449+
const groundingMetadata = response.candidates[0].groundingMetadata;
450+
expect(!!groundingMetadata).toBeTruthy();
451+
if (groundingMetadata) {
452+
expect(!!groundingMetadata.groundingAttributions).toBeTruthy();
453+
expect(!!groundingMetadata.webSearchQueries).toBeTruthy();
454+
}
455+
});
456+
it('should return grounding metadata when passed GoogleSearchRetriever in generateContent', async () => {
457+
const generativeTextModel = vertexAI.getGenerativeModel({
458+
model: TEXT_MODEL_NAME,
459+
});
460+
const result = await generativeTextModel.generateContentStream({
461+
contents: [{role: 'user', parts: [{text: 'Why is the sky blue?'}]}],
462+
tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL,
463+
});
464+
const response = await result.response;
465+
const groundingMetadata = response.candidates[0].groundingMetadata;
466+
expect(!!groundingMetadata).toBeTruthy();
467+
if (groundingMetadata) {
468+
expect(!!groundingMetadata.groundingAttributions).toBeTruthy();
469+
expect(!!groundingMetadata.webSearchQueries).toBeTruthy();
470+
}
471+
});
472+
it('in preview should return grounding metadata when passed GoogleSearchRetriever in getGenerativeModel', async () => {
473+
const generativeTextModel = vertexAI.preview.getGenerativeModel({
474+
model: TEXT_MODEL_NAME,
475+
tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL,
476+
});
477+
const result = await generativeTextModel.generateContentStream({
478+
contents: [{role: 'user', parts: [{text: 'Why is the sky blue?'}]}],
479+
});
480+
const response = await result.response;
481+
const groundingMetadata = response.candidates[0].groundingMetadata;
482+
expect(!!groundingMetadata).toBeTruthy();
483+
if (groundingMetadata) {
484+
expect(!!groundingMetadata.groundingAttributions).toBeTruthy();
485+
expect(!!groundingMetadata.webSearchQueries).toBeTruthy();
486+
}
487+
});
488+
it('in preview should return grounding metadata when passed GoogleSearchRetriever in generateContent', async () => {
489+
const generativeTextModel = vertexAI.preview.getGenerativeModel({
490+
model: TEXT_MODEL_NAME,
491+
});
492+
const result = await generativeTextModel.generateContentStream({
493+
contents: [{role: 'user', parts: [{text: 'Why is the sky blue?'}]}],
494+
tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL,
495+
});
496+
const response = await result.response;
497+
const groundingMetadata = response.candidates[0].groundingMetadata;
498+
expect(!!groundingMetadata).toBeTruthy();
499+
if (groundingMetadata) {
500+
expect(!!groundingMetadata.groundingAttributions).toBeTruthy();
501+
expect(!!groundingMetadata.webSearchQueries).toBeTruthy();
502+
}
503+
});
445504
});
446505

447506
describe('generateContent', () => {
@@ -467,7 +526,7 @@ describe('generateContent', () => {
467526
});
468527
it('should return grounding metadata when passed GoogleSearchRetriever in getGenerativeModel', async () => {
469528
const generativeTextModel = vertexAI.getGenerativeModel({
470-
model: 'gemini-pro',
529+
model: TEXT_MODEL_NAME,
471530
tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL,
472531
});
473532
const result = await generativeTextModel.generateContent({
@@ -483,7 +542,7 @@ describe('generateContent', () => {
483542
});
484543
it('should return grounding metadata when passed GoogleSearchRetriever in generateContent', async () => {
485544
const generativeTextModel = vertexAI.getGenerativeModel({
486-
model: 'gemini-pro',
545+
model: TEXT_MODEL_NAME,
487546
});
488547
const result = await generativeTextModel.generateContent({
489548
contents: [{role: 'user', parts: [{text: 'Why is the sky blue?'}]}],
@@ -499,7 +558,7 @@ describe('generateContent', () => {
499558
});
500559
it('in preview should return grounding metadata when passed GoogleSearchRetriever in getGenerativeModel', async () => {
501560
const generativeTextModel = vertexAI.preview.getGenerativeModel({
502-
model: 'gemini-pro',
561+
model: TEXT_MODEL_NAME,
503562
tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL,
504563
});
505564
const result = await generativeTextModel.generateContent({
@@ -515,7 +574,7 @@ describe('generateContent', () => {
515574
});
516575
it('in preview should return grounding metadata when passed GoogleSearchRetriever in generateContent', async () => {
517576
const generativeTextModel = vertexAI.preview.getGenerativeModel({
518-
model: 'gemini-pro',
577+
model: TEXT_MODEL_NAME,
519578
});
520579
const result = await generativeTextModel.generateContent({
521580
contents: [{role: 'user', parts: [{text: 'Why is the sky blue?'}]}],
@@ -630,7 +689,7 @@ describe('sendMessage', () => {
630689
});
631690
it('should return grounding metadata when passed GoogleSearchRetriever in getGenerativeModel', async () => {
632691
const generativeTextModel = vertexAI.getGenerativeModel({
633-
model: 'gemini-pro',
692+
model: TEXT_MODEL_NAME,
634693
tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL,
635694
});
636695
const chat = generativeTextModel.startChat();
@@ -645,7 +704,7 @@ describe('sendMessage', () => {
645704
});
646705
it('should return grounding metadata when passed GoogleSearchRetriever in startChat', async () => {
647706
const generativeTextModel = vertexAI.getGenerativeModel({
648-
model: 'gemini-pro',
707+
model: TEXT_MODEL_NAME,
649708
});
650709
const chat = generativeTextModel.startChat({
651710
tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL,
@@ -661,7 +720,7 @@ describe('sendMessage', () => {
661720
});
662721
it('in preview should return grounding metadata when passed GoogleSearchRetriever in getGenerativeModel', async () => {
663722
const generativeTextModel = vertexAI.preview.getGenerativeModel({
664-
model: 'gemini-pro',
723+
model: TEXT_MODEL_NAME,
665724
tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL,
666725
});
667726
const chat = generativeTextModel.startChat();
@@ -676,7 +735,7 @@ describe('sendMessage', () => {
676735
});
677736
it('in preview should return grounding metadata when passed GoogleSearchRetriever in startChat', async () => {
678737
const generativeTextModel = vertexAI.preview.getGenerativeModel({
679-
model: 'gemini-pro',
738+
model: TEXT_MODEL_NAME,
680739
});
681740
const chat = generativeTextModel.startChat({
682741
tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL,
@@ -828,6 +887,68 @@ describe('sendMessageStream', () => {
828887
JSON.stringify(response2.candidates[0].content.parts[0].text)
829888
).toContain(WEATHER_FORECAST);
830889
});
890+
it('should return grounding metadata when passed GoogleSearchRetriever in getGenerativeModel', async () => {
891+
const generativeTextModel = vertexAI.getGenerativeModel({
892+
model: TEXT_MODEL_NAME,
893+
tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL,
894+
});
895+
const chat = generativeTextModel.startChat();
896+
const result = await chat.sendMessageStream('Why is the sky blue?');
897+
const response = await result.response;
898+
const groundingMetadata = response.candidates[0].groundingMetadata;
899+
expect(!!groundingMetadata).toBeTruthy();
900+
if (groundingMetadata) {
901+
expect(!!groundingMetadata.groundingAttributions).toBeTruthy();
902+
expect(!!groundingMetadata.webSearchQueries).toBeTruthy();
903+
}
904+
});
905+
it('should return grounding metadata when passed GoogleSearchRetriever in startChat', async () => {
906+
const generativeTextModel = vertexAI.getGenerativeModel({
907+
model: TEXT_MODEL_NAME,
908+
});
909+
const chat = generativeTextModel.startChat({
910+
tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL,
911+
});
912+
const result = await chat.sendMessageStream('Why is the sky blue?');
913+
const response = await result.response;
914+
const groundingMetadata = response.candidates[0].groundingMetadata;
915+
expect(!!groundingMetadata).toBeTruthy();
916+
if (groundingMetadata) {
917+
expect(!!groundingMetadata.groundingAttributions).toBeTruthy();
918+
expect(!!groundingMetadata.webSearchQueries).toBeTruthy();
919+
}
920+
});
921+
it('in preview should return grounding metadata when passed GoogleSearchRetriever in getGenerativeModel', async () => {
922+
const generativeTextModel = vertexAI.preview.getGenerativeModel({
923+
model: TEXT_MODEL_NAME,
924+
tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL,
925+
});
926+
const chat = generativeTextModel.startChat();
927+
const result = await chat.sendMessageStream('Why is the sky blue?');
928+
const response = await result.response;
929+
const groundingMetadata = response.candidates[0].groundingMetadata;
930+
expect(!!groundingMetadata).toBeTruthy();
931+
if (groundingMetadata) {
932+
expect(!!groundingMetadata.groundingAttributions).toBeTruthy();
933+
expect(!!groundingMetadata.webSearchQueries).toBeTruthy();
934+
}
935+
});
936+
it('in preview should return grounding metadata when passed GoogleSearchRetriever in startChat', async () => {
937+
const generativeTextModel = vertexAI.preview.getGenerativeModel({
938+
model: TEXT_MODEL_NAME,
939+
});
940+
const chat = generativeTextModel.startChat({
941+
tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL,
942+
});
943+
const result = await chat.sendMessageStream('Why is the sky blue?');
944+
const response = await result.response;
945+
const groundingMetadata = response.candidates[0].groundingMetadata;
946+
expect(!!groundingMetadata).toBeTruthy();
947+
if (groundingMetadata) {
948+
expect(!!groundingMetadata.groundingAttributions).toBeTruthy();
949+
expect(!!groundingMetadata.webSearchQueries).toBeTruthy();
950+
}
951+
});
831952
});
832953

833954
describe('countTokens', () => {

0 commit comments

Comments
 (0)