Skip to content

Commit e048e3a

Browse files
matthew29tangcopybara-github
authored andcommitted
feat: Add Ray on Vertex BigQuery read/write support for Ray 2.9
PiperOrigin-RevId: 611326589
1 parent e0f7250 commit e048e3a

File tree

3 files changed

+200
-8
lines changed

3 files changed

+200
-8
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2024 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
19+
import logging
20+
import os
21+
import tempfile
22+
import time
23+
import uuid
24+
from typing import Any, Iterable, Optional
25+
26+
import pyarrow.parquet as pq
27+
28+
from google.api_core import client_info
29+
from google.api_core import exceptions
30+
from google.cloud import bigquery
31+
from google.cloud.aiplatform import initializer
32+
33+
import ray
34+
from ray.data._internal.execution.interfaces import TaskContext
35+
from ray.data._internal.remote_fn import cached_remote_fn
36+
from ray.data.block import Block, BlockAccessor
37+
38+
try:
39+
from ray.data.datasource.datasink import Datasink
40+
except ImportError:
41+
# If datasink cannot be imported, Ray 2.9.3 is not installed
42+
Datasink = None
43+
44+
DEFAULT_MAX_RETRY_CNT = 10
45+
RATE_LIMIT_EXCEEDED_SLEEP_TIME = 11
46+
47+
_BQ_GAPIC_VERSION = bigquery.__version__ + "+vertex_ray"
48+
bq_info = client_info.ClientInfo(
49+
gapic_version=_BQ_GAPIC_VERSION, user_agent=f"ray-on-vertex/{_BQ_GAPIC_VERSION}"
50+
)
51+
52+
if Datasink is None:
53+
_BigQueryDatasink = None
54+
else:
55+
# BigQuery write for Ray 2.9.3
56+
class _BigQueryDatasink(Datasink):
57+
def __init__(
58+
self,
59+
dataset: str,
60+
project_id: str = None,
61+
max_retry_cnt: int = DEFAULT_MAX_RETRY_CNT,
62+
overwrite_table: Optional[bool] = True,
63+
) -> None:
64+
self.dataset = dataset
65+
self.project_id = project_id or initializer.global_config.project
66+
self.max_retry_cnt = max_retry_cnt
67+
self.overwrite_table = overwrite_table
68+
69+
def on_write_start(self) -> None:
70+
# Set up datasets to write
71+
client = bigquery.Client(project=self.project_id, client_info=bq_info)
72+
dataset_id = self.dataset.split(".", 1)[0]
73+
try:
74+
client.get_dataset(dataset_id)
75+
except exceptions.NotFound:
76+
client.create_dataset(f"{self.project_id}.{dataset_id}", timeout=30)
77+
print("[Ray on Vertex AI]: Created dataset " + dataset_id)
78+
79+
# Delete table if overwrite_table is True
80+
if self.overwrite_table:
81+
print(
82+
f"[Ray on Vertex AI]: Attempting to delete table {self.dataset}"
83+
+ " if it already exists since kwarg overwrite_table = True."
84+
)
85+
client.delete_table(
86+
f"{self.project_id}.{self.dataset}", not_found_ok=True
87+
)
88+
else:
89+
print(
90+
"[Ray on Vertex AI]: The write will append to table "
91+
+ f"{self.dataset} if it already exists "
92+
+ "since kwarg overwrite_table = False."
93+
)
94+
95+
def write(
96+
self,
97+
blocks: Iterable[Block],
98+
ctx: TaskContext,
99+
) -> Any:
100+
def _write_single_block(
101+
block: Block, project_id: str, dataset: str
102+
) -> None:
103+
block = BlockAccessor.for_block(block).to_arrow()
104+
105+
client = bigquery.Client(project=project_id, client_info=bq_info)
106+
job_config = bigquery.LoadJobConfig(autodetect=True)
107+
job_config.source_format = bigquery.SourceFormat.PARQUET
108+
job_config.write_disposition = bigquery.WriteDisposition.WRITE_APPEND
109+
110+
with tempfile.TemporaryDirectory() as temp_dir:
111+
fp = os.path.join(temp_dir, f"block_{uuid.uuid4()}.parquet")
112+
pq.write_table(block, fp, compression="SNAPPY")
113+
114+
retry_cnt = 0
115+
while retry_cnt <= self.max_retry_cnt:
116+
with open(fp, "rb") as source_file:
117+
job = client.load_table_from_file(
118+
source_file, dataset, job_config=job_config
119+
)
120+
try:
121+
logging.info(job.result())
122+
break
123+
except exceptions.Forbidden as e:
124+
retry_cnt += 1
125+
if retry_cnt > self.max_retry_cnt:
126+
break
127+
print(
128+
"[Ray on Vertex AI]: A block write encountered"
129+
+ f" a rate limit exceeded error {retry_cnt} time(s)."
130+
+ " Sleeping to try again."
131+
)
132+
logging.debug(e)
133+
time.sleep(RATE_LIMIT_EXCEEDED_SLEEP_TIME)
134+
135+
# Raise exception if retry_cnt exceeds max_retry_cnt
136+
if retry_cnt > self.max_retry_cnt:
137+
print(
138+
f"[Ray on Vertex AI]: Maximum ({self.max_retry_cnt}) retry count exceeded."
139+
+ " Ray will attempt to retry the block write via fault tolerance."
140+
+ " For more information, see https://docs.ray.io/en/latest/ray-core/fault_tolerance/tasks.html"
141+
)
142+
raise RuntimeError(
143+
f"[Ray on Vertex AI]: Write failed due to {retry_cnt}"
144+
+ " repeated API rate limit exceeded responses. Consider"
145+
+ " specifiying the max_retry_cnt kwarg with a higher value."
146+
)
147+
148+
_write_single_block = cached_remote_fn(_write_single_block)
149+
150+
# Launch a remote task for each block within this write task
151+
ray.get(
152+
[
153+
_write_single_block.remote(block, self.project_id, self.dataset)
154+
for block in blocks
155+
]
156+
)
157+
158+
return "ok"

google/cloud/aiplatform/preview/vertex_ray/bigquery_datasource.py

+1
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ class BigQueryDatasource(Datasource):
171171
def create_reader(self, **kwargs) -> Reader:
172172
return _BigQueryDatasourceReader(**kwargs)
173173

174+
# BigQuery write for Ray 2.4.0
174175
def do_write(
175176
self,
176177
blocks: List[ObjectRef[Block]],

google/cloud/aiplatform/preview/vertex_ray/data.py

+41-8
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,19 @@
1717

1818
import ray.data
1919
from ray.data.dataset import Dataset
20-
from typing import Optional
20+
from typing import Any, Dict, Optional
2121

2222
from google.cloud.aiplatform.preview.vertex_ray.bigquery_datasource import (
2323
BigQueryDatasource,
2424
)
2525

26+
try:
27+
from google.cloud.aiplatform.preview.vertex_ray.bigquery_datasink import (
28+
_BigQueryDatasink,
29+
)
30+
except ImportError:
31+
_BigQueryDatasink = None
32+
2633

2734
def read_bigquery(
2835
project_id: Optional[str] = None,
@@ -31,6 +38,7 @@ def read_bigquery(
3138
*,
3239
parallelism: int = -1,
3340
) -> Dataset:
41+
# The read is identical in Ray 2.4 and 2.9
3442
return ray.data.read_datasource(
3543
BigQueryDatasource(),
3644
project_id=project_id,
@@ -45,10 +53,35 @@ def write_bigquery(
4553
project_id: Optional[str] = None,
4654
dataset: Optional[str] = None,
4755
max_retry_cnt: int = 10,
48-
) -> None:
49-
return ds.write_datasource(
50-
BigQueryDatasource(),
51-
project_id=project_id,
52-
dataset=dataset,
53-
max_retry_cnt=max_retry_cnt,
54-
)
56+
ray_remote_args: Dict[str, Any] = None,
57+
) -> Any:
58+
if ray.__version__ == "2.4.0":
59+
return ds.write_datasource(
60+
BigQueryDatasource(),
61+
project_id=project_id,
62+
dataset=dataset,
63+
max_retry_cnt=max_retry_cnt,
64+
)
65+
elif ray.__version__ == "2.9.3":
66+
if ray_remote_args is None:
67+
ray_remote_args = {}
68+
69+
# Each write task will launch individual remote tasks to write each block
70+
# To avoid duplicate block writes, the write task should not be retried
71+
if ray_remote_args.get("max_retries", 0) != 0:
72+
print(
73+
"[Ray on Vertex AI]: The max_retries of a BigQuery Write "
74+
"Task should be set to 0 to avoid duplicate writes."
75+
)
76+
else:
77+
ray_remote_args["max_retries"] = 0
78+
79+
datasink = _BigQueryDatasink(
80+
project_id=project_id, dataset=dataset, max_retry_cnt=max_retry_cnt
81+
)
82+
return ds.write_datasink(datasink, ray_remote_args=ray_remote_args)
83+
else:
84+
raise ImportError(
85+
f"[Ray on Vertex AI]: Unsupported version {ray.__version__}."
86+
+ "Only 2.4.0 and 2.9.3 are supported."
87+
)

0 commit comments

Comments
 (0)