Skip to content

Commit 28a3c56

Browse files
yeesiancopybara-github
authored andcommitted
feat: Support VertexTool in langchain template.
PiperOrigin-RevId: 639027200
1 parent 0c874a4 commit 28a3c56

File tree

2 files changed

+42
-11
lines changed

2 files changed

+42
-11
lines changed

tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import vertexai
2222
from google.cloud.aiplatform import initializer
2323
from vertexai.preview import reasoning_engines
24+
from vertexai.preview.generative_models import grounding
25+
from vertexai.generative_models import Tool
2426
import pytest
2527

2628

@@ -81,6 +83,12 @@ def langchain_dump_mock():
8183
yield langchain_dump_mock
8284

8385

86+
@pytest.fixture
87+
def mock_chatvertexai():
88+
with mock.patch("langchain_google_vertexai.ChatVertexAI") as model_mock:
89+
yield model_mock
90+
91+
8492
@pytest.mark.usefixtures("google_auth_mock")
8593
class TestLangchainAgent:
8694
def setup_method(self):
@@ -113,19 +121,23 @@ def test_initialization(self):
113121
assert agent._location == _TEST_LOCATION
114122
assert agent._runnable is None
115123

116-
def test_initialization_with_tools(self):
124+
def test_initialization_with_tools(self, mock_chatvertexai):
117125
tools = [
118126
place_tool_query,
119127
StructuredTool.from_function(place_photo_query),
128+
Tool.from_google_search_retrieval(grounding.GoogleSearchRetrieval()),
120129
]
121130
agent = reasoning_engines.LangchainAgent(
122131
model=_TEST_MODEL,
123132
tools=tools,
124133
)
125134
for tool, agent_tool in zip(tools, agent._tools):
126135
assert isinstance(agent_tool, type(tool))
136+
assert agent._runnable is None
137+
agent.set_up()
138+
assert agent._runnable is not None
127139

128-
def test_set_up(self, vertexai_init_mock):
140+
def test_set_up(self):
129141
agent = reasoning_engines.LangchainAgent(
130142
model=_TEST_MODEL,
131143
prompt=self.prompt,
@@ -135,7 +147,7 @@ def test_set_up(self, vertexai_init_mock):
135147
agent.set_up()
136148
assert agent._runnable is not None
137149

138-
def test_clone(self, vertexai_init_mock):
150+
def test_clone(self):
139151
agent = reasoning_engines.LangchainAgent(
140152
model=_TEST_MODEL,
141153
prompt=self.prompt,

vertexai/preview/reasoning_engines/templates/langchain.py

+27-8
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@
4242
RunnableConfig = Any
4343
RunnableSerializable = Any
4444

45+
try:
46+
from langchain_google_vertexai.functions_utils import _ToolsType
47+
48+
_ToolLike = _ToolsType
49+
except ImportError:
50+
_ToolLike = Any
51+
4552

4653
def _default_runnable_kwargs(has_history: bool) -> Mapping[str, Any]:
4754
# https://github.com/langchain-ai/langchain/blob/5784dfed001730530637793bea1795d9d5a7c244/libs/core/langchain_core/runnables/history.py#L237-L241
@@ -62,7 +69,13 @@ def _default_runnable_kwargs(has_history: bool) -> Mapping[str, Any]:
6269

6370

6471
def _default_output_parser():
65-
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
72+
try:
73+
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
74+
except (ModuleNotFoundError, ImportError):
75+
# Fallback to an older version if needed.
76+
from langchain.agents.output_parsers.openai_tools import (
77+
OpenAIToolsAgentOutputParser as ToolsAgentOutputParser,
78+
)
6679

6780
return ToolsAgentOutputParser()
6881

@@ -90,7 +103,7 @@ def _default_model_builder(
90103
def _default_runnable_builder(
91104
model: "BaseLanguageModel",
92105
*,
93-
tools: Optional[Sequence[Union[Callable, "BaseTool"]]] = None,
106+
tools: Optional[Sequence["_ToolLike"]] = None,
94107
prompt: Optional["RunnableSerializable"] = None,
95108
output_parser: Optional["RunnableSerializable"] = None,
96109
chat_history: Optional["GetSessionHistoryCallable"] = None,
@@ -123,6 +136,7 @@ def _default_runnable_builder(
123136
if isinstance(tool, lc_tools.BaseTool)
124137
else StructuredTool.from_function(tool)
125138
for tool in tools
139+
if isinstance(tool, (Callable, lc_tools.BaseTool))
126140
],
127141
**agent_executor_kwargs,
128142
)
@@ -139,7 +153,14 @@ def _default_runnable_builder(
139153

140154
def _default_prompt(has_history: bool) -> "RunnableSerializable":
141155
from langchain_core import prompts
142-
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
156+
157+
try:
158+
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
159+
except (ModuleNotFoundError, ImportError):
160+
# Fallback to an older version if needed.
161+
from langchain.agents.format_scratchpad.openai_tools import (
162+
format_to_openai_tool_messages as format_to_tool_messages,
163+
)
143164

144165
if has_history:
145166
return {
@@ -186,12 +207,10 @@ def _validate_callable_parameters_are_annotated(callable: Callable):
186207
)
187208

188209

189-
def _validate_tools(tools: Sequence[Union[Callable, "BaseTool"]]):
210+
def _validate_tools(tools: Sequence["_ToolLike"]):
190211
"""Validates that the tools are usable for tool calling."""
191-
from langchain_core import tools as lc_tools
192-
193212
for tool in tools:
194-
if not isinstance(tool, lc_tools.BaseTool):
213+
if isinstance(tool, Callable):
195214
_validate_callable_parameters_are_annotated(tool)
196215

197216

@@ -208,7 +227,7 @@ def __init__(
208227
model: str,
209228
*,
210229
prompt: Optional["RunnableSerializable"] = None,
211-
tools: Optional[Sequence[Union[Callable, "BaseTool"]]] = None,
230+
tools: Optional[Sequence["_ToolLike"]] = None,
212231
output_parser: Optional["RunnableSerializable"] = None,
213232
chat_history: Optional["GetSessionHistoryCallable"] = None,
214233
model_kwargs: Optional[Mapping[str, Any]] = None,

0 commit comments

Comments
 (0)