@@ -199,6 +199,40 @@ def _format_to_messages(
199
199
])
200
200
201
201
202
+ def _validate_callable_parameters_are_annotated (callable : Callable ):
203
+ """Validates that the parameters of the callable have type annotations.
204
+
205
+ This ensures that they can be used for constructing LangChain tools that are
206
+ usable with Gemini function calling.
207
+ """
208
+ import inspect
209
+ parameters = dict (inspect .signature (callable ).parameters )
210
+ for name , parameter in parameters .items ():
211
+ if parameter .annotation == inspect .Parameter .empty :
212
+ raise TypeError (
213
+ f"Callable={ callable .__name__ } has untyped input_arg={ name } . "
214
+ f"Please specify a type when defining it, e.g. `{ name } : str`."
215
+ )
216
+
217
+
218
+ def _convert_tools_or_raise (
219
+ tools : Sequence [Union [Callable , "BaseTool" ]]
220
+ ) -> Sequence ["BaseTool" ]:
221
+ """Converts the tools into Langchain tools (if needed).
222
+
223
+ See https://blog.langchain.dev/structured-tools/ for details.
224
+ """
225
+ from langchain_core import tools as lc_tools
226
+ from langchain .tools .base import StructuredTool
227
+ result = []
228
+ for tool in tools :
229
+ if not isinstance (tool , lc_tools .BaseTool ):
230
+ _validate_callable_parameters_are_annotated (tool )
231
+ tool = StructuredTool .from_function (tool )
232
+ result .append (tool )
233
+ return result
234
+
235
+
202
236
class LangchainAgent :
203
237
"""A Langchain Agent.
204
238
@@ -302,19 +336,19 @@ def __init__(
302
336
langchain.runnables.history.RunnableWithMessageHistory if
303
337
chat_history is specified. If chat_history is None, this will be
304
338
ignored.
339
+
340
+ Raises:
341
+ TypeError: If there is an invalid tool (e.g. function with an input
342
+ that did not specify its type).
305
343
"""
306
344
from google .cloud .aiplatform import initializer
307
345
self ._project = initializer .global_config .project
308
346
self ._location = initializer .global_config .location
309
347
self ._tools = []
310
348
if tools :
311
- from langchain_core import tools as lc_tools
312
- from langchain .tools .base import StructuredTool
313
- self ._tools = [
314
- tool if isinstance (tool , lc_tools .BaseTool )
315
- else StructuredTool .from_function (tool )
316
- for tool in tools
317
- ]
349
+ # Unlike the other fields, we convert tools at initialization to
350
+ # validate the functions/tools before they are deployed.
351
+ self ._tools = _convert_tools_or_raise (tools )
318
352
self ._model_name = model
319
353
self ._prompt = prompt
320
354
self ._output_parser = output_parser
0 commit comments