Skip to content

Commit 99f613b

Browse files
shawn-yang-googlecopybara-github
authored andcommitted
feat: Support stream_query in LangChain Agent Templates in the Python Reasoning Engine Client
PiperOrigin-RevId: 706882105
1 parent 25622f8 commit 99f613b

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py

+10
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,16 @@ def test_query(self, langchain_dump_mock):
221221
[mock.call.invoke.invoke(input={"input": "test query"}, config=None)]
222222
)
223223

224+
def test_stream_query(self, langchain_dump_mock):
225+
agent = reasoning_engines.LangchainAgent(model=_TEST_MODEL)
226+
agent._runnable = mock.Mock()
227+
agent._runnable.stream.return_value = []
228+
list(agent.stream_query(input="test stream query"))
229+
agent._runnable.stream.assert_called_once_with(
230+
input={"input": "test stream query"},
231+
config=None,
232+
)
233+
224234
@pytest.mark.usefixtures("caplog")
225235
def test_enable_tracing(
226236
self,

vertexai/preview/reasoning_engines/templates/langchain.py

+31
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Any,
1919
Callable,
2020
Dict,
21+
Iterable,
2122
Mapping,
2223
Optional,
2324
Sequence,
@@ -609,3 +610,33 @@ def query(
609610
return langchain_load_dump.dumpd(
610611
self._runnable.invoke(input=input, config=config, **kwargs)
611612
)
613+
614+
def stream_query(
615+
self,
616+
*,
617+
input: Union[str, Mapping[str, Any]],
618+
config: Optional["RunnableConfig"] = None,
619+
**kwargs,
620+
) -> Iterable[Any]:
621+
"""Stream queries the Agent with the given input and config.
622+
623+
Args:
624+
input (Union[str, Mapping[str, Any]]):
625+
Required. The input to be passed to the Agent.
626+
config (langchain_core.runnables.RunnableConfig):
627+
Optional. The config (if any) to be used for invoking the Agent.
628+
**kwargs:
629+
Optional. Any additional keyword arguments to be passed to the
630+
`.invoke()` method of the corresponding AgentExecutor.
631+
632+
Yields:
633+
The output of querying the Agent with the given input and config.
634+
"""
635+
from langchain.load import dump as langchain_load_dump
636+
637+
if isinstance(input, str):
638+
input = {"input": input}
639+
if not self._runnable:
640+
self.set_up()
641+
for chunk in self._runnable.stream(input=input, config=config, **kwargs):
642+
yield langchain_load_dump.dumpd(chunk)

0 commit comments

Comments
 (0)