Skip to content

Commit 00e934f

Browse files
avishniakovactions-userstefannica
authored
Improve Artifact Store isolation (#2490)
* dir traversal issue * Auto-update of Starter template * Auto-update of NLP template * reroute artifacts and logs via AS * reroute materializers via AS * simplify to one deco * fix materializer tests * allow local download * Auto-update of E2E template * fix test issues * rework based on comments * fix bugs * lint * Candidate (#2493) Co-authored-by: Stefan Nica <[email protected]> * darglint --------- Co-authored-by: GitHub Actions <[email protected]> Co-authored-by: Stefan Nica <[email protected]>
1 parent 683e943 commit 00e934f

17 files changed

+279
-113
lines changed

src/zenml/artifact_stores/base_artifact_store.py

Lines changed: 113 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
# permissions and limitations under the License.
1414
"""The base interface to extend the ZenML artifact store."""
1515

16+
import inspect
1617
import textwrap
1718
from abc import abstractmethod
19+
from pathlib import Path
1820
from typing import (
1921
Any,
2022
Callable,
@@ -44,50 +46,93 @@
4446
PathType = Union[bytes, str]
4547

4648

47-
def _sanitize_potential_path(potential_path: Any) -> Any:
48-
"""Sanitizes the input if it is a path.
49+
class _sanitize_paths:
50+
"""Sanitizes path inputs before calling the original function.
4951
50-
If the input is a **remote** path, this function replaces backslash path
51-
separators by forward slashes.
52+
Extra decoration layer is needed to pass in fixed artifact store root
53+
path for static methods that are called on filesystems directly.
5254
5355
Args:
54-
potential_path: Value that potentially refers to a (remote) path.
56+
func: The function to decorate.
57+
fixed_root_path: The fixed artifact store root path.
58+
is_static: Whether the function is static or not.
5559
5660
Returns:
57-
The original input or a sanitized version of it in case of a remote
58-
path.
61+
Function that calls the input function with sanitized path inputs.
5962
"""
60-
if isinstance(potential_path, bytes):
61-
path = fileio.convert_to_str(potential_path)
62-
elif isinstance(potential_path, str):
63-
path = potential_path
64-
else:
65-
# Neither string nor bytes, this is not a path
66-
return potential_path
6763

68-
if io_utils.is_remote(path):
69-
# If we have a remote path, replace windows path separators with
70-
# slashes
71-
import ntpath
72-
import posixpath
64+
def __init__(self, func: Callable[..., Any], fixed_root_path: str) -> None:
65+
"""Initializes the decorator.
7366
74-
path = path.replace(ntpath.sep, posixpath.sep)
67+
Args:
68+
func: The function to decorate.
69+
fixed_root_path: The fixed artifact store root path.
70+
"""
71+
self.func = func
72+
self.fixed_root_path = fixed_root_path
7573

76-
return path
74+
self.path_args: List[int] = []
75+
self.path_kwargs: List[str] = []
76+
for i, param in enumerate(
77+
inspect.signature(self.func).parameters.values()
78+
):
79+
if param.annotation == PathType:
80+
self.path_kwargs.append(param.name)
81+
if param.default == inspect.Parameter.empty:
82+
self.path_args.append(i)
7783

84+
def _validate_path(self, path: str) -> None:
85+
"""Validates a path.
7886
79-
def _sanitize_paths(_func: Callable[..., Any]) -> Callable[..., Any]:
80-
"""Sanitizes path inputs before calling the original function.
87+
Args:
88+
path: The path to validate.
8189
82-
Args:
83-
_func: The function for which to sanitize the inputs.
90+
Raises:
91+
FileNotFoundError: If the path is outside of the artifact store
92+
bounds.
93+
"""
94+
if not path.startswith(self.fixed_root_path):
95+
raise FileNotFoundError(
96+
f"File `{path}` is outside of "
97+
f"artifact store bounds `{self.fixed_root_path}`"
98+
)
8499

85-
Returns:
86-
Function that calls the input function with sanitized path inputs.
87-
"""
100+
def _sanitize_potential_path(self, potential_path: Any) -> Any:
101+
"""Sanitizes the input if it is a path.
102+
103+
If the input is a **remote** path, this function replaces backslash path
104+
separators by forward slashes.
88105
89-
def inner_function(*args: Any, **kwargs: Any) -> Any:
90-
"""Inner function.
106+
Args:
107+
potential_path: Value that potentially refers to a (remote) path.
108+
109+
Returns:
110+
The original input or a sanitized version of it in case of a remote
111+
path.
112+
"""
113+
if isinstance(potential_path, bytes):
114+
path = fileio.convert_to_str(potential_path)
115+
elif isinstance(potential_path, str):
116+
path = potential_path
117+
else:
118+
# Neither string nor bytes, this is not a path
119+
return potential_path
120+
121+
if io_utils.is_remote(path):
122+
# If we have a remote path, replace windows path separators with
123+
# slashes
124+
import ntpath
125+
import posixpath
126+
127+
path = path.replace(ntpath.sep, posixpath.sep)
128+
self._validate_path(path)
129+
else:
130+
self._validate_path(str(Path(path).absolute().resolve()))
131+
132+
return path
133+
134+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
135+
"""Decorator function that sanitizes paths before calling the original function.
91136
92137
Args:
93138
*args: Positional args.
@@ -96,15 +141,28 @@ def inner_function(*args: Any, **kwargs: Any) -> Any:
96141
Returns:
97142
Output of the input function called with sanitized paths.
98143
"""
99-
args = tuple(_sanitize_potential_path(arg) for arg in args)
144+
# verify if `self` is part of the args
145+
has_self = bool(args and isinstance(args[0], BaseArtifactStore))
146+
147+
# sanitize inputs for relevant args and kwargs, keep rest unchanged
148+
args = tuple(
149+
self._sanitize_potential_path(
150+
arg,
151+
)
152+
if i + has_self in self.path_args
153+
else arg
154+
for i, arg in enumerate(args)
155+
)
100156
kwargs = {
101-
key: _sanitize_potential_path(value)
157+
key: self._sanitize_potential_path(
158+
value,
159+
)
160+
if key in self.path_kwargs
161+
else value
102162
for key, value in kwargs.items()
103163
}
104164

105-
return _func(*args, **kwargs)
106-
107-
return inner_function
165+
return self.func(*args, **kwargs)
108166

109167

110168
class BaseArtifactStoreConfig(StackComponentConfig):
@@ -323,6 +381,7 @@ def stat(self, path: PathType) -> Any:
323381
The stat descriptor.
324382
"""
325383

384+
@abstractmethod
326385
def size(self, path: PathType) -> Optional[int]:
327386
"""Get the size of a file in bytes.
328387
@@ -376,30 +435,30 @@ def _register(self) -> None:
376435
from zenml.io.filesystem_registry import default_filesystem_registry
377436
from zenml.io.local_filesystem import LocalFilesystem
378437

438+
overloads: Dict[str, Any] = {
439+
"SUPPORTED_SCHEMES": self.config.SUPPORTED_SCHEMES,
440+
}
441+
for abc_method in inspect.getmembers(BaseArtifactStore):
442+
if getattr(abc_method[1], "__isabstractmethod__", False):
443+
sanitized_method = _sanitize_paths(
444+
getattr(self, abc_method[0]), self.path
445+
)
446+
# prepare overloads for filesystem methods
447+
overloads[abc_method[0]] = staticmethod(sanitized_method)
448+
449+
# decorate artifact store methods
450+
setattr(
451+
self,
452+
abc_method[0],
453+
sanitized_method,
454+
)
455+
379456
# Local filesystem is always registered, no point in doing it again.
380457
if isinstance(self, LocalFilesystem):
381458
return
382459

383460
filesystem_class = type(
384-
self.__class__.__name__,
385-
(BaseFilesystem,),
386-
{
387-
"SUPPORTED_SCHEMES": self.config.SUPPORTED_SCHEMES,
388-
"open": staticmethod(_sanitize_paths(self.open)),
389-
"copyfile": staticmethod(_sanitize_paths(self.copyfile)),
390-
"exists": staticmethod(_sanitize_paths(self.exists)),
391-
"glob": staticmethod(_sanitize_paths(self.glob)),
392-
"isdir": staticmethod(_sanitize_paths(self.isdir)),
393-
"listdir": staticmethod(_sanitize_paths(self.listdir)),
394-
"makedirs": staticmethod(_sanitize_paths(self.makedirs)),
395-
"mkdir": staticmethod(_sanitize_paths(self.mkdir)),
396-
"remove": staticmethod(_sanitize_paths(self.remove)),
397-
"rename": staticmethod(_sanitize_paths(self.rename)),
398-
"rmtree": staticmethod(_sanitize_paths(self.rmtree)),
399-
"size": staticmethod(_sanitize_paths(self.size)),
400-
"stat": staticmethod(_sanitize_paths(self.stat)),
401-
"walk": staticmethod(_sanitize_paths(self.walk)),
402-
},
461+
self.__class__.__name__, (BaseFilesystem,), overloads
403462
)
404463

405464
default_filesystem_registry.register(filesystem_class)

src/zenml/artifacts/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def save_artifact(
152152
if not uri.startswith(artifact_store.path):
153153
uri = os.path.join(artifact_store.path, uri)
154154

155-
if manual_save and fileio.exists(uri):
155+
if manual_save and artifact_store.exists(uri):
156156
# This check is only necessary for manual saves as we already check
157157
# it when creating the directory for step output artifacts
158158
other_artifacts = client.list_artifact_versions(uri=uri, size=1)
@@ -162,7 +162,7 @@ def save_artifact(
162162
f"{uri} because the URI is already used by artifact "
163163
f"{other_artifact.name} (version {other_artifact.version})."
164164
)
165-
fileio.makedirs(uri)
165+
artifact_store.makedirs(uri)
166166

167167
# Find and initialize the right materializer class
168168
if isinstance(materializer, type):
@@ -752,6 +752,7 @@ def _load_file_from_artifact_store(
752752
Raises:
753753
DoesNotExistException: If the file does not exist in the artifact store.
754754
NotImplementedError: If the artifact store cannot open the file.
755+
IOError: If the artifact store rejects the request.
755756
"""
756757
try:
757758
with artifact_store.open(uri, mode) as text_file:
@@ -761,6 +762,8 @@ def _load_file_from_artifact_store(
761762
f"File '{uri}' does not exist in artifact store "
762763
f"'{artifact_store.name}'."
763764
)
765+
except IOError as e:
766+
raise e
764767
except Exception as e:
765768
logger.exception(e)
766769
link = "https://docs.zenml.io/stacks-and-components/component-guide/artifact-stores/custom#enabling-artifact-visualizations-with-custom-artifact-stores"
@@ -819,7 +822,8 @@ def load_model_from_metadata(model_uri: str) -> Any:
819822
The ML model object loaded into memory.
820823
"""
821824
# Load the model from its metadata
822-
with fileio.open(
825+
artifact_store = Client().active_stack.artifact_store
826+
with artifact_store.open(
823827
os.path.join(model_uri, MODEL_METADATA_YAML_FILE_NAME), "r"
824828
) as f:
825829
metadata = read_yaml(f.name)

src/zenml/logging/step_logging.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from uuid import uuid4
2424

2525
from zenml.artifact_stores import BaseArtifactStore
26-
from zenml.io import fileio
26+
from zenml.client import Client
2727
from zenml.logger import get_logger
2828
from zenml.logging import (
2929
STEP_LOGS_STORAGE_INTERVAL_SECONDS,
@@ -64,6 +64,7 @@ def prepare_logs_uri(
6464
Returns:
6565
The URI of the logs file.
6666
"""
67+
artifact_store = Client().active_stack.artifact_store
6768
if log_key is None:
6869
log_key = str(uuid4())
6970

@@ -74,16 +75,16 @@ def prepare_logs_uri(
7475
)
7576

7677
# Create the dir
77-
if not fileio.exists(logs_base_uri):
78-
fileio.makedirs(logs_base_uri)
78+
if not artifact_store.exists(logs_base_uri):
79+
artifact_store.makedirs(logs_base_uri)
7980

8081
# Delete the file if it already exists
8182
logs_uri = os.path.join(logs_base_uri, f"{log_key}.log")
82-
if fileio.exists(logs_uri):
83+
if artifact_store.exists(logs_uri):
8384
logger.warning(
8485
f"Logs file {logs_uri} already exists! Removing old log file..."
8586
)
86-
fileio.remove(logs_uri)
87+
artifact_store.remove(logs_uri)
8788
return logs_uri
8889

8990

@@ -135,12 +136,13 @@ def write(self, text: str) -> None:
135136

136137
def save_to_file(self) -> None:
137138
"""Method to save the buffer to the given URI."""
139+
artifact_store = Client().active_stack.artifact_store
138140
if not self.disabled:
139141
try:
140142
self.disabled = True
141143

142144
if self.buffer:
143-
with fileio.open(self.logs_uri, "a") as file:
145+
with artifact_store.open(self.logs_uri, "a") as file:
144146
for message in self.buffer:
145147
file.write(
146148
remove_ansi_escape_codes(message) + "\n"

src/zenml/materializers/base_materializer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,9 @@ def save_visualizations(self, data: Any) -> Dict[str, VisualizationType]:
156156
157157
Example:
158158
```
159+
artifact_store = Client().active_stack.artifact_store
159160
visualization_uri = os.path.join(self.uri, "visualization.html")
160-
with fileio.open(visualization_uri, "w") as f:
161+
with artifact_store.open(visualization_uri, "w") as f:
161162
f.write("<html><body>data</body></html>")
162163
163164
visualization_uri_2 = os.path.join(self.uri, "visualization.png")

0 commit comments

Comments
 (0)