Skip to content

Commit ea0dcb7

Browse files
yyyu-googlecopybara-github
authored andcommitted
fix: throw ClientError or GoogleGenerativeAIError according to response status so that users can catch them and handle them according to class name.
PiperOrigin-RevId: 595149601
1 parent 2a75efa commit ea0dcb7

File tree

7 files changed

+371
-89
lines changed

7 files changed

+371
-89
lines changed

package.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@
3939
"@types/node": "^20.9.0",
4040
"gts": "^5.2.0",
4141
"jasmine": "^5.1.0",
42-
"typescript": "~5.2.0",
4342
"jsdoc": "^4.0.0",
4443
"jsdoc-fresh": "^3.0.0",
4544
"jsdoc-region-tag": "^3.0.0",
46-
"linkinator": "^4.0.0"
45+
"linkinator": "^4.0.0",
46+
"typescript": "~5.2.0"
4747
}
4848
}

src/index.ts

+81-79
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
/* tslint:disable */
1919
import {GoogleAuth} from 'google-auth-library';
2020

21-
import {processNonStream, processStream} from './process_stream';
21+
import {
22+
processCountTokenResponse,
23+
processNonStream,
24+
processStream,
25+
} from './process_stream';
2226
import {
2327
Content,
2428
CountTokensRequest,
@@ -32,7 +36,11 @@ import {
3236
StreamGenerateContentResult,
3337
VertexInit,
3438
} from './types/content';
35-
import {GoogleAuthError} from './types/errors';
39+
import {
40+
ClientError,
41+
GoogleAuthError,
42+
GoogleGenerativeAIError,
43+
} from './types/errors';
3644
import {constants, postRequest} from './util';
3745
export * from './types';
3846

@@ -101,7 +109,7 @@ export class VertexAI_Preview {
101109
\n -`auth.authenticate_user()`\
102110
\n- if in service account or other: please follow guidance in https://cloud.google.com/docs/authentication';
103111
const tokenPromise = this.googleAuth.getAccessToken().catch(e => {
104-
throw new GoogleAuthError(`${credential_error_message}\n${e}`);
112+
throw new GoogleAuthError(credential_error_message, e);
105113
});
106114
return tokenPromise;
107115
}
@@ -194,10 +202,13 @@ export class ChatSession {
194202
generation_config: this.generation_config,
195203
};
196204

197-
const generateContentResult = await this._model_instance.generateContent(
198-
generateContentrequest
199-
);
200-
const generateContentResponse = await generateContentResult.response;
205+
const generateContentResult: GenerateContentResult =
206+
await this._model_instance
207+
.generateContent(generateContentrequest)
208+
.catch(e => {
209+
throw e;
210+
});
211+
const generateContentResponse = generateContentResult.response;
201212
// Only push the latest message to history if the response returned a result
202213
if (generateContentResponse.candidates.length !== 0) {
203214
this.historyInternal.push(newContent);
@@ -253,13 +264,18 @@ export class ChatSession {
253264
generation_config: this.generation_config,
254265
};
255266

256-
const streamGenerateContentResultPromise =
257-
this._model_instance.generateContentStream(generateContentrequest);
267+
const streamGenerateContentResultPromise = this._model_instance
268+
.generateContentStream(generateContentrequest)
269+
.catch(e => {
270+
throw e;
271+
});
258272

259273
this._send_stream_promise = this.appendHistory(
260274
streamGenerateContentResultPromise,
261275
newContent
262-
);
276+
).catch(e => {
277+
throw new GoogleGenerativeAIError('exception appending chat history', e);
278+
});
263279
return streamGenerateContentResultPromise;
264280
}
265281
}
@@ -320,7 +336,9 @@ export class GenerativeModel {
320336

321337
if (!this._use_non_stream) {
322338
const streamGenerateContentResult: StreamGenerateContentResult =
323-
await this.generateContentStream(request);
339+
await this.generateContentStream(request).catch(e => {
340+
throw e;
341+
});
324342
const result: GenerateContentResult = {
325343
response: await streamGenerateContentResult.response,
326344
};
@@ -333,27 +351,18 @@ export class GenerativeModel {
333351
safety_settings: request.safety_settings ?? this.safety_settings,
334352
};
335353

336-
let response;
337-
try {
338-
response = await postRequest({
339-
region: this._vertex_instance.location,
340-
project: this._vertex_instance.project,
341-
resourcePath: this.publisherModelEndpoint,
342-
resourceMethod: constants.GENERATE_CONTENT_METHOD,
343-
token: await this._vertex_instance.token,
344-
data: generateContentRequest,
345-
apiEndpoint: this._vertex_instance.apiEndpoint,
346-
});
347-
if (response === undefined) {
348-
throw new Error('did not get a valid response.');
349-
}
350-
if (!response.ok) {
351-
throw new Error(`${response.status} ${response.statusText}`);
352-
}
353-
} catch (e) {
354-
console.log(e);
355-
}
356-
354+
const response: Response | undefined = await postRequest({
355+
region: this._vertex_instance.location,
356+
project: this._vertex_instance.project,
357+
resourcePath: this.publisherModelEndpoint,
358+
resourceMethod: constants.GENERATE_CONTENT_METHOD,
359+
token: await this._vertex_instance.token,
360+
data: generateContentRequest,
361+
apiEndpoint: this._vertex_instance.apiEndpoint,
362+
}).catch(e => {
363+
throw new GoogleGenerativeAIError('exception posting request', e);
364+
});
365+
throwErrorIfNotOK(response);
357366
const result: GenerateContentResult = processNonStream(response);
358367
return Promise.resolve(result);
359368
}
@@ -379,27 +388,18 @@ export class GenerativeModel {
379388
generation_config: request.generation_config ?? this.generation_config,
380389
safety_settings: request.safety_settings ?? this.safety_settings,
381390
};
382-
let response;
383-
try {
384-
response = await postRequest({
385-
region: this._vertex_instance.location,
386-
project: this._vertex_instance.project,
387-
resourcePath: this.publisherModelEndpoint,
388-
resourceMethod: constants.STREAMING_GENERATE_CONTENT_METHOD,
389-
token: await this._vertex_instance.token,
390-
data: generateContentRequest,
391-
apiEndpoint: this._vertex_instance.apiEndpoint,
392-
});
393-
if (response === undefined) {
394-
throw new Error('did not get a valid response.');
395-
}
396-
if (!response.ok) {
397-
throw new Error(`${response.status} ${response.statusText}`);
398-
}
399-
} catch (e) {
400-
console.log(e);
401-
}
402-
391+
const response = await postRequest({
392+
region: this._vertex_instance.location,
393+
project: this._vertex_instance.project,
394+
resourcePath: this.publisherModelEndpoint,
395+
resourceMethod: constants.STREAMING_GENERATE_CONTENT_METHOD,
396+
token: await this._vertex_instance.token,
397+
data: generateContentRequest,
398+
apiEndpoint: this._vertex_instance.apiEndpoint,
399+
}).catch(e => {
400+
throw new GoogleGenerativeAIError('exception posting request', e);
401+
});
402+
throwErrorIfNotOK(response);
403403
const streamResult = processStream(response);
404404
return Promise.resolve(streamResult);
405405
}
@@ -410,32 +410,19 @@ export class GenerativeModel {
410410
* @return The CountTokensResponse object with the token count.
411411
*/
412412
async countTokens(request: CountTokensRequest): Promise<CountTokensResponse> {
413-
let response;
414-
try {
415-
response = await postRequest({
416-
region: this._vertex_instance.location,
417-
project: this._vertex_instance.project,
418-
resourcePath: this.publisherModelEndpoint,
419-
resourceMethod: 'countTokens',
420-
token: await this._vertex_instance.token,
421-
data: request,
422-
apiEndpoint: this._vertex_instance.apiEndpoint,
423-
});
424-
if (response === undefined) {
425-
throw new Error('did not get a valid response.');
426-
}
427-
if (!response.ok) {
428-
throw new Error(`${response.status} ${response.statusText}`);
429-
}
430-
} catch (e) {
431-
console.log(e);
432-
}
433-
if (response) {
434-
const responseJson = await response.json();
435-
return responseJson as CountTokensResponse;
436-
} else {
437-
throw new Error('did not get a valid response.');
438-
}
413+
const response = await postRequest({
414+
region: this._vertex_instance.location,
415+
project: this._vertex_instance.project,
416+
resourcePath: this.publisherModelEndpoint,
417+
resourceMethod: 'countTokens',
418+
token: await this._vertex_instance.token,
419+
data: request,
420+
apiEndpoint: this._vertex_instance.apiEndpoint,
421+
}).catch(e => {
422+
throw new GoogleGenerativeAIError('exception posting request', e);
423+
});
424+
throwErrorIfNotOK(response);
425+
return processCountTokenResponse(response);
439426
}
440427

441428
/**
@@ -481,6 +468,21 @@ function formulateNewContent(request: string | Array<string | Part>): Content {
481468
return newContent;
482469
}
483470

471+
function throwErrorIfNotOK(response: Response | undefined) {
472+
if (response === undefined) {
473+
throw new GoogleGenerativeAIError('response is undefined');
474+
}
475+
const status: number = response.status;
476+
const statusText: string = response.statusText;
477+
const errorMessage = `got status: ${status} ${statusText}`;
478+
if (status >= 400 && status < 500) {
479+
throw new ClientError(errorMessage);
480+
}
481+
if (!response.ok) {
482+
throw new GoogleGenerativeAIError(errorMessage);
483+
}
484+
}
485+
484486
function validateGcsInput(contents: Content[]) {
485487
for (const content of contents) {
486488
for (const part of content.parts) {

src/process_stream.ts

+11
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import {
1919
CitationSource,
20+
CountTokensResponse,
2021
GenerateContentCandidate,
2122
GenerateContentResponse,
2223
GenerateContentResult,
@@ -218,3 +219,13 @@ export function processNonStream(response: any): GenerateContentResult {
218219
response: {candidates: []},
219220
};
220221
}
222+
223+
/**
224+
* Process model responses from countTokens
225+
* @ignore
226+
*/
227+
export function processCountTokenResponse(response: any): CountTokensResponse {
228+
// ts-ignore
229+
const responseJson = response.json();
230+
return responseJson as CountTokensResponse;
231+
}

src/types/errors.ts

+43-2
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,52 @@
1515
* limitations under the License.
1616
*/
1717

18+
/**
19+
* GoogleAuthError is thrown when there is authentication issue with the request
20+
*/
1821
class GoogleAuthError extends Error {
19-
constructor(message: string) {
22+
public readonly stack_trace: any = undefined;
23+
constructor(message: string, stack_trace: any = undefined) {
2024
super(message);
25+
this.message = constructErrorMessage('GoogleAuthError', message);
2126
this.name = 'GoogleAuthError';
27+
this.stack_trace = stack_trace;
28+
}
29+
}
30+
31+
/**
32+
* ClientError is thrown when http 4XX status is received.
33+
* For details please refer to https://developer.mozilla.org/en-US/docs/Web/HTTP/Status#client_error_responses
34+
*/
35+
class ClientError extends Error {
36+
public readonly stack_trace: any = undefined;
37+
constructor(message: string, stack_trace: any = undefined) {
38+
super(message);
39+
this.message = constructErrorMessage('ClientError', message);
40+
this.name = 'ClientError';
41+
this.stack_trace = stack_trace;
2242
}
2343
}
2444

25-
export {GoogleAuthError};
45+
/**
46+
* GoogleGenerativeAIError is thrown when http response is not ok and status code is not 4XX
47+
* For details please refer to https://developer.mozilla.org/en-US/docs/Web/HTTP/Status
48+
*/
49+
class GoogleGenerativeAIError extends Error {
50+
public readonly stack_trace: any = undefined;
51+
constructor(message: string, stack_trace: any = undefined) {
52+
super(message);
53+
this.message = constructErrorMessage('GoogleGenerativeAIError', message);
54+
this.name = 'GoogleGenerativeAIError';
55+
this.stack_trace = stack_trace;
56+
}
57+
}
58+
59+
function constructErrorMessage(
60+
exceptionClass: string,
61+
message: string
62+
): string {
63+
return `[VertexAI.${exceptionClass}]: ${message}`;
64+
}
65+
66+
export {ClientError, GoogleAuthError, GoogleGenerativeAIError};

src/util/post_request.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ export async function postRequest({
5252
vertexEndpoint += '?alt=sse';
5353
}
5454

55-
return await fetch(vertexEndpoint, {
55+
return fetch(vertexEndpoint, {
5656
method: 'POST',
5757
headers: {
5858
Authorization: `Bearer ${token}`,

system_test/end_to_end_sample_test.ts

+25-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
// @ts-ignore
1919
import * as assert from 'assert';
2020

21-
import {VertexAI, TextPart} from '../src';
21+
import {ClientError, VertexAI, TextPart} from '../src';
2222

2323
// TODO: this env var isn't getting populated correctly
2424
const PROJECT = process.env.GCLOUD_PROJECT;
@@ -129,7 +129,7 @@ describe('generateContentStream', () => {
129129
`sys test failure on generateContentStream for aggregated response: ${aggregatedResp}`
130130
);
131131
});
132-
it('should should return a stream and aggregated response when passed multipart base64 content', async () => {
132+
it('should return a stream and aggregated response when passed multipart base64 content', async () => {
133133
const streamingResp = await generativeVisionModel.generateContentStream(
134134
MULTI_PART_BASE64_REQUEST
135135
);
@@ -147,6 +147,29 @@ describe('generateContentStream', () => {
147147
`sys test failure on generateContentStream for aggregated response: ${aggregatedResp}`
148148
);
149149
});
150+
it('should throw ClientError when having invalid input', async () => {
151+
const badRequest = {
152+
contents: [
153+
{
154+
role: 'user',
155+
parts: [
156+
{text: 'describe this image:'},
157+
{inline_data: {mime_type: 'image/png', data: 'invalid data'}},
158+
],
159+
},
160+
],
161+
};
162+
await generativeVisionModel.generateContentStream(badRequest).catch(e => {
163+
assert(
164+
e instanceof ClientError,
165+
`sys test failure on generateContentStream when having bad request should throw ClientError but actually thrown ${e}`
166+
);
167+
assert(
168+
e.message === '[VertexAI.ClientError]: got status: 400 Bad Request',
169+
`sys test failure on generateContentStream when having bad request got wrong error message: ${e.message}`
170+
);
171+
});
172+
});
150173
// TODO: this is returning a 500 on the system test project
151174
// it('should should return a stream and aggregated response when passed
152175
// multipart GCS content',

0 commit comments

Comments
 (0)