Skip to content

Commit f755f84

Browse files
authored
feat(xai): xAI polish (#7722)
1 parent 993d0f8 commit f755f84

File tree

5 files changed

+122
-19
lines changed

5 files changed

+122
-19
lines changed

libs/langchain-xai/package.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535
"author": "LangChain",
3636
"license": "MIT",
3737
"dependencies": {
38-
"@langchain/openai": "~0.3.0"
38+
"@langchain/openai": "~0.4.4",
39+
"zod": "^3.24.2"
3940
},
4041
"peerDependencies": {
4142
"@langchain/core": ">=0.2.21 <0.4.0"
@@ -67,7 +68,6 @@
6768
"rollup": "^4.5.2",
6869
"ts-jest": "^29.1.0",
6970
"typescript": "<5.2.0",
70-
"zod": "^3.22.4",
7171
"zod-to-json-schema": "^3.23.1"
7272
},
7373
"publishConfig": {

libs/langchain-xai/src/chat_models.ts

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
1+
import { BaseLanguageModelInput } from "@langchain/core/language_models/base";
12
import {
23
BaseChatModelCallOptions,
34
BindToolsInput,
45
LangSmithParams,
56
type BaseChatModelParams,
67
} from "@langchain/core/language_models/chat_models";
78
import { Serialized } from "@langchain/core/load/serializable";
9+
import { AIMessageChunk, BaseMessage } from "@langchain/core/messages";
10+
import { Runnable } from "@langchain/core/runnables";
811
import { getEnvironmentVariable } from "@langchain/core/utils/env";
912
import {
1013
type OpenAICoreRequestOptions,
1114
type OpenAIClient,
1215
ChatOpenAI,
1316
OpenAIToolChoice,
17+
ChatOpenAIStructuredOutputMethodOptions,
1418
} from "@langchain/openai";
19+
import { z } from "zod";
1520

1621
type ChatXAIToolType = BindToolsInput | OpenAIClient.ChatCompletionTool;
1722

@@ -494,4 +499,104 @@ export class ChatXAI extends ChatOpenAI<ChatXAICallOptions> {
494499

495500
return super.completionWithRetry(newRequest, options);
496501
}
502+
503+
protected override _convertOpenAIDeltaToBaseMessageChunk(
504+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
505+
delta: Record<string, any>,
506+
rawResponse: OpenAIClient.ChatCompletionChunk,
507+
defaultRole?:
508+
| "function"
509+
| "user"
510+
| "system"
511+
| "developer"
512+
| "assistant"
513+
| "tool"
514+
) {
515+
const messageChunk: AIMessageChunk =
516+
super._convertOpenAIDeltaToBaseMessageChunk(
517+
delta,
518+
rawResponse,
519+
defaultRole
520+
);
521+
// Make concatenating chunks work without merge warning
522+
if (!rawResponse.choices[0]?.finish_reason) {
523+
delete messageChunk.response_metadata.usage;
524+
delete messageChunk.usage_metadata;
525+
} else {
526+
messageChunk.usage_metadata = messageChunk.response_metadata.usage;
527+
}
528+
return messageChunk;
529+
}
530+
531+
protected override _convertOpenAIChatCompletionMessageToBaseMessage(
532+
message: OpenAIClient.ChatCompletionMessage,
533+
rawResponse: OpenAIClient.ChatCompletion
534+
) {
535+
const langChainMessage =
536+
super._convertOpenAIChatCompletionMessageToBaseMessage(
537+
message,
538+
rawResponse
539+
);
540+
langChainMessage.additional_kwargs.reasoning_content =
541+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
542+
(message as any).reasoning_content;
543+
return langChainMessage;
544+
}
545+
546+
override withStructuredOutput<
547+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
548+
RunOutput extends Record<string, any> = Record<string, any>
549+
>(
550+
outputSchema:
551+
| z.ZodType<RunOutput>
552+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
553+
| Record<string, any>,
554+
config?: ChatOpenAIStructuredOutputMethodOptions<false>
555+
): Runnable<BaseLanguageModelInput, RunOutput>;
556+
557+
override withStructuredOutput<
558+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
559+
RunOutput extends Record<string, any> = Record<string, any>
560+
>(
561+
outputSchema:
562+
| z.ZodType<RunOutput>
563+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
564+
| Record<string, any>,
565+
config?: ChatOpenAIStructuredOutputMethodOptions<true>
566+
): Runnable<BaseLanguageModelInput, { raw: BaseMessage; parsed: RunOutput }>;
567+
568+
override withStructuredOutput<
569+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
570+
RunOutput extends Record<string, any> = Record<string, any>
571+
>(
572+
outputSchema:
573+
| z.ZodType<RunOutput>
574+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
575+
| Record<string, any>,
576+
config?: ChatOpenAIStructuredOutputMethodOptions<boolean>
577+
):
578+
| Runnable<BaseLanguageModelInput, RunOutput>
579+
| Runnable<BaseLanguageModelInput, { raw: BaseMessage; parsed: RunOutput }>;
580+
581+
override withStructuredOutput<
582+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
583+
RunOutput extends Record<string, any> = Record<string, any>
584+
>(
585+
outputSchema:
586+
| z.ZodType<RunOutput>
587+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
588+
| Record<string, any>,
589+
config?: ChatOpenAIStructuredOutputMethodOptions<boolean>
590+
):
591+
| Runnable<BaseLanguageModelInput, RunOutput>
592+
| Runnable<
593+
BaseLanguageModelInput,
594+
{ raw: BaseMessage; parsed: RunOutput }
595+
> {
596+
const ensuredConfig = { ...config };
597+
if (ensuredConfig?.method === undefined) {
598+
ensuredConfig.method = "functionCalling";
599+
}
600+
return super.withStructuredOutput<RunOutput>(outputSchema, ensuredConfig);
601+
}
497602
}

libs/langchain-xai/src/tests/chat_models.int.test.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ test("streaming", async () => {
6868
test("invoke with bound tools", async () => {
6969
const chat = new ChatXAI({
7070
maxRetries: 0,
71-
model: "grok-beta",
71+
model: "grok-2-1212",
7272
});
7373
const message = new HumanMessage("What is the current weather in Hawaii?");
7474
const res = await chat
@@ -144,7 +144,7 @@ test("stream with bound tools, yielding a single chunk", async () => {
144144

145145
test("Few shotting with tool calls", async () => {
146146
const chat = new ChatXAI({
147-
model: "grok-beta",
147+
model: "grok-2-1212",
148148
temperature: 0,
149149
}).bind({
150150
tools: [
@@ -194,9 +194,9 @@ test("Few shotting with tool calls", async () => {
194194
expect(res.content).toContain("24");
195195
});
196196

197-
test("Groq can stream tool calls", async () => {
197+
test("xAI can stream tool calls", async () => {
198198
const model = new ChatXAI({
199-
model: "grok-beta",
199+
model: "grok-2-1212",
200200
temperature: 0,
201201
});
202202

libs/langchain-xai/src/tests/chat_models_structured_output.int.test.ts

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import { ChatXAI } from "../chat_models.js";
77
test("withStructuredOutput zod schema function calling", async () => {
88
const model = new ChatXAI({
99
temperature: 0,
10-
model: "grok-beta",
10+
model: "grok-2-1212",
1111
});
1212

1313
const calculatorSchema = z.object({
@@ -37,7 +37,7 @@ test("withStructuredOutput zod schema function calling", async () => {
3737
test("withStructuredOutput zod schema JSON mode", async () => {
3838
const model = new ChatXAI({
3939
temperature: 0,
40-
model: "grok-beta",
40+
model: "grok-2-1212",
4141
});
4242

4343
const calculatorSchema = z.object({
@@ -76,7 +76,7 @@ Respond with a JSON object containing three keys:
7676
test("withStructuredOutput JSON schema function calling", async () => {
7777
const model = new ChatXAI({
7878
temperature: 0,
79-
model: "grok-beta",
79+
model: "grok-2-1212",
8080
});
8181

8282
const calculatorSchema = z.object({
@@ -106,7 +106,7 @@ test("withStructuredOutput JSON schema function calling", async () => {
106106
test("withStructuredOutput OpenAI function definition function calling", async () => {
107107
const model = new ChatXAI({
108108
temperature: 0,
109-
model: "grok-beta",
109+
model: "grok-2-1212",
110110
});
111111

112112
const calculatorSchema = z.object({
@@ -120,14 +120,12 @@ test("withStructuredOutput OpenAI function definition function calling", async (
120120
});
121121

122122
const prompt = ChatPromptTemplate.fromMessages([
123-
"system",
124-
`You are VERY bad at math and must always use a calculator.`,
125-
"human",
126-
"Please help me!! What is 2 + 2?",
123+
["system", `You are VERY bad at math and must always use a calculator.`],
124+
["human", "Please help me!! What is 2 + 2?"],
127125
]);
128126
const chain = prompt.pipe(modelWithStructuredOutput);
129127
const result = await chain.invoke({});
130-
// console.log(result);
128+
131129
expect("operation" in result).toBe(true);
132130
expect("number1" in result).toBe(true);
133131
expect("number2" in result).toBe(true);
@@ -136,7 +134,7 @@ test("withStructuredOutput OpenAI function definition function calling", async (
136134
test("withStructuredOutput JSON schema JSON mode", async () => {
137135
const model = new ChatXAI({
138136
temperature: 0,
139-
model: "grok-beta",
137+
model: "grok-2-1212",
140138
});
141139

142140
const calculatorSchema = z.object({
@@ -175,7 +173,7 @@ Respond with a JSON object containing three keys:
175173
test("withStructuredOutput JSON schema", async () => {
176174
const model = new ChatXAI({
177175
temperature: 0,
178-
model: "grok-beta",
176+
model: "grok-2-1212",
179177
});
180178

181179
const jsonSchema = {
@@ -216,7 +214,7 @@ Respond with a JSON object containing three keys:
216214
test("withStructuredOutput includeRaw true", async () => {
217215
const model = new ChatXAI({
218216
temperature: 0,
219-
model: "grok-beta",
217+
model: "grok-2-1212",
220218
});
221219

222220
const calculatorSchema = z.object({

yarn.lock

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13335,7 +13335,7 @@ __metadata:
1333513335
rollup: ^4.5.2
1333613336
ts-jest: ^29.1.0
1333713337
typescript: <5.2.0
13338-
zod: ^3.22.4
13338+
zod: ^3.24.2
1333913339
zod-to-json-schema: ^3.23.1
1334013340
peerDependencies:
1334113341
"@langchain/core": ">=0.2.21 <0.4.0"

0 commit comments

Comments
 (0)