Skip to content

Commit e6616a5

Browse files
authored
standard-tests[minor]: Improve prompting to force model to call tool (#6004)
1 parent 3d52258 commit e6616a5

File tree

1 file changed

+37
-17
lines changed

1 file changed

+37
-17
lines changed

libs/langchain-standard-tests/src/integration_tests/chat_models.ts

+37-17
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import {
1111
import { z } from "zod";
1212
import { StructuredTool } from "@langchain/core/tools";
1313
import { zodToJsonSchema } from "zod-to-json-schema";
14+
import { ChatPromptTemplate } from "@langchain/core/prompts";
1415
import {
1516
BaseChatModelsTests,
1617
BaseChatModelsTestsFields,
@@ -37,6 +38,14 @@ class AdderTool extends StructuredTool {
3738
}
3839
}
3940

41+
const MATH_ADDITION_PROMPT = /* #__PURE__ */ ChatPromptTemplate.fromMessages([
42+
[
43+
"system",
44+
"You are bad at math and must ALWAYS call the {toolName} function.",
45+
],
46+
["human", "What is the sum of 1836281973 and 19973286?"],
47+
]);
48+
4049
interface ChatModelIntegrationTestsFields<
4150
CallOptions extends BaseChatModelCallOptions = BaseChatModelCallOptions,
4251
OutputMessageType extends BaseMessageChunk = BaseMessageChunk,
@@ -228,11 +237,11 @@ export abstract class ChatModelIntegrationTests<
228237
new ToolMessage(functionResult, functionId, functionName),
229238
];
230239

231-
const resultStringContent = await modelWithTools.invoke(
240+
const result = await modelWithTools.invoke(
232241
messagesStringContent,
233242
callOptions
234243
);
235-
expect(resultStringContent).toBeInstanceOf(this.invokeResponseType);
244+
expect(result).toBeInstanceOf(this.invokeResponseType);
236245
}
237246

238247
/**
@@ -334,11 +343,11 @@ export abstract class ChatModelIntegrationTests<
334343
new HumanMessage("What is 3 + 4"),
335344
];
336345

337-
const resultStringContent = await modelWithTools.invoke(
346+
const result = await modelWithTools.invoke(
338347
messagesStringContent,
339348
callOptions
340349
);
341-
expect(resultStringContent).toBeInstanceOf(this.invokeResponseType);
350+
expect(result).toBeInstanceOf(this.invokeResponseType);
342351
}
343352

344353
async testWithStructuredOutput() {
@@ -353,13 +362,17 @@ export abstract class ChatModelIntegrationTests<
353362
"withStructuredOutput undefined. Cannot test tool message histories."
354363
);
355364
}
356-
const modelWithTools = model.withStructuredOutput(adderSchema);
365+
const modelWithTools = model.withStructuredOutput(adderSchema, {
366+
name: "math_addition",
367+
});
357368

358-
const resultStringContent = await modelWithTools.invoke("What is 1 + 2");
359-
expect(resultStringContent.a).toBeDefined();
360-
expect([1, 2].includes(resultStringContent.a)).toBeTruthy();
361-
expect(resultStringContent.b).toBeDefined();
362-
expect([1, 2].includes(resultStringContent.b)).toBeTruthy();
369+
const result = await MATH_ADDITION_PROMPT.pipe(modelWithTools).invoke({
370+
toolName: "math_addition",
371+
});
372+
expect(result.a).toBeDefined();
373+
expect(typeof result.a).toBe("number");
374+
expect(result.b).toBeDefined();
375+
expect(typeof result.b).toBe("number");
363376
}
364377

365378
async testWithStructuredOutputIncludeRaw() {
@@ -376,14 +389,17 @@ export abstract class ChatModelIntegrationTests<
376389
}
377390
const modelWithTools = model.withStructuredOutput(adderSchema, {
378391
includeRaw: true,
392+
name: "math_addition",
379393
});
380394

381-
const resultStringContent = await modelWithTools.invoke("What is 1 + 2");
382-
expect(resultStringContent.raw).toBeInstanceOf(this.invokeResponseType);
383-
expect(resultStringContent.parsed.a).toBeDefined();
384-
expect([1, 2].includes(resultStringContent.parsed.a)).toBeTruthy();
385-
expect(resultStringContent.parsed.b).toBeDefined();
386-
expect([1, 2].includes(resultStringContent.parsed.b)).toBeTruthy();
395+
const result = await MATH_ADDITION_PROMPT.pipe(modelWithTools).invoke({
396+
toolName: "math_addition",
397+
});
398+
expect(result.raw).toBeInstanceOf(this.invokeResponseType);
399+
expect(result.parsed.a).toBeDefined();
400+
expect(typeof result.parsed.a).toBe("number");
401+
expect(result.parsed.b).toBeDefined();
402+
expect(typeof result.parsed.b).toBe("number");
387403
}
388404

389405
async testBindToolsWithOpenAIFormattedTools() {
@@ -409,7 +425,11 @@ export abstract class ChatModelIntegrationTests<
409425
},
410426
]);
411427

412-
const result: AIMessage = await modelWithTools.invoke("What is 1 + 2");
428+
const result: AIMessage = await MATH_ADDITION_PROMPT.pipe(
429+
modelWithTools
430+
).invoke({
431+
toolName: "math_addition",
432+
});
413433
expect(result.tool_calls).toHaveLength(1);
414434
if (!result.tool_calls) {
415435
throw new Error("result.tool_calls is undefined");

0 commit comments

Comments
 (0)