Skip to content

Commit 075ebbc

Browse files
feat(google-common), feat(core): Improve token counting (#8128)
2 parents 26d3b0d + 02fbd1b commit 075ebbc

File tree

4 files changed

+151
-32
lines changed

4 files changed

+151
-32
lines changed

langchain-core/src/messages/ai.ts

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,41 @@ export type AIMessageFields = BaseMessageFields & {
2121
usage_metadata?: UsageMetadata;
2222
};
2323

24+
export type ModalitiesTokenDetails = {
25+
/**
26+
* Text tokens.
27+
* Does not need to be reported, but some models will do so.
28+
*/
29+
text?: number;
30+
31+
/**
32+
* Image (non-video) tokens.
33+
*/
34+
image?: number;
35+
36+
/**
37+
* Audio tokens.
38+
*/
39+
audio?: number;
40+
41+
/**
42+
* Video tokens.
43+
*/
44+
video?: number;
45+
46+
/**
47+
* Document tokens.
48+
* e.g. PDF
49+
*/
50+
document?: number;
51+
};
52+
2453
/**
2554
* Breakdown of input token counts.
2655
*
2756
* Does not *need* to sum to full input token count. Does *not* need to have all keys.
2857
*/
29-
export type InputTokenDetails = {
30-
/**
31-
* Audio input tokens.
32-
*/
33-
audio?: number;
34-
58+
export type InputTokenDetails = ModalitiesTokenDetails & {
3559
/**
3660
* Input tokens that were cached and there was a cache hit.
3761
*
@@ -53,12 +77,7 @@ export type InputTokenDetails = {
5377
*
5478
* Does *not* need to sum to full output token count. Does *not* need to have all keys.
5579
*/
56-
export type OutputTokenDetails = {
57-
/**
58-
* Audio output tokens
59-
*/
60-
audio?: number;
61-
80+
export type OutputTokenDetails = ModalitiesTokenDetails & {
6281
/**
6382
* Reasoning output tokens.
6483
*

libs/langchain-google-common/src/types.ts

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,10 +596,39 @@ interface GeminiResponsePromptFeedback {
596596
safetyRatings: GeminiSafetyRating[];
597597
}
598598

599+
export type ModalityEnum =
600+
| "TEXT"
601+
| "IMAGE"
602+
| "VIDEO"
603+
| "AUDIO"
604+
| "DOCUMENT"
605+
| string;
606+
607+
export interface ModalityTokenCount {
608+
modality: ModalityEnum;
609+
tokenCount: number;
610+
}
611+
612+
export interface GenerateContentResponseUsageMetadata {
613+
promptTokenCount: number;
614+
toolUsePromptTokenCount: number;
615+
cachedContentTokenCount: number;
616+
thoughtsTokenCount: number;
617+
candidatesTokenCount: number;
618+
totalTokenCount: number;
619+
620+
promptTokensDetails: ModalityTokenCount[];
621+
toolUsePromptTokensDetails: ModalityTokenCount[];
622+
cacheTokensDetails: ModalityTokenCount[];
623+
candidatesTokensDetails: ModalityTokenCount[];
624+
625+
[key: string]: unknown;
626+
}
627+
599628
export interface GenerateContentResponseData {
600629
candidates: GeminiResponseCandidate[];
601630
promptFeedback: GeminiResponsePromptFeedback;
602-
usageMetadata: Record<string, unknown>;
631+
usageMetadata: GenerateContentResponseUsageMetadata;
603632
}
604633

605634
export type GoogleLLMModelFamily = null | "palm" | "gemini" | "gemma";

libs/langchain-google-common/src/utils/gemini.ts

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ import {
1919
parseBase64DataUrl,
2020
isDataContentBlock,
2121
convertToProviderContentBlock,
22+
InputTokenDetails,
23+
OutputTokenDetails,
24+
ModalitiesTokenDetails,
2225
} from "@langchain/core/messages";
2326
import {
2427
ChatGeneration,
@@ -47,6 +50,7 @@ import type {
4750
GeminiLogprobsResult,
4851
GeminiLogprobsResultCandidate,
4952
GeminiLogprobsTopCandidate,
53+
ModalityTokenCount,
5054
} from "../types.js";
5155
import { GoogleAISafetyError } from "./safety.js";
5256
import { MediaBlob } from "../experimental/utils/media_core.js";
@@ -855,6 +859,58 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI {
855859
};
856860
}
857861

862+
function addModalityCounts(
863+
modalityTokenCounts: ModalityTokenCount[],
864+
details: InputTokenDetails | OutputTokenDetails
865+
): void {
866+
modalityTokenCounts?.forEach((modalityTokenCount) => {
867+
const { modality, tokenCount } = modalityTokenCount;
868+
const modalityLc: keyof ModalitiesTokenDetails =
869+
modality.toLowerCase() as keyof ModalitiesTokenDetails;
870+
const currentCount = details[modalityLc] ?? 0;
871+
// eslint-disable-next-line no-param-reassign
872+
details[modalityLc] = currentCount + tokenCount;
873+
});
874+
}
875+
876+
function responseToUsageMetadata(
877+
response: GoogleLLMResponse
878+
): UsageMetadata | undefined {
879+
if ("usageMetadata" in response.data) {
880+
const data: GenerateContentResponseData = response?.data;
881+
const usageMetadata = data?.usageMetadata;
882+
883+
const input_tokens = usageMetadata.promptTokenCount ?? 0;
884+
const candidatesTokenCount = usageMetadata.candidatesTokenCount ?? 0;
885+
const thoughtsTokenCount = usageMetadata.thoughtsTokenCount ?? 0;
886+
const output_tokens = candidatesTokenCount + thoughtsTokenCount;
887+
const total_tokens =
888+
usageMetadata.totalTokenCount ?? input_tokens + output_tokens;
889+
890+
const input_token_details: InputTokenDetails = {};
891+
addModalityCounts(usageMetadata.promptTokensDetails, input_token_details);
892+
893+
const output_token_details: OutputTokenDetails = {};
894+
addModalityCounts(
895+
usageMetadata?.candidatesTokensDetails,
896+
output_token_details
897+
);
898+
if (typeof usageMetadata?.thoughtsTokenCount === "number") {
899+
output_token_details.reasoning = usageMetadata.thoughtsTokenCount;
900+
}
901+
902+
const ret: UsageMetadata = {
903+
input_tokens,
904+
output_tokens,
905+
total_tokens,
906+
input_token_details,
907+
output_token_details,
908+
};
909+
return ret;
910+
}
911+
return undefined;
912+
}
913+
858914
function responseToGenerationInfo(response: GoogleLLMResponse) {
859915
const data =
860916
// eslint-disable-next-line no-nested-ternary
@@ -890,11 +946,7 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI {
890946
// Only add the usage_metadata on the last chunk
891947
// sent while streaming (see issue 8102).
892948
if (typeof finish_reason === "string") {
893-
ret.usage_metadata = {
894-
prompt_token_count: data.usageMetadata?.promptTokenCount,
895-
candidates_token_count: data.usageMetadata?.candidatesTokenCount,
896-
total_token_count: data.usageMetadata?.totalTokenCount,
897-
};
949+
ret.usage_metadata = responseToUsageMetadata(response);
898950
}
899951

900952
return ret;
@@ -1115,15 +1167,7 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI {
11151167
const lastContent = gen.content[gen.content.length - 1];
11161168

11171169
// Add usage metadata
1118-
let usageMetadata: UsageMetadata | undefined;
1119-
if ("usageMetadata" in response.data) {
1120-
usageMetadata = {
1121-
input_tokens: response.data.usageMetadata.promptTokenCount as number,
1122-
output_tokens: response.data.usageMetadata
1123-
.candidatesTokenCount as number,
1124-
total_tokens: response.data.usageMetadata.totalTokenCount as number,
1125-
};
1126-
}
1170+
const usage_metadata = responseToUsageMetadata(response);
11271171

11281172
// Add thinking / reasoning
11291173
// if (gen.reasoning && gen.reasoning.length > 0) {
@@ -1134,7 +1178,7 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI {
11341178
const message = new AIMessageChunk({
11351179
content: combinedContent,
11361180
additional_kwargs: kwargs,
1137-
usage_metadata: usageMetadata,
1181+
usage_metadata,
11381182
tool_calls: combinedToolCalls.tool_calls,
11391183
invalid_tool_calls: combinedToolCalls.invalid_tool_calls,
11401184
});

libs/langchain-google-webauth/src/tests/chat_models.int.test.ts

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ import { ChatGoogle, ChatGoogleInput } from "../chat_models.js";
3939
import { BlobStoreAIStudioFile } from "../media.js";
4040
import MockedFunction = jest.MockedFunction;
4141

42+
function propSum(o: Record<string, number>): number {
43+
return Object.keys(o)
44+
.map((key) => o[key])
45+
.reduce((acc, val) => acc + val);
46+
}
47+
4248
class WeatherTool extends StructuredTool {
4349
schema = z.object({
4450
locations: z
@@ -442,10 +448,16 @@ describe.each(testGeminiModelNames)(
442448
expect(text).toMatch(/(1 + 1 (equals|is|=) )?2.? ?/);
443449

444450
expect(res).toHaveProperty("response_metadata");
445-
expect(res.response_metadata).not.toHaveProperty("groundingMetadata");
446-
expect(res.response_metadata).not.toHaveProperty("groundingSupport");
447-
448-
console.log(recorder);
451+
const meta = res.response_metadata;
452+
expect(meta).not.toHaveProperty("groundingMetadata");
453+
expect(meta).not.toHaveProperty("groundingSupport");
454+
expect(meta).toHaveProperty("usage_metadata");
455+
const usage = meta.usage_metadata;
456+
457+
// Although LangChainJS doesn't require that the details sum to the
458+
// available tokens, this should be the case for how we're doing Gemini.
459+
expect(propSum(usage.input_token_details)).toEqual(usage.input_tokens);
460+
expect(propSum(usage.output_token_details)).toEqual(usage.output_tokens);
449461
});
450462

451463
test(`generate`, async () => {
@@ -883,6 +895,21 @@ describe.each(testGeminiModelNames)(
883895

884896
expect(typeof response.content).toBe("string");
885897
expect((response.content as string).length).toBeGreaterThan(15);
898+
899+
expect(response).toHaveProperty("response_metadata");
900+
const meta = response.response_metadata;
901+
expect(meta).not.toHaveProperty("groundingMetadata");
902+
expect(meta).not.toHaveProperty("groundingSupport");
903+
expect(meta).toHaveProperty("usage_metadata");
904+
const usage = meta.usage_metadata;
905+
906+
// Although LangChainJS doesn't require that the details sum to the
907+
// available tokens, this should be the case for how we're doing Gemini.
908+
expect(propSum(usage.input_token_details)).toEqual(usage.input_tokens);
909+
expect(propSum(usage.output_token_details)).toEqual(usage.output_tokens);
910+
expect(usage.input_token_details).toHaveProperty("audio");
911+
912+
console.log(response);
886913
});
887914

888915
test("Supports GoogleSearchRetrievalTool", async () => {

0 commit comments

Comments
 (0)