@@ -984,7 +984,8 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=
984
984
for attr_name in dir (raw_module ):
985
985
# Skip dangerous patterns at any level
986
986
if any (
987
- pattern in raw_module .__name__ .split ("." ) + [attr_name ] and pattern not in authorized_imports
987
+ pattern in raw_module .__name__ .split ("." ) + [attr_name ]
988
+ and not check_module_authorized (pattern , authorized_imports , dangerous_patterns )
988
989
for pattern in dangerous_patterns
989
990
):
990
991
logger .info (f"Skipping dangerous attribute { raw_module .__name__ } .{ attr_name } " )
@@ -1007,6 +1008,18 @@ def get_safe_module(raw_module, dangerous_patterns, authorized_imports, visited=
1007
1008
return safe_module
1008
1009
1009
1010
1011
+ def check_module_authorized (module_name , authorized_imports , dangerous_patterns ):
1012
+ if "*" in authorized_imports :
1013
+ return True
1014
+ else :
1015
+ module_path = module_name .split ("." )
1016
+ if any ([module in dangerous_patterns and module not in authorized_imports for module in module_path ]):
1017
+ return False
1018
+ # ["A", "B", "C"] -> ["A", "A.B", "A.B.C"]
1019
+ module_subpaths = ["." .join (module_path [:i ]) for i in range (1 , len (module_path ) + 1 )]
1020
+ return any (subpath in authorized_imports for subpath in module_subpaths )
1021
+
1022
+
1010
1023
def import_modules (expression , state , authorized_imports ):
1011
1024
dangerous_patterns = (
1012
1025
"_os" ,
@@ -1028,19 +1041,9 @@ def import_modules(expression, state, authorized_imports):
1028
1041
"multiprocessing" ,
1029
1042
)
1030
1043
1031
- def check_module_authorized (module_name ):
1032
- if "*" in authorized_imports :
1033
- return True
1034
- else :
1035
- module_path = module_name .split ("." )
1036
- if any ([module in dangerous_patterns and module not in authorized_imports for module in module_path ]):
1037
- return False
1038
- module_subpaths = ["." .join (module_path [:i ]) for i in range (1 , len (module_path ) + 1 )]
1039
- return any (subpath in authorized_imports for subpath in module_subpaths )
1040
-
1041
1044
if isinstance (expression , ast .Import ):
1042
1045
for alias in expression .names :
1043
- if check_module_authorized (alias .name ):
1046
+ if check_module_authorized (alias .name , authorized_imports , dangerous_patterns ):
1044
1047
raw_module = import_module (alias .name )
1045
1048
state [alias .asname or alias .name ] = get_safe_module (raw_module , dangerous_patterns , authorized_imports )
1046
1049
else :
@@ -1049,7 +1052,7 @@ def check_module_authorized(module_name):
1049
1052
)
1050
1053
return None
1051
1054
elif isinstance (expression , ast .ImportFrom ):
1052
- if check_module_authorized (expression .module ):
1055
+ if check_module_authorized (expression .module , authorized_imports , dangerous_patterns ):
1053
1056
raw_module = __import__ (expression .module , fromlist = [alias .name for alias in expression .names ])
1054
1057
module = get_safe_module (raw_module , dangerous_patterns , authorized_imports )
1055
1058
if expression .names [0 ].name == "*" : # Handle "from module import *"
0 commit comments