Skip to content

Commit 8c1aa0a

Browse files
authored
[train] Fix train_multinode_persistence release test (#39563) (#39589)
Signed-off-by: Justin Yu <[email protected]>
1 parent eaf6ffd commit 8c1aa0a

File tree

1 file changed

+45
-25
lines changed

1 file changed

+45
-25
lines changed

release/train_tests/e2e/test_persistence.py

+45-25
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
from contextlib import contextmanager
1818
import json
1919
import os
20+
from pathlib import Path
2021
import pickle
22+
import shutil
23+
import subprocess
2124
import time
2225
from typing import Any, Dict
2326

@@ -37,6 +40,7 @@
3740
from ray.train.base_trainer import TrainingFailedError
3841
from ray.train.torch import TorchTrainer
3942

43+
4044
from test_new_persistence import (
4145
train_fn,
4246
_assert_storage_contents,
@@ -145,13 +149,35 @@ def strip_prefix(path: str) -> str:
145149
return path.replace("s3://", "").replace("gs://", "")
146150

147151

152+
def delete_at_uri(uri: str):
153+
if uri.startswith("s3://"):
154+
subprocess.check_output(["aws", "s3", "rm", "--recursive", uri])
155+
elif uri.startswith("gs://"):
156+
subprocess.check_output(["gsutil", "-m", "rm", "-r", uri])
157+
else:
158+
raise NotImplementedError(f"Invalid URI: {uri}")
159+
160+
161+
def download_from_uri(uri: str, local_path: str):
162+
if uri.startswith("s3://"):
163+
subprocess.check_output(["aws", "s3", "cp", "--recursive", uri, local_path])
164+
elif uri.startswith("gs://"):
165+
subprocess.check_output(
166+
["gsutil", "-m", "cp", "-r", uri.rstrip("/") + "/*", local_path]
167+
)
168+
else:
169+
raise NotImplementedError(f"Invalid URI: {uri}")
170+
171+
148172
@pytest.mark.parametrize(
149173
"storage_path_storage_filesystem_label",
150174
[
151-
(os.environ["ANYSCALE_ARTIFACT_STORAGE"] + "/test-persistence", None, "cloud"),
175+
(os.environ["ANYSCALE_ARTIFACT_STORAGE"] + "/test-persistence/", None, "cloud"),
152176
("/mnt/cluster_storage/test-persistence", None, "nfs"),
153177
(
154-
strip_prefix(os.environ["ANYSCALE_ARTIFACT_STORAGE"] + "/test-persistence"),
178+
strip_prefix(
179+
os.environ["ANYSCALE_ARTIFACT_STORAGE"] + "/test-persistence/"
180+
),
155181
get_custom_cloud_fs(),
156182
"cloud+custom_fs",
157183
),
@@ -184,21 +210,13 @@ def test_trainer(storage_path_storage_filesystem_label, tmp_path, monkeypatch):
184210
)
185211
exp_name = "test_trainer"
186212

187-
# NOTE: We use fsspec directly for cleaning up the cloud folders and
188-
# downloading for inspection, since the pyarrow default implementation
189-
# doesn't delete/download files properly from GCS.
190-
fsspec_fs, storage_fs_path = (
191-
fsspec.core.url_to_fs(
192-
os.environ["ANYSCALE_ARTIFACT_STORAGE"] + "/test-persistence"
193-
)
194-
if "cloud" in label
195-
else fsspec.core.url_to_fs(storage_path)
196-
)
197-
experiment_fs_path = os.path.join(storage_fs_path, exp_name)
198-
if fsspec_fs.exists(experiment_fs_path):
199-
print("\nDeleting results from a previous run...\n")
200-
fsspec_fs.rm(experiment_fs_path, recursive=True)
201-
assert not fsspec_fs.exists(experiment_fs_path)
213+
print("Deleting files from previous run...")
214+
if "cloud" in label:
215+
# NOTE: Use the CLI to delete files on cloud, since the python libraries
216+
# (pyarrow, fsspec) aren't consistent across cloud platforms (s3, gs).
217+
delete_at_uri(os.environ["ANYSCALE_ARTIFACT_STORAGE"] + "/test-persistence/")
218+
else:
219+
shutil.rmtree(storage_path, ignore_errors=True)
202220

203221
trainer = TorchTrainer(
204222
train_fn,
@@ -234,22 +252,24 @@ def test_trainer(storage_path_storage_filesystem_label, tmp_path, monkeypatch):
234252
storage_filesystem=storage_filesystem,
235253
)
236254
result = restored_trainer.fit()
237-
238-
# First, inspect that the result object returns the correct paths.
239255
print(result)
240-
trial_fs_path = result.path
241-
assert trial_fs_path.startswith(storage_fs_path)
242-
for checkpoint, _ in result.best_checkpoints:
243-
assert checkpoint.path.startswith(trial_fs_path)
244256

245257
print("\nAsserting contents of uploaded results.\n")
246258
local_inspect_dir = tmp_path / "inspect_dir"
247259
local_inspect_dir.mkdir()
248260
# Download the results from storage
249-
fsspec_fs.get(storage_fs_path, str(local_inspect_dir), recursive=True)
261+
if "cloud" in label:
262+
# NOTE: Use the CLI to download, since the python libraries
263+
# (pyarrow, fsspec) aren't consistent across cloud platforms (s3, gs).
264+
download_from_uri(
265+
os.environ["ANYSCALE_ARTIFACT_STORAGE"] + "/test-persistence/",
266+
str(local_inspect_dir),
267+
)
268+
else:
269+
local_inspect_dir = Path(storage_path)
250270

251271
_assert_storage_contents(
252-
local_inspect_dir / "test-persistence",
272+
local_inspect_dir,
253273
exp_name,
254274
checkpoint_config,
255275
"TorchTrainer",

0 commit comments

Comments
 (0)