Skip to content

Commit 6ade289

Browse files
authored
google-common[minor]: Add stream_usage (#5763)
* google-vertexai[minor]: Add stream usage * fix tests * bump min core version
1 parent ed493df commit 6ade289

File tree

7 files changed

+103
-23
lines changed

7 files changed

+103
-23
lines changed

libs/langchain-google-common/package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
"author": "LangChain",
4141
"license": "MIT",
4242
"dependencies": {
43-
"@langchain/core": ">0.1.56 <0.3.0",
43+
"@langchain/core": ">=0.2.5 <0.3.0",
4444
"uuid": "^9.0.0",
4545
"zod-to-json-schema": "^3.22.4"
4646
},

libs/langchain-google-common/src/chat_models.ts

+20-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { getEnvironmentVariable } from "@langchain/core/utils/env";
2-
import { type BaseMessage } from "@langchain/core/messages";
2+
import { UsageMetadata, type BaseMessage } from "@langchain/core/messages";
33
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
44

55
import {
@@ -150,7 +150,8 @@ export interface ChatGoogleBaseInput<AuthOptions>
150150
extends BaseChatModelParams,
151151
GoogleConnectionParams<AuthOptions>,
152152
GoogleAIModelParams,
153-
GoogleAISafetyParams {}
153+
GoogleAISafetyParams,
154+
Pick<GoogleAIBaseLanguageModelCallOptions, "streamUsage"> {}
154155

155156
function convertToGeminiTools(
156157
structuredTools: (StructuredToolInterface | Record<string, unknown>)[]
@@ -216,6 +217,8 @@ export abstract class ChatGoogleBase<AuthOptions>
216217

217218
safetyHandler: GoogleAISafetyHandler;
218219

220+
streamUsage = true;
221+
219222
protected connection: ChatConnection<AuthOptions>;
220223

221224
protected streamedConnection: ChatConnection<AuthOptions>;
@@ -226,7 +229,7 @@ export abstract class ChatGoogleBase<AuthOptions>
226229
copyAndValidateModelParamsInto(fields, this);
227230
this.safetyHandler =
228231
fields?.safetyHandler ?? new DefaultGeminiSafetyHandler();
229-
232+
this.streamUsage = fields?.streamUsage ?? this.streamUsage;
230233
const client = this.buildClient(fields);
231234
this.buildConnection(fields ?? {}, client);
232235
}
@@ -342,12 +345,24 @@ export abstract class ChatGoogleBase<AuthOptions>
342345

343346
// Get the streaming parser of the response
344347
const stream = response.data as JsonStream;
345-
348+
let usageMetadata: UsageMetadata | undefined;
346349
// Loop until the end of the stream
347350
// During the loop, yield each time we get a chunk from the streaming parser
348351
// that is either available or added to the queue
349352
while (!stream.streamDone) {
350353
const output = await stream.nextChunk();
354+
if (
355+
output &&
356+
output.usageMetadata &&
357+
this.streamUsage !== false &&
358+
options.streamUsage !== false
359+
) {
360+
usageMetadata = {
361+
input_tokens: output.usageMetadata.promptTokenCount,
362+
output_tokens: output.usageMetadata.candidatesTokenCount,
363+
total_tokens: output.usageMetadata.totalTokenCount,
364+
};
365+
}
351366
const chunk =
352367
output !== null
353368
? safeResponseToChatGeneration({ data: output }, this.safetyHandler)
@@ -356,6 +371,7 @@ export abstract class ChatGoogleBase<AuthOptions>
356371
generationInfo: { finishReason: "stop" },
357372
message: new AIMessageChunk({
358373
content: "",
374+
usage_metadata: usageMetadata,
359375
}),
360376
});
361377
yield chunk;

libs/langchain-google-common/src/types.ts

+8-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,14 @@ export interface GoogleAIBaseLLMInput<AuthOptions>
122122
export interface GoogleAIBaseLanguageModelCallOptions
123123
extends BaseLanguageModelCallOptions,
124124
GoogleAIModelRequestParams,
125-
GoogleAISafetyParams {}
125+
GoogleAISafetyParams {
126+
/**
127+
* Whether or not to include usage data, like token counts
128+
* in the streamed response chunks.
129+
* @default true
130+
*/
131+
streamUsage?: boolean;
132+
}
126133

127134
/**
128135
* Input to LLM class.

libs/langchain-google-common/src/utils/gemini.ts

+11
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import {
1212
MessageContentText,
1313
SystemMessage,
1414
ToolMessage,
15+
UsageMetadata,
1516
isAIMessage,
1617
} from "@langchain/core/messages";
1718
import {
@@ -604,12 +605,22 @@ export function responseToChatGenerations(
604605
id: toolCall.id,
605606
index: i,
606607
}));
608+
let usageMetadata: UsageMetadata | undefined;
609+
if ("usageMetadata" in response.data) {
610+
usageMetadata = {
611+
input_tokens: response.data.usageMetadata.promptTokenCount as number,
612+
output_tokens: response.data.usageMetadata
613+
.candidatesTokenCount as number,
614+
total_tokens: response.data.usageMetadata.totalTokenCount as number,
615+
};
616+
}
607617
ret = [
608618
new ChatGenerationChunk({
609619
message: new AIMessageChunk({
610620
content: combinedContent,
611621
additional_kwargs: ret[ret.length - 1]?.message.additional_kwargs,
612622
tool_call_chunks: toolCallChunks,
623+
usage_metadata: usageMetadata,
613624
}),
614625
text: combinedText,
615626
generationInfo: ret[ret.length - 1].generationInfo,

libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts

+62
Original file line numberDiff line numberDiff line change
@@ -233,3 +233,65 @@ describe("GAuth Chat", () => {
233233
expect(result).toHaveProperty("location");
234234
});
235235
});
236+
237+
test("Stream token count usage_metadata", async () => {
238+
const model = new ChatVertexAI({
239+
temperature: 0,
240+
});
241+
let res: AIMessageChunk | null = null;
242+
for await (const chunk of await model.stream(
243+
"Why is the sky blue? Be concise."
244+
)) {
245+
if (!res) {
246+
res = chunk;
247+
} else {
248+
res = res.concat(chunk);
249+
}
250+
}
251+
console.log(res);
252+
expect(res?.usage_metadata).toBeDefined();
253+
if (!res?.usage_metadata) {
254+
return;
255+
}
256+
expect(res.usage_metadata.input_tokens).toBe(9);
257+
expect(res.usage_metadata.output_tokens).toBeGreaterThan(10);
258+
expect(res.usage_metadata.total_tokens).toBe(
259+
res.usage_metadata.input_tokens + res.usage_metadata.output_tokens
260+
);
261+
});
262+
263+
test("streamUsage excludes token usage", async () => {
264+
const model = new ChatVertexAI({
265+
temperature: 0,
266+
streamUsage: false,
267+
});
268+
let res: AIMessageChunk | null = null;
269+
for await (const chunk of await model.stream(
270+
"Why is the sky blue? Be concise."
271+
)) {
272+
if (!res) {
273+
res = chunk;
274+
} else {
275+
res = res.concat(chunk);
276+
}
277+
}
278+
console.log(res);
279+
expect(res?.usage_metadata).not.toBeDefined();
280+
});
281+
282+
test("Invoke token count usage_metadata", async () => {
283+
const model = new ChatVertexAI({
284+
temperature: 0,
285+
});
286+
const res = await model.invoke("Why is the sky blue? Be concise.");
287+
console.log(res);
288+
expect(res?.usage_metadata).toBeDefined();
289+
if (!res?.usage_metadata) {
290+
return;
291+
}
292+
expect(res.usage_metadata.input_tokens).toBe(9);
293+
expect(res.usage_metadata.output_tokens).toBeGreaterThan(10);
294+
expect(res.usage_metadata.total_tokens).toBe(
295+
res.usage_metadata.input_tokens + res.usage_metadata.output_tokens
296+
);
297+
});

libs/langchain-google-vertexai/src/tests/chat_models.standard.int.test.ts

-16
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,6 @@ class ChatVertexAIStandardIntegrationTests extends ChatModelIntegrationTests<
2525
});
2626
}
2727

28-
async testUsageMetadataStreaming() {
29-
this.skipTestMessage(
30-
"testUsageMetadataStreaming",
31-
"ChatVertexAI",
32-
"Streaming tokens is not currently supported."
33-
);
34-
}
35-
36-
async testUsageMetadata() {
37-
this.skipTestMessage(
38-
"testUsageMetadata",
39-
"ChatVertexAI",
40-
"Usage metadata tokens is not currently supported."
41-
);
42-
}
43-
4428
async testToolMessageHistoriesListContent() {
4529
this.skipTestMessage(
4630
"testToolMessageHistoriesListContent",

yarn.lock

+1-1
Original file line numberDiff line numberDiff line change
@@ -10194,7 +10194,7 @@ __metadata:
1019410194
resolution: "@langchain/google-common@workspace:libs/langchain-google-common"
1019510195
dependencies:
1019610196
"@jest/globals": ^29.5.0
10197-
"@langchain/core": ">0.1.56 <0.3.0"
10197+
"@langchain/core": ">=0.2.5 <0.3.0"
1019810198
"@langchain/scripts": ~0.0.14
1019910199
"@swc/core": ^1.3.90
1020010200
"@swc/jest": ^0.2.29

0 commit comments

Comments
 (0)