Skip to content

Commit 520d347

Browse files
authored
feat(LangChainAgent): stream_response (#589)
1 parent 1b3f204 commit 520d347

File tree

3 files changed

+45
-5
lines changed

3 files changed

+45
-5
lines changed

src/rai_core/rai/agents/langchain/agent.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ class LangChainAgent(BaseAgent):
5353
Dict of target_name: connector. Agent will send it's output to these targets using connectors.
5454
runnable : Runnable
5555
LangChain runnable that will be used to generate output.
56+
stream_response : bool, optional
57+
If True, the agent will stream the response to the target connectors. Make sure that the runnable is configured to stream.
5658
state : BaseState | None, optional
5759
State to seed the LangChain runnable. If None - empty state is used.
5860
new_message_behavior : newMessageBehaviorType, optional
@@ -107,20 +109,23 @@ def __init__(
107109
self,
108110
target_connectors: Dict[str, HRIConnector[HRIMessage]],
109111
runnable: Runnable[Any, Any],
112+
stream_response: bool = True,
110113
state: BaseState | None = None,
111114
new_message_behavior: newMessageBehaviorType = "interrupt_keep_last",
112115
max_size: int = 100,
113116
):
114117
super().__init__()
115118
self.logger = logging.getLogger(__name__)
116119
self.agent = runnable
120+
self.stream_response = stream_response
117121
self.new_message_behavior: newMessageBehaviorType = new_message_behavior
118122
self.tracing_callbacks = get_tracing_callbacks()
119123
self.state = state or ReActAgentState(messages=[])
120124
self._langchain_callback = HRICallbackHandler(
121125
connectors=target_connectors,
122126
aggregate_chunks=True,
123127
logger=self.logger,
128+
stream_response=stream_response,
124129
)
125130

126131
self._received_messages: Deque[HRIMessage] = deque()
@@ -182,8 +187,8 @@ def _interrupt_agent_and_run(self):
182187
def _run_agent(self):
183188
if len(self._received_messages) == 0:
184189
self._agent_ready_event.set()
185-
self.logger.info("Waiting for messages...")
186-
time.sleep(0.5)
190+
self.logger.debug("Waiting for messages...")
191+
time.sleep(0.1)
187192
return
188193
self._agent_ready_event.clear()
189194
try:

src/rai_core/rai/agents/langchain/callback.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import logging
1616
import threading
17-
from typing import Dict, List, Optional
17+
from typing import Any, Dict, List, Optional
1818
from uuid import UUID
1919

2020
from langchain_core.callbacks import BaseCallbackHandler
@@ -32,16 +32,20 @@ def __init__(
3232
splitting_chars: Optional[List[str]] = None,
3333
max_buffer_size: int = 200,
3434
logger: Optional[logging.Logger] = None,
35+
stream_response: bool = True,
3536
):
3637
self.connectors = connectors
3738
self.aggregate_chunks = aggregate_chunks
39+
self.stream_response = stream_response
3840
self.splitting_chars = splitting_chars or ["\n", ".", "!", "?"]
3941
self.chunks_buffer = ""
4042
self.max_buffer_size = max_buffer_size
4143
self._buffer_lock = threading.Lock()
4244
self.logger = logger or logging.getLogger(__name__)
4345
self.current_conversation_id = None
4446
self.current_chunk_id = 0
47+
self.working = False
48+
self.hit_on_llm_new_token = False
4549

4650
def _should_split(self, token: str) -> bool:
4751
return token in self.splitting_chars
@@ -63,8 +67,22 @@ def _send_all_targets(self, tokens: str, done: bool = False):
6367
f"Failed to send {len(tokens)} tokens to hri_connector: {e}"
6468
)
6569

70+
def on_llm_start(
71+
self,
72+
serialized: dict[str, Any],
73+
prompts: list[str],
74+
*,
75+
run_id: UUID,
76+
parent_run_id: Optional[UUID] = None,
77+
tags: Optional[list[str]] = None,
78+
metadata: Optional[dict[str, Any]] = None,
79+
**kwargs: Any,
80+
) -> Any:
81+
self.working = True
82+
6683
def on_llm_new_token(self, token: str, *, run_id: UUID, **kwargs):
67-
if token == "":
84+
self.hit_on_llm_new_token = True
85+
if token == "" or not self.stream_response:
6886
return
6987
if self.current_conversation_id != str(run_id):
7088
self.current_conversation_id = str(run_id)
@@ -93,7 +111,22 @@ def on_llm_end(
93111
**kwargs,
94112
):
95113
self.current_conversation_id = str(run_id)
96-
if self.aggregate_chunks and self.chunks_buffer:
114+
if self.stream_response and not self.hit_on_llm_new_token:
115+
self.logger.error(
116+
(
117+
"No tokens were sent to the callback handler. "
118+
"LLM did not stream response. "
119+
"Is your BaseChatModel configured to stream? "
120+
"Sending generated text as a single message."
121+
)
122+
)
123+
msg = response.generations[0][0].message
124+
self._send_all_targets(msg.content, done=True)
125+
elif not self.stream_response:
126+
msg = response.generations[0][0].message
127+
self._send_all_targets(msg.content, done=True)
128+
elif self.aggregate_chunks and self.chunks_buffer:
97129
with self._buffer_lock:
98130
self._send_all_targets(self.chunks_buffer, done=True)
99131
self.chunks_buffer = ""
132+
self.working = False

src/rai_core/rai/agents/langchain/react_agent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
tools: Optional[List[BaseTool]] = None,
3535
state: Optional[ReActAgentState] = None,
3636
system_prompt: Optional[str | SystemMultimodalMessage] = None,
37+
stream_response: bool = True,
3738
):
3839
runnable = create_react_runnable(
3940
llm=llm, tools=tools, system_prompt=system_prompt
@@ -42,4 +43,5 @@ def __init__(
4243
target_connectors=target_connectors,
4344
runnable=runnable,
4445
state=state,
46+
stream_response=stream_response,
4547
)

0 commit comments

Comments
 (0)