Skip to content

Fix tests of Agent.save and Tool.save #1029

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,8 +1024,8 @@ def test_from_folder(self, agent_dict_version, get_agent_dict):
assert agent.prompt_templates["system_prompt"] == "dummy system prompt"


class MultiAgentsTests(unittest.TestCase):
def test_multiagents_save(self):
class TestMultiAgents:
def test_multiagents_save(self, tmp_path):
model = HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct", max_tokens=2096, temperature=0.5)

web_agent = ToolCallingAgent(
Expand All @@ -1045,7 +1045,7 @@ def test_multiagents_save(self):
executor_type="local",
executor_kwargs={"max_workers": 2},
)
agent.save("agent_export")
agent.save(tmp_path)

expected_structure = {
"managed_agents": {
Expand Down Expand Up @@ -1074,10 +1074,10 @@ def verify_structure(current_path: Path, structure: dict):
assert file_path.exists(), f"File {file_path} does not exist"
assert file_path.is_file(), f"{file_path} is not a file"

verify_structure(Path("agent_export"), expected_structure)
verify_structure(tmp_path, expected_structure)

# Test that re-loaded agents work as expected.
agent2 = CodeAgent.from_folder("agent_export", planning_interval=5)
agent2 = CodeAgent.from_folder(tmp_path, planning_interval=5)
assert agent2.planning_interval == 5 # Check that kwargs are used
assert set(agent2.authorized_imports) == set(["pandas", "datetime"] + BASE_BUILTIN_MODULES)
assert agent2.max_print_outputs_length == 1000
Expand Down
43 changes: 19 additions & 24 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
from pathlib import Path
from textwrap import dedent
from typing import Any, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -100,7 +98,7 @@ def test_agent_type_output(self):
self.assertTrue(isinstance(output, agent_type))


class ToolTests(unittest.TestCase):
class TestTool:
def test_tool_init_with_decorator(self):
@tool
def coolfunc(a: str, b: int) -> float:
Expand Down Expand Up @@ -165,7 +163,7 @@ def coolfunc(a: str, b: int) -> int:
assert coolfunc.output_type == "number"
assert "docstring has no description for the argument" in str(e)

def test_saving_tool_raises_error_imports_outside_function(self):
def test_saving_tool_raises_error_imports_outside_function(self, tmp_path):
with pytest.raises(Exception) as e:
import numpy as np

Expand All @@ -176,7 +174,7 @@ def get_current_time() -> str:
"""
return str(np.random.random())

get_current_time.save("output")
get_current_time.save(tmp_path)

assert "np" in str(e)

Expand All @@ -193,7 +191,7 @@ def forward(self):
return str(np.random.random())

get_current_time = GetCurrentTimeTool()
get_current_time.save("output")
get_current_time.save(tmp_path)

assert "np" in str(e)

Expand Down Expand Up @@ -255,7 +253,7 @@ def forward(self, string_input: str) -> str:
fail_tool = PassTool()
fail_tool.to_dict()

def test_saving_tool_allows_no_imports_from_outside_methods(self):
def test_saving_tool_allows_no_imports_from_outside_methods(self, tmp_path):
# Test that using imports from outside functions fails
import numpy as np

Expand All @@ -274,7 +272,7 @@ def forward(self, string_input):

fail_tool = FailTool()
with pytest.raises(Exception) as e:
fail_tool.save("output")
fail_tool.save(tmp_path)
assert "'np' is undefined" in str(e)

# Test that putting these imports inside functions works
Expand All @@ -294,7 +292,7 @@ def forward(self, string_input):
return self.useless_method() + string_input

success_tool = SuccessTool()
success_tool.save("output")
success_tool.save(tmp_path)

def test_tool_missing_class_attributes_raises_error(self):
with pytest.raises(Exception) as e:
Expand Down Expand Up @@ -409,7 +407,7 @@ def get_weather(location: str, celsius: bool = False) -> str:

assert get_weather.inputs["celsius"]["nullable"]

def test_tool_supports_any_none(self):
def test_tool_supports_any_none(self, tmp_path):
@tool
def get_weather(location: Any) -> None:
"""
Expand All @@ -420,8 +418,7 @@ def get_weather(location: Any) -> None:
"""
return

with tempfile.TemporaryDirectory() as tmp_dir:
get_weather.save(tmp_dir)
get_weather.save(tmp_path)
assert get_weather.inputs["location"]["type"] == "any"
assert get_weather.output_type == "null"

Expand All @@ -440,7 +437,7 @@ def get_weather(locations: List[str], months: Optional[Tuple[str, str]] = None)
assert get_weather.inputs["locations"]["type"] == "array"
assert get_weather.inputs["months"]["type"] == "array"

def test_saving_tool_produces_valid_pyhon_code_with_multiline_description(self):
def test_saving_tool_produces_valid_pyhon_code_with_multiline_description(self, tmp_path):
@tool
def get_weather(location: Any) -> None:
"""
Expand All @@ -452,13 +449,12 @@ def get_weather(location: Any) -> None:
"""
return

with tempfile.TemporaryDirectory() as tmp_dir:
get_weather.save(tmp_dir)
with open(os.path.join(tmp_dir, "tool.py"), "r", encoding="utf-8") as f:
source_code = f.read()
compile(source_code, f.name, "exec")
get_weather.save(tmp_path)
with open(os.path.join(tmp_path, "tool.py"), "r", encoding="utf-8") as f:
source_code = f.read()
compile(source_code, f.name, "exec")

def test_saving_tool_produces_valid_python_code_with_complex_name(self):
def test_saving_tool_produces_valid_python_code_with_complex_name(self, tmp_path):
# Test one cannot save tool with additional args in init
class FailTool(Tool):
name = 'spe"\rcific'
Expand All @@ -474,11 +470,10 @@ def forward(self, string_input):
return "foo"

fail_tool = FailTool()
with tempfile.TemporaryDirectory() as tmp_dir:
fail_tool.save(tmp_dir)
with open(os.path.join(tmp_dir, "tool.py"), "r", encoding="utf-8") as f:
source_code = f.read()
compile(source_code, f.name, "exec")
fail_tool.save(tmp_path)
with open(os.path.join(tmp_path, "tool.py"), "r", encoding="utf-8") as f:
source_code = f.read()
compile(source_code, f.name, "exec")


@pytest.fixture
Expand Down
Loading