Skip to content

Commit 9318c8c

Browse files
CalOmnieaymeric-roucher
andauthored
Move check_module_authorized out of import_module for use in get_safe_module (#507)
Co-authored-by: Aymeric Roucher <[email protected]>
1 parent 8c6f90c commit 9318c8c

File tree

2 files changed

+57
-13
lines changed

2 files changed

+57
-13
lines changed

src/smolagents/local_python_executor.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -984,7 +984,8 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=
984984
for attr_name in dir(raw_module):
985985
# Skip dangerous patterns at any level
986986
if any(
987-
pattern in raw_module.__name__.split(".") + [attr_name] and pattern not in authorized_imports
987+
pattern in raw_module.__name__.split(".") + [attr_name]
988+
and not check_module_authorized(pattern, authorized_imports, dangerous_patterns)
988989
for pattern in dangerous_patterns
989990
):
990991
logger.info(f"Skipping dangerous attribute {raw_module.__name__}.{attr_name}")
@@ -1007,6 +1008,18 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=
10071008
return safe_module
10081009

10091010

1011+
def check_module_authorized(module_name, authorized_imports, dangerous_patterns):
1012+
if "*" in authorized_imports:
1013+
return True
1014+
else:
1015+
module_path = module_name.split(".")
1016+
if any([module in dangerous_patterns and module not in authorized_imports for module in module_path]):
1017+
return False
1018+
# ["A", "B", "C"] -> ["A", "A.B", "A.B.C"]
1019+
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
1020+
return any(subpath in authorized_imports for subpath in module_subpaths)
1021+
1022+
10101023
def import_modules(expression, state, authorized_imports):
10111024
dangerous_patterns = (
10121025
"_os",
@@ -1028,19 +1041,9 @@ def import_modules(expression, state, authorized_imports):
10281041
"multiprocessing",
10291042
)
10301043

1031-
def check_module_authorized(module_name):
1032-
if "*" in authorized_imports:
1033-
return True
1034-
else:
1035-
module_path = module_name.split(".")
1036-
if any([module in dangerous_patterns and module not in authorized_imports for module in module_path]):
1037-
return False
1038-
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
1039-
return any(subpath in authorized_imports for subpath in module_subpaths)
1040-
10411044
if isinstance(expression, ast.Import):
10421045
for alias in expression.names:
1043-
if check_module_authorized(alias.name):
1046+
if check_module_authorized(alias.name, authorized_imports, dangerous_patterns):
10441047
raw_module = import_module(alias.name)
10451048
state[alias.asname or alias.name] = get_safe_module(raw_module, dangerous_patterns, authorized_imports)
10461049
else:
@@ -1049,7 +1052,7 @@ def check_module_authorized(module_name):
10491052
)
10501053
return None
10511054
elif isinstance(expression, ast.ImportFrom):
1052-
if check_module_authorized(expression.module):
1055+
if check_module_authorized(expression.module, authorized_imports, dangerous_patterns):
10531056
raw_module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
10541057
module = get_safe_module(raw_module, dangerous_patterns, authorized_imports)
10551058
if expression.names[0].name == "*": # Handle "from module import *"

tests/test_local_python_executor.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from smolagents.local_python_executor import (
2626
InterpreterError,
2727
PrintContainer,
28+
check_module_authorized,
2829
evaluate_delete,
2930
evaluate_python_code,
3031
fix_final_answer_code,
@@ -975,6 +976,10 @@ def test_can_import_os_if_explicitly_authorized(self):
975976
dangerous_code = "import os; os.listdir('./')"
976977
evaluate_python_code(dangerous_code, authorized_imports=["os"])
977978

979+
def test_can_import_os_if_all_imports_authorized(self):
980+
dangerous_code = "import os; os.listdir('./')"
981+
evaluate_python_code(dangerous_code, authorized_imports=["*"])
982+
978983

979984
@pytest.mark.parametrize(
980985
"code, expected_result",
@@ -1205,3 +1210,39 @@ def test_len(self):
12051210
pc = PrintContainer()
12061211
pc.append("Hello")
12071212
assert len(pc) == 5
1213+
1214+
1215+
@pytest.mark.parametrize(
1216+
"module,authorized_imports,expected",
1217+
[
1218+
("os", ["*"], True),
1219+
("AnyModule", ["*"], True),
1220+
("os", ["os"], True),
1221+
("AnyModule", ["AnyModule"], True),
1222+
("Module.os", ["Module"], False),
1223+
("Module.os", ["Module", "os"], True),
1224+
("os.path", ["os"], True),
1225+
("os", ["os.path"], False),
1226+
],
1227+
)
1228+
def test_check_module_authorized(module: str, authorized_imports: list[str], expected: bool):
1229+
dangerous_patterns = (
1230+
"_os",
1231+
"os",
1232+
"subprocess",
1233+
"_subprocess",
1234+
"pty",
1235+
"system",
1236+
"popen",
1237+
"spawn",
1238+
"shutil",
1239+
"sys",
1240+
"pathlib",
1241+
"io",
1242+
"socket",
1243+
"compile",
1244+
"eval",
1245+
"exec",
1246+
"multiprocessing",
1247+
)
1248+
assert check_module_authorized(module, authorized_imports, dangerous_patterns) == expected

0 commit comments

Comments
 (0)