Skip to content

Commit 4b4423a

Browse files
authored
google-genai[minor]: Add support for token counting via usage_metadata (#5757)
* google-genai[minor]: Add support for token counting via usage_metadata * jsdoc nits * pass entire request obj when getting input tok * fix and stop making api calls
1 parent baab194 commit 4b4423a

File tree

4 files changed

+189
-39
lines changed

4 files changed

+189
-39
lines changed

libs/langchain-google-genai/src/chat_models.ts

+101-26
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,14 @@ import {
66
type FunctionDeclarationSchema as GenerativeAIFunctionDeclarationSchema,
77
GenerateContentRequest,
88
SafetySetting,
9+
Part as GenerativeAIPart,
910
} from "@google/generative-ai";
1011
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
11-
import { AIMessageChunk, BaseMessage } from "@langchain/core/messages";
12+
import {
13+
AIMessageChunk,
14+
BaseMessage,
15+
UsageMetadata,
16+
} from "@langchain/core/messages";
1217
import { ChatGenerationChunk, ChatResult } from "@langchain/core/outputs";
1318
import { getEnvironmentVariable } from "@langchain/core/utils/env";
1419
import {
@@ -56,12 +61,20 @@ export interface GoogleGenerativeAIChatCallOptions
5661
tools?:
5762
| StructuredToolInterface[]
5863
| GoogleGenerativeAIFunctionDeclarationsTool[];
64+
/**
65+
* Whether or not to include usage data, like token counts
66+
* in the streamed response chunks.
67+
* @default true
68+
*/
69+
streamUsage?: boolean;
5970
}
6071

6172
/**
6273
* An interface defining the input to the ChatGoogleGenerativeAI class.
6374
*/
64-
export interface GoogleGenerativeAIChatInput extends BaseChatModelParams {
75+
export interface GoogleGenerativeAIChatInput
76+
extends BaseChatModelParams,
77+
Pick<GoogleGenerativeAIChatCallOptions, "streamUsage"> {
6578
/**
6679
* Model Name to use
6780
*
@@ -222,6 +235,8 @@ export class ChatGoogleGenerativeAI
222235

223236
streaming = false;
224237

238+
streamUsage = true;
239+
225240
private client: GenerativeModel;
226241

227242
get _isMultimodalModel() {
@@ -306,6 +321,7 @@ export class ChatGoogleGenerativeAI
306321
baseUrl: fields?.baseUrl,
307322
}
308323
);
324+
this.streamUsage = fields?.streamUsage ?? this.streamUsage;
309325
}
310326

311327
getLsParams(options: this["ParsedCallOptions"]): LangSmithParams {
@@ -398,27 +414,31 @@ export class ChatGoogleGenerativeAI
398414
return { generations, llmOutput: { estimatedTokenUsage: tokenUsage } };
399415
}
400416

401-
const res = await this.caller.callWithOptions(
402-
{ signal: options?.signal },
403-
async () => {
404-
let output;
405-
try {
406-
output = await this.client.generateContent({
407-
...parameters,
408-
contents: prompt,
409-
});
410-
// eslint-disable-next-line @typescript-eslint/no-explicit-any
411-
} catch (e: any) {
412-
// TODO: Improve error handling
413-
if (e.message?.includes("400 Bad Request")) {
414-
e.status = 400;
415-
}
416-
throw e;
417-
}
418-
return output;
417+
const res = await this.completionWithRetry({
418+
...parameters,
419+
contents: prompt,
420+
});
421+
422+
let usageMetadata: UsageMetadata | undefined;
423+
if ("usageMetadata" in res.response) {
424+
const genAIUsageMetadata = res.response.usageMetadata as {
425+
promptTokenCount: number | undefined;
426+
candidatesTokenCount: number | undefined;
427+
totalTokenCount: number | undefined;
428+
};
429+
usageMetadata = {
430+
input_tokens: genAIUsageMetadata.promptTokenCount ?? 0,
431+
output_tokens: genAIUsageMetadata.candidatesTokenCount ?? 0,
432+
total_tokens: genAIUsageMetadata.totalTokenCount ?? 0,
433+
};
434+
}
435+
436+
const generationResult = mapGenerateContentResultToChatResult(
437+
res.response,
438+
{
439+
usageMetadata,
419440
}
420441
);
421-
const generationResult = mapGenerateContentResultToChatResult(res.response);
422442
await runManager?.handleLLMNewToken(
423443
generationResult.generations[0].text ?? ""
424444
);
@@ -435,19 +455,53 @@ export class ChatGoogleGenerativeAI
435455
this._isMultimodalModel
436456
);
437457
const parameters = this.invocationParams(options);
458+
const request = {
459+
...parameters,
460+
contents: prompt,
461+
};
438462
const stream = await this.caller.callWithOptions(
439463
{ signal: options?.signal },
440464
async () => {
441-
const { stream } = await this.client.generateContentStream({
442-
...parameters,
443-
contents: prompt,
444-
});
465+
const { stream } = await this.client.generateContentStream(request);
445466
return stream;
446467
}
447468
);
448469

470+
let usageMetadata: UsageMetadata | undefined;
449471
for await (const response of stream) {
450-
const chunk = convertResponseContentToChatGenerationChunk(response);
472+
if (
473+
"usageMetadata" in response &&
474+
this.streamUsage !== false &&
475+
options.streamUsage !== false
476+
) {
477+
const genAIUsageMetadata = response.usageMetadata as {
478+
promptTokenCount: number;
479+
candidatesTokenCount: number;
480+
totalTokenCount: number;
481+
};
482+
if (!usageMetadata) {
483+
usageMetadata = {
484+
input_tokens: genAIUsageMetadata.promptTokenCount,
485+
output_tokens: genAIUsageMetadata.candidatesTokenCount,
486+
total_tokens: genAIUsageMetadata.totalTokenCount,
487+
};
488+
} else {
489+
// Under the hood, LangChain combines the prompt tokens. Google returns the updated
490+
// total each time, so we need to find the difference between the tokens.
491+
const outputTokenDiff =
492+
genAIUsageMetadata.candidatesTokenCount -
493+
usageMetadata.output_tokens;
494+
usageMetadata = {
495+
input_tokens: 0,
496+
output_tokens: outputTokenDiff,
497+
total_tokens: outputTokenDiff,
498+
};
499+
}
500+
}
501+
502+
const chunk = convertResponseContentToChatGenerationChunk(response, {
503+
usageMetadata,
504+
});
451505
if (!chunk) {
452506
continue;
453507
}
@@ -457,6 +511,27 @@ export class ChatGoogleGenerativeAI
457511
}
458512
}
459513

514+
async completionWithRetry(
515+
request: string | GenerateContentRequest | (string | GenerativeAIPart)[],
516+
options?: this["ParsedCallOptions"]
517+
) {
518+
return this.caller.callWithOptions(
519+
{ signal: options?.signal },
520+
async () => {
521+
try {
522+
return this.client.generateContent(request);
523+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
524+
} catch (e: any) {
525+
// TODO: Improve error handling
526+
if (e.message?.includes("400 Bad Request")) {
527+
e.status = 400;
528+
}
529+
throw e;
530+
}
531+
}
532+
);
533+
}
534+
460535
withStructuredOutput<
461536
// eslint-disable-next-line @typescript-eslint/no-explicit-any
462537
RunOutput extends Record<string, any> = Record<string, any>

libs/langchain-google-genai/src/tests/chat_models.int.test.ts

+63-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { test } from "@jest/globals";
22
import * as fs from "node:fs/promises";
33
import { fileURLToPath } from "node:url";
44
import * as path from "node:path";
5-
import { HumanMessage } from "@langchain/core/messages";
5+
import { AIMessageChunk, HumanMessage } from "@langchain/core/messages";
66
import {
77
ChatPromptTemplate,
88
MessagesPlaceholder,
@@ -320,3 +320,65 @@ test("ChatGoogleGenerativeAI can call withStructuredOutput genai tools and invok
320320
console.log(res);
321321
expect(typeof res.url === "string").toBe(true);
322322
});
323+
324+
test("Stream token count usage_metadata", async () => {
325+
const model = new ChatGoogleGenerativeAI({
326+
temperature: 0,
327+
});
328+
let res: AIMessageChunk | null = null;
329+
for await (const chunk of await model.stream(
330+
"Why is the sky blue? Be concise."
331+
)) {
332+
if (!res) {
333+
res = chunk;
334+
} else {
335+
res = res.concat(chunk);
336+
}
337+
}
338+
console.log(res);
339+
expect(res?.usage_metadata).toBeDefined();
340+
if (!res?.usage_metadata) {
341+
return;
342+
}
343+
expect(res.usage_metadata.input_tokens).toBe(10);
344+
expect(res.usage_metadata.output_tokens).toBeGreaterThan(10);
345+
expect(res.usage_metadata.total_tokens).toBe(
346+
res.usage_metadata.input_tokens + res.usage_metadata.output_tokens
347+
);
348+
});
349+
350+
test("streamUsage excludes token usage", async () => {
351+
const model = new ChatGoogleGenerativeAI({
352+
temperature: 0,
353+
streamUsage: false,
354+
});
355+
let res: AIMessageChunk | null = null;
356+
for await (const chunk of await model.stream(
357+
"Why is the sky blue? Be concise."
358+
)) {
359+
if (!res) {
360+
res = chunk;
361+
} else {
362+
res = res.concat(chunk);
363+
}
364+
}
365+
console.log(res);
366+
expect(res?.usage_metadata).not.toBeDefined();
367+
});
368+
369+
test("Invoke token count usage_metadata", async () => {
370+
const model = new ChatGoogleGenerativeAI({
371+
temperature: 0,
372+
});
373+
const res = await model.invoke("Why is the sky blue? Be concise.");
374+
console.log(res);
375+
expect(res?.usage_metadata).toBeDefined();
376+
if (!res?.usage_metadata) {
377+
return;
378+
}
379+
expect(res.usage_metadata.input_tokens).toBe(10);
380+
expect(res.usage_metadata.output_tokens).toBeGreaterThan(10);
381+
expect(res.usage_metadata.total_tokens).toBe(
382+
res.usage_metadata.input_tokens + res.usage_metadata.output_tokens
383+
);
384+
});

libs/langchain-google-genai/src/tests/chat_models.standard.int.test.ts

+14-10
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,23 @@ class ChatGoogleGenerativeAIStandardIntegrationTests extends ChatModelIntegratio
2828
}
2929

3030
async testUsageMetadataStreaming() {
31-
this.skipTestMessage(
32-
"testUsageMetadataStreaming",
33-
"ChatGoogleGenerativeAI",
34-
"Streaming tokens is not currently supported."
35-
);
31+
// ChatGoogleGenerativeAI does not support streaming tokens by
32+
// default, so we must pass in a call option to
33+
// enable streaming tokens.
34+
const callOptions: ChatGoogleGenerativeAI["ParsedCallOptions"] = {
35+
streamUsage: true,
36+
};
37+
await super.testUsageMetadataStreaming(callOptions);
3638
}
3739

3840
async testUsageMetadata() {
39-
this.skipTestMessage(
40-
"testUsageMetadata",
41-
"ChatGoogleGenerativeAI",
42-
"Usage metadata tokens is not currently supported."
43-
);
41+
// ChatGoogleGenerativeAI does not support counting tokens
42+
// by default, so we must pass in a call option to enable
43+
// streaming tokens.
44+
const callOptions: ChatGoogleGenerativeAI["ParsedCallOptions"] = {
45+
streamUsage: true,
46+
};
47+
await super.testUsageMetadata(callOptions);
4448
}
4549

4650
async testToolMessageHistoriesStringContent() {

libs/langchain-google-genai/src/utils/common.ts

+11-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import {
1212
ChatMessage,
1313
MessageContent,
1414
MessageContentComplex,
15+
UsageMetadata,
1516
isBaseMessage,
1617
} from "@langchain/core/messages";
1718
import {
@@ -179,7 +180,10 @@ export function convertBaseMessagesToContent(
179180
}
180181

181182
export function mapGenerateContentResultToChatResult(
182-
response: EnhancedGenerateContentResponse
183+
response: EnhancedGenerateContentResponse,
184+
extra?: {
185+
usageMetadata: UsageMetadata | undefined;
186+
}
183187
): ChatResult {
184188
// if rejected or error, return empty generations with reason in filters
185189
if (
@@ -208,6 +212,7 @@ export function mapGenerateContentResultToChatResult(
208212
additional_kwargs: {
209213
...generationInfo,
210214
},
215+
usage_metadata: extra?.usageMetadata,
211216
}),
212217
generationInfo,
213218
};
@@ -218,7 +223,10 @@ export function mapGenerateContentResultToChatResult(
218223
}
219224

220225
export function convertResponseContentToChatGenerationChunk(
221-
response: EnhancedGenerateContentResponse
226+
response: EnhancedGenerateContentResponse,
227+
extra?: {
228+
usageMetadata: UsageMetadata | undefined;
229+
}
222230
): ChatGenerationChunk | null {
223231
if (!response.candidates || response.candidates.length === 0) {
224232
return null;
@@ -235,6 +243,7 @@ export function convertResponseContentToChatGenerationChunk(
235243
// Each chunk can have unique "generationInfo", and merging strategy is unclear,
236244
// so leave blank for now.
237245
additional_kwargs: {},
246+
usage_metadata: extra?.usageMetadata,
238247
}),
239248
generationInfo,
240249
});

0 commit comments

Comments
 (0)