Skip to content

Commit 1deb4e9

Browse files
sararobcopybara-github
authored andcommitted
feat: add function calling support
PiperOrigin-RevId: 599199680
1 parent 558aee9 commit 1deb4e9

File tree

7 files changed

+792
-190
lines changed

7 files changed

+792
-190
lines changed

README.md

+92
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,98 @@ async function countTokens() {
167167
countTokens();
168168
```
169169

170+
## Function calling
171+
172+
The Node SDK supports
173+
[function calling](https://cloud.google.com/vertex-ai/docs/generative-ai/multimodal/function-calling) via `sendMessage`, `sendMessageStream`, `generateContent`, and `generateContentStream`. We recommend using it through chat methods
174+
(`sendMessage` or `sendMessageStream`) but have included examples of both
175+
approaches below.
176+
177+
### Function declarations and response
178+
179+
This is an example of a function declaration and function response, which are
180+
passed to the model in the snippets that follow.
181+
182+
```typescript
183+
const functionDeclarations = [
184+
{
185+
function_declarations: [
186+
{
187+
name: "get_current_weather",
188+
description: 'get weather in a given location',
189+
parameters: {
190+
type: FunctionDeclarationSchemaType.OBJECT,
191+
properties: {
192+
location: {type: FunctionDeclarationSchemaType.STRING},
193+
unit: {
194+
type: FunctionDeclarationSchemaType.STRING,
195+
enum: ['celsius', 'fahrenheit'],
196+
},
197+
},
198+
required: ['location'],
199+
},
200+
},
201+
],
202+
},
203+
];
204+
205+
const functionResponseParts = [
206+
{
207+
functionResponse: {
208+
name: "get_current_weather",
209+
response:
210+
{name: "get_current_weather", content: {weather: "super nice"}},
211+
},
212+
},
213+
];
214+
```
215+
216+
### Function calling with chat
217+
218+
```typescript
219+
// Create a chat session and pass your function declarations
220+
const chat = generativeModel.startChat({
221+
tools: functionDeclarations,
222+
});
223+
224+
const chatInput1 = 'What is the weather in Boston?';
225+
226+
// This should include a functionCall response from the model
227+
const result1 = await chat.sendMessageStream(chatInput1);
228+
for await (const item of result1.stream) {
229+
console.log(item.candidates[0]);
230+
}
231+
const response1 = await result1.response;
232+
233+
// Send a follow up message with a FunctionResponse
234+
const result2 = await chat.sendMessageStream(functionResponseParts);
235+
for await (const item of result2.stream) {
236+
console.log(item.candidates[0]);
237+
}
238+
239+
// This should include a text response from the model using the response content
240+
// provided above
241+
const response2 = await result2.response;
242+
```
243+
244+
### Function calling with generateContentStream
245+
246+
```typescript
247+
const request = {
248+
contents: [
249+
{role: 'user', parts: [{text: 'What is the weather in Boston?'}]},
250+
{role: 'model', parts: [{functionCall: {name: 'get_current_weather', args: {'location': 'Boston'}}}]},
251+
{role: 'function', parts: functionResponseParts}
252+
],
253+
tools: functionDeclarations,
254+
};
255+
const streamingResp =
256+
await generativeModel.generateContentStream(request);
257+
for await (const item of streamingResp.stream) {
258+
console.log(item.candidates[0]);
259+
}
260+
```
261+
170262
## License
171263

172264
The contents of this repository are licensed under the

src/index.ts

+59-12
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import {
3434
Part,
3535
SafetySetting,
3636
StreamGenerateContentResult,
37+
Tool,
3738
VertexInit,
3839
} from './types/content';
3940
import {
@@ -134,7 +135,8 @@ export class VertexAI_Preview {
134135
this,
135136
modelParams.model,
136137
modelParams.generation_config,
137-
modelParams.safety_settings
138+
modelParams.safety_settings,
139+
modelParams.tools
138140
);
139141
}
140142

@@ -185,6 +187,7 @@ export declare interface StartChatParams {
185187
history?: Content[];
186188
safety_settings?: SafetySetting[];
187189
generation_config?: GenerationConfig;
190+
tools?: Tool[];
188191
}
189192

190193
// StartChatSessionRequest and ChatSession are defined here instead of in
@@ -216,6 +219,7 @@ export class ChatSession {
216219
private _send_stream_promise: Promise<void> = Promise.resolve();
217220
generation_config?: GenerationConfig;
218221
safety_settings?: SafetySetting[];
222+
tools?: Tool[];
219223

220224
get history(): Content[] {
221225
return this.historyInternal;
@@ -231,6 +235,9 @@ export class ChatSession {
231235
this._model_instance = request._model_instance;
232236
this.historyInternal = request.history ?? [];
233237
this._vertex_instance = request._vertex_instance;
238+
this.generation_config = request.generation_config;
239+
this.safety_settings = request.safety_settings;
240+
this.tools = request.tools;
234241
}
235242

236243
/**
@@ -241,11 +248,12 @@ export class ChatSession {
241248
async sendMessage(
242249
request: string | Array<string | Part>
243250
): Promise<GenerateContentResult> {
244-
const newContent: Content = formulateNewContent(request);
251+
const newContent: Content[] = formulateNewContent(request);
245252
const generateContentrequest: GenerateContentRequest = {
246-
contents: this.historyInternal.concat([newContent]),
253+
contents: this.historyInternal.concat(newContent),
247254
safety_settings: this.safety_settings,
248255
generation_config: this.generation_config,
256+
tools: this.tools,
249257
};
250258

251259
const generateContentResult: GenerateContentResult =
@@ -257,7 +265,7 @@ export class ChatSession {
257265
const generateContentResponse = generateContentResult.response;
258266
// Only push the latest message to history if the response returned a result
259267
if (generateContentResponse.candidates.length !== 0) {
260-
this.historyInternal.push(newContent);
268+
this.historyInternal = this.historyInternal.concat(newContent);
261269
const contentFromAssistant =
262270
generateContentResponse.candidates[0].content;
263271
if (!contentFromAssistant.role) {
@@ -274,15 +282,15 @@ export class ChatSession {
274282

275283
async appendHistory(
276284
streamGenerateContentResultPromise: Promise<StreamGenerateContentResult>,
277-
newContent: Content
285+
newContent: Content[]
278286
): Promise<void> {
279287
const streamGenerateContentResult =
280288
await streamGenerateContentResultPromise;
281289
const streamGenerateContentResponse =
282290
await streamGenerateContentResult.response;
283291
// Only push the latest message to history if the response returned a result
284292
if (streamGenerateContentResponse.candidates.length !== 0) {
285-
this.historyInternal.push(newContent);
293+
this.historyInternal = this.historyInternal.concat(newContent);
286294
const contentFromAssistant =
287295
streamGenerateContentResponse.candidates[0].content;
288296
if (!contentFromAssistant.role) {
@@ -303,11 +311,12 @@ export class ChatSession {
303311
async sendMessageStream(
304312
request: string | Array<string | Part>
305313
): Promise<StreamGenerateContentResult> {
306-
const newContent: Content = formulateNewContent(request);
314+
const newContent: Content[] = formulateNewContent(request);
307315
const generateContentrequest: GenerateContentRequest = {
308-
contents: this.historyInternal.concat([newContent]),
316+
contents: this.historyInternal.concat(newContent),
309317
safety_settings: this.safety_settings,
310318
generation_config: this.generation_config,
319+
tools: this.tools,
311320
};
312321

313322
const streamGenerateContentResultPromise = this._model_instance
@@ -335,6 +344,7 @@ export class GenerativeModel {
335344
model: string;
336345
generation_config?: GenerationConfig;
337346
safety_settings?: SafetySetting[];
347+
tools?: Tool[];
338348
private _vertex_instance: VertexAI_Preview;
339349
private _use_non_stream = false;
340350
private publisherModelEndpoint: string;
@@ -351,12 +361,14 @@ export class GenerativeModel {
351361
vertex_instance: VertexAI_Preview,
352362
model: string,
353363
generation_config?: GenerationConfig,
354-
safety_settings?: SafetySetting[]
364+
safety_settings?: SafetySetting[],
365+
tools?: Tool[]
355366
) {
356367
this._vertex_instance = vertex_instance;
357368
this.model = model;
358369
this.generation_config = generation_config;
359370
this.safety_settings = safety_settings;
371+
this.tools = tools;
360372
if (model.startsWith('models/')) {
361373
this.publisherModelEndpoint = `publishers/google/${this.model}`;
362374
} else {
@@ -401,6 +413,7 @@ export class GenerativeModel {
401413
contents: request.contents,
402414
generation_config: request.generation_config ?? this.generation_config,
403415
safety_settings: request.safety_settings ?? this.safety_settings,
416+
tools: request.tools ?? [],
404417
};
405418

406419
const response: Response | undefined = await postRequest({
@@ -444,6 +457,7 @@ export class GenerativeModel {
444457
contents: request.contents,
445458
generation_config: request.generation_config ?? this.generation_config,
446459
safety_settings: request.safety_settings ?? this.safety_settings,
460+
tools: request.tools ?? [],
447461
};
448462
const response = await postRequest({
449463
region: this._vertex_instance.location,
@@ -501,12 +515,15 @@ export class GenerativeModel {
501515
request.generation_config ?? this.generation_config;
502516
startChatRequest.safety_settings =
503517
request.safety_settings ?? this.safety_settings;
518+
startChatRequest.tools = request.tools ?? this.tools;
504519
}
505520
return new ChatSession(startChatRequest);
506521
}
507522
}
508523

509-
function formulateNewContent(request: string | Array<string | Part>): Content {
524+
function formulateNewContent(
525+
request: string | Array<string | Part>
526+
): Content[] {
510527
let newParts: Part[] = [];
511528

512529
if (typeof request === 'string') {
@@ -521,8 +538,38 @@ function formulateNewContent(request: string | Array<string | Part>): Content {
521538
}
522539
}
523540

524-
const newContent: Content = {role: constants.USER_ROLE, parts: newParts};
525-
return newContent;
541+
return formatPartsByRole(newParts);
542+
}
543+
544+
/**
545+
* When multiple Part types (i.e. FunctionResponsePart and TextPart) are
546+
* passed in a single Part array, we may need to assign different roles to each
547+
* part. Currently only FunctionResponsePart requires a role other than 'user'.
548+
* @ignore
549+
* @param {Array<Part>} parts Array of parts to pass to the model
550+
* @return {Content[]} Array of content items
551+
*/
552+
function formatPartsByRole(parts: Array<Part>): Content[] {
553+
const partsByRole: Content[] = [];
554+
const userContent: Content = {role: constants.USER_ROLE, parts: []};
555+
const functionContent: Content = {role: constants.FUNCTION_ROLE, parts: []};
556+
557+
for (const part of parts) {
558+
if ('functionResponse' in part) {
559+
functionContent.parts.push(part);
560+
} else {
561+
userContent.parts.push(part);
562+
}
563+
}
564+
565+
if (userContent.parts.length > 0) {
566+
partsByRole.push(userContent);
567+
}
568+
if (functionContent.parts.length > 0) {
569+
partsByRole.push(functionContent);
570+
}
571+
572+
return partsByRole;
526573
}
527574

528575
function throwErrorIfNotOK(response: Response | undefined) {

src/process_stream.ts

+7
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,13 @@ function aggregateResponses(
193193
if (part.text) {
194194
aggregatedResponse.candidates[i].content.parts[0].text += part.text;
195195
}
196+
if (part.functionCall) {
197+
aggregatedResponse.candidates[i].content.parts[0].functionCall =
198+
part.functionCall;
199+
// the empty 'text' key should be removed if functionCall is in the
200+
// response
201+
delete aggregatedResponse.candidates[i].content.parts[0].text;
202+
}
196203
}
197204
}
198205
}

0 commit comments

Comments
 (0)