Skip to content

Commit 754c89d

Browse files
yeesiancopybara-github
authored andcommitted
fix: Parse intermediate steps from LangChain into JSON.
PiperOrigin-RevId: 627444864
1 parent 76c5d6d commit 754c89d

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from langchain_core import messages
3030
from langchain_core import outputs
3131
from langchain_core import tools as lc_tools
32+
from langchain.load import dump as langchain_load_dump
3233
from langchain.tools.base import StructuredTool
3334

3435

@@ -77,6 +78,12 @@ def vertexai_init_mock():
7778
yield vertexai_init_mock
7879

7980

81+
@pytest.fixture
82+
def langchain_dump_mock():
83+
with mock.patch.object(langchain_load_dump, "dumpd") as langchain_dump_mock:
84+
yield langchain_dump_mock
85+
86+
8087
@pytest.mark.usefixtures("google_auth_mock")
8188
class TestLangchainAgent:
8289
def setup_method(self):
@@ -114,7 +121,7 @@ def test_set_up(self, vertexai_init_mock):
114121
agent.set_up()
115122
assert agent._runnable is not None
116123

117-
def test_query(self):
124+
def test_query(self, langchain_dump_mock):
118125
agent = reasoning_engines.LangchainAgent(model=_TEST_MODEL)
119126
agent._runnable = mock.Mock()
120127
mocks = mock.Mock()

vertexai/preview/reasoning_engines/templates/langchain.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
TYPE_CHECKING,
1919
Any,
2020
Callable,
21+
Dict,
2122
List,
2223
Mapping,
2324
Optional,
@@ -418,7 +419,7 @@ def query(
418419
input: Union[str, Mapping[str, Any]],
419420
config: Optional["RunnableConfig"] = None,
420421
**kwargs: Any,
421-
) -> Mapping[str, Any]:
422+
) -> Dict[str, Any]:
422423
"""Queries the Agent with the given input and config.
423424
424425
Args:
@@ -433,8 +434,11 @@ def query(
433434
Returns:
434435
The output of querying the Agent with the given input and config.
435436
"""
437+
from langchain.load import dump as langchain_load_dump
436438
if isinstance(input, str):
437439
input = {"input": input}
438440
if not self._runnable:
439441
self.set_up()
440-
return self._runnable.invoke(input=input, config=config, **kwargs)
442+
return langchain_load_dump.dumpd(
443+
self._runnable.invoke(input=input, config=config, **kwargs)
444+
)

0 commit comments

Comments
 (0)