Skip to content

Commit 590ca5a

Browse files
yyyu-googlecopybara-github
authored andcommitted
feat: enable system instruction for GenerativeModel
PiperOrigin-RevId: 622995958
1 parent 44b8884 commit 590ca5a

File tree

6 files changed

+557
-4
lines changed

6 files changed

+557
-4
lines changed

src/functions/generate_content.ts

+2
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ export async function generateContent(
6969

7070
const generateContentRequest: GenerateContentRequest = {
7171
contents: request.contents,
72+
systemInstruction: request.systemInstruction,
7273
generationConfig: request.generationConfig ?? generationConfig,
7374
safetySettings: request.safetySettings ?? safetySettings,
7475
tools: request.tools ?? tools,
@@ -121,6 +122,7 @@ export async function generateContentStream(
121122

122123
const generateContentRequest: GenerateContentRequest = {
123124
contents: request.contents,
125+
systemInstruction: request.systemInstruction,
124126
generationConfig: request.generationConfig ?? generationConfig,
125127
safetySettings: request.safetySettings ?? safetySettings,
126128
tools: request.tools ?? tools,

src/models/generative_models.ts

+64-4
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import {
2424
generateContentStream,
2525
} from '../functions/generate_content';
2626
import {
27+
Content,
2728
CountTokensRequest,
2829
CountTokensResponse,
2930
GenerateContentRequest,
@@ -54,6 +55,7 @@ export class GenerativeModel {
5455
private readonly safetySettings?: SafetySetting[];
5556
private readonly tools?: Tool[];
5657
private readonly requestOptions?: RequestOptions;
58+
private readonly systemInstruction?: Content;
5759
private readonly project: string;
5860
private readonly location: string;
5961
private readonly googleAuth: GoogleAuth;
@@ -74,6 +76,10 @@ export class GenerativeModel {
7476
this.safetySettings = getGenerativeModelParams.safetySettings;
7577
this.tools = getGenerativeModelParams.tools;
7678
this.requestOptions = getGenerativeModelParams.requestOptions ?? {};
79+
if (getGenerativeModelParams.systemInstruction) {
80+
getGenerativeModelParams.systemInstruction.role = constants.SYSTEM_ROLE;
81+
}
82+
this.systemInstruction = getGenerativeModelParams.systemInstruction;
7783
if (this.model.startsWith('models/')) {
7884
this.publisherModelEndpoint = `publishers/google/${this.model}`;
7985
} else {
@@ -114,12 +120,18 @@ export class GenerativeModel {
114120
async generateContent(
115121
request: GenerateContentRequest | string
116122
): Promise<GenerateContentResult> {
123+
request = formulateRequestToGenerateContentRequest(request);
124+
const formulatedRequest =
125+
formulateSystemInstructionIntoGenerateContentRequest(
126+
request,
127+
this.systemInstruction
128+
);
117129
return generateContent(
118130
this.location,
119131
this.project,
120132
this.publisherModelEndpoint,
121133
this.fetchToken(),
122-
request,
134+
formulatedRequest,
123135
this.apiEndpoint,
124136
this.generationConfig,
125137
this.safetySettings,
@@ -155,12 +167,18 @@ export class GenerativeModel {
155167
async generateContentStream(
156168
request: GenerateContentRequest | string
157169
): Promise<StreamGenerateContentResult> {
170+
request = formulateRequestToGenerateContentRequest(request);
171+
const formulatedRequest =
172+
formulateSystemInstructionIntoGenerateContentRequest(
173+
request,
174+
this.systemInstruction
175+
);
158176
return generateContentStream(
159177
this.location,
160178
this.project,
161179
this.publisherModelEndpoint,
162180
this.fetchToken(),
163-
request,
181+
formulatedRequest,
164182
this.apiEndpoint,
165183
this.generationConfig,
166184
this.safetySettings,
@@ -257,6 +275,7 @@ export class GenerativeModelPreview {
257275
private readonly safetySettings?: SafetySetting[];
258276
private readonly tools?: Tool[];
259277
private readonly requestOptions?: RequestOptions;
278+
private readonly systemInstruction?: Content;
260279
private readonly project: string;
261280
private readonly location: string;
262281
private readonly googleAuth: GoogleAuth;
@@ -277,6 +296,10 @@ export class GenerativeModelPreview {
277296
this.safetySettings = getGenerativeModelParams.safetySettings;
278297
this.tools = getGenerativeModelParams.tools;
279298
this.requestOptions = getGenerativeModelParams.requestOptions ?? {};
299+
if (getGenerativeModelParams.systemInstruction) {
300+
getGenerativeModelParams.systemInstruction.role = constants.SYSTEM_ROLE;
301+
}
302+
this.systemInstruction = getGenerativeModelParams.systemInstruction;
280303
if (this.model.startsWith('models/')) {
281304
this.publisherModelEndpoint = `publishers/google/${this.model}`;
282305
} else {
@@ -316,12 +339,18 @@ export class GenerativeModelPreview {
316339
async generateContent(
317340
request: GenerateContentRequest | string
318341
): Promise<GenerateContentResult> {
342+
request = formulateRequestToGenerateContentRequest(request);
343+
const formulatedRequest =
344+
formulateSystemInstructionIntoGenerateContentRequest(
345+
request,
346+
this.systemInstruction
347+
);
319348
return generateContent(
320349
this.location,
321350
this.project,
322351
this.publisherModelEndpoint,
323352
this.fetchToken(),
324-
request,
353+
formulatedRequest,
325354
this.apiEndpoint,
326355
this.generationConfig,
327356
this.safetySettings,
@@ -357,12 +386,18 @@ export class GenerativeModelPreview {
357386
async generateContentStream(
358387
request: GenerateContentRequest | string
359388
): Promise<StreamGenerateContentResult> {
389+
request = formulateRequestToGenerateContentRequest(request);
390+
const formulatedRequest =
391+
formulateSystemInstructionIntoGenerateContentRequest(
392+
request,
393+
this.systemInstruction
394+
);
360395
return generateContentStream(
361396
this.location,
362397
this.project,
363398
this.publisherModelEndpoint,
364399
this.fetchToken(),
365-
request,
400+
formulatedRequest,
366401
this.apiEndpoint,
367402
this.generationConfig,
368403
this.safetySettings,
@@ -445,3 +480,28 @@ export class GenerativeModelPreview {
445480
return new ChatSessionPreview(startChatRequest, this.requestOptions);
446481
}
447482
}
483+
484+
function formulateRequestToGenerateContentRequest(
485+
request: GenerateContentRequest | string
486+
): GenerateContentRequest {
487+
if (typeof request === 'string') {
488+
return {
489+
contents: [{role: constants.USER_ROLE, parts: [{text: request}]}],
490+
} as GenerateContentRequest;
491+
}
492+
return request;
493+
}
494+
495+
function formulateSystemInstructionIntoGenerateContentRequest(
496+
methodRequest: GenerateContentRequest,
497+
classSystemInstruction?: Content
498+
): GenerateContentRequest {
499+
if (methodRequest.systemInstruction) {
500+
methodRequest.systemInstruction.role = constants.SYSTEM_ROLE;
501+
return methodRequest;
502+
}
503+
if (classSystemInstruction) {
504+
methodRequest.systemInstruction = classSystemInstruction;
505+
}
506+
return methodRequest;
507+
}

0 commit comments

Comments
 (0)