Skip to content

Commit ecabb9e

Browse files
Delete prompts_path argument and use prompt_templates (#541)
1 parent fd9eec8 commit ecabb9e

File tree

2 files changed

+31
-10
lines changed

2 files changed

+31
-10
lines changed

src/smolagents/agents.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class MultiStepAgent:
8888
Args:
8989
tools (`list[Tool]`): [`Tool`]s that the agent can use.
9090
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
91-
prompts_path (`str`, *optional*): The path from which to load this agent's prompt dictionary.
91+
prompt_templates (`dict`, *optional*): Prompt templates.
9292
max_steps (`int`, default `6`): Maximum number of steps the agent can take to solve the task.
9393
tool_parser (`Callable`, *optional*): Function used to parse the tool calls from the LLM output.
9494
add_base_tools (`bool`, default `False`): Whether to add the base tools to the agent's tools.
@@ -107,7 +107,7 @@ def __init__(
107107
self,
108108
tools: List[Tool],
109109
model: Callable[[List[Dict[str, str]]], ChatMessage],
110-
prompts_path: Optional[str] = None,
110+
prompt_templates: Optional[dict] = None,
111111
max_steps: int = 6,
112112
tool_parser: Optional[Callable] = None,
113113
add_base_tools: bool = False,
@@ -125,6 +125,7 @@ def __init__(
125125
tool_parser = parse_json_tool_call
126126
self.agent_name = self.__class__.__name__
127127
self.model = model
128+
self.prompt_templates = prompt_templates or {}
128129
self.max_steps = max_steps
129130
self.step_number: int = 0
130131
self.tool_parser = tool_parser
@@ -633,7 +634,7 @@ class ToolCallingAgent(MultiStepAgent):
633634
Args:
634635
tools (`list[Tool]`): [`Tool`]s that the agent can use.
635636
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
636-
prompts_path (`str`, *optional*): The path from which to load this agent's prompt dictionary.
637+
prompt_templates (`dict`, *optional*): Prompt templates.
637638
planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
638639
**kwargs: Additional keyword arguments.
639640
"""
@@ -642,17 +643,17 @@ def __init__(
642643
self,
643644
tools: List[Tool],
644645
model: Callable[[List[Dict[str, str]]], ChatMessage],
645-
prompts_path: Optional[str] = None,
646+
prompt_templates: Optional[dict] = None,
646647
planning_interval: Optional[int] = None,
647648
**kwargs,
648649
):
649-
self.prompt_templates = yaml.safe_load(
650+
prompt_templates = prompt_templates or yaml.safe_load(
650651
importlib.resources.read_text("smolagents.prompts", "toolcalling_agent.yaml")
651652
)
652653
super().__init__(
653654
tools=tools,
654655
model=model,
655-
prompts_path=prompts_path,
656+
prompt_templates=prompt_templates,
656657
planning_interval=planning_interval,
657658
**kwargs,
658659
)
@@ -755,7 +756,7 @@ class CodeAgent(MultiStepAgent):
755756
Args:
756757
tools (`list[Tool]`): [`Tool`]s that the agent can use.
757758
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
758-
prompts_path (`str`, *optional*): The path from which to load this agent's prompt dictionary.
759+
prompt_templates (`dict`, *optional*): Prompt templates.
759760
grammar (`dict[str, str]`, *optional*): Grammar used to parse the LLM output.
760761
additional_authorized_imports (`list[str]`, *optional*): Additional authorized imports for the agent.
761762
planning_interval (`int`, *optional*): Interval at which the agent will run a planning step.
@@ -769,7 +770,7 @@ def __init__(
769770
self,
770771
tools: List[Tool],
771772
model: Callable[[List[Dict[str, str]]], ChatMessage],
772-
prompts_path: Optional[str] = None,
773+
prompt_templates: Optional[dict] = None,
773774
grammar: Optional[Dict[str, str]] = None,
774775
additional_authorized_imports: Optional[List[str]] = None,
775776
planning_interval: Optional[int] = None,
@@ -779,10 +780,13 @@ def __init__(
779780
):
780781
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
781782
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
782-
self.prompt_templates = yaml.safe_load(importlib.resources.read_text("smolagents.prompts", "code_agent.yaml"))
783+
prompt_templates = prompt_templates or yaml.safe_load(
784+
importlib.resources.read_text("smolagents.prompts", "code_agent.yaml")
785+
)
783786
super().__init__(
784787
tools=tools,
785788
model=model,
789+
prompt_templates=prompt_templates,
786790
grammar=grammar,
787791
planning_interval=planning_interval,
788792
**kwargs,

tests/test_agents.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pathlib import Path
2020
from unittest.mock import MagicMock
2121

22+
import pytest
2223
from transformers.testing_utils import get_tests_dir
2324

2425
from smolagents.agent_types import AgentImage, AgentText
@@ -664,11 +665,19 @@ def check_always_fails(final_answer, agent_memory):
664665

665666

666667
class TestMultiStepAgent:
667-
def test_logging_to_terminal_is_disabled(self):
668+
def test_instantiation_disables_logging_to_terminal(self):
668669
fake_model = MagicMock()
669670
agent = MultiStepAgent(tools=[], model=fake_model)
670671
assert agent.logger.level == -1, "logging to terminal should be disabled for testing using a fixture"
671672

673+
def test_instantiation_with_prompt_templates(self, prompt_templates):
674+
agent = MultiStepAgent(tools=[], model=MagicMock(), prompt_templates=prompt_templates)
675+
assert agent.prompt_templates == prompt_templates
676+
assert agent.prompt_templates["system_prompt"] == "This is a test system prompt."
677+
assert "managed_agent" in agent.prompt_templates
678+
assert agent.prompt_templates["managed_agent"]["task"] == "Task for {{name}}: {{task}}"
679+
assert agent.prompt_templates["managed_agent"]["report"] == "Report for {{name}}: {{final_answer}}"
680+
672681
def test_step_number(self):
673682
fake_model = MagicMock()
674683
fake_model.last_input_token_count = 10
@@ -724,3 +733,11 @@ def test_planning_step_first_step(self):
724733
assert isinstance(content, dict)
725734
assert "type" in content
726735
assert "text" in content
736+
737+
738+
@pytest.fixture
739+
def prompt_templates():
740+
return {
741+
"system_prompt": "This is a test system prompt.",
742+
"managed_agent": {"task": "Task for {{name}}: {{task}}", "report": "Report for {{name}}: {{final_answer}}"},
743+
}

0 commit comments

Comments
 (0)