13
13
# permissions and limitations under the License.
14
14
"""The base interface to extend the ZenML artifact store."""
15
15
16
+ import inspect
16
17
import textwrap
17
18
from abc import abstractmethod
19
+ from pathlib import Path
18
20
from typing import (
19
21
Any ,
20
22
Callable ,
44
46
PathType = Union [bytes , str ]
45
47
46
48
47
- def _sanitize_potential_path ( potential_path : Any ) -> Any :
48
- """Sanitizes the input if it is a path .
49
+ class _sanitize_paths :
50
+ """Sanitizes path inputs before calling the original function .
49
51
50
- If the input is a **remote** path, this function replaces backslash path
51
- separators by forward slashes .
52
+ Extra decoration layer is needed to pass in fixed artifact store root
53
+ path for static methods that are called on filesystems directly .
52
54
53
55
Args:
54
- potential_path: Value that potentially refers to a (remote) path.
56
+ func: The function to decorate.
57
+ fixed_root_path: The fixed artifact store root path.
58
+ is_static: Whether the function is static or not.
55
59
56
60
Returns:
57
- The original input or a sanitized version of it in case of a remote
58
- path.
61
+ Function that calls the input function with sanitized path inputs.
59
62
"""
60
- if isinstance (potential_path , bytes ):
61
- path = fileio .convert_to_str (potential_path )
62
- elif isinstance (potential_path , str ):
63
- path = potential_path
64
- else :
65
- # Neither string nor bytes, this is not a path
66
- return potential_path
67
63
68
- if io_utils .is_remote (path ):
69
- # If we have a remote path, replace windows path separators with
70
- # slashes
71
- import ntpath
72
- import posixpath
64
+ def __init__ (self , func : Callable [..., Any ], fixed_root_path : str ) -> None :
65
+ """Initializes the decorator.
73
66
74
- path = path .replace (ntpath .sep , posixpath .sep )
67
+ Args:
68
+ func: The function to decorate.
69
+ fixed_root_path: The fixed artifact store root path.
70
+ """
71
+ self .func = func
72
+ self .fixed_root_path = fixed_root_path
75
73
76
- return path
74
+ self .path_args : List [int ] = []
75
+ self .path_kwargs : List [str ] = []
76
+ for i , param in enumerate (
77
+ inspect .signature (self .func ).parameters .values ()
78
+ ):
79
+ if param .annotation == PathType :
80
+ self .path_kwargs .append (param .name )
81
+ if param .default == inspect .Parameter .empty :
82
+ self .path_args .append (i )
77
83
84
+ def _validate_path (self , path : str ) -> None :
85
+ """Validates a path.
78
86
79
- def _sanitize_paths ( _func : Callable [..., Any ]) -> Callable [..., Any ] :
80
- """Sanitizes path inputs before calling the original function .
87
+ Args :
88
+ path: The path to validate .
81
89
82
- Args:
83
- _func: The function for which to sanitize the inputs.
90
+ Raises:
91
+ FileNotFoundError: If the path is outside of the artifact store
92
+ bounds.
93
+ """
94
+ if not path .startswith (self .fixed_root_path ):
95
+ raise FileNotFoundError (
96
+ f"File `{ path } ` is outside of "
97
+ f"artifact store bounds `{ self .fixed_root_path } `"
98
+ )
84
99
85
- Returns:
86
- Function that calls the input function with sanitized path inputs.
87
- """
100
+ def _sanitize_potential_path (self , potential_path : Any ) -> Any :
101
+ """Sanitizes the input if it is a path.
102
+
103
+ If the input is a **remote** path, this function replaces backslash path
104
+ separators by forward slashes.
88
105
89
- def inner_function (* args : Any , ** kwargs : Any ) -> Any :
90
- """Inner function.
106
+ Args:
107
+ potential_path: Value that potentially refers to a (remote) path.
108
+
109
+ Returns:
110
+ The original input or a sanitized version of it in case of a remote
111
+ path.
112
+ """
113
+ if isinstance (potential_path , bytes ):
114
+ path = fileio .convert_to_str (potential_path )
115
+ elif isinstance (potential_path , str ):
116
+ path = potential_path
117
+ else :
118
+ # Neither string nor bytes, this is not a path
119
+ return potential_path
120
+
121
+ if io_utils .is_remote (path ):
122
+ # If we have a remote path, replace windows path separators with
123
+ # slashes
124
+ import ntpath
125
+ import posixpath
126
+
127
+ path = path .replace (ntpath .sep , posixpath .sep )
128
+ self ._validate_path (path )
129
+ else :
130
+ self ._validate_path (str (Path (path ).absolute ().resolve ()))
131
+
132
+ return path
133
+
134
+ def __call__ (self , * args : Any , ** kwargs : Any ) -> Any :
135
+ """Decorator function that sanitizes paths before calling the original function.
91
136
92
137
Args:
93
138
*args: Positional args.
@@ -96,15 +141,28 @@ def inner_function(*args: Any, **kwargs: Any) -> Any:
96
141
Returns:
97
142
Output of the input function called with sanitized paths.
98
143
"""
99
- args = tuple (_sanitize_potential_path (arg ) for arg in args )
144
+ # verify if `self` is part of the args
145
+ has_self = bool (args and isinstance (args [0 ], BaseArtifactStore ))
146
+
147
+ # sanitize inputs for relevant args and kwargs, keep rest unchanged
148
+ args = tuple (
149
+ self ._sanitize_potential_path (
150
+ arg ,
151
+ )
152
+ if i + has_self in self .path_args
153
+ else arg
154
+ for i , arg in enumerate (args )
155
+ )
100
156
kwargs = {
101
- key : _sanitize_potential_path (value )
157
+ key : self ._sanitize_potential_path (
158
+ value ,
159
+ )
160
+ if key in self .path_kwargs
161
+ else value
102
162
for key , value in kwargs .items ()
103
163
}
104
164
105
- return _func (* args , ** kwargs )
106
-
107
- return inner_function
165
+ return self .func (* args , ** kwargs )
108
166
109
167
110
168
class BaseArtifactStoreConfig (StackComponentConfig ):
@@ -323,6 +381,7 @@ def stat(self, path: PathType) -> Any:
323
381
The stat descriptor.
324
382
"""
325
383
384
+ @abstractmethod
326
385
def size (self , path : PathType ) -> Optional [int ]:
327
386
"""Get the size of a file in bytes.
328
387
@@ -376,30 +435,30 @@ def _register(self) -> None:
376
435
from zenml .io .filesystem_registry import default_filesystem_registry
377
436
from zenml .io .local_filesystem import LocalFilesystem
378
437
438
+ overloads : Dict [str , Any ] = {
439
+ "SUPPORTED_SCHEMES" : self .config .SUPPORTED_SCHEMES ,
440
+ }
441
+ for abc_method in inspect .getmembers (BaseArtifactStore ):
442
+ if getattr (abc_method [1 ], "__isabstractmethod__" , False ):
443
+ sanitized_method = _sanitize_paths (
444
+ getattr (self , abc_method [0 ]), self .path
445
+ )
446
+ # prepare overloads for filesystem methods
447
+ overloads [abc_method [0 ]] = staticmethod (sanitized_method )
448
+
449
+ # decorate artifact store methods
450
+ setattr (
451
+ self ,
452
+ abc_method [0 ],
453
+ sanitized_method ,
454
+ )
455
+
379
456
# Local filesystem is always registered, no point in doing it again.
380
457
if isinstance (self , LocalFilesystem ):
381
458
return
382
459
383
460
filesystem_class = type (
384
- self .__class__ .__name__ ,
385
- (BaseFilesystem ,),
386
- {
387
- "SUPPORTED_SCHEMES" : self .config .SUPPORTED_SCHEMES ,
388
- "open" : staticmethod (_sanitize_paths (self .open )),
389
- "copyfile" : staticmethod (_sanitize_paths (self .copyfile )),
390
- "exists" : staticmethod (_sanitize_paths (self .exists )),
391
- "glob" : staticmethod (_sanitize_paths (self .glob )),
392
- "isdir" : staticmethod (_sanitize_paths (self .isdir )),
393
- "listdir" : staticmethod (_sanitize_paths (self .listdir )),
394
- "makedirs" : staticmethod (_sanitize_paths (self .makedirs )),
395
- "mkdir" : staticmethod (_sanitize_paths (self .mkdir )),
396
- "remove" : staticmethod (_sanitize_paths (self .remove )),
397
- "rename" : staticmethod (_sanitize_paths (self .rename )),
398
- "rmtree" : staticmethod (_sanitize_paths (self .rmtree )),
399
- "size" : staticmethod (_sanitize_paths (self .size )),
400
- "stat" : staticmethod (_sanitize_paths (self .stat )),
401
- "walk" : staticmethod (_sanitize_paths (self .walk )),
402
- },
461
+ self .__class__ .__name__ , (BaseFilesystem ,), overloads
403
462
)
404
463
405
464
default_filesystem_registry .register (filesystem_class )
0 commit comments