From 016f6687c2dce176d1a9ecaf22bffc0c41cb0cae Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Fri, 10 Jan 2025 19:32:34 -0800 Subject: [PATCH] Add support for disable_streaming, set for o1 --- .../src/language_models/chat_models.ts | 18 ++++++++++-- libs/langchain-openai/src/chat_models.ts | 4 +++ .../src/tests/chat_models.int.test.ts | 29 ++++++++++++++++++- 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/langchain-core/src/language_models/chat_models.ts b/langchain-core/src/language_models/chat_models.ts index e80e5cf90886..9ca563a38a04 100644 --- a/langchain-core/src/language_models/chat_models.ts +++ b/langchain-core/src/language_models/chat_models.ts @@ -74,7 +74,18 @@ export type SerializedLLM = { /** * Represents the parameters for a base chat model. */ -export type BaseChatModelParams = BaseLanguageModelParams; +export type BaseChatModelParams = BaseLanguageModelParams & { + /** + * Whether to disable streaming. + * + * If streaming is bypassed, then `stream()` will defer to + * `invoke()`. + * + * - If true, will always bypass streaming case. + * - If false (default), will always use streaming case if available. + */ + disableStreaming?: boolean; +}; /** * Represents the call options for a base chat model. @@ -152,6 +163,8 @@ export abstract class BaseChatModel< // Only ever instantiated in main LangChain lc_namespace = ["langchain", "chat_models", this._llmType()]; + disableStreaming = false; + constructor(fields: BaseChatModelParams) { super(fields); } @@ -220,7 +233,8 @@ export abstract class BaseChatModel< // Subclass check required to avoid double callbacks with default implementation if ( this._streamResponseChunks === - BaseChatModel.prototype._streamResponseChunks + BaseChatModel.prototype._streamResponseChunks || + this.disableStreaming ) { yield this.invoke(input, options); } else { diff --git a/libs/langchain-openai/src/chat_models.ts b/libs/langchain-openai/src/chat_models.ts index 3bd3a4e5ee57..2e04059c2fd4 100644 --- a/libs/langchain-openai/src/chat_models.ts +++ b/libs/langchain-openai/src/chat_models.ts @@ -1223,6 +1223,10 @@ export class ChatOpenAI< this.streamUsage = false; } + if (this.model === "o1") { + this.disableStreaming = true; + } + this.streaming = fields?.streaming ?? false; this.streamUsage = fields?.streamUsage ?? this.streamUsage; diff --git a/libs/langchain-openai/src/tests/chat_models.int.test.ts b/libs/langchain-openai/src/tests/chat_models.int.test.ts index c2588312895c..2fbc484789db 100644 --- a/libs/langchain-openai/src/tests/chat_models.int.test.ts +++ b/libs/langchain-openai/src/tests/chat_models.int.test.ts @@ -1165,7 +1165,7 @@ describe("Audio output", () => { }); }); -test("Can stream o1 requests", async () => { +test("Can stream o1-mini requests", async () => { const model = new ChatOpenAI({ model: "o1-mini", }); @@ -1192,6 +1192,33 @@ test("Can stream o1 requests", async () => { expect(numChunks).toBeGreaterThan(3); }); +test("Doesn't stream o1 requests", async () => { + const model = new ChatOpenAI({ + model: "o1", + }); + const stream = await model.stream( + "Write me a very simple hello world program in Python. Ensure it is wrapped in a function called 'hello_world' and has descriptive comments." + ); + let finalMsg: AIMessageChunk | undefined; + let numChunks = 0; + for await (const chunk of stream) { + finalMsg = finalMsg ? concat(finalMsg, chunk) : chunk; + numChunks += 1; + } + + expect(finalMsg).toBeTruthy(); + if (!finalMsg) { + throw new Error("No final message found"); + } + if (typeof finalMsg.content === "string") { + expect(finalMsg.content.length).toBeGreaterThan(10); + } else { + expect(finalMsg.content.length).toBeGreaterThanOrEqual(1); + } + + expect(numChunks).toBe(1); +}); + test("Allows developer messages with o1", async () => { const model = new ChatOpenAI({ model: "o1",