Skip to content

Commit a821d50

Browse files
yeesiancopybara-github
authored andcommitted
fix: Add validation for langchain tools.
PiperOrigin-RevId: 625092824
1 parent bb5690c commit a821d50

File tree

2 files changed

+63
-8
lines changed

2 files changed

+63
-8
lines changed

tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@
2424
from vertexai.preview import reasoning_engines
2525
import pytest
2626

27+
2728
from langchain_core import agents
2829
from langchain_core import messages
2930
from langchain_core import outputs
3031
from langchain_core import tools as lc_tools
32+
from langchain.tools.base import StructuredTool
3133

3234

3335
_DEFAULT_PLACE_TOOL_ACTIVITY = "museums"
@@ -100,7 +102,7 @@ def test_initialization_with_tools(self):
100102
model=_TEST_MODEL,
101103
tools=[
102104
place_tool_query,
103-
place_photo_query,
105+
StructuredTool.from_function(place_photo_query),
104106
],
105107
)
106108
for tool in agent._tools:
@@ -178,3 +180,22 @@ def test_parse_text_errors(self, vertexai_init_mock):
178180
agent.set_up()
179181
with pytest.raises(ValueError, match=r"Can only parse messages"):
180182
agent._output_parser.parse("text")
183+
184+
185+
class TestConvertToolsOrRaise:
186+
def test_convert_tools_or_raise(self, vertexai_init_mock):
187+
pass
188+
189+
190+
def _return_input_no_typing(input_):
191+
"""Returns input back to user."""
192+
return input_
193+
194+
195+
class TestConvertToolsOrRaiseErrors:
196+
def test_raise_untyped_input_args(self, vertexai_init_mock):
197+
with pytest.raises(TypeError, match=r"has untyped input_arg"):
198+
reasoning_engines.LangchainAgent(
199+
model=_TEST_MODEL,
200+
tools=[_return_input_no_typing],
201+
)

vertexai/preview/reasoning_engines/templates/langchain.py

+41-7
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,40 @@ def _format_to_messages(
199199
])
200200

201201

202+
def _validate_callable_parameters_are_annotated(callable: Callable):
203+
"""Validates that the parameters of the callable have type annotations.
204+
205+
This ensures that they can be used for constructing LangChain tools that are
206+
usable with Gemini function calling.
207+
"""
208+
import inspect
209+
parameters = dict(inspect.signature(callable).parameters)
210+
for name, parameter in parameters.items():
211+
if parameter.annotation == inspect.Parameter.empty:
212+
raise TypeError(
213+
f"Callable={callable.__name__} has untyped input_arg={name}. "
214+
f"Please specify a type when defining it, e.g. `{name}: str`."
215+
)
216+
217+
218+
def _convert_tools_or_raise(
219+
tools: Sequence[Union[Callable, "BaseTool"]]
220+
) -> Sequence["BaseTool"]:
221+
"""Converts the tools into Langchain tools (if needed).
222+
223+
See https://blog.langchain.dev/structured-tools/ for details.
224+
"""
225+
from langchain_core import tools as lc_tools
226+
from langchain.tools.base import StructuredTool
227+
result = []
228+
for tool in tools:
229+
if not isinstance(tool, lc_tools.BaseTool):
230+
_validate_callable_parameters_are_annotated(tool)
231+
tool = StructuredTool.from_function(tool)
232+
result.append(tool)
233+
return result
234+
235+
202236
class LangchainAgent:
203237
"""A Langchain Agent.
204238
@@ -302,19 +336,19 @@ def __init__(
302336
langchain.runnables.history.RunnableWithMessageHistory if
303337
chat_history is specified. If chat_history is None, this will be
304338
ignored.
339+
340+
Raises:
341+
TypeError: If there is an invalid tool (e.g. function with an input
342+
that did not specify its type).
305343
"""
306344
from google.cloud.aiplatform import initializer
307345
self._project = initializer.global_config.project
308346
self._location = initializer.global_config.location
309347
self._tools = []
310348
if tools:
311-
from langchain_core import tools as lc_tools
312-
from langchain.tools.base import StructuredTool
313-
self._tools = [
314-
tool if isinstance(tool, lc_tools.BaseTool)
315-
else StructuredTool.from_function(tool)
316-
for tool in tools
317-
]
349+
# Unlike the other fields, we convert tools at initialization to
350+
# validate the functions/tools before they are deployed.
351+
self._tools = _convert_tools_or_raise(tools)
318352
self._model_name = model
319353
self._prompt = prompt
320354
self._output_parser = output_parser

0 commit comments

Comments
 (0)