Skip to content

Commit e94b285

Browse files
yyyu-googlecopybara-github
authored andcommitted
feat: allow user to pass "models/model-ID" to instantiate model
PiperOrigin-RevId: 592906210
1 parent de477ce commit e94b285

File tree

2 files changed

+62
-9
lines changed

2 files changed

+62
-9
lines changed

src/index.ts

+9-7
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ export class GenerativeModel {
270270
safety_settings?: SafetySetting[];
271271
private _vertex_instance: VertexAI_Internal;
272272
private _use_non_stream = false;
273+
private publisherModelEndpoint: string;
273274

274275
/**
275276
* @constructor
@@ -288,6 +289,11 @@ export class GenerativeModel {
288289
this.model = model;
289290
this.generation_config = generation_config;
290291
this.safety_settings = safety_settings;
292+
if (model.startsWith("models/")) {
293+
this.publisherModelEndpoint = `publishers/google/${this.model}`;
294+
} else {
295+
this.publisherModelEndpoint = `publishers/google/models/${this.model}`;
296+
}
291297
}
292298

293299
/**
@@ -314,8 +320,6 @@ export class GenerativeModel {
314320
return Promise.resolve(result);
315321
}
316322

317-
const publisherModelEndpoint = `publishers/google/models/${this.model}`;
318-
319323
const generateContentRequest: GenerateContentRequest = {
320324
contents: request.contents,
321325
generation_config: request.generation_config ?? this.generation_config,
@@ -327,7 +331,7 @@ export class GenerativeModel {
327331
response = await postRequest({
328332
region: this._vertex_instance.location,
329333
project: this._vertex_instance.project,
330-
resourcePath: publisherModelEndpoint,
334+
resourcePath: this.publisherModelEndpoint,
331335
resourceMethod: constants.GENERATE_CONTENT_METHOD,
332336
token: await this._vertex_instance.token,
333337
data: generateContentRequest,
@@ -361,8 +365,6 @@ export class GenerativeModel {
361365
validateGenerationConfig(request.generation_config);
362366
}
363367

364-
const publisherModelEndpoint = `publishers/google/models/${this.model}`;
365-
366368
const generateContentRequest: GenerateContentRequest = {
367369
contents: request.contents,
368370
generation_config: request.generation_config ?? this.generation_config,
@@ -374,7 +376,7 @@ export class GenerativeModel {
374376
response = await postRequest({
375377
region: this._vertex_instance.location,
376378
project: this._vertex_instance.project,
377-
resourcePath: publisherModelEndpoint,
379+
resourcePath: this.publisherModelEndpoint,
378380
resourceMethod: constants.STREAMING_GENERATE_CONTENT_METHOD,
379381
token: await this._vertex_instance.token,
380382
data: generateContentRequest,
@@ -405,7 +407,7 @@ export class GenerativeModel {
405407
response = await postRequest({
406408
region: this._vertex_instance.location,
407409
project: this._vertex_instance.project,
408-
resourcePath: `publishers/google/models/${this.model}`,
410+
resourcePath: this.publisherModelEndpoint,
409411
resourceMethod: 'countTokens',
410412
token: await this._vertex_instance.token,
411413
data: request,

system_test/end_to_end_sample_test.ts

+53-2
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,22 @@ const generativeTextModel = vertex_ai.preview.getGenerativeModel({
6262
max_output_tokens: 256,
6363
},
6464
});
65-
65+
const generativeTextModelWithPrefix = vertex_ai.preview.getGenerativeModel({
66+
model: 'models/gemini-pro',
67+
generation_config: {
68+
max_output_tokens: 256,
69+
},
70+
});
6671
const textModelNoOutputLimit = vertex_ai.preview.getGenerativeModel({
6772
model: 'gemini-pro',
6873
});
6974

7075
const generativeVisionModel = vertex_ai.preview.getGenerativeModel({
7176
model: 'gemini-pro-vision',
7277
});
78+
const generativeVisionModelWithPrefix = vertex_ai.preview.getGenerativeModel({
79+
model: 'models/gemini-pro-vision',
80+
});
7381

7482
// TODO (b/316599049): update tests to use jasmine expect syntax:
7583
// expect(...).toBeInstanceOf(...)
@@ -92,7 +100,7 @@ describe('generateContentStream', () => {
92100
const aggregatedResp = await streamingResp.response;
93101
assert(
94102
aggregatedResp.candidates[0],
95-
`sys test failure on generateContentStream for aggregated response: ${aggregatedResp.candidates[0]}`
103+
`sys test failure on generateContentStream for aggregated response: ${aggregatedResp}`
96104
);
97105
});
98106
it('should not return a invalid unicode', async () => {
@@ -213,3 +221,46 @@ describe('countTokens', () => {
213221
);
214222
});
215223
});
224+
225+
describe('generateContentStream using models/model-id', () => {
226+
beforeEach(() => {
227+
jasmine.DEFAULT_TIMEOUT_INTERVAL = 10000;
228+
});
229+
230+
it('should should return a stream and aggregated response when passed text', async () => {
231+
const streamingResp =
232+
await generativeTextModelWithPrefix.generateContentStream(TEXT_REQUEST);
233+
234+
for await (const item of streamingResp.stream) {
235+
assert(
236+
item.candidates[0],
237+
`sys test failure on generateContentStream using models/gemini-pro, for item ${item}`
238+
);
239+
}
240+
241+
const aggregatedResp = await streamingResp.response;
242+
assert(
243+
aggregatedResp.candidates[0],
244+
`sys test failure on generateContentStream using models/gemini-pro for aggregated response: ${aggregatedResp}`
245+
);
246+
});
247+
248+
it('should should return a stream and aggregated response when passed multipart base64 content when using models/gemini-pro-vision', async () => {
249+
const streamingResp = await generativeVisionModelWithPrefix.generateContentStream(
250+
MULTI_PART_BASE64_REQUEST
251+
);
252+
253+
for await (const item of streamingResp.stream) {
254+
assert(
255+
item.candidates[0],
256+
`sys test failure on generateContentStream using models/gemini-pro-vision, for item ${item}`
257+
);
258+
}
259+
260+
const aggregatedResp = await streamingResp.response;
261+
assert(
262+
aggregatedResp.candidates[0],
263+
`sys test failure on generateContentStream using models/gemini-pro-visionfor aggregated response: ${aggregatedResp}`
264+
);
265+
});
266+
});

0 commit comments

Comments
 (0)