Skip to content

Move check_module_authorized out of import_module for use in get_safe_module #507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions src/smolagents/local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,8 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=
for attr_name in dir(raw_module):
# Skip dangerous patterns at any level
if any(
pattern in raw_module.__name__.split(".") + [attr_name] and pattern not in authorized_imports
pattern in raw_module.__name__.split(".") + [attr_name]
and not check_module_authorized(pattern, authorized_imports, dangerous_patterns)
for pattern in dangerous_patterns
):
logger.info(f"Skipping dangerous attribute {raw_module.__name__}.{attr_name}")
Expand All @@ -1007,6 +1008,18 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=
return safe_module


def check_module_authorized(module_name, authorized_imports, dangerous_patterns):
if "*" in authorized_imports:
Copy link
Contributor

@sysradium sysradium Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am actually not sure it's a good thing. If understand correctly, for example, given itertools.* instead of allowing all sub-packages of itertools, it will allow everything.

Copy link
Contributor Author

@CalOmnie CalOmnie Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you might misunderstand. Basically the current issue is that get_safe_module does not recognize the special additional import * while import_module does. The main effect of that is that when you use "*", you can execute otherwise forbidden import statements e.g. import os, but those modules will be gutted of all their attributes because get_safe_module will consider them all "dangerous".

See this example:

from smolagents import CodeAgent, HfApiModel, tool
from types import ModuleType


model = HfApiModel()

@tool
def print_dir(module: ModuleType) -> None:
    """
    print the dir
    Args:
        module: The module to print dir of
    """
    import os
    print("module dir is", dir(module))
    print("os dir is", dir(os))
agent = CodeAgent(model=model, additional_authorized_imports=["*"], tools=[print_dir])

code = """
import os

print_dir(os)
"""

agent.run(f"Run the following code: {code}")

This is the output:

module dir is ['__doc__', '__loader__', '__name__', '__package__', '__spec__']
os dir is ['CLD_CONTINUED', 'CLD_DUMPED', 'CLD_EXITED', 'CLD_KILLED', 'CLD_STOPPED', 'CLD_TRAPPED', 'DirEntry', 'EFD_CLOEXEC', 'EFD_NONBLOCK', 'EFD_SEMAPHORE', 'EX_CANTCREAT', 'EX_CONFIG', 'EX_DATAERR', 'EX_IOERR', 'EX_NOHOST', 'EX_NOINPUT', 'EX_NOPERM', 'EX_NOUSER', 'EX_OK', 'EX_OSERR', 'EX_OSFILE', 'EX_PROTOCOL', 'EX_SOFTWARE', 'EX_TEMPFAIL', 'EX_UNAVAILABLE', 'EX_USAGE', 'F_LOCK', 'F_OK', 'F_TEST', 'F_TLOCK', 'F_ULOCK', 'GRND_NONBLOCK', 'GRND_RANDOM', 'GenericAlias', 'MFD_ALLOW_SEALING', 'MFD_CLOEXEC', 'MFD_HUGETLB', 'MFD_HUGE_16GB', 'MFD_HUGE_16MB', 'MFD_HUGE_1GB', 'MFD_HUGE_1MB', 'MFD_HUGE_256MB', 'MFD_HUGE_2GB', 'MFD_HUGE_2MB', 'MFD_HUGE_32MB', 'MFD_HUGE_512KB', 'MFD_HUGE_512MB', 'MFD_HUGE_64KB', 'MFD_HUGE_8MB', 'MFD_HUGE_MASK', 'MFD_HUGE_SHIFT', 'Mapping', 'MutableMapping', 'NGROUPS_MAX', 'O_ACCMODE', 'O_APPEND', 'O_ASYNC', 'O_CLOEXEC', 'O_CREAT', 'O_DIRECT', 'O_DIRECTORY', 'O_DSYNC', 'O_EXCL', 'O_FSYNC', 'O_LARGEFILE', 'O_NDELAY', 'O_NOATIME', 'O_NOCTTY', 'O_NOFOLLOW', 'O_NONBLOCK', 'O_PATH', 'O_RDONLY', 'O_RDWR', 'O_RSYNC', 'O_SYNC', 'O_TMPFILE', 'O_TRUNC', 'O_WRONLY', 'POSIX_FADV_DONTNEED', 'POSIX_FADV_NOREUSE', 'POSIX_FADV_NORMAL', 'POSIX_FADV_RANDOM', 'POSIX_FADV_SEQUENTIAL', 'POSIX_FADV_WILLNEED', 'POSIX_SPAWN_CLOSE', 'POSIX_SPAWN_DUP2', 'POSIX_SPAWN_OPEN', 'PRIO_PGRP', 'PRIO_PROCESS', 'PRIO_USER', 'P_ALL', 'P_NOWAIT', 'P_NOWAITO', 'P_PGID', 'P_PID', 'P_PIDFD', 'P_WAIT', 'PathLike', 'RTLD_DEEPBIND', 'RTLD_GLOBAL', 'RTLD_LAZY', 'RTLD_LOCAL', 'RTLD_NODELETE', 'RTLD_NOLOAD', 'RTLD_NOW', 'RWF_APPEND', 'RWF_DSYNC', 'RWF_HIPRI', 'RWF_NOWAIT', 'RWF_SYNC', 'R_OK', 'SCHED_BATCH', 'SCHED_FIFO', 'SCHED_IDLE', 'SCHED_OTHER', 'SCHED_RESET_ON_FORK', 'SCHED_RR', 'SEEK_CUR', 'SEEK_DATA', 'SEEK_END', 'SEEK_HOLE', 'SEEK_SET', 'SPLICE_F_MORE', 'SPLICE_F_MOVE', 'SPLICE_F_NONBLOCK', 'ST_APPEND', 'ST_MANDLOCK', 'ST_NOATIME', 'ST_NODEV', 'ST_NODIRATIME', 'ST_NOEXEC', 'ST_NOSUID', 'ST_RDONLY', 'ST_RELATIME', 'ST_SYNCHRONOUS', 'ST_WRITE', 'TMP_MAX', 'WCONTINUED', 'WCOREDUMP', 'WEXITED', 'WEXITSTATUS', 'WIFCONTINUED', 'WIFEXITED', 'WIFSIGNALED', 'WIFSTOPPED', 'WNOHANG', 'WNOWAIT', 'WSTOPPED', 'WSTOPSIG', 'WTERMSIG', 'WUNTRACED', 'W_OK', 'XATTR_CREATE', 'XATTR_REPLACE', 'XATTR_SIZE_MAX', 'X_OK', '_Environ', '__all__', '__builtins__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__spec__', '_check_methods', '_execvpe', '_exists', '_exit', '_fspath', '_fwalk', '_get_exports_list', '_spawnvef', '_walk', '_wrap_close', 'abc', 'abort', 'access', 'altsep', 'chdir', 'chmod', 'chown', 'chroot', 'close', 'closerange', 'confstr', 'confstr_names', 'copy_file_range', 'cpu_count', 'ctermid', 'curdir', 'defpath', 'device_encoding', 'devnull', 'dup', 'dup2', 'environ', 'environb', 'error', 'eventfd', 'eventfd_read', 'eventfd_write', 'execl', 'execle', 'execlp', 'execlpe', 'execv', 'execve', 'execvp', 'execvpe', 'extsep', 'fchdir', 'fchmod', 'fchown', 'fdatasync', 'fdopen', 'fork', 'forkpty', 'fpathconf', 'fsdecode', 'fsencode', 'fspath', 'fstat', 'fstatvfs', 'fsync', 'ftruncate', 'fwalk', 'get_blocking', 'get_exec_path', 'get_inheritable', 'get_terminal_size', 'getcwd', 'getcwdb', 'getegid', 'getenv', 'getenvb', 'geteuid', 'getgid', 'getgrouplist', 'getgroups', 'getloadavg', 'getlogin', 'getpgid', 'getpgrp', 'getpid', 'getppid', 'getpriority', 'getrandom', 'getresgid', 'getresuid', 'getsid', 'getuid', 'getxattr', 'initgroups', 'isatty', 'kill', 'killpg', 'lchown', 'linesep', 'link', 'listdir', 'listxattr', 'lockf', 'login_tty', 'lseek', 'lstat', 'major', 'makedev', 'makedirs', 'memfd_create', 'minor', 'mkdir', 'mkfifo', 'mknod', 'name', 'nice', 'open', 'openpty', 'pardir', 'path', 'pathconf', 'pathconf_names', 'pathsep', 'pidfd_open', 'pipe', 'pipe2', 'popen', 'posix_fadvise', 'posix_fallocate', 'posix_spawn', 'posix_spawnp', 'pread', 'preadv', 'putenv', 'pwrite', 'pwritev', 'read', 'readlink', 'readv', 'register_at_fork', 'remove', 'removedirs', 'removexattr', 'rename', 'renames', 'replace', 'rmdir', 'scandir', 'sched_get_priority_max', 'sched_get_priority_min', 'sched_getaffinity', 'sched_getparam', 'sched_getscheduler', 'sched_param', 'sched_rr_get_interval', 'sched_setaffinity', 'sched_setparam', 'sched_setscheduler', 'sched_yield', 'sendfile', 'sep', 'set_blocking', 'set_inheritable', 'setegid', 'seteuid', 'setgid', 'setgroups', 'setpgid', 'setpgrp', 'setpriority', 'setregid', 'setresgid', 'setresuid', 'setreuid', 'setsid', 'setuid', 'setxattr', 'spawnl', 'spawnle', 'spawnlp', 'spawnlpe', 'spawnv', 'spawnve', 'spawnvp', 'spawnvpe', 'splice', 'st', 'stat', 'stat_result', 'statvfs', 'statvfs_result', 'strerror', 'supports_bytes_environ', 'supports_dir_fd', 'supports_effective_ids', 'supports_fd', 'supports_follow_symlinks', 'symlink', 'sync', 'sys', 'sysconf', 'sysconf_names', 'system', 'tcgetpgrp', 'tcsetpgrp', 'terminal_size', 'times', 'times_result', 'truncate', 'ttyname', 'umask', 'uname', 'uname_result', 'unlink', 'unsetenv', 'urandom', 'utime', 'wait', 'wait3', 'wait4', 'waitid', 'waitid_result', 'waitpid', 'waitstatus_to_exitcode', 'walk', 'write', 'writev']

As you can see the "os" that was import in the code run by the agent does not have any function.
If you run the same code from my branch, you will see that the two dirs are equal, which I believe should be the case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mhm, probably I have misread the code. I.e. for some reason I have treated authorized_imports as a string. My worry was that if it is, then if "*" in authrorized_imports would have worked equally for authrozed_imports='*' and 'authorized_imports='foo.bar.*'. My bad.

return True
else:
module_path = module_name.split(".")
if any([module in dangerous_patterns and module not in authorized_imports for module in module_path]):
return False
# ["A", "B", "C"] -> ["A", "A.B", "A.B.C"]
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
return any(subpath in authorized_imports for subpath in module_subpaths)


def import_modules(expression, state, authorized_imports):
dangerous_patterns = (
"_os",
Expand All @@ -1028,19 +1041,9 @@ def import_modules(expression, state, authorized_imports):
"multiprocessing",
)

def check_module_authorized(module_name):
if "*" in authorized_imports:
return True
else:
module_path = module_name.split(".")
if any([module in dangerous_patterns and module not in authorized_imports for module in module_path]):
return False
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
return any(subpath in authorized_imports for subpath in module_subpaths)

if isinstance(expression, ast.Import):
for alias in expression.names:
if check_module_authorized(alias.name):
if check_module_authorized(alias.name, authorized_imports, dangerous_patterns):
raw_module = import_module(alias.name)
state[alias.asname or alias.name] = get_safe_module(raw_module, dangerous_patterns, authorized_imports)
else:
Expand All @@ -1049,7 +1052,7 @@ def check_module_authorized(module_name):
)
return None
elif isinstance(expression, ast.ImportFrom):
if check_module_authorized(expression.module):
if check_module_authorized(expression.module, authorized_imports, dangerous_patterns):
raw_module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
module = get_safe_module(raw_module, dangerous_patterns, authorized_imports)
if expression.names[0].name == "*": # Handle "from module import *"
Expand Down
41 changes: 41 additions & 0 deletions tests/test_local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from smolagents.local_python_executor import (
InterpreterError,
PrintContainer,
check_module_authorized,
evaluate_delete,
evaluate_python_code,
fix_final_answer_code,
Expand Down Expand Up @@ -975,6 +976,10 @@ def test_can_import_os_if_explicitly_authorized(self):
dangerous_code = "import os; os.listdir('./')"
evaluate_python_code(dangerous_code, authorized_imports=["os"])

def test_can_import_os_if_all_imports_authorized(self):
dangerous_code = "import os; os.listdir('./')"
evaluate_python_code(dangerous_code, authorized_imports=["*"])


@pytest.mark.parametrize(
"code, expected_result",
Expand Down Expand Up @@ -1205,3 +1210,39 @@ def test_len(self):
pc = PrintContainer()
pc.append("Hello")
assert len(pc) == 5


@pytest.mark.parametrize(
"module,authorized_imports,expected",
[
("os", ["*"], True),
("AnyModule", ["*"], True),
("os", ["os"], True),
("AnyModule", ["AnyModule"], True),
("Module.os", ["Module"], False),
("Module.os", ["Module", "os"], True),
("os.path", ["os"], True),
("os", ["os.path"], False),
],
)
def test_check_module_authorized(module: str, authorized_imports: list[str], expected: bool):
dangerous_patterns = (
"_os",
"os",
"subprocess",
"_subprocess",
"pty",
"system",
"popen",
"spawn",
"shutil",
"sys",
"pathlib",
"io",
"socket",
"compile",
"eval",
"exec",
"multiprocessing",
)
assert check_module_authorized(module, authorized_imports, dangerous_patterns) == expected