Skip to content

Add Detailed Streaming to the Tool Agent #4155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions packages/components/nodes/agents/ToolAgent/ToolAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import {
IUsedTool,
IVisionChatModal
} from '../../../src/Interface'
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler'
import { ConsoleCallbackHandler, CustomChainHandler, CustomStreamingHandler, additionalCallbacks } from '../../../src/handler'
import { AgentExecutor, ToolCallingAgentOutputParser } from '../../../src/agents'
import { Moderation, checkInputs, streamResponse } from '../../moderation/Moderation'
import { formatResponse } from '../../outputparsers/OutputParserHelpers'
Expand Down Expand Up @@ -101,6 +101,15 @@ class ToolAgent_Agents implements INode {
type: 'number',
optional: true,
additionalParams: true
},
{
label: 'Enable Detailed Streaming',
name: 'enableDetailedStreaming',
type: 'boolean',
default: false,
description: 'Stream detailed intermediate steps during agent execution',
optional: true,
additionalParams: true
}
]
this.sessionId = fields?.sessionId
Expand All @@ -113,6 +122,7 @@ class ToolAgent_Agents implements INode {
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> {
const memory = nodeData.inputs?.memory as FlowiseMemory
const moderations = nodeData.inputs?.inputModeration as Moderation[]
const enableDetailedStreaming = nodeData.inputs?.enableDetailedStreaming as boolean

const shouldStreamResponse = options.shouldStreamResponse
const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer
Expand All @@ -136,14 +146,28 @@ class ToolAgent_Agents implements INode {
const loggerHandler = new ConsoleCallbackHandler(options.logger)
const callbacks = await additionalCallbacks(nodeData, options)

// Add custom streaming handler if detailed streaming is enabled
let customStreamingHandler = null

if (enableDetailedStreaming && shouldStreamResponse) {
customStreamingHandler = new CustomStreamingHandler(sseStreamer, chatId)
}

let res: ChainValues = {}
let sourceDocuments: ICommonObject[] = []
let usedTools: IUsedTool[] = []
let artifacts = []

if (shouldStreamResponse) {
const handler = new CustomChainHandler(sseStreamer, chatId)
res = await executor.invoke({ input }, { callbacks: [loggerHandler, handler, ...callbacks] })
const allCallbacks = [loggerHandler, handler, ...callbacks]

// Add detailed streaming handler if enabled
if (enableDetailedStreaming && customStreamingHandler) {
allCallbacks.push(customStreamingHandler)
}

res = await executor.invoke({ input }, { callbacks: allCallbacks })
if (res.sourceDocuments) {
if (sseStreamer) {
sseStreamer.streamSourceDocumentsEvent(chatId, flatten(res.sourceDocuments))
Expand Down Expand Up @@ -174,7 +198,14 @@ class ToolAgent_Agents implements INode {
}
}
} else {
res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] })
const allCallbacks = [loggerHandler, ...callbacks]

// Add detailed streaming handler if enabled
if (enableDetailedStreaming && customStreamingHandler) {
allCallbacks.push(customStreamingHandler)
}

res = await executor.invoke({ input }, { callbacks: allCallbacks })
if (res.sourceDocuments) {
sourceDocuments = res.sourceDocuments
}
Expand Down
84 changes: 84 additions & 0 deletions packages/components/src/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import { LangWatch, LangWatchSpan, LangWatchTrace, autoconvertTypedValues } from
import { DataSource } from 'typeorm'
import { ChatGenerationChunk } from '@langchain/core/outputs'
import { AIMessageChunk } from '@langchain/core/messages'
import { Serialized } from '@langchain/core/load/serializable'

interface AgentRun extends Run {
actions: AgentAction[]
Expand Down Expand Up @@ -1499,3 +1500,86 @@ export class AnalyticHandler {
}
}
}

/**
* Custom callback handler for streaming detailed intermediate information
* during agent execution, specifically tool invocation inputs and outputs.
*/
export class CustomStreamingHandler extends BaseCallbackHandler {
name = 'custom_streaming_handler'

private sseStreamer: IServerSideEventStreamer
private chatId: string

constructor(sseStreamer: IServerSideEventStreamer, chatId: string) {
super()
this.sseStreamer = sseStreamer
this.chatId = chatId
}

/**
* Handle the start of a tool invocation
*/
async handleToolStart(tool: Serialized, input: string, runId: string, parentRunId?: string): Promise<void> {
if (!this.sseStreamer) return

const toolName = typeof tool === 'object' && tool.name ? tool.name : 'unknown-tool'
const toolInput = typeof input === 'string' ? input : JSON.stringify(input, null, 2)

// Stream the tool invocation details using the agent_trace event type for consistency
this.sseStreamer.streamCustomEvent(this.chatId, 'agent_trace', {
step: 'tool_start',
name: toolName,
input: toolInput,
runId,
parentRunId: parentRunId || null
})
}

/**
* Handle the end of a tool invocation
*/
async handleToolEnd(output: string | object, runId: string, parentRunId?: string): Promise<void> {
if (!this.sseStreamer) return

const toolOutput = typeof output === 'string' ? output : JSON.stringify(output, null, 2)

// Stream the tool output details using the agent_trace event type for consistency
this.sseStreamer.streamCustomEvent(this.chatId, 'agent_trace', {
step: 'tool_end',
output: toolOutput,
runId,
parentRunId: parentRunId || null
})
}

/**
* Handle tool errors
*/
async handleToolError(error: Error, runId: string, parentRunId?: string): Promise<void> {
if (!this.sseStreamer) return

// Stream the tool error details using the agent_trace event type for consistency
this.sseStreamer.streamCustomEvent(this.chatId, 'agent_trace', {
step: 'tool_error',
error: error.message,
runId,
parentRunId: parentRunId || null
})
}

/**
* Handle agent actions
*/
async handleAgentAction(action: AgentAction, runId: string, parentRunId?: string): Promise<void> {
if (!this.sseStreamer) return

// Stream the agent action details using the agent_trace event type for consistency
this.sseStreamer.streamCustomEvent(this.chatId, 'agent_trace', {
step: 'agent_action',
action: JSON.stringify(action),
runId,
parentRunId: parentRunId || null
})
}
}