Skip to content

Commit 030a83f

Browse files
committed
add tets, correct tool generation values
1 parent f4c8f93 commit 030a83f

File tree

3 files changed

+47
-4
lines changed

3 files changed

+47
-4
lines changed

src/smolagents/tools.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,10 @@ def save(self, output_dir):
235235
from typing import Optional
236236
237237
class {class_name}(Tool):
238-
name = {json.dumps(self.name)}
238+
name = "{self.name}"
239239
description = {json.dumps(textwrap.dedent(self.description).strip())}
240240
inputs = {json.dumps(self.inputs, separators=(",", ":"))}
241-
output_type = {json.dumps(self.output_type)}
241+
output_type = "{self.output_type}"
242242
"""
243243
).strip()
244244
import re

src/smolagents/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,10 +298,12 @@ def instance_to_source(instance, base_cls=None):
298298

299299
for name, value in class_attrs.items():
300300
if isinstance(value, str):
301+
# multiline value
301302
if "\n" in value:
302-
class_lines.append(f' {name} = """{value}"""')
303+
escaped_value = value.replace('"""', r"\"\"\"") # Escape triple quotes
304+
class_lines.append(f' {name} = """{escaped_value}"""')
303305
else:
304-
class_lines.append(f' {name} = "{value}"')
306+
class_lines.append(f" {name} = {json.dumps(value)}")
305307
else:
306308
class_lines.append(f" {name} = {repr(value)}")
307309

tests/test_tools.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import os
1516
import tempfile
1617
import unittest
1718
from pathlib import Path
@@ -420,6 +421,46 @@ def get_weather(locations: List[str], months: Optional[Tuple[str, str]] = None)
420421
assert get_weather.inputs["locations"]["type"] == "array"
421422
assert get_weather.inputs["months"]["type"] == "array"
422423

424+
def test_saving_tool_produces_valid_pyhon_code_with_multiline_description(self):
425+
@tool
426+
def get_weather(location: Any) -> None:
427+
"""
428+
Get weather in the next days at given location.
429+
And works pretty well.
430+
431+
Args:
432+
location: The location to get the weather for.
433+
"""
434+
return
435+
436+
with tempfile.TemporaryDirectory() as tmp_dir:
437+
get_weather.save(tmp_dir)
438+
with open(os.path.join(tmp_dir, "tool.py"), "r", encoding="utf-8") as f:
439+
source_code = f.read()
440+
compile(source_code, f.name, "exec")
441+
442+
def test_saving_tool_produces_valid_python_code_with_complex_name(self):
443+
# Test one cannot save tool with additional args in init
444+
class FailTool(Tool):
445+
name = 'spe"\rcific'
446+
description = """test \n\r
447+
description"""
448+
inputs = {"string_input": {"type": "string", "description": "input description"}}
449+
output_type = "string"
450+
451+
def __init__(self):
452+
super().__init__(self)
453+
454+
def forward(self, string_input):
455+
return "foo"
456+
457+
fail_tool = FailTool()
458+
with tempfile.TemporaryDirectory() as tmp_dir:
459+
fail_tool.save(tmp_dir)
460+
with open(os.path.join(tmp_dir, "tool.py"), "r", encoding="utf-8") as f:
461+
source_code = f.read()
462+
compile(source_code, f.name, "exec")
463+
423464

424465
@pytest.fixture
425466
def mock_server_parameters():

0 commit comments

Comments
 (0)