|
12 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | 13 | # See the License for the specific language governing permissions and
|
14 | 14 | # limitations under the License.
|
| 15 | +import os |
15 | 16 | import tempfile
|
16 | 17 | import unittest
|
17 | 18 | from pathlib import Path
|
@@ -420,6 +421,46 @@ def get_weather(locations: List[str], months: Optional[Tuple[str, str]] = None)
|
420 | 421 | assert get_weather.inputs["locations"]["type"] == "array"
|
421 | 422 | assert get_weather.inputs["months"]["type"] == "array"
|
422 | 423 |
|
| 424 | + def test_saving_tool_produces_valid_pyhon_code_with_multiline_description(self): |
| 425 | + @tool |
| 426 | + def get_weather(location: Any) -> None: |
| 427 | + """ |
| 428 | + Get weather in the next days at given location. |
| 429 | + And works pretty well. |
| 430 | +
|
| 431 | + Args: |
| 432 | + location: The location to get the weather for. |
| 433 | + """ |
| 434 | + return |
| 435 | + |
| 436 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 437 | + get_weather.save(tmp_dir) |
| 438 | + with open(os.path.join(tmp_dir, "tool.py"), "r", encoding="utf-8") as f: |
| 439 | + source_code = f.read() |
| 440 | + compile(source_code, f.name, "exec") |
| 441 | + |
| 442 | + def test_saving_tool_produces_valid_python_code_with_complex_name(self): |
| 443 | + # Test one cannot save tool with additional args in init |
| 444 | + class FailTool(Tool): |
| 445 | + name = 'spe"\rcific' |
| 446 | + description = """test \n\r |
| 447 | + description""" |
| 448 | + inputs = {"string_input": {"type": "string", "description": "input description"}} |
| 449 | + output_type = "string" |
| 450 | + |
| 451 | + def __init__(self): |
| 452 | + super().__init__(self) |
| 453 | + |
| 454 | + def forward(self, string_input): |
| 455 | + return "foo" |
| 456 | + |
| 457 | + fail_tool = FailTool() |
| 458 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 459 | + fail_tool.save(tmp_dir) |
| 460 | + with open(os.path.join(tmp_dir, "tool.py"), "r", encoding="utf-8") as f: |
| 461 | + source_code = f.read() |
| 462 | + compile(source_code, f.name, "exec") |
| 463 | + |
423 | 464 |
|
424 | 465 | @pytest.fixture
|
425 | 466 | def mock_server_parameters():
|
|
0 commit comments