Skip to content

feat(LangChainAgent): stream_response #589

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 16, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/rai_core/rai/agents/langchain/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class LangChainAgent(BaseAgent):
Dict of target_name: connector. Agent will send it's output to these targets using connectors.
runnable : Runnable
LangChain runnable that will be used to generate output.
stream_response : bool, optional
If True, the agent will stream the response to the target connectors. Make sure that the runnable is configured to stream.
state : BaseState | None, optional
State to seed the LangChain runnable. If None - empty state is used.
new_message_behavior : newMessageBehaviorType, optional
Expand Down Expand Up @@ -107,20 +109,23 @@ def __init__(
self,
target_connectors: Dict[str, HRIConnector[HRIMessage]],
runnable: Runnable[Any, Any],
stream_response: bool,
state: BaseState | None = None,
new_message_behavior: newMessageBehaviorType = "interrupt_keep_last",
max_size: int = 100,
):
super().__init__()
self.logger = logging.getLogger(__name__)
self.agent = runnable
self.stream_response = stream_response
self.new_message_behavior: newMessageBehaviorType = new_message_behavior
self.tracing_callbacks = get_tracing_callbacks()
self.state = state or ReActAgentState(messages=[])
self._langchain_callback = HRICallbackHandler(
connectors=target_connectors,
aggregate_chunks=True,
logger=self.logger,
stream_response=stream_response,
)

self._received_messages: Deque[HRIMessage] = deque()
Expand Down Expand Up @@ -182,8 +187,8 @@ def _interrupt_agent_and_run(self):
def _run_agent(self):
if len(self._received_messages) == 0:
self._agent_ready_event.set()
self.logger.info("Waiting for messages...")
time.sleep(0.5)
self.logger.debug("Waiting for messages...")
time.sleep(0.1)
return
self._agent_ready_event.clear()
try:
Expand Down
39 changes: 36 additions & 3 deletions src/rai_core/rai/agents/langchain/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
import threading
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
from uuid import UUID

from langchain_core.callbacks import BaseCallbackHandler
Expand All @@ -32,16 +32,20 @@ def __init__(
splitting_chars: Optional[List[str]] = None,
max_buffer_size: int = 200,
logger: Optional[logging.Logger] = None,
stream_response: bool = True,
):
self.connectors = connectors
self.aggregate_chunks = aggregate_chunks
self.stream_response = stream_response
self.splitting_chars = splitting_chars or ["\n", ".", "!", "?"]
self.chunks_buffer = ""
self.max_buffer_size = max_buffer_size
self._buffer_lock = threading.Lock()
self.logger = logger or logging.getLogger(__name__)
self.current_conversation_id = None
self.current_chunk_id = 0
self.working = False
self.hit_on_llm_new_token = False

def _should_split(self, token: str) -> bool:
return token in self.splitting_chars
Expand All @@ -63,8 +67,22 @@ def _send_all_targets(self, tokens: str, done: bool = False):
f"Failed to send {len(tokens)} tokens to hri_connector: {e}"
)

def on_llm_start(
self,
serialized: dict[str, Any],
prompts: list[str],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
self.working = True

def on_llm_new_token(self, token: str, *, run_id: UUID, **kwargs):
if token == "":
self.hit_on_llm_new_token = True
if token == "" or not self.stream_response:
return
if self.current_conversation_id != str(run_id):
self.current_conversation_id = str(run_id)
Expand Down Expand Up @@ -93,7 +111,22 @@ def on_llm_end(
**kwargs,
):
self.current_conversation_id = str(run_id)
if self.aggregate_chunks and self.chunks_buffer:
if self.stream_response and not self.hit_on_llm_new_token:
self.logger.error(
(
"No tokens were sent to the callback handler. "
"LLM did not stream response. "
"Is your BaseChatModel configured to stream? "
"Sending generated text as a single message."
)
)
msg = response.generations[0][0].message
self._send_all_targets(msg.content, done=True)
elif not self.stream_response:
msg = response.generations[0][0].message
self._send_all_targets(msg.content, done=True)
elif self.aggregate_chunks and self.chunks_buffer:
with self._buffer_lock:
self._send_all_targets(self.chunks_buffer, done=True)
self.chunks_buffer = ""
self.working = False
2 changes: 2 additions & 0 deletions src/rai_core/rai/agents/langchain/react_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
tools: Optional[List[BaseTool]] = None,
state: Optional[ReActAgentState] = None,
system_prompt: Optional[str | SystemMultimodalMessage] = None,
stream_response: bool = True,
):
runnable = create_react_runnable(
llm=llm, tools=tools, system_prompt=system_prompt
Expand All @@ -42,4 +43,5 @@ def __init__(
target_connectors=target_connectors,
runnable=runnable,
state=state,
stream_response=stream_response,
)