Skip to content

Commit 6762c99

Browse files
yyyu-googlecopybara-github
authored andcommitted
fix: decouple dependency between VertexAI_Preivew and GenerativeModel classes
PiperOrigin-RevId: 604684922
1 parent f29c273 commit 6762c99

File tree

2 files changed

+97
-81
lines changed

2 files changed

+97
-81
lines changed

src/index.ts

+88-64
Original file line numberDiff line numberDiff line change
@@ -101,43 +101,27 @@ export class VertexAI_Preview {
101101
this.googleAuth = new GoogleAuth(opts);
102102
}
103103

104-
/**
105-
* Get access token from GoogleAuth. Throws GoogleAuthError when fails.
106-
* @return {Promise<any>} Promise of token
107-
*/
108-
get token(): Promise<any> {
109-
const credential_error_message =
110-
'\nUnable to authenticate your request\
111-
\nDepending on your run time environment, you can get authentication by\
112-
\n- if in local instance or cloud shell: `!gcloud auth login`\
113-
\n- if in Colab:\
114-
\n -`from google.colab import auth`\
115-
\n -`auth.authenticate_user()`\
116-
\n- if in service account or other: please follow guidance in https://cloud.google.com/docs/authentication';
117-
const tokenPromise = this.googleAuth.getAccessToken().catch(e => {
118-
throw new GoogleAuthError(credential_error_message, e);
119-
});
120-
return tokenPromise;
121-
}
122-
123104
/**
124105
* @param {ModelParams} modelParams - {@link ModelParams} Parameters to specify the generative model.
125106
* @return {GenerativeModel} Instance of the GenerativeModel class. {@link GenerativeModel}
126107
*/
127108
getGenerativeModel(modelParams: ModelParams): GenerativeModel {
109+
const getGenerativeModelParams: GetGenerativeModelParams = {
110+
model: modelParams.model,
111+
project: this.project,
112+
location: this.location,
113+
googleAuth: this.googleAuth,
114+
apiEndpoint: this.apiEndpoint,
115+
safety_settings: modelParams.safety_settings,
116+
tools: modelParams.tools,
117+
};
128118
if (modelParams.generation_config) {
129-
modelParams.generation_config = validateGenerationConfig(
119+
getGenerativeModelParams.generation_config = validateGenerationConfig(
130120
modelParams.generation_config
131121
);
132122
}
133123

134-
return new GenerativeModel(
135-
this,
136-
modelParams.model,
137-
modelParams.generation_config,
138-
modelParams.safety_settings,
139-
modelParams.tools
140-
);
124+
return new GenerativeModel(getGenerativeModelParams);
141125
}
142126

143127
validateGoogleAuthOptions(
@@ -200,10 +184,36 @@ export declare interface StartChatParams {
200184
* @property {GenerativeModel} - _model_instance {@link GenerativeModel}
201185
*/
202186
export declare interface StartChatSessionRequest extends StartChatParams {
203-
_vertex_instance: VertexAI_Preview;
187+
project: string;
188+
location: string;
204189
_model_instance: GenerativeModel;
205190
}
206191

192+
/**
193+
* @property {string} model - model name
194+
* @property {string} project - project The Google Cloud project to use for the request
195+
* @property {string} location - The Google Cloud project location to use for the request
196+
* @property {GoogleAuth} googleAuth - GoogleAuth class instance that handles authentication.
197+
* Details about GoogleAuth is referred to https://github.com/googleapis/google-auth-library-nodejs/blob/main/src/auth/googleauth.ts
198+
* @property {string} - [apiEndpoint] The base Vertex AI endpoint to use for the request. If
199+
* not provided, the default regionalized endpoint
200+
* (i.e. us-central1-aiplatform.googleapis.com) will be used.
201+
* @property {GenerationConfig} [generation_config] - {@link
202+
* GenerationConfig}
203+
* @property {SafetySetting[]} [safety_settings] - {@link SafetySetting}
204+
* @property {Tool[]} [tools] - {@link Tool}
205+
*/
206+
export declare interface GetGenerativeModelParams extends ModelParams {
207+
model: string;
208+
project: string;
209+
location: string;
210+
googleAuth: GoogleAuth;
211+
apiEndpoint?: string;
212+
generation_config?: GenerationConfig;
213+
safety_settings?: SafetySetting[];
214+
tools?: Tool[];
215+
}
216+
207217
/**
208218
* Chat session to make multi-turn send message request.
209219
* `sendMessage` method makes async call to get response of a chat message.
@@ -214,7 +224,6 @@ export class ChatSession {
214224
private location: string;
215225

216226
private historyInternal: Content[];
217-
private _vertex_instance: VertexAI_Preview;
218227
private _model_instance: GenerativeModel;
219228
private _send_stream_promise: Promise<void> = Promise.resolve();
220229
generation_config?: GenerationConfig;
@@ -230,11 +239,10 @@ export class ChatSession {
230239
* @param {StartChatSessionRequest} request - {@link StartChatSessionRequest}
231240
*/
232241
constructor(request: StartChatSessionRequest) {
233-
this.project = request._vertex_instance.project;
234-
this.location = request._vertex_instance.location;
242+
this.project = request.project;
243+
this.location = request.location;
235244
this._model_instance = request._model_instance;
236245
this.historyInternal = request.history ?? [];
237-
this._vertex_instance = request._vertex_instance;
238246
this.generation_config = request.generation_config;
239247
this.safety_settings = request.safety_settings;
240248
this.tools = request.tools;
@@ -347,36 +355,51 @@ export class GenerativeModel {
347355
generation_config?: GenerationConfig;
348356
safety_settings?: SafetySetting[];
349357
tools?: Tool[];
350-
private _vertex_instance: VertexAI_Preview;
358+
private project: string;
359+
private location: string;
360+
private googleAuth: GoogleAuth;
351361
private publisherModelEndpoint: string;
362+
private apiEndpoint?: string;
352363

353364
/**
354365
* @constructor
355-
* @param {VertexAI_Preview} vertex_instance - {@link VertexAI_Preview}
356-
* @param {string} model - model name
357-
* @param {GenerationConfig} generation_config - Optional. {@link
358-
* GenerationConfig}
359-
* @param {SafetySetting[]} safety_settings - Optional. {@link SafetySetting}
366+
* @param {GetGenerativeModelParams} getGenerativeModelParams - {@link GetGenerativeModelParams}
360367
*/
361-
constructor(
362-
vertex_instance: VertexAI_Preview,
363-
model: string,
364-
generation_config?: GenerationConfig,
365-
safety_settings?: SafetySetting[],
366-
tools?: Tool[]
367-
) {
368-
this._vertex_instance = vertex_instance;
369-
this.model = model;
370-
this.generation_config = generation_config;
371-
this.safety_settings = safety_settings;
372-
this.tools = tools;
373-
if (model.startsWith('models/')) {
368+
constructor(getGenerativeModelParams: GetGenerativeModelParams) {
369+
this.project = getGenerativeModelParams.project;
370+
this.location = getGenerativeModelParams.location;
371+
this.apiEndpoint = getGenerativeModelParams.apiEndpoint;
372+
this.googleAuth = getGenerativeModelParams.googleAuth;
373+
this.model = getGenerativeModelParams.model;
374+
this.generation_config = getGenerativeModelParams.generation_config;
375+
this.safety_settings = getGenerativeModelParams.safety_settings;
376+
this.tools = getGenerativeModelParams.tools;
377+
if (this.model.startsWith('models/')) {
374378
this.publisherModelEndpoint = `publishers/google/${this.model}`;
375379
} else {
376380
this.publisherModelEndpoint = `publishers/google/models/${this.model}`;
377381
}
378382
}
379383

384+
/**
385+
* Get access token from GoogleAuth. Throws GoogleAuthError when fails.
386+
* @return {Promise<any>} Promise of token
387+
*/
388+
get token(): Promise<any> {
389+
const credential_error_message =
390+
'\nUnable to authenticate your request\
391+
\nDepending on your run time environment, you can get authentication by\
392+
\n- if in local instance or cloud shell: `!gcloud auth login`\
393+
\n- if in Colab:\
394+
\n -`from google.colab import auth`\
395+
\n -`auth.authenticate_user()`\
396+
\n- if in service account or other: please follow guidance in https://cloud.google.com/docs/authentication';
397+
const tokenPromise = this.googleAuth.getAccessToken().catch(e => {
398+
throw new GoogleAuthError(credential_error_message, e);
399+
});
400+
return tokenPromise;
401+
}
402+
380403
/**
381404
* Make a async call to generate content.
382405
* @param request A GenerateContentRequest object with the request contents.
@@ -407,13 +430,13 @@ export class GenerativeModel {
407430
};
408431

409432
const response: Response | undefined = await postRequest({
410-
region: this._vertex_instance.location,
411-
project: this._vertex_instance.project,
433+
region: this.location,
434+
project: this.project,
412435
resourcePath: this.publisherModelEndpoint,
413436
resourceMethod: constants.GENERATE_CONTENT_METHOD,
414-
token: await this._vertex_instance.token,
437+
token: await this.token,
415438
data: generateContentRequest,
416-
apiEndpoint: this._vertex_instance.apiEndpoint,
439+
apiEndpoint: this.apiEndpoint,
417440
}).catch(e => {
418441
throw new GoogleGenerativeAIError('exception posting request', e);
419442
});
@@ -450,13 +473,13 @@ export class GenerativeModel {
450473
tools: request.tools ?? [],
451474
};
452475
const response = await postRequest({
453-
region: this._vertex_instance.location,
454-
project: this._vertex_instance.project,
476+
region: this.location,
477+
project: this.project,
455478
resourcePath: this.publisherModelEndpoint,
456479
resourceMethod: constants.STREAMING_GENERATE_CONTENT_METHOD,
457-
token: await this._vertex_instance.token,
480+
token: await this.token,
458481
data: generateContentRequest,
459-
apiEndpoint: this._vertex_instance.apiEndpoint,
482+
apiEndpoint: this.apiEndpoint,
460483
}).catch(e => {
461484
throw new GoogleGenerativeAIError('exception posting request', e);
462485
});
@@ -472,13 +495,13 @@ export class GenerativeModel {
472495
*/
473496
async countTokens(request: CountTokensRequest): Promise<CountTokensResponse> {
474497
const response = await postRequest({
475-
region: this._vertex_instance.location,
476-
project: this._vertex_instance.project,
498+
region: this.location,
499+
project: this.project,
477500
resourcePath: this.publisherModelEndpoint,
478501
resourceMethod: 'countTokens',
479-
token: await this._vertex_instance.token,
502+
token: await this.token,
480503
data: request,
481-
apiEndpoint: this._vertex_instance.apiEndpoint,
504+
apiEndpoint: this.apiEndpoint,
482505
}).catch(e => {
483506
throw new GoogleGenerativeAIError('exception posting request', e);
484507
});
@@ -495,7 +518,8 @@ export class GenerativeModel {
495518
*/
496519
startChat(request?: StartChatParams): ChatSession {
497520
const startChatRequest: StartChatSessionRequest = {
498-
_vertex_instance: this._vertex_instance,
521+
project: this.project,
522+
location: this.location,
499523
_model_instance: this,
500524
};
501525

test/index_test.ts

+9-17
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,8 @@ describe('VertexAI', () => {
255255
project: PROJECT,
256256
location: LOCATION,
257257
});
258-
spyOnProperty(vertexai.preview, 'token', 'get').and.resolveTo(TEST_TOKEN);
259258
model = vertexai.preview.getGenerativeModel({model: 'gemini-pro'});
259+
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN);
260260
expectedStreamResult = {
261261
response: Promise.resolve(TEST_MODEL_RESPONSE),
262262
stream: testGenerator(),
@@ -406,13 +406,10 @@ describe('VertexAI', () => {
406406
location: LOCATION,
407407
apiEndpoint: TEST_ENDPOINT_BASE_PATH,
408408
});
409-
spyOnProperty(vertexaiWithBasePath.preview, 'token', 'get').and.resolveTo(
410-
TEST_TOKEN
411-
);
412409
model = vertexaiWithBasePath.preview.getGenerativeModel({
413410
model: 'gemini-pro',
414411
});
415-
412+
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN);
416413
const req: GenerateContentRequest = {
417414
contents: TEST_USER_CHAT_MESSAGE,
418415
};
@@ -433,15 +430,10 @@ describe('VertexAI', () => {
433430
project: PROJECT,
434431
location: LOCATION,
435432
});
436-
spyOnProperty(
437-
vertexaiWithoutBasePath.preview,
438-
'token',
439-
'get'
440-
).and.resolveTo(TEST_TOKEN);
441433
model = vertexaiWithoutBasePath.preview.getGenerativeModel({
442434
model: 'gemini-pro',
443435
});
444-
436+
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN);
445437
const req: GenerateContentRequest = {
446438
contents: TEST_USER_CHAT_MESSAGE,
447439
};
@@ -691,8 +683,8 @@ describe('countTokens', () => {
691683
project: PROJECT,
692684
location: LOCATION,
693685
});
694-
spyOnProperty(vertexai.preview, 'token', 'get').and.resolveTo(TEST_TOKEN);
695686
const model = vertexai.preview.getGenerativeModel({model: 'gemini-pro'});
687+
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN);
696688
const req: CountTokensRequest = {
697689
contents: TEST_USER_CHAT_MESSAGE,
698690
};
@@ -720,7 +712,6 @@ describe('ChatSession', () => {
720712

721713
beforeEach(() => {
722714
vertexai = new VertexAI({project: PROJECT, location: LOCATION});
723-
spyOnProperty(vertexai.preview, 'token', 'get').and.resolveTo(TEST_TOKEN);
724715
model = vertexai.preview.getGenerativeModel({model: 'gemini-pro'});
725716
chatSession = model.startChat({
726717
history: TEST_USER_CHAT_MESSAGE,
@@ -739,6 +730,7 @@ describe('ChatSession', () => {
739730
new Response(JSON.stringify(expectedStreamResult), fetchResponseObj)
740731
);
741732
spyOn(global, 'fetch').and.returnValue(fetchResult);
733+
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN);
742734
});
743735

744736
describe('sendMessage', () => {
@@ -970,7 +962,7 @@ describe('when exception at fetch', () => {
970962
contents: TEST_USER_CHAT_MESSAGE,
971963
};
972964
beforeEach(() => {
973-
spyOnProperty(vertexai.preview, 'token', 'get').and.resolveTo(TEST_TOKEN);
965+
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN);
974966
spyOn(global, 'fetch').and.throwError('error');
975967
});
976968

@@ -1008,8 +1000,8 @@ describe('when response is undefined', () => {
10081000
contents: TEST_USER_CHAT_MESSAGE,
10091001
};
10101002
beforeEach(() => {
1011-
spyOnProperty(vertexai.preview, 'token', 'get').and.resolveTo(TEST_TOKEN);
10121003
spyOn(global, 'fetch').and.resolveTo();
1004+
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN);
10131005
});
10141006

10151007
it('generateContent should throw GoogleGenerativeAI error', async () => {
@@ -1065,7 +1057,7 @@ describe('when response is 4XX', () => {
10651057
contents: TEST_USER_CHAT_MESSAGE,
10661058
};
10671059
beforeEach(() => {
1068-
spyOnProperty(vertexai.preview, 'token', 'get').and.resolveTo(TEST_TOKEN);
1060+
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN);
10691061
spyOn(global, 'fetch').and.resolveTo(response);
10701062
});
10711063

@@ -1122,7 +1114,7 @@ describe('when response is not OK and not 4XX', () => {
11221114
contents: TEST_USER_CHAT_MESSAGE,
11231115
};
11241116
beforeEach(() => {
1125-
spyOnProperty(vertexai.preview, 'token', 'get').and.resolveTo(TEST_TOKEN);
1117+
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN);
11261118
spyOn(global, 'fetch').and.resolveTo(response);
11271119
});
11281120

0 commit comments

Comments
 (0)