@@ -24,6 +24,7 @@ import {
24
24
generateContentStream ,
25
25
} from '../functions/generate_content' ;
26
26
import {
27
+ Content ,
27
28
CountTokensRequest ,
28
29
CountTokensResponse ,
29
30
GenerateContentRequest ,
@@ -54,6 +55,7 @@ export class GenerativeModel {
54
55
private readonly safetySettings ?: SafetySetting [ ] ;
55
56
private readonly tools ?: Tool [ ] ;
56
57
private readonly requestOptions ?: RequestOptions ;
58
+ private readonly systemInstruction ?: Content ;
57
59
private readonly project : string ;
58
60
private readonly location : string ;
59
61
private readonly googleAuth : GoogleAuth ;
@@ -74,6 +76,10 @@ export class GenerativeModel {
74
76
this . safetySettings = getGenerativeModelParams . safetySettings ;
75
77
this . tools = getGenerativeModelParams . tools ;
76
78
this . requestOptions = getGenerativeModelParams . requestOptions ?? { } ;
79
+ if ( getGenerativeModelParams . systemInstruction ) {
80
+ getGenerativeModelParams . systemInstruction . role = constants . SYSTEM_ROLE ;
81
+ }
82
+ this . systemInstruction = getGenerativeModelParams . systemInstruction ;
77
83
if ( this . model . startsWith ( 'models/' ) ) {
78
84
this . publisherModelEndpoint = `publishers/google/${ this . model } ` ;
79
85
} else {
@@ -114,12 +120,18 @@ export class GenerativeModel {
114
120
async generateContent (
115
121
request : GenerateContentRequest | string
116
122
) : Promise < GenerateContentResult > {
123
+ request = formulateRequestToGenerateContentRequest ( request ) ;
124
+ const formulatedRequest =
125
+ formulateSystemInstructionIntoGenerateContentRequest (
126
+ request ,
127
+ this . systemInstruction
128
+ ) ;
117
129
return generateContent (
118
130
this . location ,
119
131
this . project ,
120
132
this . publisherModelEndpoint ,
121
133
this . fetchToken ( ) ,
122
- request ,
134
+ formulatedRequest ,
123
135
this . apiEndpoint ,
124
136
this . generationConfig ,
125
137
this . safetySettings ,
@@ -155,12 +167,18 @@ export class GenerativeModel {
155
167
async generateContentStream (
156
168
request : GenerateContentRequest | string
157
169
) : Promise < StreamGenerateContentResult > {
170
+ request = formulateRequestToGenerateContentRequest ( request ) ;
171
+ const formulatedRequest =
172
+ formulateSystemInstructionIntoGenerateContentRequest (
173
+ request ,
174
+ this . systemInstruction
175
+ ) ;
158
176
return generateContentStream (
159
177
this . location ,
160
178
this . project ,
161
179
this . publisherModelEndpoint ,
162
180
this . fetchToken ( ) ,
163
- request ,
181
+ formulatedRequest ,
164
182
this . apiEndpoint ,
165
183
this . generationConfig ,
166
184
this . safetySettings ,
@@ -257,6 +275,7 @@ export class GenerativeModelPreview {
257
275
private readonly safetySettings ?: SafetySetting [ ] ;
258
276
private readonly tools ?: Tool [ ] ;
259
277
private readonly requestOptions ?: RequestOptions ;
278
+ private readonly systemInstruction ?: Content ;
260
279
private readonly project : string ;
261
280
private readonly location : string ;
262
281
private readonly googleAuth : GoogleAuth ;
@@ -277,6 +296,10 @@ export class GenerativeModelPreview {
277
296
this . safetySettings = getGenerativeModelParams . safetySettings ;
278
297
this . tools = getGenerativeModelParams . tools ;
279
298
this . requestOptions = getGenerativeModelParams . requestOptions ?? { } ;
299
+ if ( getGenerativeModelParams . systemInstruction ) {
300
+ getGenerativeModelParams . systemInstruction . role = constants . SYSTEM_ROLE ;
301
+ }
302
+ this . systemInstruction = getGenerativeModelParams . systemInstruction ;
280
303
if ( this . model . startsWith ( 'models/' ) ) {
281
304
this . publisherModelEndpoint = `publishers/google/${ this . model } ` ;
282
305
} else {
@@ -316,12 +339,18 @@ export class GenerativeModelPreview {
316
339
async generateContent (
317
340
request : GenerateContentRequest | string
318
341
) : Promise < GenerateContentResult > {
342
+ request = formulateRequestToGenerateContentRequest ( request ) ;
343
+ const formulatedRequest =
344
+ formulateSystemInstructionIntoGenerateContentRequest (
345
+ request ,
346
+ this . systemInstruction
347
+ ) ;
319
348
return generateContent (
320
349
this . location ,
321
350
this . project ,
322
351
this . publisherModelEndpoint ,
323
352
this . fetchToken ( ) ,
324
- request ,
353
+ formulatedRequest ,
325
354
this . apiEndpoint ,
326
355
this . generationConfig ,
327
356
this . safetySettings ,
@@ -357,12 +386,18 @@ export class GenerativeModelPreview {
357
386
async generateContentStream (
358
387
request : GenerateContentRequest | string
359
388
) : Promise < StreamGenerateContentResult > {
389
+ request = formulateRequestToGenerateContentRequest ( request ) ;
390
+ const formulatedRequest =
391
+ formulateSystemInstructionIntoGenerateContentRequest (
392
+ request ,
393
+ this . systemInstruction
394
+ ) ;
360
395
return generateContentStream (
361
396
this . location ,
362
397
this . project ,
363
398
this . publisherModelEndpoint ,
364
399
this . fetchToken ( ) ,
365
- request ,
400
+ formulatedRequest ,
366
401
this . apiEndpoint ,
367
402
this . generationConfig ,
368
403
this . safetySettings ,
@@ -445,3 +480,28 @@ export class GenerativeModelPreview {
445
480
return new ChatSessionPreview ( startChatRequest , this . requestOptions ) ;
446
481
}
447
482
}
483
+
484
+ function formulateRequestToGenerateContentRequest (
485
+ request : GenerateContentRequest | string
486
+ ) : GenerateContentRequest {
487
+ if ( typeof request === 'string' ) {
488
+ return {
489
+ contents : [ { role : constants . USER_ROLE , parts : [ { text : request } ] } ] ,
490
+ } as GenerateContentRequest ;
491
+ }
492
+ return request ;
493
+ }
494
+
495
+ function formulateSystemInstructionIntoGenerateContentRequest (
496
+ methodRequest : GenerateContentRequest ,
497
+ classSystemInstruction ?: Content
498
+ ) : GenerateContentRequest {
499
+ if ( methodRequest . systemInstruction ) {
500
+ methodRequest . systemInstruction . role = constants . SYSTEM_ROLE ;
501
+ return methodRequest ;
502
+ }
503
+ if ( classSystemInstruction ) {
504
+ methodRequest . systemInstruction = classSystemInstruction ;
505
+ }
506
+ return methodRequest ;
507
+ }
0 commit comments