Skip to content

Commit a4b6c60

Browse files
speedstorm1copybara-github
authored andcommitted
feat: add support for version 2.33 for RoV Bigquery read/write, remove dead code from version 2.4
PiperOrigin-RevId: 671554025
1 parent 58f1294 commit a4b6c60

File tree

7 files changed

+263
-344
lines changed

7 files changed

+263
-344
lines changed

google/cloud/aiplatform/vertex_ray/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import sys
2020

2121
from google.cloud.aiplatform.vertex_ray.bigquery_datasource import (
22-
BigQueryDatasource,
22+
_BigQueryDatasource,
2323
)
2424
from google.cloud.aiplatform.vertex_ray.client_builder import (
2525
VertexRayClientBuilder as ClientBuilder,
@@ -52,7 +52,7 @@
5252
)
5353

5454
__all__ = (
55-
"BigQueryDatasource",
55+
"_BigQueryDatasource",
5656
"data",
5757
"ClientBuilder",
5858
"get_job_submission_client_cluster_info",

google/cloud/aiplatform/vertex_ray/bigquery_datasink.py

+100-109
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,8 @@
3535
from ray.data._internal.remote_fn import cached_remote_fn
3636
from ray.data.block import Block, BlockAccessor
3737

38-
try:
39-
from ray.data.datasource.datasink import Datasink
40-
except ImportError:
41-
# If datasink cannot be imported, Ray 2.9.3 is not installed
42-
Datasink = None
38+
from ray.data.datasource.datasink import Datasink
39+
4340

4441
DEFAULT_MAX_RETRY_CNT = 10
4542
RATE_LIMIT_EXCEEDED_SLEEP_TIME = 11
@@ -49,110 +46,104 @@
4946
gapic_version=_BQ_GAPIC_VERSION, user_agent=f"ray-on-vertex/{_BQ_GAPIC_VERSION}"
5047
)
5148

52-
if Datasink is None:
53-
_BigQueryDatasink = None
54-
else:
55-
# BigQuery write for Ray 2.9.3
56-
class _BigQueryDatasink(Datasink):
57-
def __init__(
58-
self,
59-
dataset: str,
60-
project_id: str = None,
61-
max_retry_cnt: int = DEFAULT_MAX_RETRY_CNT,
62-
overwrite_table: Optional[bool] = True,
63-
) -> None:
64-
self.dataset = dataset
65-
self.project_id = project_id or initializer.global_config.project
66-
self.max_retry_cnt = max_retry_cnt
67-
self.overwrite_table = overwrite_table
68-
69-
def on_write_start(self) -> None:
70-
# Set up datasets to write
71-
client = bigquery.Client(project=self.project_id, client_info=bq_info)
72-
dataset_id = self.dataset.split(".", 1)[0]
73-
try:
74-
client.get_dataset(dataset_id)
75-
except exceptions.NotFound:
76-
client.create_dataset(f"{self.project_id}.{dataset_id}", timeout=30)
77-
print("[Ray on Vertex AI]: Created dataset " + dataset_id)
78-
79-
# Delete table if overwrite_table is True
80-
if self.overwrite_table:
81-
print(
82-
f"[Ray on Vertex AI]: Attempting to delete table {self.dataset}"
83-
+ " if it already exists since kwarg overwrite_table = True."
84-
)
85-
client.delete_table(
86-
f"{self.project_id}.{self.dataset}", not_found_ok=True
87-
)
88-
else:
89-
print(
90-
"[Ray on Vertex AI]: The write will append to table "
91-
+ f"{self.dataset} if it already exists "
92-
+ "since kwarg overwrite_table = False."
93-
)
94-
95-
def write(
96-
self,
97-
blocks: Iterable[Block],
98-
ctx: TaskContext,
99-
) -> Any:
100-
def _write_single_block(
101-
block: Block, project_id: str, dataset: str
102-
) -> None:
103-
block = BlockAccessor.for_block(block).to_arrow()
104-
105-
client = bigquery.Client(project=project_id, client_info=bq_info)
106-
job_config = bigquery.LoadJobConfig(autodetect=True)
107-
job_config.source_format = bigquery.SourceFormat.PARQUET
108-
job_config.write_disposition = bigquery.WriteDisposition.WRITE_APPEND
109-
110-
with tempfile.TemporaryDirectory() as temp_dir:
111-
fp = os.path.join(temp_dir, f"block_{uuid.uuid4()}.parquet")
112-
pq.write_table(block, fp, compression="SNAPPY")
113-
114-
retry_cnt = 0
115-
while retry_cnt <= self.max_retry_cnt:
116-
with open(fp, "rb") as source_file:
117-
job = client.load_table_from_file(
118-
source_file, dataset, job_config=job_config
119-
)
120-
try:
121-
logging.info(job.result())
122-
break
123-
except exceptions.Forbidden as e:
124-
retry_cnt += 1
125-
if retry_cnt > self.max_retry_cnt:
126-
break
127-
print(
128-
"[Ray on Vertex AI]: A block write encountered"
129-
+ f" a rate limit exceeded error {retry_cnt} time(s)."
130-
+ " Sleeping to try again."
131-
)
132-
logging.debug(e)
133-
time.sleep(RATE_LIMIT_EXCEEDED_SLEEP_TIME)
134-
135-
# Raise exception if retry_cnt exceeds max_retry_cnt
136-
if retry_cnt > self.max_retry_cnt:
137-
print(
138-
f"[Ray on Vertex AI]: Maximum ({self.max_retry_cnt}) retry count exceeded."
139-
+ " Ray will attempt to retry the block write via fault tolerance."
140-
+ " For more information, see https://docs.ray.io/en/latest/ray-core/fault_tolerance/tasks.html"
141-
)
142-
raise RuntimeError(
143-
f"[Ray on Vertex AI]: Write failed due to {retry_cnt}"
144-
+ " repeated API rate limit exceeded responses. Consider"
145-
+ " specifiying the max_retry_cnt kwarg with a higher value."
146-
)
147-
148-
_write_single_block = cached_remote_fn(_write_single_block)
14949

150-
# Launch a remote task for each block within this write task
151-
ray.get(
152-
[
153-
_write_single_block.remote(block, self.project_id, self.dataset)
154-
for block in blocks
155-
]
50+
# BigQuery write for Ray 2.33.0 and 2.9.3
51+
class _BigQueryDatasink(Datasink):
52+
def __init__(
53+
self,
54+
dataset: str,
55+
project_id: Optional[str] = None,
56+
max_retry_cnt: int = DEFAULT_MAX_RETRY_CNT,
57+
overwrite_table: Optional[bool] = True,
58+
) -> None:
59+
self.dataset = dataset
60+
self.project_id = project_id or initializer.global_config.project
61+
self.max_retry_cnt = max_retry_cnt
62+
self.overwrite_table = overwrite_table
63+
64+
def on_write_start(self) -> None:
65+
# Set up datasets to write
66+
client = bigquery.Client(project=self.project_id, client_info=bq_info)
67+
dataset_id = self.dataset.split(".", 1)[0]
68+
try:
69+
client.get_dataset(dataset_id)
70+
except exceptions.NotFound:
71+
client.create_dataset(f"{self.project_id}.{dataset_id}", timeout=30)
72+
print("[Ray on Vertex AI]: Created dataset " + dataset_id)
73+
74+
# Delete table if overwrite_table is True
75+
if self.overwrite_table:
76+
print(
77+
f"[Ray on Vertex AI]: Attempting to delete table {self.dataset}"
78+
+ " if it already exists since kwarg overwrite_table = True."
79+
)
80+
client.delete_table(f"{self.project_id}.{self.dataset}", not_found_ok=True)
81+
else:
82+
print(
83+
"[Ray on Vertex AI]: The write will append to table "
84+
+ f"{self.dataset} if it already exists "
85+
+ "since kwarg overwrite_table = False."
15686
)
15787

158-
return "ok"
88+
def write(
89+
self,
90+
blocks: Iterable[Block],
91+
ctx: TaskContext,
92+
) -> Any:
93+
def _write_single_block(block: Block, project_id: str, dataset: str) -> None:
94+
block = BlockAccessor.for_block(block).to_arrow()
95+
96+
client = bigquery.Client(project=project_id, client_info=bq_info)
97+
job_config = bigquery.LoadJobConfig(autodetect=True)
98+
job_config.source_format = bigquery.SourceFormat.PARQUET
99+
job_config.write_disposition = bigquery.WriteDisposition.WRITE_APPEND
100+
101+
with tempfile.TemporaryDirectory() as temp_dir:
102+
fp = os.path.join(temp_dir, f"block_{uuid.uuid4()}.parquet")
103+
pq.write_table(block, fp, compression="SNAPPY")
104+
105+
retry_cnt = 0
106+
while retry_cnt <= self.max_retry_cnt:
107+
with open(fp, "rb") as source_file:
108+
job = client.load_table_from_file(
109+
source_file, dataset, job_config=job_config
110+
)
111+
try:
112+
logging.info(job.result())
113+
break
114+
except exceptions.Forbidden as e:
115+
retry_cnt += 1
116+
if retry_cnt > self.max_retry_cnt:
117+
break
118+
print(
119+
"[Ray on Vertex AI]: A block write encountered"
120+
+ f" a rate limit exceeded error {retry_cnt} time(s)."
121+
+ " Sleeping to try again."
122+
)
123+
logging.debug(e)
124+
time.sleep(RATE_LIMIT_EXCEEDED_SLEEP_TIME)
125+
126+
# Raise exception if retry_cnt exceeds max_retry_cnt
127+
if retry_cnt > self.max_retry_cnt:
128+
print(
129+
f"[Ray on Vertex AI]: Maximum ({self.max_retry_cnt}) retry count exceeded."
130+
+ " Ray will attempt to retry the block write via fault tolerance."
131+
+ " For more information, see https://docs.ray.io/en/latest/ray-core/fault_tolerance/tasks.html"
132+
)
133+
raise RuntimeError(
134+
f"[Ray on Vertex AI]: Write failed due to {retry_cnt}"
135+
+ " repeated API rate limit exceeded responses. Consider"
136+
+ " specifiying the max_retry_cnt kwarg with a higher value."
137+
)
138+
139+
_write_single_block = cached_remote_fn(_write_single_block)
140+
141+
# Launch a remote task for each block within this write task
142+
ray.get(
143+
[
144+
_write_single_block.remote(block, self.project_id, self.dataset)
145+
for block in blocks
146+
]
147+
)
148+
149+
return "ok"

0 commit comments

Comments
 (0)