Skip to content

Commit 7e71f75

Browse files
yyyu-googlecopybara-github
authored andcommitted
feat: enable system instruction in chat experience
PiperOrigin-RevId: 622997107
1 parent 590ca5a commit 7e71f75

File tree

4 files changed

+254
-0
lines changed

4 files changed

+254
-0
lines changed

src/models/chat_session.ts

+14
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ export class ChatSession {
6161
private readonly safetySettings?: SafetySetting[];
6262
private readonly tools?: Tool[];
6363
private readonly apiEndpoint?: string;
64+
private readonly systemInstruction?: Content;
6465

6566
async getHistory(): Promise<Content[]> {
6667
return Promise.resolve(this.historyInternal);
@@ -84,6 +85,10 @@ export class ChatSession {
8485
this.tools = request.tools;
8586
this.apiEndpoint = request.apiEndpoint;
8687
this.requestOptions = requestOptions ?? {};
88+
if (request.systemInstruction) {
89+
request.systemInstruction.role = constants.SYSTEM_ROLE;
90+
}
91+
this.systemInstruction = request.systemInstruction;
8792
}
8893

8994
/**
@@ -127,6 +132,7 @@ export class ChatSession {
127132
safetySettings: this.safetySettings,
128133
generationConfig: this.generationConfig,
129134
tools: this.tools,
135+
systemInstruction: this.systemInstruction,
130136
};
131137

132138
const generateContentResult: GenerateContentResult = await generateContent(
@@ -210,6 +216,7 @@ export class ChatSession {
210216
safetySettings: this.safetySettings,
211217
generationConfig: this.generationConfig,
212218
tools: this.tools,
219+
systemInstruction: this.systemInstruction,
213220
};
214221

215222
const streamGenerateContentResultPromise = generateContentStream(
@@ -257,6 +264,7 @@ export class ChatSessionPreview {
257264
private readonly safetySettings?: SafetySetting[];
258265
private readonly tools?: Tool[];
259266
private readonly apiEndpoint?: string;
267+
private readonly systemInstruction?: Content;
260268

261269
async getHistory(): Promise<Content[]> {
262270
return Promise.resolve(this.historyInternal);
@@ -280,6 +288,10 @@ export class ChatSessionPreview {
280288
this.tools = request.tools;
281289
this.apiEndpoint = request.apiEndpoint;
282290
this.requestOptions = requestOptions ?? {};
291+
if (request.systemInstruction) {
292+
request.systemInstruction.role = constants.SYSTEM_ROLE;
293+
}
294+
this.systemInstruction = request.systemInstruction;
283295
}
284296

285297
/**
@@ -322,6 +334,7 @@ export class ChatSessionPreview {
322334
safetySettings: this.safetySettings,
323335
generationConfig: this.generationConfig,
324336
tools: this.tools,
337+
systemInstruction: this.systemInstruction,
325338
};
326339

327340
const generateContentResult: GenerateContentResult = await generateContent(
@@ -406,6 +419,7 @@ export class ChatSessionPreview {
406419
safetySettings: this.safetySettings,
407420
generationConfig: this.generationConfig,
408421
tools: this.tools,
422+
systemInstruction: this.systemInstruction,
409423
};
410424

411425
const streamGenerateContentResultPromise = generateContentStream(

src/models/generative_models.ts

+4
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,8 @@ export class GenerativeModel {
258258
request.safetySettings ?? this.safetySettings;
259259
startChatRequest.tools = request.tools ?? this.tools;
260260
startChatRequest.apiEndpoint = request.apiEndpoint ?? this.apiEndpoint;
261+
startChatRequest.systemInstruction =
262+
request.systemInstruction ?? this.systemInstruction;
261263
}
262264
return new ChatSession(startChatRequest, this.requestOptions);
263265
}
@@ -476,6 +478,8 @@ export class GenerativeModelPreview {
476478
startChatRequest.safetySettings =
477479
request.safetySettings ?? this.safetySettings;
478480
startChatRequest.tools = request.tools ?? this.tools;
481+
startChatRequest.systemInstruction =
482+
request.systemInstruction ?? this.systemInstruction;
479483
}
480484
return new ChatSessionPreview(startChatRequest, this.requestOptions);
481485
}

src/models/test/models_test.ts

+228
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,10 @@ const TEST_SYSTEM_INSTRUCTION = {
223223
role: constants.SYSTEM_ROLE,
224224
parts: [{text: 'system instruction'}],
225225
};
226+
const TEST_SYSTEM_INSTRUCTION_1 = {
227+
role: constants.SYSTEM_ROLE,
228+
parts: [{text: 'system instruction1'}],
229+
};
226230
const TEST_SYSTEM_INSTRUCTION_WRONG_ROLE = {
227231
role: 'WRONG_ROLE',
228232
parts: [{text: 'system instruction'}],
@@ -339,6 +343,118 @@ describe('GenerativeModel startChat', () => {
339343
const actualBody = fetchSpy.calls.allArgs()[0][1].body;
340344
expect(actualBody).toEqual(expectedBody);
341345
});
346+
it('pass system instruction to remote endpoint from GenerativeModel constructor', async () => {
347+
const expectedResult = TEST_MODEL_RESPONSE;
348+
const fetchResult = Promise.resolve(
349+
new Response(JSON.stringify(expectedResult), fetchResponseObj)
350+
);
351+
const fetchSpy = spyOn(global, 'fetch').and.returnValue(fetchResult);
352+
const req = 'How are you doing today?';
353+
const model = new GenerativeModel({
354+
model: 'gemini-pro',
355+
project: PROJECT,
356+
location: LOCATION,
357+
googleAuth: FAKE_GOOGLE_AUTH,
358+
systemInstruction: TEST_SYSTEM_INSTRUCTION,
359+
});
360+
const chat = model.startChat({
361+
history: TEST_USER_CHAT_MESSAGE,
362+
});
363+
const expectedBody =
364+
'{"contents":[{"role":"user","parts":[{"text":"How are you doing today?"}]},{"role":"user","parts":[{"text":"How are you doing today?"}]}],"systemInstruction":{"role":"system","parts":[{"text":"system instruction"}]}}';
365+
await chat.sendMessage(req);
366+
// @ts-ignore
367+
const actualBody = fetchSpy.calls.allArgs()[0][1].body;
368+
expect(actualBody).toEqual(
369+
expectedBody,
370+
`unit test failed in chat.sendMessage with ${actualBody} not equal to ${expectedBody}`
371+
);
372+
});
373+
it('pass system instruction to remote endpoint from startChat', async () => {
374+
const expectedResult = TEST_MODEL_RESPONSE;
375+
const fetchResult = Promise.resolve(
376+
new Response(JSON.stringify(expectedResult), fetchResponseObj)
377+
);
378+
const fetchSpy = spyOn(global, 'fetch').and.returnValue(fetchResult);
379+
const req = 'How are you doing today?';
380+
const model = new GenerativeModel({
381+
model: 'gemini-pro',
382+
project: PROJECT,
383+
location: LOCATION,
384+
googleAuth: FAKE_GOOGLE_AUTH,
385+
systemInstruction: TEST_SYSTEM_INSTRUCTION_1,
386+
});
387+
const chat = model.startChat({
388+
history: TEST_USER_CHAT_MESSAGE,
389+
// this is different from constructor
390+
systemInstruction: TEST_SYSTEM_INSTRUCTION,
391+
});
392+
const expectedBody =
393+
'{"contents":[{"role":"user","parts":[{"text":"How are you doing today?"}]},{"role":"user","parts":[{"text":"How are you doing today?"}]}],"systemInstruction":{"role":"system","parts":[{"text":"system instruction"}]}}';
394+
await chat.sendMessage(req);
395+
// @ts-ignore
396+
const actualBody = fetchSpy.calls.allArgs()[0][1].body;
397+
expect(actualBody).toEqual(
398+
expectedBody,
399+
`unit test failed in chat.sendMessage with ${actualBody} not equal to ${expectedBody}`
400+
);
401+
});
402+
it('pass system instruction with wrong role to remote endpoint from GenerativeModel constructor', async () => {
403+
const expectedResult = TEST_MODEL_RESPONSE;
404+
const fetchResult = Promise.resolve(
405+
new Response(JSON.stringify(expectedResult), fetchResponseObj)
406+
);
407+
const fetchSpy = spyOn(global, 'fetch').and.returnValue(fetchResult);
408+
const req = 'How are you doing today?';
409+
const model = new GenerativeModel({
410+
model: 'gemini-pro',
411+
project: PROJECT,
412+
location: LOCATION,
413+
googleAuth: FAKE_GOOGLE_AUTH,
414+
systemInstruction: TEST_SYSTEM_INSTRUCTION_WRONG_ROLE,
415+
});
416+
const chat = model.startChat({
417+
history: TEST_USER_CHAT_MESSAGE,
418+
});
419+
const expectedBody =
420+
'{"contents":[{"role":"user","parts":[{"text":"How are you doing today?"}]},{"role":"user","parts":[{"text":"How are you doing today?"}]}],"systemInstruction":{"role":"system","parts":[{"text":"system instruction"}]}}';
421+
await chat.sendMessage(req);
422+
// @ts-ignore
423+
const actualBody = fetchSpy.calls.allArgs()[0][1].body;
424+
expect(actualBody).toEqual(
425+
expectedBody,
426+
`unit test failed in chat.sendMessage with ${actualBody} not equal to ${expectedBody}`
427+
);
428+
});
429+
it('pass system instruction with wrong role to remote endpoint from startChat', async () => {
430+
const expectedResult = TEST_MODEL_RESPONSE;
431+
const fetchResult = Promise.resolve(
432+
new Response(JSON.stringify(expectedResult), fetchResponseObj)
433+
);
434+
const fetchSpy = spyOn(global, 'fetch').and.returnValue(fetchResult);
435+
const req = 'How are you doing today?';
436+
const model = new GenerativeModel({
437+
model: 'gemini-pro',
438+
project: PROJECT,
439+
location: LOCATION,
440+
googleAuth: FAKE_GOOGLE_AUTH,
441+
systemInstruction: TEST_SYSTEM_INSTRUCTION_1,
442+
});
443+
const chat = model.startChat({
444+
history: TEST_USER_CHAT_MESSAGE,
445+
// this is different from constructor
446+
systemInstruction: TEST_SYSTEM_INSTRUCTION_WRONG_ROLE,
447+
});
448+
const expectedBody =
449+
'{"contents":[{"role":"user","parts":[{"text":"How are you doing today?"}]},{"role":"user","parts":[{"text":"How are you doing today?"}]}],"systemInstruction":{"role":"system","parts":[{"text":"system instruction"}]}}';
450+
await chat.sendMessage(req);
451+
// @ts-ignore
452+
const actualBody = fetchSpy.calls.allArgs()[0][1].body;
453+
expect(actualBody).toEqual(
454+
expectedBody,
455+
`unit test failed in chat.sendMessage with ${actualBody} not equal to ${expectedBody}`
456+
);
457+
});
342458
});
343459

344460
describe('GenerativeModelPreview startChat', () => {
@@ -429,6 +545,118 @@ describe('GenerativeModelPreview startChat', () => {
429545
const actualBody = fetchSpy.calls.allArgs()[0][1].body;
430546
expect(actualBody).toEqual(expectedBody);
431547
});
548+
it('pass system instruction to remote endpoint from GenerativeModelPreview constructor', async () => {
549+
const expectedResult = TEST_MODEL_RESPONSE;
550+
const fetchResult = Promise.resolve(
551+
new Response(JSON.stringify(expectedResult), fetchResponseObj)
552+
);
553+
const fetchSpy = spyOn(global, 'fetch').and.returnValue(fetchResult);
554+
const req = 'How are you doing today?';
555+
const model = new GenerativeModelPreview({
556+
model: 'gemini-pro',
557+
project: PROJECT,
558+
location: LOCATION,
559+
googleAuth: FAKE_GOOGLE_AUTH,
560+
systemInstruction: TEST_SYSTEM_INSTRUCTION,
561+
});
562+
const chat = model.startChat({
563+
history: TEST_USER_CHAT_MESSAGE,
564+
});
565+
const expectedBody =
566+
'{"contents":[{"role":"user","parts":[{"text":"How are you doing today?"}]},{"role":"user","parts":[{"text":"How are you doing today?"}]}],"systemInstruction":{"role":"system","parts":[{"text":"system instruction"}]}}';
567+
await chat.sendMessage(req);
568+
// @ts-ignore
569+
const actualBody = fetchSpy.calls.allArgs()[0][1].body;
570+
expect(actualBody).toEqual(
571+
expectedBody,
572+
`unit test failed in chat.sendMessage with ${actualBody} not equal to ${expectedBody}`
573+
);
574+
});
575+
it('pass system instruction to remote endpoint from startChat', async () => {
576+
const expectedResult = TEST_MODEL_RESPONSE;
577+
const fetchResult = Promise.resolve(
578+
new Response(JSON.stringify(expectedResult), fetchResponseObj)
579+
);
580+
const fetchSpy = spyOn(global, 'fetch').and.returnValue(fetchResult);
581+
const req = 'How are you doing today?';
582+
const model = new GenerativeModelPreview({
583+
model: 'gemini-pro',
584+
project: PROJECT,
585+
location: LOCATION,
586+
googleAuth: FAKE_GOOGLE_AUTH,
587+
systemInstruction: TEST_SYSTEM_INSTRUCTION_1,
588+
});
589+
const chat = model.startChat({
590+
history: TEST_USER_CHAT_MESSAGE,
591+
// this is different from constructor
592+
systemInstruction: TEST_SYSTEM_INSTRUCTION,
593+
});
594+
const expectedBody =
595+
'{"contents":[{"role":"user","parts":[{"text":"How are you doing today?"}]},{"role":"user","parts":[{"text":"How are you doing today?"}]}],"systemInstruction":{"role":"system","parts":[{"text":"system instruction"}]}}';
596+
await chat.sendMessage(req);
597+
// @ts-ignore
598+
const actualBody = fetchSpy.calls.allArgs()[0][1].body;
599+
expect(actualBody).toEqual(
600+
expectedBody,
601+
`unit test failed in chat.sendMessage with ${actualBody} not equal to ${expectedBody}`
602+
);
603+
});
604+
it('pass system instruction with wrong role to remote endpoint from GenerativeModelPreview constructor', async () => {
605+
const expectedResult = TEST_MODEL_RESPONSE;
606+
const fetchResult = Promise.resolve(
607+
new Response(JSON.stringify(expectedResult), fetchResponseObj)
608+
);
609+
const fetchSpy = spyOn(global, 'fetch').and.returnValue(fetchResult);
610+
const req = 'How are you doing today?';
611+
const model = new GenerativeModelPreview({
612+
model: 'gemini-pro',
613+
project: PROJECT,
614+
location: LOCATION,
615+
googleAuth: FAKE_GOOGLE_AUTH,
616+
systemInstruction: TEST_SYSTEM_INSTRUCTION_WRONG_ROLE,
617+
});
618+
const chat = model.startChat({
619+
history: TEST_USER_CHAT_MESSAGE,
620+
});
621+
const expectedBody =
622+
'{"contents":[{"role":"user","parts":[{"text":"How are you doing today?"}]},{"role":"user","parts":[{"text":"How are you doing today?"}]}],"systemInstruction":{"role":"system","parts":[{"text":"system instruction"}]}}';
623+
await chat.sendMessage(req);
624+
// @ts-ignore
625+
const actualBody = fetchSpy.calls.allArgs()[0][1].body;
626+
expect(actualBody).toEqual(
627+
expectedBody,
628+
`unit test failed in chat.sendMessage with ${actualBody} not equal to ${expectedBody}`
629+
);
630+
});
631+
it('pass system instruction with wrong role to remote endpoint from startChat', async () => {
632+
const expectedResult = TEST_MODEL_RESPONSE;
633+
const fetchResult = Promise.resolve(
634+
new Response(JSON.stringify(expectedResult), fetchResponseObj)
635+
);
636+
const fetchSpy = spyOn(global, 'fetch').and.returnValue(fetchResult);
637+
const req = 'How are you doing today?';
638+
const model = new GenerativeModelPreview({
639+
model: 'gemini-pro',
640+
project: PROJECT,
641+
location: LOCATION,
642+
googleAuth: FAKE_GOOGLE_AUTH,
643+
systemInstruction: TEST_SYSTEM_INSTRUCTION_1,
644+
});
645+
const chat = model.startChat({
646+
history: TEST_USER_CHAT_MESSAGE,
647+
// this is different from constructor
648+
systemInstruction: TEST_SYSTEM_INSTRUCTION_WRONG_ROLE,
649+
});
650+
const expectedBody =
651+
'{"contents":[{"role":"user","parts":[{"text":"How are you doing today?"}]},{"role":"user","parts":[{"text":"How are you doing today?"}]}],"systemInstruction":{"role":"system","parts":[{"text":"system instruction"}]}}';
652+
await chat.sendMessage(req);
653+
// @ts-ignore
654+
const actualBody = fetchSpy.calls.allArgs()[0][1].body;
655+
expect(actualBody).toEqual(
656+
expectedBody,
657+
`unit test failed in chat.sendMessage with ${actualBody} not equal to ${expectedBody}`
658+
);
659+
});
432660
});
433661

434662
describe('GenerativeModel generateContent', () => {

src/types/content.ts

+8
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,10 @@ export declare interface StartChatParams {
880880
tools?: Tool[];
881881
/** Optional. The base Vertex AI endpoint to use for the request. */
882882
apiEndpoint?: string;
883+
/** Optional. The user provided system instructions for the model.
884+
* Note: only text should be used in parts of {@link Content}
885+
*/
886+
systemInstruction?: Content;
883887
}
884888

885889
/**
@@ -894,6 +898,10 @@ export declare interface StartChatSessionRequest extends StartChatParams {
894898
googleAuth: GoogleAuth;
895899
/** The publisher model endpoint to use for the request. */
896900
publisherModelEndpoint: string;
901+
/** Optional. The user provided system instructions for the model.
902+
* Note: only text should be used in parts of {@link Content}
903+
*/
904+
systemInstruction?: Content;
897905
}
898906

899907
/**

0 commit comments

Comments
 (0)