Skip to content

RunnableWithFallbacks doesn't implement streaming support #6026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
5 tasks done
tomi-bigpi opened this issue Jul 10, 2024 · 8 comments · Fixed by #6487
Closed
5 tasks done

RunnableWithFallbacks doesn't implement streaming support #6026

tomi-bigpi opened this issue Jul 10, 2024 · 8 comments · Fixed by #6487
Assignees
Labels
auto:bug Related to a bug, vulnerability, unexpected error with an existing feature

Comments

@tomi-bigpi
Copy link

Checked other resources

  • I added a very descriptive title to this issue.
  • I searched the LangChain.js documentation with the integrated search.
  • I used the GitHub search to find a similar question and didn't find it.
  • I am sure that this is a bug in LangChain.js rather than my code.
  • The bug is not resolved by updating to the latest stable version of LangChain (or the specific integration package).

Example Code

Sample code similar to the following should stream the results:

import { BedrockChat } from "@langchain/community/chat_models/bedrock";
import { concat } from "@langchain/core/utils/stream";
import type { AIMessageChunk } from "@langchain/core/messages";

function createModelWithFallbacks() {
  const primaryModel = new BedrockChat({
    model: "anthropic.claude-3-5-sonnet-20240620-v1:0",
    region: "us-east-1",
  });

   const secondaryModel = new BedrockChat({
     model: "anthropic.claude-3-haiku-20240307-v1:0",
     region: "us-east-1",
   });

  // Using `withFallbacks` causes streaming to not work
  return primaryModel.withFallbacks({
    fallbacks: [
      secondaryModel,
    ],
  });
}

export async function testStream() {
  const model = createModelWithFallbacks();

  // This will result in a call to `.invoke()` instead
  const stream = await model.stream("What color is the sky on Mars?");

  let gathered: AIMessageChunk | undefined = undefined;

  for await (const chunk of stream) {
    console.log(chunk);
    if (gathered === undefined) {
      gathered = chunk;
    } else {
      gathered = concat(gathered, chunk);
    }
  }

  console.log(gathered);
}

testStream();

However, currently invoke will be called on the model(s) and no streaming happens.

A sample patch to resolves the issue:

diff --git a/dist/runnables/base.cjs b/dist/runnables/base.cjs
index 2276cfd01b0ad7072094f634bbe7ec115189f20e..cf87a60ff48ceb57b6bcc0abfbf9e7db27a9c5f5 100644
--- a/dist/runnables/base.cjs
+++ b/dist/runnables/base.cjs
@@ -1699,6 +1699,73 @@ class RunnableWithFallbacks extends Runnable {
         await runManager?.handleChainError(firstError);
         throw firstError;
     }
+    async *_streamIterator(input, options) {
+        const callbackManager_ = await manager_js_1.CallbackManager.configure(options?.callbacks, undefined, options?.tags, undefined, options?.metadata);
+        const { runId, ...otherOptions } = options ?? {};
+        const runManager = await callbackManager_?.handleChainStart(this.toJSON(), _coerceToDict(input, "input"), runId, undefined, undefined, undefined, otherOptions?.runName);
+
+        let chunk;
+        let firstError = undefined;
+        let stream = undefined;
+        let concatSupported = false;
+        let finalOutput = undefined;
+        for await (const runnable of this.runnables()) {
+            concatSupported = true;
+            try {
+                stream = await runnable.stream(input, config_js_1.patchConfig(otherOptions, { callbacks: runManager?.getChild() }));
+                chunk = stream.next().value;
+
+                // Clear the error since we successfully got a chunk from this runnable
+                firstError = undefined;
+            } catch (e) {
+                if (firstError === undefined) {
+                    firstError = e;
+                }
+            }
+
+            // Check if we successfully got a chunk from this runnable
+            if (!firstError) {
+                // Use the stream from this runnable for the rest of the chunks
+                break;
+            }
+        }
+
+        if (firstError) {
+            await runManager?.handleChainError(firstError);
+            throw firstError;
+        }
+
+        if (chunk) {
+            yield chunk;
+        }
+
+        if (stream) {
+            try {
+                for await (chunk of stream) {
+                    if (concatSupported) {
+                        if (finalOutput === undefined) {
+                            finalOutput = chunk;
+                        } else {
+                            try {
+                                // eslint-disable-next-line @typescript-eslint/no-explicit-any
+                                finalOutput = concat(finalOutput, chunk);
+                            }
+                            catch (e) {
+                                finalOutput = undefined;
+                                concatSupported = false;
+                            }
+                        }
+                    }
+                    yield chunk;
+                }
+            } catch (e) {
+                await runManager?.handleChainError(e);
+                throw e;
+            }
+        }
+
+        await runManager?.handleChainEnd(finalOutput);
+    }
     async batch(inputs, options, batchOptions) {
         if (batchOptions?.returnExceptions) {
             throw new Error("Not implemented.");
diff --git a/dist/runnables/base.js b/dist/runnables/base.js
index d60358b742c78df8cb1edc7d3a872a1c383bc3ed..76e5ab8fcafbb5ffe134929411c09300970321b6 100644
--- a/dist/runnables/base.js
+++ b/dist/runnables/base.js
@@ -1683,6 +1683,73 @@ export class RunnableWithFallbacks extends Runnable {
         await runManager?.handleChainError(firstError);
         throw firstError;
     }
+    async *_streamIterator(input, options) {
+        const callbackManager_ = await CallbackManager.configure(options?.callbacks, undefined, options?.tags, undefined, options?.metadata);
+        const { runId, ...otherOptions } = options ?? {};
+        const runManager = await callbackManager_?.handleChainStart(this.toJSON(), _coerceToDict(input, "input"), runId, undefined, undefined, undefined, otherOptions?.runName);
+
+        let chunk;
+        let firstError = undefined;
+        let stream = undefined;
+        let concatSupported = false;
+        let finalOutput = undefined;
+        for await (const runnable of this.runnables()) {
+            concatSupported = true;
+            try {
+                stream = await runnable.stream(input, patchConfig(otherOptions, { callbacks: runManager?.getChild() }));
+                chunk = stream.next().value;
+
+                // Clear the error since we successfully got a chunk from this runnable
+                firstError = undefined;
+            } catch (e) {
+                if (firstError === undefined) {
+                    firstError = e;
+                }
+            }
+
+            // Check if we successfully got a chunk from this runnable
+            if (!firstError) {
+                // Use the stream from this runnable for the rest of the chunks
+                break;
+            }
+        }
+
+        if (firstError) {
+            await runManager?.handleChainError(firstError);
+            throw firstError;
+        }
+
+        if (chunk) {
+            yield chunk;
+        }
+
+        if (stream) {
+            try {
+                for await (chunk of stream) {
+                    if (concatSupported) {
+                        if (finalOutput === undefined) {
+                            finalOutput = chunk;
+                        } else {
+                            try {
+                                // eslint-disable-next-line @typescript-eslint/no-explicit-any
+                                finalOutput = concat(finalOutput, chunk);
+                            }
+                            catch (e) {
+                                finalOutput = undefined;
+                                concatSupported = false;
+                            }
+                        }
+                    }
+                    yield chunk;
+                }
+            } catch (e) {
+                await runManager?.handleChainError(e);
+                throw e;
+            }
+        }
+
+        await runManager?.handleChainEnd(finalOutput);
+    }
     async batch(inputs, options, batchOptions) {
         if (batchOptions?.returnExceptions) {
             throw new Error("Not implemented.");

Error Message and Stack Trace (if applicable)

No response

Description

(Modified description of the situation based on dosubot's description from a similar issue:)
The stream method implementation provided in the base class (from which other runnables inherit) is designed to work with individual Runnable instances. It leverages an asynchronous generator (_streamIterator) to yield results as they become available.

The RunnableWithFallbacks and similar constructs do not override or provide their own implementations of the stream method that would allow them to handle streaming appropriately across all constituent runnables. Without such an implementation, attempting to stream through a RunnableWithFallbacks defaults to the base class's streaming behavior, which does not implement streaming/call streaming on the actual primary or fallback models.

We did a quick local patch of the built files. I've added the patch code above even though it's not the best way to fix the issue. The patch code is just for reference.

System Info

> rush-pnpm show langchain

[email protected] | MIT | deps: 16 | versions: 281
Typescript bindings for langchain
https://github.com/langchain-ai/langchainjs/tree/main/langchain/

keywords: llm, ai, gpt3, chain, prompt, prompt engineering, chatgpt, machine learning, ml, openai, embeddings, vectorstores

dist
.tarball: https://registry.npmjs.org/langchain/-/langchain-0.2.9.tgz
.shasum: 1341bdd7166f4f6da0b9337f363e409a79523dbb
.integrity: sha512-iZ0l7BDVfoifqZlDl1gy3JP5mIdhYjWiToPlDnlmfHD748cw3okvF0gZo0ruT4nbftnQcaM7JzPUiNC43UPfgg==
.unpackedSize: 4.0 MB

dependencies:
@langchain/core: >=0.2.11 <0.3.0  @langchain/textsplitters: ~0.0.0  js-tiktoken: ^1.0.12              jsonpointer: ^5.0.1               langsmith: ~0.1.30                openapi-types: ^12.1.3            uuid: ^10.0.0                     zod-to-json-schema: ^3.22.3
@langchain/openai: >=0.1.0 <0.3.0 binary-extensions: ^2.2.0         js-yaml: ^4.1.0                   langchainhub: ~0.0.8              ml-distance: ^4.0.0               p-retry: 4                        yaml: ^2.2.1                      zod: ^3.22.4

maintainers:
- nfcampos <[email protected]>
- jacoblee93 <[email protected]>
- andrewnguonly <[email protected]>
- davidduong <[email protected]>
- hwchase17 <[email protected]>
- basproul <[email protected]>

dist-tags:
latest: 0.2.9     next: 0.2.3-rc.0

published 13 hours ago by jacoblee93 <[email protected]>

Platform:
mac

> node --version
Node v20.11.1
> rush-pnpm --version
pnpm 8.15.4
@dosubot dosubot bot added the auto:bug Related to a bug, vulnerability, unexpected error with an existing feature label Jul 10, 2024
Copy link

dosubot bot commented Jul 10, 2024

Hey there, @tomi-bigpi! I'm here to help you out with any bugs, questions, or contributions you have in mind. Let's tackle this together!

To implement streaming support for the RunnableWithFallbacks class in LangChain.js, you need to override the stream method to handle the streaming logic, including the fallback mechanism. Here is how you can do it:

  1. Override the stream method to handle the primary runnable and fallbacks.
  2. Ensure that the method yields chunks from the primary runnable and, if it fails, from the fallback runnables.

Here is the implementation:

export class RunnableWithFallbacks<RunInput, RunOutput> extends Runnable<
  RunInput,
  RunOutput
> {
  static lc_name() {
    return "RunnableWithFallbacks";
  }

  lc_namespace = ["langchain_core", "runnables"];

  lc_serializable = true;

  runnable: Runnable<RunInput, RunOutput>;

  fallbacks: Runnable<RunInput, RunOutput>[];

  constructor(fields: {
    runnable: Runnable<RunInput, RunOutput>;
    fallbacks: Runnable<RunInput, RunOutput>[];
  }) {
    super(fields);
    this.runnable = fields.runnable;
    this.fallbacks = fields.fallbacks;
  }

  *runnables() {
    yield this.runnable;
    for (const fallback of this.fallbacks) {
      yield fallback;
    }
  }

  async *stream(
    input: RunInput,
    options?: Partial<RunnableConfig>
  ): AsyncGenerator<RunOutput> {
    const runnables = this.runnables();
    for (const runnable of runnables) {
      try {
        const stream = await runnable.stream(input, options);
        for await (const chunk of stream) {
          yield chunk;
        }
        return; // Exit if the primary runnable succeeds
      } catch (e) {
        // Log the error and continue to the next fallback
        console.error(`Runnable failed: ${e.message}`);
      }
    }
    throw new Error("All runnables failed.");
  }
}

This implementation ensures that the RunnableWithFallbacks class attempts to stream output from the primary runnable first and falls back to the other runnables if the primary one fails [1].

To continue talking to Dosu, mention @dosu.

@tomi-bigpi
Copy link
Author

@dosu Seems like the streaming support should be implemented within Langchain, not by consumers. Please forward to a human for review.

Copy link

dosubot bot commented Jul 10, 2024

@jacoblee93, this one's a bit out of my wheelhouse, could you step in and provide your guidance?

To continue talking to Dosu, mention @dosu.

@jacoblee93
Copy link
Collaborator

So this will basically check if pulling the initial chunk throws an error, and falls back if it does? Otherwise, if there's a streaming error mid-stream, it'll throw?

@tomi-bigpi
Copy link
Author

tomi-bigpi commented Jul 10, 2024

Yes, that code is very similar to the Python implementation (which should be used as the basis for a real implementation for this - this version was a bit cobbled together and run through LLM + modeled after other methods in the Runnable* classes). The main idea is to keep falling back as long as there are fallback models available, and then actually throw if the last one fails.

@jacoblee93
Copy link
Collaborator

Seems reasonable - if you're up for making it a PR would appreciate that otherwise I can when I get a minute.

@jacoblee93 jacoblee93 self-assigned this Jul 10, 2024
@jsegeren
Copy link

jsegeren commented Aug 8, 2024

Just to comment - we would greatly value this fix/improvement as well.

@jacoblee93
Copy link
Collaborator

Live in [email protected]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
auto:bug Related to a bug, vulnerability, unexpected error with an existing feature
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants