Skip to content

Commit 8c7bf27

Browse files
speedstorm1copybara-github
authored andcommitted
feat: Update Ray system tests to be compatible with new RoV 2.33 changes
PiperOrigin-RevId: 673047153
1 parent 424ebbf commit 8c7bf27

File tree

5 files changed

+142
-105
lines changed

5 files changed

+142
-105
lines changed

google/cloud/aiplatform/vertex_ray/bigquery_datasink.py

+109-97
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@
3535
from ray.data._internal.remote_fn import cached_remote_fn
3636
from ray.data.block import Block, BlockAccessor
3737

38-
from ray.data.datasource.datasink import Datasink
38+
try:
39+
from ray.data.datasource.datasink import Datasink
40+
except ImportError:
41+
# If datasink cannot be imported, Ray >=2.9.3 is not installed
42+
Datasink = None
3943

4044

4145
DEFAULT_MAX_RETRY_CNT = 10
@@ -48,102 +52,110 @@
4852

4953

5054
# BigQuery write for Ray 2.33.0 and 2.9.3
51-
class _BigQueryDatasink(Datasink):
52-
def __init__(
53-
self,
54-
dataset: str,
55-
project_id: Optional[str] = None,
56-
max_retry_cnt: int = DEFAULT_MAX_RETRY_CNT,
57-
overwrite_table: Optional[bool] = True,
58-
) -> None:
59-
self.dataset = dataset
60-
self.project_id = project_id or initializer.global_config.project
61-
self.max_retry_cnt = max_retry_cnt
62-
self.overwrite_table = overwrite_table
63-
64-
def on_write_start(self) -> None:
65-
# Set up datasets to write
66-
client = bigquery.Client(project=self.project_id, client_info=bq_info)
67-
dataset_id = self.dataset.split(".", 1)[0]
68-
try:
69-
client.get_dataset(dataset_id)
70-
except exceptions.NotFound:
71-
client.create_dataset(f"{self.project_id}.{dataset_id}", timeout=30)
72-
print("[Ray on Vertex AI]: Created dataset " + dataset_id)
73-
74-
# Delete table if overwrite_table is True
75-
if self.overwrite_table:
76-
print(
77-
f"[Ray on Vertex AI]: Attempting to delete table {self.dataset}"
78-
+ " if it already exists since kwarg overwrite_table = True."
79-
)
80-
client.delete_table(f"{self.project_id}.{self.dataset}", not_found_ok=True)
81-
else:
82-
print(
83-
"[Ray on Vertex AI]: The write will append to table "
84-
+ f"{self.dataset} if it already exists "
85-
+ "since kwarg overwrite_table = False."
86-
)
87-
88-
def write(
89-
self,
90-
blocks: Iterable[Block],
91-
ctx: TaskContext,
92-
) -> Any:
93-
def _write_single_block(block: Block, project_id: str, dataset: str) -> None:
94-
block = BlockAccessor.for_block(block).to_arrow()
95-
96-
client = bigquery.Client(project=project_id, client_info=bq_info)
97-
job_config = bigquery.LoadJobConfig(autodetect=True)
98-
job_config.source_format = bigquery.SourceFormat.PARQUET
99-
job_config.write_disposition = bigquery.WriteDisposition.WRITE_APPEND
100-
101-
with tempfile.TemporaryDirectory() as temp_dir:
102-
fp = os.path.join(temp_dir, f"block_{uuid.uuid4()}.parquet")
103-
pq.write_table(block, fp, compression="SNAPPY")
104-
105-
retry_cnt = 0
106-
while retry_cnt <= self.max_retry_cnt:
107-
with open(fp, "rb") as source_file:
108-
job = client.load_table_from_file(
109-
source_file, dataset, job_config=job_config
110-
)
111-
try:
112-
logging.info(job.result())
113-
break
114-
except exceptions.Forbidden as e:
115-
retry_cnt += 1
116-
if retry_cnt > self.max_retry_cnt:
55+
if Datasink is None:
56+
_BigQueryDatasink = None
57+
else:
58+
59+
class _BigQueryDatasink(Datasink):
60+
def __init__(
61+
self,
62+
dataset: str,
63+
project_id: Optional[str] = None,
64+
max_retry_cnt: int = DEFAULT_MAX_RETRY_CNT,
65+
overwrite_table: Optional[bool] = True,
66+
) -> None:
67+
self.dataset = dataset
68+
self.project_id = project_id or initializer.global_config.project
69+
self.max_retry_cnt = max_retry_cnt
70+
self.overwrite_table = overwrite_table
71+
72+
def on_write_start(self) -> None:
73+
# Set up datasets to write
74+
client = bigquery.Client(project=self.project_id, client_info=bq_info)
75+
dataset_id = self.dataset.split(".", 1)[0]
76+
try:
77+
client.get_dataset(dataset_id)
78+
except exceptions.NotFound:
79+
client.create_dataset(f"{self.project_id}.{dataset_id}", timeout=30)
80+
print("[Ray on Vertex AI]: Created dataset " + dataset_id)
81+
82+
# Delete table if overwrite_table is True
83+
if self.overwrite_table:
84+
print(
85+
f"[Ray on Vertex AI]: Attempting to delete table {self.dataset}"
86+
+ " if it already exists since kwarg overwrite_table = True."
87+
)
88+
client.delete_table(
89+
f"{self.project_id}.{self.dataset}", not_found_ok=True
90+
)
91+
else:
92+
print(
93+
"[Ray on Vertex AI]: The write will append to table "
94+
+ f"{self.dataset} if it already exists "
95+
+ "since kwarg overwrite_table = False."
96+
)
97+
98+
def write(
99+
self,
100+
blocks: Iterable[Block],
101+
ctx: TaskContext,
102+
) -> Any:
103+
def _write_single_block(
104+
block: Block, project_id: str, dataset: str
105+
) -> None:
106+
block = BlockAccessor.for_block(block).to_arrow()
107+
108+
client = bigquery.Client(project=project_id, client_info=bq_info)
109+
job_config = bigquery.LoadJobConfig(autodetect=True)
110+
job_config.source_format = bigquery.SourceFormat.PARQUET
111+
job_config.write_disposition = bigquery.WriteDisposition.WRITE_APPEND
112+
113+
with tempfile.TemporaryDirectory() as temp_dir:
114+
fp = os.path.join(temp_dir, f"block_{uuid.uuid4()}.parquet")
115+
pq.write_table(block, fp, compression="SNAPPY")
116+
117+
retry_cnt = 0
118+
while retry_cnt <= self.max_retry_cnt:
119+
with open(fp, "rb") as source_file:
120+
job = client.load_table_from_file(
121+
source_file, dataset, job_config=job_config
122+
)
123+
try:
124+
logging.info(job.result())
117125
break
126+
except exceptions.Forbidden as e:
127+
retry_cnt += 1
128+
if retry_cnt > self.max_retry_cnt:
129+
break
130+
print(
131+
"[Ray on Vertex AI]: A block write encountered"
132+
+ f" a rate limit exceeded error {retry_cnt} time(s)."
133+
+ " Sleeping to try again."
134+
)
135+
logging.debug(e)
136+
time.sleep(RATE_LIMIT_EXCEEDED_SLEEP_TIME)
137+
138+
# Raise exception if retry_cnt exceeds max_retry_cnt
139+
if retry_cnt > self.max_retry_cnt:
118140
print(
119-
"[Ray on Vertex AI]: A block write encountered"
120-
+ f" a rate limit exceeded error {retry_cnt} time(s)."
121-
+ " Sleeping to try again."
141+
f"[Ray on Vertex AI]: Maximum ({self.max_retry_cnt}) retry count exceeded."
142+
+ " Ray will attempt to retry the block write via fault tolerance."
143+
+ " For more information, see https://docs.ray.io/en/latest/ray-core/fault_tolerance/tasks.html"
144+
)
145+
raise RuntimeError(
146+
f"[Ray on Vertex AI]: Write failed due to {retry_cnt}"
147+
+ " repeated API rate limit exceeded responses. Consider"
148+
+ " specifiying the max_retry_cnt kwarg with a higher value."
122149
)
123-
logging.debug(e)
124-
time.sleep(RATE_LIMIT_EXCEEDED_SLEEP_TIME)
125-
126-
# Raise exception if retry_cnt exceeds max_retry_cnt
127-
if retry_cnt > self.max_retry_cnt:
128-
print(
129-
f"[Ray on Vertex AI]: Maximum ({self.max_retry_cnt}) retry count exceeded."
130-
+ " Ray will attempt to retry the block write via fault tolerance."
131-
+ " For more information, see https://docs.ray.io/en/latest/ray-core/fault_tolerance/tasks.html"
132-
)
133-
raise RuntimeError(
134-
f"[Ray on Vertex AI]: Write failed due to {retry_cnt}"
135-
+ " repeated API rate limit exceeded responses. Consider"
136-
+ " specifiying the max_retry_cnt kwarg with a higher value."
137-
)
138-
139-
_write_single_block = cached_remote_fn(_write_single_block)
140-
141-
# Launch a remote task for each block within this write task
142-
ray.get(
143-
[
144-
_write_single_block.remote(block, self.project_id, self.dataset)
145-
for block in blocks
146-
]
147-
)
148-
149-
return "ok"
150+
151+
_write_single_block = cached_remote_fn(_write_single_block)
152+
153+
# Launch a remote task for each block within this write task
154+
ray.get(
155+
[
156+
_write_single_block.remote(block, self.project_id, self.dataset)
157+
for block in blocks
158+
]
159+
)
160+
161+
return "ok"

google/cloud/aiplatform/vertex_ray/data.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,12 @@
2323
_BigQueryDatasource,
2424
)
2525

26-
from google.cloud.aiplatform.vertex_ray.bigquery_datasink import (
27-
_BigQueryDatasink,
28-
)
26+
try:
27+
from google.cloud.aiplatform.vertex_ray.bigquery_datasink import (
28+
_BigQueryDatasink,
29+
)
30+
except ImportError:
31+
_BigQueryDatasink = None
2932

3033
from google.cloud.aiplatform.vertex_ray.util._validation_utils import (
3134
_V2_4_WARNING_MESSAGE,

tests/system/vertex_ray/test_cluster_management.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import pytest
2323
import ray
2424

25-
# Local ray version will always be 2.4 regardless of cluster version due to
25+
# Local ray version will always be 2.4.0 regardless of cluster version due to
2626
# depenency conflicts. Remote job execution's Ray version is 2.9.
2727
RAY_VERSION = "2.4.0"
2828
PROJECT_ID = "ucaip-sample-tests"
@@ -31,7 +31,7 @@
3131
class TestClusterManagement(e2e_base.TestEndToEnd):
3232
_temp_prefix = "temp-rov-cluster-management"
3333

34-
@pytest.mark.parametrize("cluster_ray_version", ["2.9"])
34+
@pytest.mark.parametrize("cluster_ray_version", ["2.9", "2.33"])
3535
def test_cluster_management(self, cluster_ray_version):
3636
assert ray.__version__ == RAY_VERSION
3737
aiplatform.init(project=PROJECT_ID, location="us-central1")

tests/system/vertex_ray/test_job_submission_dashboard.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
class TestJobSubmissionDashboard(e2e_base.TestEndToEnd):
3636
_temp_prefix = "temp-job-submission-dashboard"
3737

38-
@pytest.mark.parametrize("cluster_ray_version", ["2.9"])
38+
@pytest.mark.parametrize("cluster_ray_version", ["2.9", "2.33"])
3939
def test_job_submission_dashboard(self, cluster_ray_version):
4040
assert ray.__version__ == RAY_VERSION
4141
aiplatform.init(project=PROJECT_ID, location="us-central1")

tests/system/vertex_ray/test_ray_data.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,35 @@
5454
)
5555
"""
5656

57-
my_script = {"2.9": my_script_ray29}
57+
my_script_ray233 = """
58+
import ray
59+
import vertex_ray
60+
61+
override_num_blocks = 10
62+
query = "SELECT * FROM `bigquery-public-data.ml_datasets.ulb_fraud_detection` LIMIT 10000000"
63+
64+
ds = vertex_ray.data.read_bigquery(
65+
override_num_blocks=override_num_blocks,
66+
query=query,
67+
)
68+
69+
# The reads are lazy, so the end time cannot be captured until ds.materialize() is called
70+
ds.materialize()
71+
72+
# Write
73+
vertex_ray.data.write_bigquery(
74+
ds,
75+
dataset="bugbashbq1.system_test_ray29_write",
76+
)
77+
"""
78+
79+
my_script = {"2.9": my_script_ray29, "2.33": my_script_ray233}
5880

5981

6082
class TestRayData(e2e_base.TestEndToEnd):
6183
_temp_prefix = "temp-ray-data"
6284

63-
@pytest.mark.parametrize("cluster_ray_version", ["2.9"])
85+
@pytest.mark.parametrize("cluster_ray_version", ["2.9", "2.33"])
6486
def test_ray_data(self, cluster_ray_version):
6587
head_node_type = vertex_ray.Resources()
6688
worker_node_types = [

0 commit comments

Comments
 (0)