Skip to content

Commit 4046fcb

Browse files
authored
Add is_enabled to FunctionTool (#808)
### Summary: Allows a user to do `function_tool(is_enabled=<some_callable>)`; the callable is called when the agent runs. This allows you to dynamically enable/disable a tool based on the context/env. The meta-goal is to allow `Agent` to be effectively immutable. That enables some nice things down the line, and this allows you to dynamically modify the tools list without mutating the agent. ### Test Plan: Unit tests
1 parent 995af4d commit 4046fcb

File tree

6 files changed

+102
-24
lines changed

6 files changed

+102
-24
lines changed

src/agents/agent.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import dataclasses
45
import inspect
56
from collections.abc import Awaitable
@@ -17,7 +18,7 @@
1718
from .model_settings import ModelSettings
1819
from .models.interface import Model
1920
from .run_context import RunContextWrapper, TContext
20-
from .tool import FunctionToolResult, Tool, function_tool
21+
from .tool import FunctionTool, FunctionToolResult, Tool, function_tool
2122
from .util import _transforms
2223
from .util._types import MaybeAwaitable
2324

@@ -246,7 +247,22 @@ async def get_mcp_tools(self) -> list[Tool]:
246247
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
247248
return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict)
248249

249-
async def get_all_tools(self) -> list[Tool]:
250+
async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
250251
"""All agent tools, including MCP tools and function tools."""
251252
mcp_tools = await self.get_mcp_tools()
252-
return mcp_tools + self.tools
253+
254+
async def _check_tool_enabled(tool: Tool) -> bool:
255+
if not isinstance(tool, FunctionTool):
256+
return True
257+
258+
attr = tool.is_enabled
259+
if isinstance(attr, bool):
260+
return attr
261+
res = attr(run_context, self)
262+
if inspect.isawaitable(res):
263+
return bool(await res)
264+
return bool(res)
265+
266+
results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
267+
enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
268+
return [*mcp_tools, *enabled]

src/agents/run.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ async def run(
181181

182182
try:
183183
while True:
184-
all_tools = await cls._get_all_tools(current_agent)
184+
all_tools = await cls._get_all_tools(current_agent, context_wrapper)
185185

186186
# Start an agent span if we don't have one. This span is ended if the current
187187
# agent changes, or if the agent loop ends.
@@ -525,7 +525,7 @@ async def _run_streamed_impl(
525525
if streamed_result.is_complete:
526526
break
527527

528-
all_tools = await cls._get_all_tools(current_agent)
528+
all_tools = await cls._get_all_tools(current_agent, context_wrapper)
529529

530530
# Start an agent span if we don't have one. This span is ended if the current
531531
# agent changes, or if the agent loop ends.
@@ -980,8 +980,10 @@ def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]:
980980
return handoffs
981981

982982
@classmethod
983-
async def _get_all_tools(cls, agent: Agent[Any]) -> list[Tool]:
984-
return await agent.get_all_tools()
983+
async def _get_all_tools(
984+
cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any]
985+
) -> list[Tool]:
986+
return await agent.get_all_tools(context_wrapper)
985987

986988
@classmethod
987989
def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:

src/agents/tool.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import json
55
from collections.abc import Awaitable
66
from dataclasses import dataclass
7-
from typing import Any, Callable, Literal, Union, overload
7+
from typing import TYPE_CHECKING, Any, Callable, Literal, Union, overload
88

99
from openai.types.responses.file_search_tool_param import Filters, RankingOptions
1010
from openai.types.responses.response_output_item import LocalShellCall, McpApprovalRequest
@@ -24,6 +24,9 @@
2424
from .util import _error_tracing
2525
from .util._types import MaybeAwaitable
2626

27+
if TYPE_CHECKING:
28+
from .agent import Agent
29+
2730
ToolParams = ParamSpec("ToolParams")
2831

2932
ToolFunctionWithoutContext = Callable[ToolParams, Any]
@@ -74,6 +77,11 @@ class FunctionTool:
7477
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
7578
as it increases the likelihood of correct JSON input."""
7679

80+
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True
81+
"""Whether the tool is enabled. Either a bool or a Callable that takes the run context and agent
82+
and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool
83+
based on your context/state."""
84+
7785

7886
@dataclass
7987
class FileSearchTool:
@@ -262,6 +270,7 @@ def function_tool(
262270
use_docstring_info: bool = True,
263271
failure_error_function: ToolErrorFunction | None = None,
264272
strict_mode: bool = True,
273+
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
265274
) -> FunctionTool:
266275
"""Overload for usage as @function_tool (no parentheses)."""
267276
...
@@ -276,6 +285,7 @@ def function_tool(
276285
use_docstring_info: bool = True,
277286
failure_error_function: ToolErrorFunction | None = None,
278287
strict_mode: bool = True,
288+
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
279289
) -> Callable[[ToolFunction[...]], FunctionTool]:
280290
"""Overload for usage as @function_tool(...)."""
281291
...
@@ -290,6 +300,7 @@ def function_tool(
290300
use_docstring_info: bool = True,
291301
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
292302
strict_mode: bool = True,
303+
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
293304
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
294305
"""
295306
Decorator to create a FunctionTool from a function. By default, we will:
@@ -318,6 +329,9 @@ def function_tool(
318329
If False, it allows non-strict JSON schemas. For example, if a parameter has a default
319330
value, it will be optional, additional properties are allowed, etc. See here for more:
320331
https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas
332+
is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run
333+
context and agent and returns whether the tool is enabled. Disabled tools are hidden
334+
from the LLM at runtime.
321335
"""
322336

323337
def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
@@ -407,6 +421,7 @@ async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> Any:
407421
params_json_schema=schema.params_json_schema,
408422
on_invoke_tool=_on_invoke_tool,
409423
strict_json_schema=strict_mode,
424+
is_enabled=is_enabled,
410425
)
411426

412427
# If func is actually a callable, we were used as @function_tool with no parentheses

tests/test_function_tool.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pydantic import BaseModel
66
from typing_extensions import TypedDict
77

8-
from agents import FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool
8+
from agents import Agent, FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool
99
from agents.tool import default_tool_error_function
1010

1111

@@ -255,3 +255,44 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
255255

256256
result = await tool.on_invoke_tool(ctx, '{"a": 1, "b": 2}')
257257
assert result == "error_ValueError"
258+
259+
260+
class BoolCtx(BaseModel):
261+
enable_tools: bool
262+
263+
264+
@pytest.mark.asyncio
265+
async def test_is_enabled_bool_and_callable():
266+
@function_tool(is_enabled=False)
267+
def disabled_tool():
268+
return "nope"
269+
270+
async def cond_enabled(ctx: RunContextWrapper[BoolCtx], agent: Agent[Any]) -> bool:
271+
return ctx.context.enable_tools
272+
273+
@function_tool(is_enabled=cond_enabled)
274+
def another_tool():
275+
return "hi"
276+
277+
async def third_tool_on_invoke_tool(ctx: RunContextWrapper[Any], args: str) -> str:
278+
return "third"
279+
280+
third_tool = FunctionTool(
281+
name="third_tool",
282+
description="third tool",
283+
on_invoke_tool=third_tool_on_invoke_tool,
284+
is_enabled=lambda ctx, agent: ctx.context.enable_tools,
285+
params_json_schema={},
286+
)
287+
288+
agent = Agent(name="t", tools=[disabled_tool, another_tool, third_tool])
289+
context_1 = RunContextWrapper(BoolCtx(enable_tools=False))
290+
context_2 = RunContextWrapper(BoolCtx(enable_tools=True))
291+
292+
tools_with_ctx = await agent.get_all_tools(context_1)
293+
assert tools_with_ctx == []
294+
295+
tools_with_ctx = await agent.get_all_tools(context_2)
296+
assert len(tools_with_ctx) == 2
297+
assert tools_with_ctx[0].name == "another_tool"
298+
assert tools_with_ctx[1].name == "third_tool"

tests/test_run_step_execution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ async def get_execute_result(
290290

291291
processed_response = RunImpl.process_model_response(
292292
agent=agent,
293-
all_tools=await agent.get_all_tools(),
293+
all_tools=await agent.get_all_tools(context_wrapper or RunContextWrapper(None)),
294294
response=response,
295295
output_schema=output_schema,
296296
handoffs=handoffs,

tests/test_run_step_processing.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@
3434
)
3535

3636

37+
def _dummy_ctx() -> RunContextWrapper[None]:
38+
return RunContextWrapper(context=None)
39+
40+
3741
def test_empty_response():
3842
agent = Agent(name="test")
3943
response = ModelResponse(
@@ -83,7 +87,7 @@ async def test_single_tool_call():
8387
response=response,
8488
output_schema=None,
8589
handoffs=[],
86-
all_tools=await agent.get_all_tools(),
90+
all_tools=await agent.get_all_tools(_dummy_ctx()),
8791
)
8892
assert not result.handoffs
8993
assert result.functions and len(result.functions) == 1
@@ -111,7 +115,7 @@ async def test_missing_tool_call_raises_error():
111115
response=response,
112116
output_schema=None,
113117
handoffs=[],
114-
all_tools=await agent.get_all_tools(),
118+
all_tools=await agent.get_all_tools(_dummy_ctx()),
115119
)
116120

117121

@@ -140,7 +144,7 @@ async def test_multiple_tool_calls():
140144
response=response,
141145
output_schema=None,
142146
handoffs=[],
143-
all_tools=await agent.get_all_tools(),
147+
all_tools=await agent.get_all_tools(_dummy_ctx()),
144148
)
145149
assert not result.handoffs
146150
assert result.functions and len(result.functions) == 2
@@ -169,7 +173,7 @@ async def test_handoffs_parsed_correctly():
169173
response=response,
170174
output_schema=None,
171175
handoffs=[],
172-
all_tools=await agent_3.get_all_tools(),
176+
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
173177
)
174178
assert not result.handoffs, "Shouldn't have a handoff here"
175179

@@ -183,7 +187,7 @@ async def test_handoffs_parsed_correctly():
183187
response=response,
184188
output_schema=None,
185189
handoffs=Runner._get_handoffs(agent_3),
186-
all_tools=await agent_3.get_all_tools(),
190+
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
187191
)
188192
assert len(result.handoffs) == 1, "Should have a handoff here"
189193
handoff = result.handoffs[0]
@@ -213,7 +217,7 @@ async def test_missing_handoff_fails():
213217
response=response,
214218
output_schema=None,
215219
handoffs=Runner._get_handoffs(agent_3),
216-
all_tools=await agent_3.get_all_tools(),
220+
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
217221
)
218222

219223

@@ -236,7 +240,7 @@ async def test_multiple_handoffs_doesnt_error():
236240
response=response,
237241
output_schema=None,
238242
handoffs=Runner._get_handoffs(agent_3),
239-
all_tools=await agent_3.get_all_tools(),
243+
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
240244
)
241245
assert len(result.handoffs) == 2, "Should have multiple handoffs here"
242246

@@ -262,7 +266,7 @@ async def test_final_output_parsed_correctly():
262266
response=response,
263267
output_schema=Runner._get_output_schema(agent),
264268
handoffs=[],
265-
all_tools=await agent.get_all_tools(),
269+
all_tools=await agent.get_all_tools(_dummy_ctx()),
266270
)
267271

268272

@@ -288,7 +292,7 @@ async def test_file_search_tool_call_parsed_correctly():
288292
response=response,
289293
output_schema=None,
290294
handoffs=[],
291-
all_tools=await agent.get_all_tools(),
295+
all_tools=await agent.get_all_tools(_dummy_ctx()),
292296
)
293297
# The final item should be a ToolCallItem for the file search call
294298
assert any(
@@ -313,7 +317,7 @@ async def test_function_web_search_tool_call_parsed_correctly():
313317
response=response,
314318
output_schema=None,
315319
handoffs=[],
316-
all_tools=await agent.get_all_tools(),
320+
all_tools=await agent.get_all_tools(_dummy_ctx()),
317321
)
318322
assert any(
319323
isinstance(item, ToolCallItem) and item.raw_item is web_search_call
@@ -340,7 +344,7 @@ async def test_reasoning_item_parsed_correctly():
340344
response=response,
341345
output_schema=None,
342346
handoffs=[],
343-
all_tools=await Agent(name="test").get_all_tools(),
347+
all_tools=await Agent(name="test").get_all_tools(_dummy_ctx()),
344348
)
345349
assert any(
346350
isinstance(item, ReasoningItem) and item.raw_item is reasoning for item in result.new_items
@@ -409,7 +413,7 @@ async def test_computer_tool_call_without_computer_tool_raises_error():
409413
response=response,
410414
output_schema=None,
411415
handoffs=[],
412-
all_tools=await Agent(name="test").get_all_tools(),
416+
all_tools=await Agent(name="test").get_all_tools(_dummy_ctx()),
413417
)
414418

415419

@@ -437,7 +441,7 @@ async def test_computer_tool_call_with_computer_tool_parsed_correctly():
437441
response=response,
438442
output_schema=None,
439443
handoffs=[],
440-
all_tools=await agent.get_all_tools(),
444+
all_tools=await agent.get_all_tools(_dummy_ctx()),
441445
)
442446
assert any(
443447
isinstance(item, ToolCallItem) and item.raw_item is computer_call
@@ -468,7 +472,7 @@ async def test_tool_and_handoff_parsed_correctly():
468472
response=response,
469473
output_schema=None,
470474
handoffs=Runner._get_handoffs(agent_3),
471-
all_tools=await agent_3.get_all_tools(),
475+
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
472476
)
473477
assert result.functions and len(result.functions) == 1
474478
assert len(result.handoffs) == 1, "Should have a handoff here"

0 commit comments

Comments
 (0)