diff --git a/src/rai_core/rai/agents/langchain/agent.py b/src/rai_core/rai/agents/langchain/agent.py index e3033b193..e67c309d7 100644 --- a/src/rai_core/rai/agents/langchain/agent.py +++ b/src/rai_core/rai/agents/langchain/agent.py @@ -53,6 +53,8 @@ class LangChainAgent(BaseAgent): Dict of target_name: connector. Agent will send it's output to these targets using connectors. runnable : Runnable LangChain runnable that will be used to generate output. + stream_response : bool, optional + If True, the agent will stream the response to the target connectors. Make sure that the runnable is configured to stream. state : BaseState | None, optional State to seed the LangChain runnable. If None - empty state is used. new_message_behavior : newMessageBehaviorType, optional @@ -107,6 +109,7 @@ def __init__( self, target_connectors: Dict[str, HRIConnector[HRIMessage]], runnable: Runnable[Any, Any], + stream_response: bool = True, state: BaseState | None = None, new_message_behavior: newMessageBehaviorType = "interrupt_keep_last", max_size: int = 100, @@ -114,6 +117,7 @@ def __init__( super().__init__() self.logger = logging.getLogger(__name__) self.agent = runnable + self.stream_response = stream_response self.new_message_behavior: newMessageBehaviorType = new_message_behavior self.tracing_callbacks = get_tracing_callbacks() self.state = state or ReActAgentState(messages=[]) @@ -121,6 +125,7 @@ def __init__( connectors=target_connectors, aggregate_chunks=True, logger=self.logger, + stream_response=stream_response, ) self._received_messages: Deque[HRIMessage] = deque() @@ -182,8 +187,8 @@ def _interrupt_agent_and_run(self): def _run_agent(self): if len(self._received_messages) == 0: self._agent_ready_event.set() - self.logger.info("Waiting for messages...") - time.sleep(0.5) + self.logger.debug("Waiting for messages...") + time.sleep(0.1) return self._agent_ready_event.clear() try: diff --git a/src/rai_core/rai/agents/langchain/callback.py b/src/rai_core/rai/agents/langchain/callback.py index 5b4be684e..4978598ce 100644 --- a/src/rai_core/rai/agents/langchain/callback.py +++ b/src/rai_core/rai/agents/langchain/callback.py @@ -14,7 +14,7 @@ import logging import threading -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from uuid import UUID from langchain_core.callbacks import BaseCallbackHandler @@ -32,9 +32,11 @@ def __init__( splitting_chars: Optional[List[str]] = None, max_buffer_size: int = 200, logger: Optional[logging.Logger] = None, + stream_response: bool = True, ): self.connectors = connectors self.aggregate_chunks = aggregate_chunks + self.stream_response = stream_response self.splitting_chars = splitting_chars or ["\n", ".", "!", "?"] self.chunks_buffer = "" self.max_buffer_size = max_buffer_size @@ -42,6 +44,8 @@ def __init__( self.logger = logger or logging.getLogger(__name__) self.current_conversation_id = None self.current_chunk_id = 0 + self.working = False + self.hit_on_llm_new_token = False def _should_split(self, token: str) -> bool: return token in self.splitting_chars @@ -63,8 +67,22 @@ def _send_all_targets(self, tokens: str, done: bool = False): f"Failed to send {len(tokens)} tokens to hri_connector: {e}" ) + def on_llm_start( + self, + serialized: dict[str, Any], + prompts: list[str], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + self.working = True + def on_llm_new_token(self, token: str, *, run_id: UUID, **kwargs): - if token == "": + self.hit_on_llm_new_token = True + if token == "" or not self.stream_response: return if self.current_conversation_id != str(run_id): self.current_conversation_id = str(run_id) @@ -93,7 +111,22 @@ def on_llm_end( **kwargs, ): self.current_conversation_id = str(run_id) - if self.aggregate_chunks and self.chunks_buffer: + if self.stream_response and not self.hit_on_llm_new_token: + self.logger.error( + ( + "No tokens were sent to the callback handler. " + "LLM did not stream response. " + "Is your BaseChatModel configured to stream? " + "Sending generated text as a single message." + ) + ) + msg = response.generations[0][0].message + self._send_all_targets(msg.content, done=True) + elif not self.stream_response: + msg = response.generations[0][0].message + self._send_all_targets(msg.content, done=True) + elif self.aggregate_chunks and self.chunks_buffer: with self._buffer_lock: self._send_all_targets(self.chunks_buffer, done=True) self.chunks_buffer = "" + self.working = False diff --git a/src/rai_core/rai/agents/langchain/react_agent.py b/src/rai_core/rai/agents/langchain/react_agent.py index 03754442f..1f622fd72 100644 --- a/src/rai_core/rai/agents/langchain/react_agent.py +++ b/src/rai_core/rai/agents/langchain/react_agent.py @@ -34,6 +34,7 @@ def __init__( tools: Optional[List[BaseTool]] = None, state: Optional[ReActAgentState] = None, system_prompt: Optional[str | SystemMultimodalMessage] = None, + stream_response: bool = True, ): runnable = create_react_runnable( llm=llm, tools=tools, system_prompt=system_prompt @@ -42,4 +43,5 @@ def __init__( target_connectors=target_connectors, runnable=runnable, state=state, + stream_response=stream_response, )