|
17 | 17 | from contextlib import contextmanager
|
18 | 18 | import json
|
19 | 19 | import os
|
| 20 | +from pathlib import Path |
20 | 21 | import pickle
|
| 22 | +import shutil |
| 23 | +import subprocess |
21 | 24 | import time
|
22 | 25 | from typing import Any, Dict
|
23 | 26 |
|
|
37 | 40 | from ray.train.base_trainer import TrainingFailedError
|
38 | 41 | from ray.train.torch import TorchTrainer
|
39 | 42 |
|
| 43 | + |
40 | 44 | from test_new_persistence import (
|
41 | 45 | train_fn,
|
42 | 46 | _assert_storage_contents,
|
@@ -145,13 +149,35 @@ def strip_prefix(path: str) -> str:
|
145 | 149 | return path.replace("s3://", "").replace("gs://", "")
|
146 | 150 |
|
147 | 151 |
|
| 152 | +def delete_at_uri(uri: str): |
| 153 | + if uri.startswith("s3://"): |
| 154 | + subprocess.check_output(["aws", "s3", "rm", "--recursive", uri]) |
| 155 | + elif uri.startswith("gs://"): |
| 156 | + subprocess.check_output(["gsutil", "-m", "rm", "-r", uri]) |
| 157 | + else: |
| 158 | + raise NotImplementedError(f"Invalid URI: {uri}") |
| 159 | + |
| 160 | + |
| 161 | +def download_from_uri(uri: str, local_path: str): |
| 162 | + if uri.startswith("s3://"): |
| 163 | + subprocess.check_output(["aws", "s3", "cp", "--recursive", uri, local_path]) |
| 164 | + elif uri.startswith("gs://"): |
| 165 | + subprocess.check_output( |
| 166 | + ["gsutil", "-m", "cp", "-r", uri.rstrip("/") + "/*", local_path] |
| 167 | + ) |
| 168 | + else: |
| 169 | + raise NotImplementedError(f"Invalid URI: {uri}") |
| 170 | + |
| 171 | + |
148 | 172 | @pytest.mark.parametrize(
|
149 | 173 | "storage_path_storage_filesystem_label",
|
150 | 174 | [
|
151 |
| - (os.environ["ANYSCALE_ARTIFACT_STORAGE"] + "/test-persistence", None, "cloud"), |
| 175 | + (os.environ["ANYSCALE_ARTIFACT_STORAGE"] + "/test-persistence/", None, "cloud"), |
152 | 176 | ("/mnt/cluster_storage/test-persistence", None, "nfs"),
|
153 | 177 | (
|
154 |
| - strip_prefix(os.environ["ANYSCALE_ARTIFACT_STORAGE"] + "/test-persistence"), |
| 178 | + strip_prefix( |
| 179 | + os.environ["ANYSCALE_ARTIFACT_STORAGE"] + "/test-persistence/" |
| 180 | + ), |
155 | 181 | get_custom_cloud_fs(),
|
156 | 182 | "cloud+custom_fs",
|
157 | 183 | ),
|
@@ -184,21 +210,13 @@ def test_trainer(storage_path_storage_filesystem_label, tmp_path, monkeypatch):
|
184 | 210 | )
|
185 | 211 | exp_name = "test_trainer"
|
186 | 212 |
|
187 |
| - # NOTE: We use fsspec directly for cleaning up the cloud folders and |
188 |
| - # downloading for inspection, since the pyarrow default implementation |
189 |
| - # doesn't delete/download files properly from GCS. |
190 |
| - fsspec_fs, storage_fs_path = ( |
191 |
| - fsspec.core.url_to_fs( |
192 |
| - os.environ["ANYSCALE_ARTIFACT_STORAGE"] + "/test-persistence" |
193 |
| - ) |
194 |
| - if "cloud" in label |
195 |
| - else fsspec.core.url_to_fs(storage_path) |
196 |
| - ) |
197 |
| - experiment_fs_path = os.path.join(storage_fs_path, exp_name) |
198 |
| - if fsspec_fs.exists(experiment_fs_path): |
199 |
| - print("\nDeleting results from a previous run...\n") |
200 |
| - fsspec_fs.rm(experiment_fs_path, recursive=True) |
201 |
| - assert not fsspec_fs.exists(experiment_fs_path) |
| 213 | + print("Deleting files from previous run...") |
| 214 | + if "cloud" in label: |
| 215 | + # NOTE: Use the CLI to delete files on cloud, since the python libraries |
| 216 | + # (pyarrow, fsspec) aren't consistent across cloud platforms (s3, gs). |
| 217 | + delete_at_uri(os.environ["ANYSCALE_ARTIFACT_STORAGE"] + "/test-persistence/") |
| 218 | + else: |
| 219 | + shutil.rmtree(storage_path, ignore_errors=True) |
202 | 220 |
|
203 | 221 | trainer = TorchTrainer(
|
204 | 222 | train_fn,
|
@@ -234,22 +252,24 @@ def test_trainer(storage_path_storage_filesystem_label, tmp_path, monkeypatch):
|
234 | 252 | storage_filesystem=storage_filesystem,
|
235 | 253 | )
|
236 | 254 | result = restored_trainer.fit()
|
237 |
| - |
238 |
| - # First, inspect that the result object returns the correct paths. |
239 | 255 | print(result)
|
240 |
| - trial_fs_path = result.path |
241 |
| - assert trial_fs_path.startswith(storage_fs_path) |
242 |
| - for checkpoint, _ in result.best_checkpoints: |
243 |
| - assert checkpoint.path.startswith(trial_fs_path) |
244 | 256 |
|
245 | 257 | print("\nAsserting contents of uploaded results.\n")
|
246 | 258 | local_inspect_dir = tmp_path / "inspect_dir"
|
247 | 259 | local_inspect_dir.mkdir()
|
248 | 260 | # Download the results from storage
|
249 |
| - fsspec_fs.get(storage_fs_path, str(local_inspect_dir), recursive=True) |
| 261 | + if "cloud" in label: |
| 262 | + # NOTE: Use the CLI to download, since the python libraries |
| 263 | + # (pyarrow, fsspec) aren't consistent across cloud platforms (s3, gs). |
| 264 | + download_from_uri( |
| 265 | + os.environ["ANYSCALE_ARTIFACT_STORAGE"] + "/test-persistence/", |
| 266 | + str(local_inspect_dir), |
| 267 | + ) |
| 268 | + else: |
| 269 | + local_inspect_dir = Path(storage_path) |
250 | 270 |
|
251 | 271 | _assert_storage_contents(
|
252 |
| - local_inspect_dir / "test-persistence", |
| 272 | + local_inspect_dir, |
253 | 273 | exp_name,
|
254 | 274 | checkpoint_config,
|
255 | 275 | "TorchTrainer",
|
|
0 commit comments