Skip to content

Commit 1da0527

Browse files
Support managed agents in ToolCallingAgent (#1456)
1 parent 76ecb9b commit 1da0527

File tree

5 files changed

+52
-6
lines changed

5 files changed

+52
-6
lines changed

src/smolagents/agents.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,10 @@ def _setup_managed_agents(self, managed_agents: list | None = None) -> None:
324324
"All managed agents need both a name and a description!"
325325
)
326326
self.managed_agents = {agent.name: agent for agent in managed_agents}
327+
# Ensure managed agents can be called as tools by the model: set their inputs and output_type
328+
for agent in self.managed_agents.values():
329+
agent.inputs = {"task": {"type": "string", "description": "Long detailed description of the task."}}
330+
agent.output_type = "string"
327331

328332
def _setup_tools(self, tools, add_base_tools):
329333
assert all(isinstance(tool, Tool) for tool in tools), "All elements must be instance of Tool (or a subclass)"
@@ -1190,6 +1194,11 @@ def __init__(
11901194
# Tool calling setup
11911195
self.max_tool_threads = max_tool_threads
11921196

1197+
@property
1198+
def tools_and_managed_agents(self):
1199+
"""Returns a combined list of tools and managed agents."""
1200+
return list(self.tools.values()) + list(self.managed_agents.values())
1201+
11931202
def initialize_system_prompt(self) -> str:
11941203
system_prompt = populate_template(
11951204
self.prompt_templates["system_prompt"],
@@ -1219,7 +1228,7 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDe
12191228
output_stream = self.model.generate_stream(
12201229
input_messages,
12211230
stop_sequences=["Observation:", "Calling tools:"],
1222-
tools_to_call_from=list(self.tools.values()),
1231+
tools_to_call_from=self.tools_and_managed_agents,
12231232
)
12241233

12251234
model_output = ""
@@ -1254,7 +1263,7 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDe
12541263
chat_message: ChatMessage = self.model.generate(
12551264
input_messages,
12561265
stop_sequences=["Observation:", "Calling tools:"],
1257-
tools_to_call_from=list(self.tools.values()),
1266+
tools_to_call_from=self.tools_and_managed_agents,
12581267
)
12591268

12601269
model_output = chat_message.content

src/smolagents/prompts/code_agent.yaml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,12 @@ system_prompt: |-
150150
Here is a list of the team members that you can call:
151151
```python
152152
{%- for agent in managed_agents.values() %}
153-
def {{ agent.name }}("Your query goes here.") -> str:
154-
"""{{ agent.description }}"""
153+
def {{ agent.name }}(task: str) -> str:
154+
"""{{ agent.description }}
155+
156+
Args:
157+
task: Long detailed description of the task.
158+
"""
155159
{% endfor %}
156160
```
157161
{%- endif %}

src/smolagents/prompts/structured_code_agent.yaml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,12 @@ system_prompt: |-
9696
Here is a list of the team members that you can call:
9797
```python
9898
{%- for agent in managed_agents.values() %}
99-
def {{ agent.name }}("Your query goes here.") -> str:
100-
"""{{ agent.description }}"""
99+
def {{ agent.name }}(task: str) -> str:
100+
"""{{ agent.description }}
101+
102+
Args:
103+
task: Long detailed description of the task.
104+
"""
101105
{% endfor %}
102106
```
103107
{%- endif %}

src/smolagents/prompts/toolcalling_agent.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ system_prompt: |-
103103
Here is a list of the team members that you can call:
104104
{%- for agent in managed_agents.values() %}
105105
- {{ agent.name }}: {{ agent.description }}
106+
Takes inputs: {{agent.inputs}}
107+
Returns an output of type: {{agent.output_type}}
106108
{%- endfor %}
107109
{%- endif %}
108110

tests/test_agents.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,6 +1102,33 @@ def test_toolcalling_agent_instructions(self):
11021102
assert agent.instructions == "Test instructions"
11031103
assert "Test instructions" in agent.system_prompt
11041104

1105+
def test_toolcalling_agent_passes_both_tools_and_managed_agents(self, test_tool):
1106+
"""Test that both tools and managed agents are passed to the model."""
1107+
managed_agent = MagicMock()
1108+
managed_agent.name = "managed_agent"
1109+
model = MagicMock()
1110+
model.generate.return_value = ChatMessage(
1111+
role="assistant",
1112+
content="",
1113+
tool_calls=[
1114+
ChatMessageToolCall(
1115+
id="call_0",
1116+
type="function",
1117+
function=ChatMessageToolCallDefinition(name="test_tool", arguments={"input": "test_value"}),
1118+
)
1119+
],
1120+
)
1121+
agent = ToolCallingAgent(tools=[test_tool], managed_agents=[managed_agent], model=model)
1122+
# Run the agent one step to trigger the model call
1123+
next(agent.run("Test task", stream=True))
1124+
# Check that the model was called with both tools and managed agents:
1125+
# - Get all tool_to_call_from names passed to the model
1126+
tools_to_call_from_names = [tool.name for tool in model.generate.call_args.kwargs["tools_to_call_from"]]
1127+
# - Verify both regular tools and managed agents are included
1128+
assert "test_tool" in tools_to_call_from_names # The regular tool
1129+
assert "managed_agent" in tools_to_call_from_names # The managed agent
1130+
assert "final_answer" in tools_to_call_from_names # The final_answer tool (added by default)
1131+
11051132
@patch("huggingface_hub.InferenceClient")
11061133
def test_toolcalling_agent_api(self, mock_inference_client):
11071134
mock_client = mock_inference_client.return_value

0 commit comments

Comments
 (0)