@@ -11,6 +11,7 @@ import {
11
11
import { z } from "zod" ;
12
12
import { StructuredTool } from "@langchain/core/tools" ;
13
13
import { zodToJsonSchema } from "zod-to-json-schema" ;
14
+ import { ChatPromptTemplate } from "@langchain/core/prompts" ;
14
15
import {
15
16
BaseChatModelsTests ,
16
17
BaseChatModelsTestsFields ,
@@ -37,6 +38,14 @@ class AdderTool extends StructuredTool {
37
38
}
38
39
}
39
40
41
+ const MATH_ADDITION_PROMPT = /* #__PURE__ */ ChatPromptTemplate . fromMessages ( [
42
+ [
43
+ "system" ,
44
+ "You are bad at math and must ALWAYS call the {toolName} function." ,
45
+ ] ,
46
+ [ "human" , "What is the sum of 1836281973 and 19973286?" ] ,
47
+ ] ) ;
48
+
40
49
interface ChatModelIntegrationTestsFields <
41
50
CallOptions extends BaseChatModelCallOptions = BaseChatModelCallOptions ,
42
51
OutputMessageType extends BaseMessageChunk = BaseMessageChunk ,
@@ -228,11 +237,11 @@ export abstract class ChatModelIntegrationTests<
228
237
new ToolMessage ( functionResult , functionId , functionName ) ,
229
238
] ;
230
239
231
- const resultStringContent = await modelWithTools . invoke (
240
+ const result = await modelWithTools . invoke (
232
241
messagesStringContent ,
233
242
callOptions
234
243
) ;
235
- expect ( resultStringContent ) . toBeInstanceOf ( this . invokeResponseType ) ;
244
+ expect ( result ) . toBeInstanceOf ( this . invokeResponseType ) ;
236
245
}
237
246
238
247
/**
@@ -334,11 +343,11 @@ export abstract class ChatModelIntegrationTests<
334
343
new HumanMessage ( "What is 3 + 4" ) ,
335
344
] ;
336
345
337
- const resultStringContent = await modelWithTools . invoke (
346
+ const result = await modelWithTools . invoke (
338
347
messagesStringContent ,
339
348
callOptions
340
349
) ;
341
- expect ( resultStringContent ) . toBeInstanceOf ( this . invokeResponseType ) ;
350
+ expect ( result ) . toBeInstanceOf ( this . invokeResponseType ) ;
342
351
}
343
352
344
353
async testWithStructuredOutput ( ) {
@@ -353,13 +362,17 @@ export abstract class ChatModelIntegrationTests<
353
362
"withStructuredOutput undefined. Cannot test tool message histories."
354
363
) ;
355
364
}
356
- const modelWithTools = model . withStructuredOutput ( adderSchema ) ;
365
+ const modelWithTools = model . withStructuredOutput ( adderSchema , {
366
+ name : "math_addition" ,
367
+ } ) ;
357
368
358
- const resultStringContent = await modelWithTools . invoke ( "What is 1 + 2" ) ;
359
- expect ( resultStringContent . a ) . toBeDefined ( ) ;
360
- expect ( [ 1 , 2 ] . includes ( resultStringContent . a ) ) . toBeTruthy ( ) ;
361
- expect ( resultStringContent . b ) . toBeDefined ( ) ;
362
- expect ( [ 1 , 2 ] . includes ( resultStringContent . b ) ) . toBeTruthy ( ) ;
369
+ const result = await MATH_ADDITION_PROMPT . pipe ( modelWithTools ) . invoke ( {
370
+ toolName : "math_addition" ,
371
+ } ) ;
372
+ expect ( result . a ) . toBeDefined ( ) ;
373
+ expect ( typeof result . a ) . toBe ( "number" ) ;
374
+ expect ( result . b ) . toBeDefined ( ) ;
375
+ expect ( typeof result . b ) . toBe ( "number" ) ;
363
376
}
364
377
365
378
async testWithStructuredOutputIncludeRaw ( ) {
@@ -376,14 +389,17 @@ export abstract class ChatModelIntegrationTests<
376
389
}
377
390
const modelWithTools = model . withStructuredOutput ( adderSchema , {
378
391
includeRaw : true ,
392
+ name : "math_addition" ,
379
393
} ) ;
380
394
381
- const resultStringContent = await modelWithTools . invoke ( "What is 1 + 2" ) ;
382
- expect ( resultStringContent . raw ) . toBeInstanceOf ( this . invokeResponseType ) ;
383
- expect ( resultStringContent . parsed . a ) . toBeDefined ( ) ;
384
- expect ( [ 1 , 2 ] . includes ( resultStringContent . parsed . a ) ) . toBeTruthy ( ) ;
385
- expect ( resultStringContent . parsed . b ) . toBeDefined ( ) ;
386
- expect ( [ 1 , 2 ] . includes ( resultStringContent . parsed . b ) ) . toBeTruthy ( ) ;
395
+ const result = await MATH_ADDITION_PROMPT . pipe ( modelWithTools ) . invoke ( {
396
+ toolName : "math_addition" ,
397
+ } ) ;
398
+ expect ( result . raw ) . toBeInstanceOf ( this . invokeResponseType ) ;
399
+ expect ( result . parsed . a ) . toBeDefined ( ) ;
400
+ expect ( typeof result . parsed . a ) . toBe ( "number" ) ;
401
+ expect ( result . parsed . b ) . toBeDefined ( ) ;
402
+ expect ( typeof result . parsed . b ) . toBe ( "number" ) ;
387
403
}
388
404
389
405
async testBindToolsWithOpenAIFormattedTools ( ) {
@@ -409,7 +425,11 @@ export abstract class ChatModelIntegrationTests<
409
425
} ,
410
426
] ) ;
411
427
412
- const result : AIMessage = await modelWithTools . invoke ( "What is 1 + 2" ) ;
428
+ const result : AIMessage = await MATH_ADDITION_PROMPT . pipe (
429
+ modelWithTools
430
+ ) . invoke ( {
431
+ toolName : "math_addition" ,
432
+ } ) ;
413
433
expect ( result . tool_calls ) . toHaveLength ( 1 ) ;
414
434
if ( ! result . tool_calls ) {
415
435
throw new Error ( "result.tool_calls is undefined" ) ;
0 commit comments