Skip to content

Commit de9c4c2

Browse files
yyyu-googlecopybara-github
authored andcommitted
feat: enable inference request to tuned model.
PiperOrigin-RevId: 634533073
1 parent d3c0a64 commit de9c4c2

12 files changed

+758
-156
lines changed

src/functions/count_tokens.ts

+4-5
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import {
2121
RequestOptions,
2222
} from '../types/content';
2323
import {GoogleGenerativeAIError} from '../types/errors';
24+
import * as constants from '../util/constants';
2425
import {
2526
throwErrorIfNotOK,
2627
processCountTokenResponse,
@@ -34,18 +35,16 @@ import {postRequest} from './post_request';
3435
*/
3536
export async function countTokens(
3637
location: string,
37-
project: string,
38-
publisherModelEndpoint: string,
38+
resourcePath: string,
3939
token: Promise<string | null | undefined>,
4040
request: CountTokensRequest,
4141
apiEndpoint?: string,
4242
requestOptions?: RequestOptions
4343
): Promise<CountTokensResponse> {
4444
const response: Response | undefined = await postRequest({
4545
region: location,
46-
project: project,
47-
resourcePath: publisherModelEndpoint,
48-
resourceMethod: 'countTokens',
46+
resourcePath: resourcePath,
47+
resourceMethod: constants.COUNT_TOKENS_METHOD,
4948
token: await token,
5049
data: request,
5150
apiEndpoint: apiEndpoint,

src/functions/generate_content.ts

+4-8
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ import {
4747

4848
export async function generateContent(
4949
location: string,
50-
project: string,
51-
publisherModelEndpoint: string,
50+
resourcePath: string,
5251
token: Promise<string | null | undefined>,
5352
request: GenerateContentRequest | string,
5453
apiEndpoint?: string,
@@ -76,8 +75,7 @@ export async function generateContent(
7675
};
7776
const response: Response | undefined = await postRequest({
7877
region: location,
79-
project: project,
80-
resourcePath: publisherModelEndpoint,
78+
resourcePath: resourcePath,
8179
resourceMethod: constants.GENERATE_CONTENT_METHOD,
8280
token: await token,
8381
data: generateContentRequest,
@@ -101,8 +99,7 @@ export async function generateContent(
10199
*/
102100
export async function generateContentStream(
103101
location: string,
104-
project: string,
105-
publisherModelEndpoint: string,
102+
resourcePath: string,
106103
token: Promise<string | null | undefined>,
107104
request: GenerateContentRequest | string,
108105
apiEndpoint?: string,
@@ -129,8 +126,7 @@ export async function generateContentStream(
129126
};
130127
const response = await postRequest({
131128
region: location,
132-
project: project,
133-
resourcePath: publisherModelEndpoint,
129+
resourcePath: resourcePath,
134130
resourceMethod: constants.STREAMING_GENERATE_CONTENT_METHOD,
135131
token: await token,
136132
data: generateContentRequest,

src/functions/post_fetch_processing.ts

-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import {
2323
GenerateContentResponse,
2424
GenerateContentResult,
2525
GroundingMetadata,
26-
Part,
2726
StreamGenerateContentResult,
2827
} from '../types/content';
2928
import {constants} from '../util';

src/functions/post_request.ts

+1-3
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ import * as constants from '../util/constants';
3838
*/
3939
export async function postRequest({
4040
region,
41-
project,
4241
resourcePath,
4342
resourceMethod,
4443
token,
@@ -48,7 +47,6 @@ export async function postRequest({
4847
apiVersion = 'v1',
4948
}: {
5049
region: string;
51-
project: string;
5250
resourcePath: string;
5351
resourceMethod: string;
5452
token: string | null | undefined;
@@ -59,7 +57,7 @@ export async function postRequest({
5957
}): Promise<Response | undefined> {
6058
const vertexBaseEndpoint = apiEndpoint ?? `${region}-${API_BASE_PATH}`;
6159

62-
let vertexEndpoint = `https://${vertexBaseEndpoint}/${apiVersion}/projects/${project}/locations/${region}/${resourcePath}:${resourceMethod}`;
60+
let vertexEndpoint = `https://${vertexBaseEndpoint}/${apiVersion}/${resourcePath}:${resourceMethod}`;
6361

6462
// Use server sent events for streamGenerateContent
6563
if (resourceMethod === constants.STREAMING_GENERATE_CONTENT_METHOD) {

src/functions/test/functions_test.ts

+25-50
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,8 @@ import {countTokens} from '../count_tokens';
3737
import {generateContent, generateContentStream} from '../generate_content';
3838
import * as StreamFunctions from '../post_fetch_processing';
3939

40-
const TEST_PROJECT = 'test-project';
4140
const TEST_LOCATION = 'test-location';
42-
const TEST_PUBLISHER_MODEL_ENDPOINT = 'test-publisher-model-endpoint';
41+
const TEST_RESOURCE_PATH = 'test-resource-path';
4342
const TEST_TOKEN = 'testtoken';
4443
const TEST_TOKEN_PROMISE = Promise.resolve(TEST_TOKEN);
4544
const TEST_API_ENDPOINT = 'test-api-endpoint';
@@ -249,8 +248,7 @@ describe('countTokens', () => {
249248

250249
const resp = await countTokens(
251250
TEST_LOCATION,
252-
TEST_PROJECT,
253-
TEST_PUBLISHER_MODEL_ENDPOINT,
251+
TEST_RESOURCE_PATH,
254252
TEST_TOKEN_PROMISE,
255253
req,
256254
TEST_API_ENDPOINT
@@ -268,8 +266,7 @@ describe('countTokens', () => {
268266
await expectAsync(
269267
countTokens(
270268
TEST_LOCATION,
271-
TEST_PROJECT,
272-
TEST_PUBLISHER_MODEL_ENDPOINT,
269+
TEST_RESOURCE_PATH,
273270
TEST_TOKEN_PROMISE,
274271
req,
275272
TEST_API_ENDPOINT,
@@ -296,8 +293,7 @@ describe('countTokens', () => {
296293
await expectAsync(
297294
countTokens(
298295
TEST_LOCATION,
299-
TEST_PROJECT,
300-
TEST_PUBLISHER_MODEL_ENDPOINT,
296+
TEST_RESOURCE_PATH,
301297
TEST_TOKEN_PROMISE,
302298
req,
303299
TEST_API_ENDPOINT
@@ -322,8 +318,7 @@ describe('countTokens', () => {
322318
await expectAsync(
323319
countTokens(
324320
TEST_LOCATION,
325-
TEST_PROJECT,
326-
TEST_PUBLISHER_MODEL_ENDPOINT,
321+
TEST_RESOURCE_PATH,
327322
TEST_TOKEN_PROMISE,
328323
req,
329324
TEST_API_ENDPOINT
@@ -351,8 +346,7 @@ describe('generateContent', () => {
351346
await expectAsync(
352347
generateContent(
353348
TEST_LOCATION,
354-
TEST_PROJECT,
355-
TEST_PUBLISHER_MODEL_ENDPOINT,
349+
TEST_RESOURCE_PATH,
356350
TEST_TOKEN_PROMISE,
357351
req,
358352
TEST_API_ENDPOINT,
@@ -374,8 +368,7 @@ describe('generateContent', () => {
374368
fetchSpy.and.resolveTo(buildFetchResponse(TEST_MODEL_RESPONSE));
375369
const resp = await generateContent(
376370
TEST_LOCATION,
377-
TEST_PROJECT,
378-
TEST_PUBLISHER_MODEL_ENDPOINT,
371+
TEST_RESOURCE_PATH,
379372
TEST_TOKEN_PROMISE,
380373
req,
381374
TEST_API_ENDPOINT
@@ -389,8 +382,7 @@ describe('generateContent', () => {
389382
fetchSpy.and.resolveTo(buildFetchResponse(TEST_MODEL_RESPONSE));
390383
const resp = await generateContent(
391384
TEST_LOCATION,
392-
TEST_PROJECT,
393-
TEST_PUBLISHER_MODEL_ENDPOINT,
385+
TEST_RESOURCE_PATH,
394386
TEST_TOKEN_PROMISE,
395387
TEST_CHAT_MESSAGE_TEXT,
396388
TEST_API_ENDPOINT
@@ -408,8 +400,7 @@ describe('generateContent', () => {
408400
fetchSpy.and.resolveTo(buildFetchResponse(TEST_MODEL_RESPONSE));
409401
const resp = await generateContent(
410402
TEST_LOCATION,
411-
TEST_PROJECT,
412-
TEST_PUBLISHER_MODEL_ENDPOINT,
403+
TEST_RESOURCE_PATH,
413404
TEST_TOKEN_PROMISE,
414405
req,
415406
TEST_API_ENDPOINT
@@ -424,8 +415,7 @@ describe('generateContent', () => {
424415
await expectAsync(
425416
generateContent(
426417
TEST_LOCATION,
427-
TEST_PROJECT,
428-
TEST_PUBLISHER_MODEL_ENDPOINT,
418+
TEST_RESOURCE_PATH,
429419
TEST_TOKEN_PROMISE,
430420
req,
431421
TEST_API_ENDPOINT
@@ -445,8 +435,7 @@ describe('generateContent', () => {
445435
fetchSpy.and.resolveTo(buildFetchResponse(TEST_MODEL_RESPONSE));
446436
const resp = await generateContent(
447437
TEST_LOCATION,
448-
TEST_PROJECT,
449-
TEST_PUBLISHER_MODEL_ENDPOINT,
438+
TEST_RESOURCE_PATH,
450439
TEST_TOKEN_PROMISE,
451440
req,
452441
TEST_API_ENDPOINT
@@ -460,8 +449,7 @@ describe('generateContent', () => {
460449
fetchSpy.and.resolveTo(buildFetchResponse(TEST_MODEL_RESPONSE));
461450
await generateContent(
462451
TEST_LOCATION,
463-
TEST_PROJECT,
464-
TEST_PUBLISHER_MODEL_ENDPOINT,
452+
TEST_RESOURCE_PATH,
465453
TEST_TOKEN_PROMISE,
466454
req,
467455
TEST_ENDPOINT_BASE_PATH
@@ -480,8 +468,7 @@ describe('generateContent', () => {
480468
fetchSpy.and.resolveTo(buildFetchResponse(TEST_MODEL_RESPONSE));
481469
await generateContent(
482470
TEST_LOCATION,
483-
TEST_PROJECT,
484-
TEST_PUBLISHER_MODEL_ENDPOINT,
471+
TEST_RESOURCE_PATH,
485472
TEST_TOKEN_PROMISE,
486473
reqWithEmptyConfigs,
487474
TEST_API_ENDPOINT
@@ -501,8 +488,7 @@ describe('generateContent', () => {
501488
fetchSpy.and.resolveTo(buildFetchResponse(TEST_MODEL_RESPONSE));
502489
await generateContent(
503490
TEST_LOCATION,
504-
TEST_PROJECT,
505-
TEST_PUBLISHER_MODEL_ENDPOINT,
491+
TEST_RESOURCE_PATH,
506492
TEST_TOKEN_PROMISE,
507493
reqWithEmptyConfigs,
508494
TEST_API_ENDPOINT
@@ -520,8 +506,7 @@ describe('generateContent', () => {
520506
fetchSpy.and.resolveTo(buildFetchResponse(TEST_MODEL_RESPONSE));
521507
const resp = await generateContent(
522508
TEST_LOCATION,
523-
TEST_PROJECT,
524-
TEST_PUBLISHER_MODEL_ENDPOINT,
509+
TEST_RESOURCE_PATH,
525510
TEST_TOKEN_PROMISE,
526511
req,
527512
TEST_API_ENDPOINT
@@ -552,8 +537,7 @@ describe('generateContent', () => {
552537

553538
const actualResult = await generateContent(
554539
TEST_LOCATION,
555-
TEST_PROJECT,
556-
TEST_PUBLISHER_MODEL_ENDPOINT,
540+
TEST_RESOURCE_PATH,
557541
TEST_TOKEN_PROMISE,
558542
req,
559543
TEST_API_ENDPOINT
@@ -595,8 +579,7 @@ describe('generateContent', () => {
595579

596580
const actualResult = await generateContent(
597581
TEST_LOCATION,
598-
TEST_PROJECT,
599-
TEST_PUBLISHER_MODEL_ENDPOINT,
582+
TEST_RESOURCE_PATH,
600583
TEST_TOKEN_PROMISE,
601584
req,
602585
TEST_API_ENDPOINT
@@ -623,8 +606,7 @@ describe('generateContent', () => {
623606

624607
const actualResult: GenerateContentResult = await generateContent(
625608
TEST_LOCATION,
626-
TEST_PROJECT,
627-
TEST_PUBLISHER_MODEL_ENDPOINT,
609+
TEST_RESOURCE_PATH,
628610
TEST_TOKEN_PROMISE,
629611
req,
630612
TEST_API_ENDPOINT
@@ -661,8 +643,7 @@ describe('generateContentStream', () => {
661643
await expectAsync(
662644
generateContentStream(
663645
TEST_LOCATION,
664-
TEST_PROJECT,
665-
TEST_PUBLISHER_MODEL_ENDPOINT,
646+
TEST_RESOURCE_PATH,
666647
TEST_TOKEN_PROMISE,
667648
req,
668649
TEST_API_ENDPOINT,
@@ -686,8 +667,7 @@ describe('generateContentStream', () => {
686667
spyOn(StreamFunctions, 'processStream').and.resolveTo(expectedResult);
687668
const resp = await generateContentStream(
688669
TEST_LOCATION,
689-
TEST_PROJECT,
690-
TEST_PUBLISHER_MODEL_ENDPOINT,
670+
TEST_RESOURCE_PATH,
691671
TEST_TOKEN_PROMISE,
692672
req,
693673
TEST_API_ENDPOINT
@@ -704,8 +684,7 @@ describe('generateContentStream', () => {
704684
spyOn(StreamFunctions, 'processStream').and.resolveTo(expectedResult);
705685
const resp = await generateContentStream(
706686
TEST_LOCATION,
707-
TEST_PROJECT,
708-
TEST_PUBLISHER_MODEL_ENDPOINT,
687+
TEST_RESOURCE_PATH,
709688
TEST_TOKEN_PROMISE,
710689
TEST_API_ENDPOINT,
711690
TEST_CHAT_MESSAGE_TEXT
@@ -725,8 +704,7 @@ describe('generateContentStream', () => {
725704
spyOn(StreamFunctions, 'processStream').and.resolveTo(expectedResult);
726705
const resp = await generateContentStream(
727706
TEST_LOCATION,
728-
TEST_PROJECT,
729-
TEST_PUBLISHER_MODEL_ENDPOINT,
707+
TEST_RESOURCE_PATH,
730708
TEST_TOKEN_PROMISE,
731709
req,
732710
TEST_API_ENDPOINT
@@ -746,8 +724,7 @@ describe('generateContentStream', () => {
746724
spyOn(StreamFunctions, 'processStream').and.resolveTo(expectedResult);
747725
const resp = await generateContentStream(
748726
TEST_LOCATION,
749-
TEST_PROJECT,
750-
TEST_PUBLISHER_MODEL_ENDPOINT,
727+
TEST_RESOURCE_PATH,
751728
TEST_TOKEN_PROMISE,
752729
req,
753730
TEST_API_ENDPOINT
@@ -769,8 +746,7 @@ describe('generateContentStream', () => {
769746
spyOn(StreamFunctions, 'processStream').and.resolveTo(expectedStreamResult);
770747
const result = await generateContentStream(
771748
TEST_LOCATION,
772-
TEST_PROJECT,
773-
TEST_PUBLISHER_MODEL_ENDPOINT,
749+
TEST_RESOURCE_PATH,
774750
TEST_TOKEN_PROMISE,
775751
req,
776752
TEST_API_ENDPOINT
@@ -805,8 +781,7 @@ describe('generateContentStream', () => {
805781

806782
const actualResult = await generateContentStream(
807783
TEST_LOCATION,
808-
TEST_PROJECT,
809-
TEST_PUBLISHER_MODEL_ENDPOINT,
784+
TEST_RESOURCE_PATH,
810785
TEST_TOKEN_PROMISE,
811786
req,
812787
TEST_API_ENDPOINT

0 commit comments

Comments
 (0)