Skip to content

Commit bbaf78a

Browse files
yyyu-googlecopybara-github
authored andcommitted
fix: pass tools from getGenerativeModel and startChat methods to top level functions
PiperOrigin-RevId: 613263123
1 parent 4e46bc4 commit bbaf78a

File tree

6 files changed

+78
-15
lines changed

6 files changed

+78
-15
lines changed

src/functions/generate_content.ts

+7-4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import {
2828
RequestOptions,
2929
SafetySetting,
3030
StreamGenerateContentResult,
31+
Tool,
3132
} from '../types/content';
3233
import {GoogleGenerativeAIError} from '../types/errors';
3334
import * as constants from '../util/constants';
@@ -53,6 +54,7 @@ export async function generateContent(
5354
apiEndpoint?: string,
5455
generation_config?: GenerationConfig,
5556
safety_settings?: SafetySetting[],
57+
tools?: Tool[],
5658
requestOptions?: RequestOptions
5759
): Promise<GenerateContentResult> {
5860
request = formatContentRequest(request, generation_config, safety_settings);
@@ -69,9 +71,9 @@ export async function generateContent(
6971
contents: request.contents,
7072
generation_config: request.generation_config ?? generation_config,
7173
safety_settings: request.safety_settings ?? safety_settings,
72-
tools: request.tools ?? [],
74+
tools: request.tools ?? tools,
7375
};
74-
const apiVersion = request.tools ? 'v1beta1' : 'v1';
76+
const apiVersion = generateContentRequest.tools ? 'v1beta1' : 'v1';
7577
const response: Response | undefined = await postRequest({
7678
region: location,
7779
project: project,
@@ -107,6 +109,7 @@ export async function generateContentStream(
107109
apiEndpoint?: string,
108110
generation_config?: GenerationConfig,
109111
safety_settings?: SafetySetting[],
112+
tools?: Tool[],
110113
requestOptions?: RequestOptions
111114
): Promise<StreamGenerateContentResult> {
112115
request = formatContentRequest(request, generation_config, safety_settings);
@@ -122,9 +125,9 @@ export async function generateContentStream(
122125
contents: request.contents,
123126
generation_config: request.generation_config ?? generation_config,
124127
safety_settings: request.safety_settings ?? safety_settings,
125-
tools: request.tools ?? [],
128+
tools: request.tools ?? tools,
126129
};
127-
const apiVersion = request.tools ? 'v1beta1' : 'v1';
130+
const apiVersion = generateContentRequest.tools ? 'v1beta1' : 'v1';
128131
const response = await postRequest({
129132
region: location,
130133
project: project,

src/functions/test/functions_test.ts

+4
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ const TEST_MULTIPART_MESSAGE_BASE64 = [
191191
},
192192
];
193193

194+
const TEST_EMPTY_TOOLS: Tool[] = [];
195+
194196
const TEST_TOOLS_WITH_FUNCTION_DECLARATION: Tool[] = [
195197
{
196198
function_declarations: [
@@ -393,6 +395,7 @@ describe('generateContent', () => {
393395
TEST_API_ENDPOINT,
394396
TEST_GENERATION_CONFIG,
395397
TEST_SAFETY_SETTINGS,
398+
TEST_EMPTY_TOOLS,
396399
TEST_REQUEST_OPTIONS
397400
)
398401
).toBeRejected();
@@ -702,6 +705,7 @@ describe('generateContentStream', () => {
702705
TEST_API_ENDPOINT,
703706
TEST_GENERATION_CONFIG,
704707
TEST_SAFETY_SETTINGS,
708+
TEST_EMPTY_TOOLS,
705709
TEST_REQUEST_OPTIONS
706710
)
707711
).toBeRejected();

src/models/chat_session.ts

+4
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ export class ChatSession {
120120
this.api_endpoint,
121121
this.generation_config,
122122
this.safety_settings,
123+
this.tools,
123124
this.requestOptions
124125
).catch(e => {
125126
throw e;
@@ -192,6 +193,7 @@ export class ChatSession {
192193
this.api_endpoint,
193194
this.generation_config,
194195
this.safety_settings,
196+
this.tools,
195197
this.requestOptions
196198
).catch(e => {
197199
throw e;
@@ -286,6 +288,7 @@ export class ChatSessionPreview {
286288
this.api_endpoint,
287289
this.generation_config,
288290
this.safety_settings,
291+
this.tools,
289292
this.requestOptions
290293
).catch(e => {
291294
throw e;
@@ -358,6 +361,7 @@ export class ChatSessionPreview {
358361
this.api_endpoint,
359362
this.generation_config,
360363
this.safety_settings,
364+
this.tools,
361365
this.requestOptions
362366
).catch(e => {
363367
throw e;

src/models/generative_models.ts

+5
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ export class GenerativeModel {
106106
this.apiEndpoint,
107107
this.generation_config,
108108
this.safety_settings,
109+
this.tools,
109110
this.requestOptions
110111
);
111112
}
@@ -128,6 +129,7 @@ export class GenerativeModel {
128129
this.apiEndpoint,
129130
this.generation_config,
130131
this.safety_settings,
132+
this.tools,
131133
this.requestOptions
132134
);
133135
}
@@ -162,6 +164,7 @@ export class GenerativeModel {
162164
location: this.location,
163165
googleAuth: this.googleAuth,
164166
publisher_model_endpoint: this.publisherModelEndpoint,
167+
tools: this.tools,
165168
};
166169

167170
if (request) {
@@ -243,6 +246,7 @@ export class GenerativeModelPreview {
243246
this.apiEndpoint,
244247
this.generation_config,
245248
this.safety_settings,
249+
this.tools,
246250
this.requestOptions
247251
);
248252
}
@@ -265,6 +269,7 @@ export class GenerativeModelPreview {
265269
this.apiEndpoint,
266270
this.generation_config,
267271
this.safety_settings,
272+
this.tools,
268273
this.requestOptions
269274
);
270275
}

src/models/test/models_test.ts

+8-8
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ describe('GenerativeModel generateContent', () => {
377377
);
378378
await modelWithRequestOptions.generateContent(req);
379379
// @ts-ignore
380-
expect(generateContentSpy.calls.allArgs()[0][8].timeoutMillis).toEqual(0);
380+
expect(generateContentSpy.calls.allArgs()[0][9].timeoutMillis).toEqual(0);
381381
});
382382
it('returns a GenerateContentResponse when passed a string', async () => {
383383
const expectedResult: GenerateContentResult = {
@@ -615,7 +615,7 @@ describe('GenerativeModelPreview generateContent', () => {
615615
);
616616
await modelWithRequestOptions.generateContent(req);
617617
// @ts-ignore
618-
expect(generateContentSpy.calls.allArgs()[0][8].timeoutMillis).toEqual(0);
618+
expect(generateContentSpy.calls.allArgs()[0][9].timeoutMillis).toEqual(0);
619619
});
620620
it('returns a GenerateContentResponse when passed a string', async () => {
621621
const expectedResult: GenerateContentResult = {
@@ -858,7 +858,7 @@ describe('GenerativeModel generateContentStream', () => {
858858
);
859859
await modelWithRequestOptions.generateContentStream(req);
860860
// @ts-ignore
861-
expect(generateContentSpy.calls.allArgs()[0][8].timeoutMillis).toEqual(0);
861+
expect(generateContentSpy.calls.allArgs()[0][9].timeoutMillis).toEqual(0);
862862
});
863863
it('returns a GenerateContentResponse when passed a string', async () => {
864864
const expectedResult: StreamGenerateContentResult = {
@@ -1017,7 +1017,7 @@ describe('GenerativeModelPreview generateContentStream', () => {
10171017
);
10181018
await modelWithRequestOptions.generateContentStream(req);
10191019
// @ts-ignore
1020-
expect(generateContentSpy.calls.allArgs()[0][8].timeoutMillis).toEqual(0);
1020+
expect(generateContentSpy.calls.allArgs()[0][9].timeoutMillis).toEqual(0);
10211021
});
10221022

10231023
it('returns a GenerateContentResponse when passed a string', async () => {
@@ -1189,7 +1189,7 @@ describe('ChatSession', () => {
11891189
expect(chatSessionWithRequestOptions.requestOptions).toEqual(
11901190
TEST_REQUEST_OPTIONS
11911191
);
1192-
expect(generateContentSpy.calls.allArgs()[0][8].timeoutMillis).toEqual(0);
1192+
expect(generateContentSpy.calls.allArgs()[0][9].timeoutMillis).toEqual(0);
11931193
});
11941194

11951195
it('returns a GenerateContentResponse and appends to history when startChat is passed with no args', async () => {
@@ -1364,7 +1364,7 @@ describe('ChatSession', () => {
13641364
expect(chatSessionWithRequestOptions.requestOptions).toEqual(
13651365
TEST_REQUEST_OPTIONS
13661366
);
1367-
expect(generateContentSpy.calls.allArgs()[0][8].timeoutMillis).toEqual(0);
1367+
expect(generateContentSpy.calls.allArgs()[0][9].timeoutMillis).toEqual(0);
13681368
});
13691369
it('returns a StreamGenerateContentResponse and appends role if missing', async () => {
13701370
const req = 'How are you doing today?';
@@ -1528,7 +1528,7 @@ describe('ChatSessionPreview', () => {
15281528
expect(chatSessionWithRequestOptions.requestOptions).toEqual(
15291529
TEST_REQUEST_OPTIONS
15301530
);
1531-
expect(generateContentSpy.calls.allArgs()[0][8].timeoutMillis).toEqual(0);
1531+
expect(generateContentSpy.calls.allArgs()[0][9].timeoutMillis).toEqual(0);
15321532
});
15331533

15341534
it('returns a GenerateContentResponse and appends to history when startChat is passed with no args', async () => {
@@ -1703,7 +1703,7 @@ describe('ChatSessionPreview', () => {
17031703
expect(chatSessionWithRequestOptions.requestOptions).toEqual(
17041704
TEST_REQUEST_OPTIONS
17051705
);
1706-
expect(generateContentSpy.calls.allArgs()[0][8].timeoutMillis).toEqual(0);
1706+
expect(generateContentSpy.calls.allArgs()[0][9].timeoutMillis).toEqual(0);
17071707
});
17081708
it('returns a StreamGenerateContentResponse and appends role if missing', async () => {
17091709
const req = 'How are you doing today?';

system_test/end_to_end_sample_test.ts

+50-3
Original file line numberDiff line numberDiff line change
@@ -424,10 +424,57 @@ describe('generateContent', () => {
424424
`sys test failure on generateContentStream in preview for aggregated response: ${aggregatedResp}`
425425
);
426426
});
427-
xit('should return grounding metadata when passed GoogleSearchRetriever or Retriever', async () => {
427+
it('should return grounding metadata when passed GoogleSearchRetriever in getGenerativeModel', async () => {
428428
const generativeTextModel = vertex_ai.getGenerativeModel({
429429
model: 'gemini-pro',
430-
//tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL,
430+
tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL,
431+
});
432+
const result = await generativeTextModel.generateContent({
433+
contents: [{role: 'user', parts: [{text: 'Why is the sky blue?'}]}],
434+
});
435+
const response = result.response;
436+
const groundingMetadata = response.candidates[0].groundingMetadata;
437+
expect(groundingMetadata).toBeDefined();
438+
if (groundingMetadata) {
439+
expect(groundingMetadata.groundingAttributions).toBeTruthy();
440+
expect(groundingMetadata.webSearchQueries).toBeTruthy();
441+
}
442+
});
443+
it('should return grounding metadata when passed GoogleSearchRetriever in generateContent', async () => {
444+
const generativeTextModel = vertex_ai.getGenerativeModel({
445+
model: 'gemini-pro',
446+
});
447+
const result = await generativeTextModel.generateContent({
448+
contents: [{role: 'user', parts: [{text: 'Why is the sky blue?'}]}],
449+
tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL,
450+
});
451+
const response = result.response;
452+
const groundingMetadata = response.candidates[0].groundingMetadata;
453+
expect(groundingMetadata).toBeDefined();
454+
if (groundingMetadata) {
455+
expect(groundingMetadata.groundingAttributions).toBeTruthy();
456+
expect(groundingMetadata.webSearchQueries).toBeTruthy();
457+
}
458+
});
459+
it('in preview should return grounding metadata when passed GoogleSearchRetriever in getGenerativeModel', async () => {
460+
const generativeTextModel = vertex_ai.preview.getGenerativeModel({
461+
model: 'gemini-pro',
462+
tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL,
463+
});
464+
const result = await generativeTextModel.generateContent({
465+
contents: [{role: 'user', parts: [{text: 'Why is the sky blue?'}]}],
466+
});
467+
const response = result.response;
468+
const groundingMetadata = response.candidates[0].groundingMetadata;
469+
expect(groundingMetadata).toBeDefined();
470+
if (groundingMetadata) {
471+
expect(groundingMetadata.groundingAttributions).toBeTruthy();
472+
expect(groundingMetadata.webSearchQueries).toBeTruthy();
473+
}
474+
});
475+
it('in preview should return grounding metadata when passed GoogleSearchRetriever in generateContent', async () => {
476+
const generativeTextModel = vertex_ai.preview.getGenerativeModel({
477+
model: 'gemini-pro',
431478
});
432479
const result = await generativeTextModel.generateContent({
433480
contents: [{role: 'user', parts: [{text: 'Why is the sky blue?'}]}],
@@ -437,7 +484,7 @@ describe('generateContent', () => {
437484
const groundingMetadata = response.candidates[0].groundingMetadata;
438485
expect(groundingMetadata).toBeDefined();
439486
if (groundingMetadata) {
440-
// expect(groundingMetadata.groundingAttributions).toBeTruthy();
487+
expect(groundingMetadata.groundingAttributions).toBeTruthy();
441488
expect(groundingMetadata.webSearchQueries).toBeTruthy();
442489
}
443490
});

0 commit comments

Comments
 (0)