Skip to content

Commit c50811e

Browse files
sararobcopybara-github
authored andcommitted
fix: enable passing only a string to generateContent and generateContentStream
PiperOrigin-RevId: 596067266
1 parent ab2fd05 commit c50811e

File tree

2 files changed

+52
-5
lines changed

2 files changed

+52
-5
lines changed

src/index.ts

+29-2
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,14 @@ export class GenerativeModel {
347347
* @return The GenerateContentResponse object with the response candidates.
348348
*/
349349
async generateContent(
350-
request: GenerateContentRequest
350+
request: GenerateContentRequest | string
351351
): Promise<GenerateContentResult> {
352+
request = formatContentRequest(
353+
request,
354+
this.generation_config,
355+
this.safety_settings
356+
);
357+
352358
validateGcsInput(request.contents);
353359

354360
if (request.generation_config) {
@@ -396,8 +402,13 @@ export class GenerativeModel {
396402
* @return {Promise<StreamGenerateContentResult>} Promise of {@link StreamGenerateContentResult}
397403
*/
398404
async generateContentStream(
399-
request: GenerateContentRequest
405+
request: GenerateContentRequest | string
400406
): Promise<StreamGenerateContentResult> {
407+
request = formatContentRequest(
408+
request,
409+
this.generation_config,
410+
this.safety_settings
411+
);
401412
validateGcsInput(request.contents);
402413

403414
if (request.generation_config) {
@@ -531,3 +542,19 @@ function validateGenerationConfig(
531542
}
532543
return generation_config;
533544
}
545+
546+
function formatContentRequest(
547+
request: GenerateContentRequest | string,
548+
generation_config?: GenerationConfig,
549+
safety_settings?: SafetySetting[]
550+
): GenerateContentRequest {
551+
if (typeof request === 'string') {
552+
return {
553+
contents: [{role: constants.USER_ROLE, parts: [{text: request}]}],
554+
generation_config: generation_config,
555+
safety_settings: safety_settings,
556+
};
557+
} else {
558+
return request;
559+
}
560+
}

test/index_test.ts

+23-3
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,17 @@ import {constants} from '../src/util';
4141

4242
const PROJECT = 'test_project';
4343
const LOCATION = 'test_location';
44+
const TEST_CHAT_MESSSAGE_TEXT = 'How are you doing today?';
4445
const TEST_USER_CHAT_MESSAGE = [
45-
{role: constants.USER_ROLE, parts: [{text: 'How are you doing today?'}]},
46+
{role: constants.USER_ROLE, parts: [{text: TEST_CHAT_MESSSAGE_TEXT}]},
4647
];
4748
const TEST_TOKEN = 'testtoken';
4849

4950
const TEST_USER_CHAT_MESSAGE_WITH_GCS_FILE = [
5051
{
5152
role: constants.USER_ROLE,
5253
parts: [
53-
{text: 'How are you doing today?'},
54+
{text: TEST_CHAT_MESSSAGE_TEXT},
5455
{
5556
file_data: {
5657
file_uri: 'gs://test_bucket/test_image.jpeg',
@@ -65,7 +66,7 @@ const TEST_USER_CHAT_MESSAGE_WITH_INVALID_GCS_FILE = [
6566
{
6667
role: constants.USER_ROLE,
6768
parts: [
68-
{text: 'How are you doing today?'},
69+
{text: TEST_CHAT_MESSSAGE_TEXT},
6970
{file_data: {file_uri: 'test_image.jpeg', mime_type: 'image/jpeg'}},
7071
],
7172
},
@@ -234,6 +235,16 @@ describe('VertexAI', () => {
234235
const resp = await model.generateContent(req);
235236
expect(resp).toEqual(expectedResult);
236237
});
238+
it('returns a GenerateContentResponse when passed a string', async () => {
239+
const expectedResult: GenerateContentResult = {
240+
response: TEST_MODEL_RESPONSE,
241+
};
242+
spyOn(StreamFunctions, 'processStream').and.returnValue(
243+
expectedStreamResult
244+
);
245+
const resp = await model.generateContent(TEST_CHAT_MESSSAGE_TEXT);
246+
expect(resp).toEqual(expectedResult);
247+
});
237248
});
238249

239250
describe('generateContent', () => {
@@ -450,6 +461,15 @@ describe('VertexAI', () => {
450461
const resp = await model.generateContentStream(req);
451462
expect(resp).toEqual(expectedResult);
452463
});
464+
it('returns a GenerateContentResponse when passed a string', async () => {
465+
const expectedResult: StreamGenerateContentResult = {
466+
response: Promise.resolve(TEST_MODEL_RESPONSE),
467+
stream: testGenerator(),
468+
};
469+
spyOn(StreamFunctions, 'processStream').and.returnValue(expectedResult);
470+
const resp = await model.generateContentStream(TEST_CHAT_MESSSAGE_TEXT);
471+
expect(resp).toEqual(expectedResult);
472+
});
453473
});
454474

455475
describe('generateContentStream', () => {

0 commit comments

Comments
 (0)