@@ -6,9 +6,14 @@ import {
6
6
type FunctionDeclarationSchema as GenerativeAIFunctionDeclarationSchema ,
7
7
GenerateContentRequest ,
8
8
SafetySetting ,
9
+ Part as GenerativeAIPart ,
9
10
} from "@google/generative-ai" ;
10
11
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager" ;
11
- import { AIMessageChunk , BaseMessage } from "@langchain/core/messages" ;
12
+ import {
13
+ AIMessageChunk ,
14
+ BaseMessage ,
15
+ UsageMetadata ,
16
+ } from "@langchain/core/messages" ;
12
17
import { ChatGenerationChunk , ChatResult } from "@langchain/core/outputs" ;
13
18
import { getEnvironmentVariable } from "@langchain/core/utils/env" ;
14
19
import {
@@ -56,12 +61,20 @@ export interface GoogleGenerativeAIChatCallOptions
56
61
tools ?:
57
62
| StructuredToolInterface [ ]
58
63
| GoogleGenerativeAIFunctionDeclarationsTool [ ] ;
64
+ /**
65
+ * Whether or not to include usage data, like token counts
66
+ * in the streamed response chunks.
67
+ * @default true
68
+ */
69
+ streamUsage ?: boolean ;
59
70
}
60
71
61
72
/**
62
73
* An interface defining the input to the ChatGoogleGenerativeAI class.
63
74
*/
64
- export interface GoogleGenerativeAIChatInput extends BaseChatModelParams {
75
+ export interface GoogleGenerativeAIChatInput
76
+ extends BaseChatModelParams ,
77
+ Pick < GoogleGenerativeAIChatCallOptions , "streamUsage" > {
65
78
/**
66
79
* Model Name to use
67
80
*
@@ -222,6 +235,8 @@ export class ChatGoogleGenerativeAI
222
235
223
236
streaming = false ;
224
237
238
+ streamUsage = true ;
239
+
225
240
private client : GenerativeModel ;
226
241
227
242
get _isMultimodalModel ( ) {
@@ -306,6 +321,7 @@ export class ChatGoogleGenerativeAI
306
321
baseUrl : fields ?. baseUrl ,
307
322
}
308
323
) ;
324
+ this . streamUsage = fields ?. streamUsage ?? this . streamUsage ;
309
325
}
310
326
311
327
getLsParams ( options : this[ "ParsedCallOptions" ] ) : LangSmithParams {
@@ -398,27 +414,31 @@ export class ChatGoogleGenerativeAI
398
414
return { generations, llmOutput : { estimatedTokenUsage : tokenUsage } } ;
399
415
}
400
416
401
- const res = await this . caller . callWithOptions (
402
- { signal : options ?. signal } ,
403
- async ( ) => {
404
- let output ;
405
- try {
406
- output = await this . client . generateContent ( {
407
- ...parameters ,
408
- contents : prompt ,
409
- } ) ;
410
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
411
- } catch ( e : any ) {
412
- // TODO: Improve error handling
413
- if ( e . message ?. includes ( "400 Bad Request" ) ) {
414
- e . status = 400 ;
415
- }
416
- throw e ;
417
- }
418
- return output ;
417
+ const res = await this . completionWithRetry ( {
418
+ ...parameters ,
419
+ contents : prompt ,
420
+ } ) ;
421
+
422
+ let usageMetadata : UsageMetadata | undefined ;
423
+ if ( "usageMetadata" in res . response ) {
424
+ const genAIUsageMetadata = res . response . usageMetadata as {
425
+ promptTokenCount : number | undefined ;
426
+ candidatesTokenCount : number | undefined ;
427
+ totalTokenCount : number | undefined ;
428
+ } ;
429
+ usageMetadata = {
430
+ input_tokens : genAIUsageMetadata . promptTokenCount ?? 0 ,
431
+ output_tokens : genAIUsageMetadata . candidatesTokenCount ?? 0 ,
432
+ total_tokens : genAIUsageMetadata . totalTokenCount ?? 0 ,
433
+ } ;
434
+ }
435
+
436
+ const generationResult = mapGenerateContentResultToChatResult (
437
+ res . response ,
438
+ {
439
+ usageMetadata,
419
440
}
420
441
) ;
421
- const generationResult = mapGenerateContentResultToChatResult ( res . response ) ;
422
442
await runManager ?. handleLLMNewToken (
423
443
generationResult . generations [ 0 ] . text ?? ""
424
444
) ;
@@ -435,19 +455,53 @@ export class ChatGoogleGenerativeAI
435
455
this . _isMultimodalModel
436
456
) ;
437
457
const parameters = this . invocationParams ( options ) ;
458
+ const request = {
459
+ ...parameters ,
460
+ contents : prompt ,
461
+ } ;
438
462
const stream = await this . caller . callWithOptions (
439
463
{ signal : options ?. signal } ,
440
464
async ( ) => {
441
- const { stream } = await this . client . generateContentStream ( {
442
- ...parameters ,
443
- contents : prompt ,
444
- } ) ;
465
+ const { stream } = await this . client . generateContentStream ( request ) ;
445
466
return stream ;
446
467
}
447
468
) ;
448
469
470
+ let usageMetadata : UsageMetadata | undefined ;
449
471
for await ( const response of stream ) {
450
- const chunk = convertResponseContentToChatGenerationChunk ( response ) ;
472
+ if (
473
+ "usageMetadata" in response &&
474
+ this . streamUsage !== false &&
475
+ options . streamUsage !== false
476
+ ) {
477
+ const genAIUsageMetadata = response . usageMetadata as {
478
+ promptTokenCount : number ;
479
+ candidatesTokenCount : number ;
480
+ totalTokenCount : number ;
481
+ } ;
482
+ if ( ! usageMetadata ) {
483
+ usageMetadata = {
484
+ input_tokens : genAIUsageMetadata . promptTokenCount ,
485
+ output_tokens : genAIUsageMetadata . candidatesTokenCount ,
486
+ total_tokens : genAIUsageMetadata . totalTokenCount ,
487
+ } ;
488
+ } else {
489
+ // Under the hood, LangChain combines the prompt tokens. Google returns the updated
490
+ // total each time, so we need to find the difference between the tokens.
491
+ const outputTokenDiff =
492
+ genAIUsageMetadata . candidatesTokenCount -
493
+ usageMetadata . output_tokens ;
494
+ usageMetadata = {
495
+ input_tokens : 0 ,
496
+ output_tokens : outputTokenDiff ,
497
+ total_tokens : outputTokenDiff ,
498
+ } ;
499
+ }
500
+ }
501
+
502
+ const chunk = convertResponseContentToChatGenerationChunk ( response , {
503
+ usageMetadata,
504
+ } ) ;
451
505
if ( ! chunk ) {
452
506
continue ;
453
507
}
@@ -457,6 +511,27 @@ export class ChatGoogleGenerativeAI
457
511
}
458
512
}
459
513
514
+ async completionWithRetry (
515
+ request : string | GenerateContentRequest | ( string | GenerativeAIPart ) [ ] ,
516
+ options ?: this[ "ParsedCallOptions" ]
517
+ ) {
518
+ return this . caller . callWithOptions (
519
+ { signal : options ?. signal } ,
520
+ async ( ) => {
521
+ try {
522
+ return this . client . generateContent ( request ) ;
523
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
524
+ } catch ( e : any ) {
525
+ // TODO: Improve error handling
526
+ if ( e . message ?. includes ( "400 Bad Request" ) ) {
527
+ e . status = 400 ;
528
+ }
529
+ throw e ;
530
+ }
531
+ }
532
+ ) ;
533
+ }
534
+
460
535
withStructuredOutput <
461
536
// eslint-disable-next-line @typescript-eslint/no-explicit-any
462
537
RunOutput extends Record < string , any > = Record < string , any >
0 commit comments