Skip to content

Commit 1b37f40

Browse files
yyyu-googlecopybara-github
authored andcommitted
feat: Introduce Request Timeout Configuration
PiperOrigin-RevId: 611274545
1 parent 27befcc commit 1b37f40

9 files changed

+463
-32
lines changed

src/functions/count_tokens.ts

+8-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
* limitations under the License.
1616
*/
1717

18-
import {CountTokensRequest, CountTokensResponse} from '../types/content';
18+
import {
19+
CountTokensRequest,
20+
CountTokensResponse,
21+
RequestOptions,
22+
} from '../types/content';
1923
import {GoogleGenerativeAIError} from '../types/errors';
2024
import {
2125
throwErrorIfNotOK,
@@ -34,7 +38,8 @@ export async function countTokens(
3438
publisherModelEndpoint: string,
3539
token: Promise<any>,
3640
request: CountTokensRequest,
37-
apiEndpoint?: string
41+
apiEndpoint?: string,
42+
requestOptions?: RequestOptions
3843
): Promise<CountTokensResponse> {
3944
const response = await postRequest({
4045
region: location,
@@ -44,6 +49,7 @@ export async function countTokens(
4449
token: await token,
4550
data: request,
4651
apiEndpoint: apiEndpoint,
52+
requestOptions: requestOptions,
4753
}).catch(e => {
4854
throw new GoogleGenerativeAIError('exception posting request', e);
4955
});

src/functions/generate_content.ts

+7-2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import {
2525
GenerateContentRequest,
2626
GenerateContentResult,
2727
GenerationConfig,
28+
RequestOptions,
2829
SafetySetting,
2930
StreamGenerateContentResult,
3031
} from '../types/content';
@@ -51,7 +52,8 @@ export async function generateContent(
5152
request: GenerateContentRequest | string,
5253
apiEndpoint?: string,
5354
generation_config?: GenerationConfig,
54-
safety_settings?: SafetySetting[]
55+
safety_settings?: SafetySetting[],
56+
requestOptions?: RequestOptions
5557
): Promise<GenerateContentResult> {
5658
request = formatContentRequest(request, generation_config, safety_settings);
5759

@@ -78,6 +80,7 @@ export async function generateContent(
7880
token: await token,
7981
data: generateContentRequest,
8082
apiEndpoint: apiEndpoint,
83+
requestOptions: requestOptions,
8184
apiVersion: apiVersion,
8285
}).catch(e => {
8386
throw new GoogleGenerativeAIError('exception posting request', e);
@@ -103,7 +106,8 @@ export async function generateContentStream(
103106
request: GenerateContentRequest | string,
104107
apiEndpoint?: string,
105108
generation_config?: GenerationConfig,
106-
safety_settings?: SafetySetting[]
109+
safety_settings?: SafetySetting[],
110+
requestOptions?: RequestOptions
107111
): Promise<StreamGenerateContentResult> {
108112
request = formatContentRequest(request, generation_config, safety_settings);
109113
validateGenerateContentRequest(request);
@@ -129,6 +133,7 @@ export async function generateContentStream(
129133
token: await token,
130134
data: generateContentRequest,
131135
apiEndpoint: apiEndpoint,
136+
requestOptions: requestOptions,
132137
apiVersion: apiVersion,
133138
}).catch(e => {
134139
throw new GoogleGenerativeAIError('exception posting request', e);

src/functions/post_request.ts

+24-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717

1818
const API_BASE_PATH = 'aiplatform.googleapis.com';
1919

20-
import {GenerateContentRequest, CountTokensRequest} from '../types/content';
20+
import {
21+
GenerateContentRequest,
22+
CountTokensRequest,
23+
RequestOptions,
24+
} from '../types/content';
2125
import * as constants from '../util/constants';
2226

2327
/**
@@ -32,6 +36,7 @@ export async function postRequest({
3236
token,
3337
data,
3438
apiEndpoint,
39+
requestOptions,
3540
apiVersion = 'v1',
3641
}: {
3742
region: string;
@@ -41,6 +46,7 @@ export async function postRequest({
4146
token: string;
4247
data: GenerateContentRequest | CountTokensRequest;
4348
apiEndpoint?: string;
49+
requestOptions?: RequestOptions;
4450
apiVersion?: string;
4551
}): Promise<Response | undefined> {
4652
const vertexBaseEndpoint = apiEndpoint ?? `${region}-${API_BASE_PATH}`;
@@ -53,6 +59,7 @@ export async function postRequest({
5359
}
5460

5561
return fetch(vertexEndpoint, {
62+
...getFetchOptions(requestOptions),
5663
method: 'POST',
5764
headers: {
5865
Authorization: `Bearer ${token}`,
@@ -62,3 +69,19 @@ export async function postRequest({
6269
body: JSON.stringify(data),
6370
});
6471
}
72+
73+
function getFetchOptions(requestOptions?: RequestOptions): RequestInit {
74+
const fetchOptions = {} as RequestInit;
75+
if (
76+
!requestOptions ||
77+
requestOptions.timeoutMillis === undefined ||
78+
requestOptions.timeoutMillis < 0
79+
) {
80+
return fetchOptions;
81+
}
82+
const abortController = new AbortController();
83+
const signal = abortController.signal;
84+
setTimeout(() => abortController.abort, requestOptions.timeoutMillis);
85+
fetchOptions.signal = signal;
86+
return fetchOptions;
87+
}

0 commit comments

Comments
 (0)