Skip to content

Commit e0265a3

Browse files
committed
feat: add countTokens method
PiperOrigin-RevId: 586023998
1 parent 038e3c2 commit e0265a3

File tree

4 files changed

+75
-6
lines changed

4 files changed

+75
-6
lines changed

src/index.ts

+37-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import {GoogleAuth} from 'google-auth-library';
2020

2121
import {emptyGenerator, processStream} from './process_stream';
22-
import {Content, GenerateContentParams, GenerateContentRequest, GenerateContentResult, GenerationConfig, ModelParams, Part, SafetySetting} from './types/content';
22+
import {Content, CountTokensRequest, CountTokensResponse, GenerateContentParams, GenerateContentRequest, GenerateContentResult, GenerationConfig, ModelParams, Part, SafetySetting} from './types/content';
2323
import {postRequest} from './util';
2424

2525
// TODO: update this when model names are available
@@ -264,6 +264,42 @@ export class GenerativeModel {
264264
}
265265
}
266266

267+
/**
268+
* Make a countTokens request.
269+
* @param request A CountTokensRequest object with the request contents.
270+
* @return The CountTokensResponse object with the token count.
271+
*/
272+
async countTokens(request: CountTokensRequest): Promise<CountTokensResponse> {
273+
let response;
274+
try {
275+
response = await postRequest({
276+
region: this._vertex_instance.location,
277+
project: this._vertex_instance.project,
278+
resourcePath: `publishers/google/models/${this.model}`,
279+
resourceMethod: 'countTokens',
280+
token: await this._vertex_instance.token,
281+
data: request,
282+
apiEndpoint: this._vertex_instance.apiEndpoint,
283+
});
284+
if (response === undefined) {
285+
throw new Error('did not get a valid response.');
286+
}
287+
if (!response.ok) {
288+
throw new Error(`${response.status} ${response.statusText}`);
289+
}
290+
291+
} catch (e) {
292+
console.log(e);
293+
}
294+
if (response) {
295+
const responseJson = await response.json();
296+
return responseJson as CountTokensResponse;
297+
} else {
298+
throw new Error('did not get a valid response.');
299+
}
300+
}
301+
302+
267303
startChat(request: StartChatParams): ChatSession {
268304
const startChatRequest = {
269305
history: request.history,

src/index_test.ts

+21-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import 'jasmine';
2020

2121
import {ChatSession, GenerativeModel, StartChatParams, VertexAI} from './index';
2222
import * as StreamFunctions from './process_stream';
23-
import {GenerateContentParams, GenerateContentResult} from './types/content';
23+
import {CountTokensRequest, CountTokensResponse, GenerateContentParams, GenerateContentResult} from './types/content';
2424
import * as PostRequest from './util/post_request';
2525

2626
const PROJECT = 'test_project';
@@ -135,6 +135,26 @@ describe('VertexAI', () => {
135135
expect(resp).toBeInstanceOf(ChatSession);
136136
});
137137
});
138+
139+
describe('countTokens', () => {
140+
it('returns the token count', async () => {
141+
const req: CountTokensRequest = {
142+
contents: TEST_USER_CHAT_MESSAGE,
143+
};
144+
const responseBody = {
145+
totalTokens: 1,
146+
};
147+
const response = new Response(JSON.stringify(responseBody), {
148+
status: 200,
149+
statusText: 'OK',
150+
headers: {'Content-Type': 'application/json'},
151+
});
152+
const responsePromise = Promise.resolve(response);
153+
spyOn(global, 'fetch').and.returnValue(responsePromise);
154+
const resp = await model.countTokens(req);
155+
expect(resp).toEqual(responseBody);
156+
});
157+
});
138158
});
139159

140160
describe('ChatSession', () => {

src/types/content.ts

+15
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,21 @@ export declare interface GenerateContentRequest extends BaseModelParams {
3131
contents: Content[];
3232
}
3333

34+
/**
35+
* Params used to call countTokens
36+
*/
37+
export declare interface CountTokensRequest {
38+
contents: Content[];
39+
}
40+
41+
/**
42+
* Response returned from countTokens
43+
*/
44+
export declare interface CountTokensResponse {
45+
totalTokens: number;
46+
totalBillableCharacters?: number;
47+
}
48+
3449

3550
/**
3651
* Configuration for initializing a model, for example via getGenerativeModel

src/util/post_request.ts

+2-4
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@
1818
// TODO: update to prod endpoint when ready
1919
const API_BASE_PATH = 'autopush-aiplatform.sandbox.googleapis.com';
2020

21-
import {
22-
GenerateContentRequest, CLIENT_INFO,
23-
} from '../types/content';
21+
import {GenerateContentRequest, CLIENT_INFO, CountTokensRequest} from '../types/content';
2422

2523
/**
2624
* Makes a POST request to a Vertex service
@@ -41,7 +39,7 @@ export async function postRequest({
4139
resourcePath: string,
4240
resourceMethod: string,
4341
token: string,
44-
data: GenerateContentRequest,
42+
data: GenerateContentRequest|CountTokensRequest,
4543
apiEndpoint?: string,
4644
apiVersion?: string,
4745
}): Promise<Response|undefined> {

0 commit comments

Comments
 (0)