diff --git a/langchain-core/.eslintrc.cjs b/langchain-core/.eslintrc.cjs index 7743f29e2aae..29b23f23cd0a 100644 --- a/langchain-core/.eslintrc.cjs +++ b/langchain-core/.eslintrc.cjs @@ -69,6 +69,7 @@ module.exports = { "new-cap": ["error", { properties: false, capIsNew: false }], 'jest/no-focused-tests': 'error', "arrow-body-style": 0, + "prefer-destructuring": 0, }, overrides: [ { diff --git a/langchain-core/src/chat_history.ts b/langchain-core/src/chat_history.ts index f0b3b0368ec7..3219a7c8d829 100644 --- a/langchain-core/src/chat_history.ts +++ b/langchain-core/src/chat_history.ts @@ -16,6 +16,20 @@ export abstract class BaseChatMessageHistory extends Serializable { public abstract addAIChatMessage(message: string): Promise; + /** + * Add a list of messages. + * + * Implementations should override this method to handle bulk addition of messages + * in an efficient manner to avoid unnecessary round-trips to the underlying store. + * + * @param messages - A list of BaseMessage objects to store. + */ + public async addMessages(messages: BaseMessage[]): Promise { + for (const message of messages) { + await this.addMessage(message); + } + } + public abstract clear(): Promise; } diff --git a/langchain-core/src/runnables/history.ts b/langchain-core/src/runnables/history.ts index 68a649973e2c..ae42997b9543 100644 --- a/langchain-core/src/runnables/history.ts +++ b/langchain-core/src/runnables/history.ts @@ -149,14 +149,45 @@ export class RunnableWithMessageHistory< } _getInputMessages( - inputValue: string | BaseMessage | Array + // eslint-disable-next-line @typescript-eslint/no-explicit-any + inputValue: string | BaseMessage | Array | Record ): Array { - if (typeof inputValue === "string") { - return [new HumanMessage(inputValue)]; - } else if (Array.isArray(inputValue)) { - return inputValue; + let parsedInputValue; + if ( + typeof inputValue === "object" && + !Array.isArray(inputValue) && + !isBaseMessage(inputValue) + ) { + let key; + if (this.inputMessagesKey) { + key = this.inputMessagesKey; + } else if (Object.keys(inputValue).length === 1) { + key = Object.keys(inputValue)[0]; + } else { + key = "input"; + } + if (Array.isArray(inputValue[key]) && Array.isArray(inputValue[key][0])) { + parsedInputValue = inputValue[key][0]; + } else { + parsedInputValue = inputValue[key]; + } } else { - return [inputValue]; + parsedInputValue = inputValue; + } + if (typeof parsedInputValue === "string") { + return [new HumanMessage(parsedInputValue)]; + } else if (Array.isArray(parsedInputValue)) { + return parsedInputValue; + } else if (isBaseMessage(parsedInputValue)) { + return [parsedInputValue]; + } else { + throw new Error( + `Expected a string, BaseMessage, or array of BaseMessages.\nGot ${JSON.stringify( + parsedInputValue, + null, + 2 + )}` + ); } } @@ -164,29 +195,46 @@ export class RunnableWithMessageHistory< // eslint-disable-next-line @typescript-eslint/no-explicit-any outputValue: string | BaseMessage | Array | Record ): Array { - let newOutputValue = outputValue; + let parsedOutputValue; if ( !Array.isArray(outputValue) && !isBaseMessage(outputValue) && typeof outputValue !== "string" ) { - newOutputValue = outputValue[this.outputMessagesKey ?? "output"]; + let key; + if (this.outputMessagesKey !== undefined) { + key = this.outputMessagesKey; + } else if (Object.keys(outputValue).length === 1) { + key = Object.keys(outputValue)[0]; + } else { + key = "output"; + } + // If you are wrapping a chat model directly + // The output is actually this weird generations object + if (outputValue.generations !== undefined) { + parsedOutputValue = outputValue.generations[0][0].message; + } else { + parsedOutputValue = outputValue[key]; + } + } else { + parsedOutputValue = outputValue; } - if (typeof newOutputValue === "string") { - return [new AIMessage(newOutputValue)]; - } else if (Array.isArray(newOutputValue)) { - return newOutputValue; - } else if (isBaseMessage(newOutputValue)) { - return [newOutputValue]; + if (typeof parsedOutputValue === "string") { + return [new AIMessage(parsedOutputValue)]; + } else if (Array.isArray(parsedOutputValue)) { + return parsedOutputValue; + } else if (isBaseMessage(parsedOutputValue)) { + return [parsedOutputValue]; + } else { + throw new Error( + `Expected a string, BaseMessage, or array of BaseMessages. Received: ${JSON.stringify( + parsedOutputValue, + null, + 2 + )}` + ); } - throw new Error( - `Expected a string, BaseMessage, or array of BaseMessages. Received: ${JSON.stringify( - newOutputValue, - null, - 2 - )}` - ); } async _enterHistory( @@ -195,29 +243,31 @@ export class RunnableWithMessageHistory< kwargs?: { config?: RunnableConfig } ): Promise { const history = kwargs?.config?.configurable?.messageHistory; - - if (this.historyMessagesKey) { - return history.getMessages(); + const messages = await history.getMessages(); + if (this.historyMessagesKey === undefined) { + return messages.concat(this._getInputMessages(input)); } - - const inputVal = - input || - (this.inputMessagesKey ? input[this.inputMessagesKey] : undefined); - const historyMessages = history ? await history.getMessages() : []; - const returnType = [ - ...historyMessages, - ...this._getInputMessages(inputVal), - ]; - return returnType; + return messages; } async _exitHistory(run: Run, config: RunnableConfig): Promise { const history = config.configurable?.messageHistory; // Get input messages - const { inputs } = run; - const inputValue = inputs[this.inputMessagesKey ?? "input"]; - const inputMessages = this._getInputMessages(inputValue); + let inputs; + // Chat model inputs are nested arrays + if (Array.isArray(run.inputs) && Array.isArray(run.inputs[0])) { + inputs = run.inputs[0]; + } else { + inputs = run.inputs; + } + let inputMessages = this._getInputMessages(inputs); + // If historic messages were prepended to the input messages, remove them to + // avoid adding duplicate messages to history. + if (this.historyMessagesKey === undefined) { + const existingMessages = await history.getMessages(); + inputMessages = inputMessages.slice(existingMessages.length); + } // Get output messages const outputValue = run.outputs; if (!outputValue) { @@ -230,10 +280,7 @@ export class RunnableWithMessageHistory< ); } const outputMessages = this._getOutputMessages(outputValue); - - for await (const message of [...inputMessages, ...outputMessages]) { - await history.addMessage(message); - } + await history.addMessages([...inputMessages, ...outputMessages]); } async _mergeConfig(...configs: Array) { diff --git a/langchain-core/src/runnables/tests/runnable_history.test.ts b/langchain-core/src/runnables/tests/runnable_history.test.ts index 1d1b0f45e421..494cd53b2aa4 100644 --- a/langchain-core/src/runnables/tests/runnable_history.test.ts +++ b/langchain-core/src/runnables/tests/runnable_history.test.ts @@ -1,4 +1,9 @@ -import { BaseMessage, HumanMessage } from "../../messages/index.js"; +import { + AIMessage, + AIMessageChunk, + BaseMessage, + HumanMessage, +} from "../../messages/index.js"; import { RunnableLambda } from "../base.js"; import { RunnableConfig } from "../config.js"; import { RunnableWithMessageHistory } from "../history.js"; @@ -10,6 +15,7 @@ import { FakeChatMessageHistory, FakeLLM, FakeListChatMessageHistory, + FakeListChatModel, FakeStreamingLLM, } from "../../utils/testing/index.js"; import { ChatPromptTemplate, MessagesPlaceholder } from "../../prompts/chat.js"; @@ -73,6 +79,79 @@ test("Runnable with message history", async () => { expect(output).toBe("you said: hello\ngood bye"); }); +test("Runnable with message history with a chat model", async () => { + const runnable = new FakeListChatModel({ + responses: ["Hello world!"], + }); + + const getMessageHistory = await getGetSessionHistory(); + const withHistory = new RunnableWithMessageHistory({ + runnable, + config: {}, + getMessageHistory, + }); + const config: RunnableConfig = { configurable: { sessionId: "2" } }; + const output = await withHistory.invoke([new HumanMessage("hello")], config); + expect(output.content).toBe("Hello world!"); + const stream = await withHistory.stream( + [new HumanMessage("good bye")], + config + ); + const chunks = []; + for await (const chunk of stream) { + console.log(chunk); + chunks.push(chunk); + } + expect(chunks.map((chunk) => chunk.content).join("")).toEqual("Hello world!"); + const sessionHistory = await getMessageHistory("2"); + expect(await sessionHistory.getMessages()).toEqual([ + new HumanMessage("hello"), + new AIMessage("Hello world!"), + new HumanMessage("good bye"), + new AIMessageChunk("Hello world!"), + ]); +}); + +test("Runnable with message history with a messages in, messages out chain", async () => { + const prompt = ChatPromptTemplate.fromMessages([ + ["system", "you are a robot"], + ["placeholder", "{messages}"], + ]); + const model = new FakeListChatModel({ + responses: ["So long and thanks for the fish!!"], + }); + const runnable = prompt.pipe(model); + + const getMessageHistory = await getGetSessionHistory(); + const withHistory = new RunnableWithMessageHistory({ + runnable, + config: {}, + getMessageHistory, + }); + const config: RunnableConfig = { configurable: { sessionId: "2" } }; + const output = await withHistory.invoke([new HumanMessage("hello")], config); + expect(output.content).toBe("So long and thanks for the fish!!"); + const stream = await withHistory.stream( + [new HumanMessage("good bye")], + config + ); + const chunks = []; + for await (const chunk of stream) { + console.log(chunk); + chunks.push(chunk); + } + expect(chunks.map((chunk) => chunk.content).join("")).toEqual( + "So long and thanks for the fish!!" + ); + const sessionHistory = await getMessageHistory("2"); + expect(await sessionHistory.getMessages()).toEqual([ + new HumanMessage("hello"), + new AIMessage("So long and thanks for the fish!!"), + new HumanMessage("good bye"), + new AIMessageChunk("So long and thanks for the fish!!"), + ]); +}); + test("Runnable with message history work with chat list memory", async () => { const runnable = new RunnableLambda({ func: (messages: BaseMessage[]) => @@ -88,7 +167,7 @@ test("Runnable with message history work with chat list memory", async () => { config: {}, getMessageHistory: getListMessageHistory, }); - const config: RunnableConfig = { configurable: { sessionId: "1" } }; + const config: RunnableConfig = { configurable: { sessionId: "3" } }; let output = await withHistory.invoke([new HumanMessage("hello")], config); expect(output).toBe("you said: hello"); output = await withHistory.invoke([new HumanMessage("good bye")], config); @@ -112,7 +191,7 @@ test("Runnable with message history and RunnableSequence", async () => { inputMessagesKey: "input", historyMessagesKey: "history", }); - const config: RunnableConfig = { configurable: { sessionId: "1" } }; + const config: RunnableConfig = { configurable: { sessionId: "4" } }; let output = await withHistory.invoke({ input: "hello" }, config); expect(output).toBe("AI: You are a helpful assistant\nHuman: hello"); output = await withHistory.invoke({ input: "good bye" }, config); @@ -140,7 +219,7 @@ test("Runnable with message history should stream through", async () => { inputMessagesKey: "input", historyMessagesKey: "history", }).pipe(new StringOutputParser()); - const config: RunnableConfig = { configurable: { sessionId: "1" } }; + const config: RunnableConfig = { configurable: { sessionId: "5" } }; const stream = await withHistory.stream({ input: "hello" }, config); const chunks = []; for await (const chunk of stream) {