Skip to content

Commit 9bda328

Browse files
yeesiancopybara-github
authored andcommitted
feat: Add support for ToolConfig in the LangChain template
PiperOrigin-RevId: 633283927
1 parent e586041 commit 9bda328

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

vertexai/preview/reasoning_engines/templates/langchain.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def _default_runnable_builder(
9696
prompt: Optional["RunnableSerializable"] = None,
9797
output_parser: Optional["RunnableSerializable"] = None,
9898
chat_history: Optional["GetSessionHistoryCallable"] = None,
99+
model_tool_kwargs: Optional[Mapping[str, Any]] = None,
99100
agent_executor_kwargs: Optional[Mapping[str, Any]] = None,
100101
runnable_kwargs: Optional[Mapping[str, Any]] = None,
101102
) -> "RunnableSerializable":
@@ -109,10 +110,11 @@ def _default_runnable_builder(
109110
has_history: bool = chat_history is not None
110111
prompt = prompt or _default_prompt(has_history)
111112
output_parser = output_parser or _default_output_parser()
113+
model_tool_kwargs = model_tool_kwargs or {}
112114
agent_executor_kwargs = agent_executor_kwargs or {}
113115
runnable_kwargs = runnable_kwargs or _default_runnable_kwargs(has_history)
114116
if tools:
115-
model = model.bind_tools(tools=tools)
117+
model = model.bind_tools(tools=tools, **model_tool_kwargs)
116118
else:
117119
tools = []
118120
agent_executor = AgentExecutor(
@@ -202,6 +204,7 @@ def __init__(
202204
output_parser: Optional["RunnableSerializable"] = None,
203205
chat_history: Optional["GetSessionHistoryCallable"] = None,
204206
model_kwargs: Optional[Mapping[str, Any]] = None,
207+
model_tool_kwargs: Optional[Mapping[str, Any]] = None,
205208
agent_executor_kwargs: Optional[Mapping[str, Any]] = None,
206209
runnable_kwargs: Optional[Mapping[str, Any]] = None,
207210
model_builder: Optional[Callable] = None,
@@ -233,8 +236,9 @@ def __init__(
233236
# runnable_builder
234237
from langchain import agents
235238
from langchain_core.runnables.history import RunnableWithMessageHistory
239+
llm_with_tools = llm.bind_tools(tools=tools, **model_tool_kwargs)
236240
agent_executor = agents.AgentExecutor(
237-
agent=prompt | llm.bind_tools(tools=tools) | output_parser,
241+
agent=prompt | llm_with_tools | output_parser,
238242
tools=tools,
239243
**agent_executor_kwargs,
240244
)
@@ -282,6 +286,9 @@ def __init__(
282286
"top_k": 40,
283287
}
284288
```
289+
model_tool_kwargs (Mapping[str, Any]):
290+
Optional. Additional keyword arguments when binding tools to the
291+
model using `model.bind_tools()`.
285292
agent_executor_kwargs (Mapping[str, Any]):
286293
Optional. Additional keyword arguments for the constructor of
287294
langchain.agents.AgentExecutor. An example would be
@@ -334,6 +341,7 @@ def __init__(
334341
self._output_parser = output_parser
335342
self._chat_history = chat_history
336343
self._model_kwargs = model_kwargs
344+
self._model_tool_kwargs = model_tool_kwargs
337345
self._agent_executor_kwargs = agent_executor_kwargs
338346
self._runnable_kwargs = runnable_kwargs
339347
self._model = None
@@ -365,6 +373,7 @@ def set_up(self):
365373
tools=self._tools,
366374
output_parser=self._output_parser,
367375
chat_history=self._chat_history,
376+
model_tool_kwargs=self._model_tool_kwargs,
368377
agent_executor_kwargs=self._agent_executor_kwargs,
369378
runnable_kwargs=self._runnable_kwargs,
370379
)

0 commit comments

Comments
 (0)