Skip to content

Commit 3be36e6

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: Release Ray on Vertex SDK Preview
PiperOrigin-RevId: 565992844
1 parent 6fb30bc commit 3be36e6

38 files changed

+3969
-7
lines changed

google/cloud/aiplatform/initializer.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from concurrent import futures
2020
import logging
21-
import pkg_resources # Note this is used after copybara replacement
21+
import pkg_resources # noqa: F401 # Note this is used after copybara replacement
2222
import os
2323
from typing import List, Optional, Type, TypeVar, Union
2424

@@ -395,6 +395,7 @@ def create_client(
395395
api_base_path_override: Optional[str] = None,
396396
api_path_override: Optional[str] = None,
397397
appended_user_agent: Optional[List[str]] = None,
398+
appended_gapic_version: Optional[str] = None,
398399
) -> _TVertexAiServiceClientWithOverride:
399400
"""Instantiates a given VertexAiServiceClient with optional
400401
overrides.
@@ -411,6 +412,8 @@ def create_client(
411412
appended_user_agent (List[str]):
412413
Optional. User agent appended in the client info. If more than one, it will be
413414
separated by spaces.
415+
appended_gapic_version (str):
416+
Optional. GAPIC version suffix appended in the client info.
414417
Returns:
415418
client: Instantiated Vertex AI Service client with optional overrides
416419
"""
@@ -422,6 +425,9 @@ def create_client(
422425
if appended_user_agent:
423426
user_agent = f"{user_agent} {' '.join(appended_user_agent)}"
424427

428+
if appended_gapic_version:
429+
gapic_version = f"{gapic_version}+{appended_gapic_version}"
430+
425431
client_info = gapic_v1.client_info.ClientInfo(
426432
gapic_version=gapic_version,
427433
user_agent=user_agent,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""Ray on Vertex AI."""
2+
3+
# -*- coding: utf-8 -*-
4+
5+
# Copyright 2022 Google LLC
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
import sys
20+
21+
from google.cloud.aiplatform.preview.vertex_ray.bigquery_datasource import (
22+
BigQueryDatasource,
23+
)
24+
from google.cloud.aiplatform.preview.vertex_ray.client_builder import (
25+
VertexRayClientBuilder as ClientBuilder,
26+
)
27+
28+
from google.cloud.aiplatform.preview.vertex_ray.cluster_init import (
29+
create_ray_cluster,
30+
delete_ray_cluster,
31+
get_ray_cluster,
32+
list_ray_clusters,
33+
update_ray_cluster,
34+
)
35+
from google.cloud.aiplatform.preview.vertex_ray.util.resources import (
36+
Resources,
37+
)
38+
39+
from google.cloud.aiplatform.preview.vertex_ray.dashboard_sdk import (
40+
get_job_submission_client_cluster_info,
41+
)
42+
43+
if sys.version_info[1] != 10:
44+
print(
45+
"[Ray on Vertex]: The client environment with Python version 3.10 is required."
46+
)
47+
48+
__all__ = (
49+
"BigQueryDatasource",
50+
"ClientBuilder",
51+
"get_job_submission_client_cluster_info",
52+
"create_ray_cluster",
53+
"delete_ray_cluster",
54+
"get_ray_cluster",
55+
"list_ray_clusters",
56+
"update_ray_cluster",
57+
"Resources",
58+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2022 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import logging
19+
import os
20+
import tempfile
21+
import time
22+
from typing import Any, Dict, List, Optional
23+
import uuid
24+
25+
from google.api_core import client_info
26+
from google.api_core import exceptions
27+
from google.api_core.gapic_v1 import client_info as v1_client_info
28+
from google.cloud import bigquery
29+
from google.cloud import bigquery_storage
30+
from google.cloud.aiplatform import initializer
31+
from google.cloud.bigquery_storage import types
32+
import pyarrow.parquet as pq
33+
from ray.data._internal.remote_fn import cached_remote_fn
34+
from ray.data.block import Block
35+
from ray.data.block import BlockAccessor
36+
from ray.data.block import BlockMetadata
37+
from ray.data.datasource.datasource import Datasource
38+
from ray.data.datasource.datasource import Reader
39+
from ray.data.datasource.datasource import ReadTask
40+
from ray.data.datasource.datasource import WriteResult
41+
from ray.types import ObjectRef
42+
43+
44+
_BQ_GAPIC_VERSION = bigquery.__version__ + "+vertex_ray"
45+
_BQS_GAPIC_VERSION = bigquery_storage.__version__ + "+vertex_ray"
46+
bq_info = client_info.ClientInfo(
47+
gapic_version=_BQ_GAPIC_VERSION, user_agent=f"ray-on-vertex/{_BQ_GAPIC_VERSION}"
48+
)
49+
bqstorage_info = v1_client_info.ClientInfo(
50+
gapic_version=_BQS_GAPIC_VERSION, user_agent=f"ray-on-vertex/{_BQS_GAPIC_VERSION}"
51+
)
52+
53+
54+
class _BigQueryDatasourceReader(Reader):
55+
def __init__(
56+
self,
57+
project_id: Optional[str] = None,
58+
dataset: Optional[str] = None,
59+
query: Optional[str] = None,
60+
parallelism: Optional[int] = -1,
61+
**kwargs: Optional[Dict[str, Any]],
62+
):
63+
self._project_id = project_id or initializer.global_config.project
64+
self._dataset = dataset
65+
self._query = query
66+
self._kwargs = kwargs
67+
68+
if query is not None and dataset is not None:
69+
raise ValueError(
70+
"[Ray on Vertex AI]: Query and dataset kwargs cannot both be provided (must be mutually exclusive)."
71+
)
72+
73+
def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
74+
# Executed by a worker node
75+
def _read_single_partition(stream, kwargs) -> Block:
76+
client = bigquery_storage.BigQueryReadClient(client_info=bqstorage_info)
77+
reader = client.read_rows(stream.name)
78+
return reader.to_arrow()
79+
80+
if self._query:
81+
query_client = bigquery.Client(
82+
project=self._project_id, client_info=bq_info
83+
)
84+
query_job = query_client.query(self._query)
85+
query_job.result()
86+
destination = str(query_job.destination)
87+
dataset_id = destination.split(".")[-2]
88+
table_id = destination.split(".")[-1]
89+
else:
90+
self._validate_dataset_table_exist(self._project_id, self._dataset)
91+
dataset_id = self._dataset.split(".")[0]
92+
table_id = self._dataset.split(".")[1]
93+
94+
bqs_client = bigquery_storage.BigQueryReadClient(client_info=bqstorage_info)
95+
table = f"projects/{self._project_id}/datasets/{dataset_id}/tables/{table_id}"
96+
97+
if parallelism == -1:
98+
parallelism = None
99+
requested_session = types.ReadSession(
100+
table=table,
101+
data_format=types.DataFormat.ARROW,
102+
)
103+
read_session = bqs_client.create_read_session(
104+
parent=f"projects/{self._project_id}",
105+
read_session=requested_session,
106+
max_stream_count=parallelism,
107+
)
108+
109+
read_tasks = []
110+
print("[Ray on Vertex AI]: Created streams:", len(read_session.streams))
111+
if len(read_session.streams) < parallelism:
112+
print(
113+
"[Ray on Vertex AI]: The number of streams created by the "
114+
+ "BigQuery Storage Read API is less than the requested "
115+
+ "parallelism due to the size of the dataset."
116+
)
117+
118+
for stream in read_session.streams:
119+
# Create a metadata block object to store schema, etc.
120+
metadata = BlockMetadata(
121+
num_rows=None,
122+
size_bytes=None,
123+
schema=None,
124+
input_files=None,
125+
exec_stats=None,
126+
)
127+
128+
# Create a no-arg wrapper read function which returns a block
129+
read_single_partition = (
130+
lambda stream=stream, kwargs=self._kwargs: [ # noqa: F731
131+
_read_single_partition(stream, kwargs)
132+
]
133+
)
134+
135+
# Create the read task and pass the wrapper and metadata in
136+
read_task = ReadTask(read_single_partition, metadata)
137+
read_tasks.append(read_task)
138+
139+
return read_tasks
140+
141+
def estimate_inmemory_data_size(self) -> Optional[int]:
142+
# TODO(b/281891467): Implement this method
143+
return None
144+
145+
def _validate_dataset_table_exist(self, project_id: str, dataset: str) -> None:
146+
client = bigquery.Client(project=project_id, client_info=bq_info)
147+
dataset_id = dataset.split(".")[0]
148+
try:
149+
client.get_dataset(dataset_id)
150+
except exceptions.NotFound:
151+
raise ValueError(
152+
"[Ray on Vertex AI]: Dataset {} is not found. Please ensure that it exists.".format(
153+
dataset_id
154+
)
155+
)
156+
157+
try:
158+
client.get_table(dataset)
159+
except exceptions.NotFound:
160+
raise ValueError(
161+
"[Ray on Vertex AI]: Table {} is not found. Please ensure that it exists.".format(
162+
dataset
163+
)
164+
)
165+
166+
167+
class BigQueryDatasource(Datasource):
168+
def create_reader(self, **kwargs) -> Reader:
169+
return _BigQueryDatasourceReader(**kwargs)
170+
171+
def do_write(
172+
self,
173+
blocks: List[ObjectRef[Block]],
174+
metadata: List[BlockMetadata],
175+
ray_remote_args: Optional[Dict[str, Any]],
176+
project_id: Optional[str] = None,
177+
dataset: Optional[str] = None,
178+
) -> List[ObjectRef[WriteResult]]:
179+
def _write_single_block(
180+
block: Block, metadata: BlockMetadata, project_id: str, dataset: str
181+
):
182+
print("[Ray on Vertex AI]: Starting to write", metadata.num_rows, "rows")
183+
block = BlockAccessor.for_block(block).to_arrow()
184+
185+
client = bigquery.Client(project=project_id, client_info=bq_info)
186+
job_config = bigquery.LoadJobConfig(autodetect=True)
187+
job_config.source_format = bigquery.SourceFormat.PARQUET
188+
job_config.write_disposition = bigquery.WriteDisposition.WRITE_APPEND
189+
190+
with tempfile.TemporaryDirectory() as temp_dir:
191+
fp = os.path.join(temp_dir, f"block_{uuid.uuid4()}.parquet")
192+
pq.write_table(block, fp, compression="SNAPPY")
193+
194+
retry_cnt = 0
195+
while retry_cnt < 10:
196+
with open(fp, "rb") as source_file:
197+
job = client.load_table_from_file(
198+
source_file, dataset, job_config=job_config
199+
)
200+
retry_cnt += 1
201+
try:
202+
logging.info(job.result())
203+
break
204+
except exceptions.Forbidden as e:
205+
print(
206+
"[Ray on Vertex AI]: Rate limit exceeded... Sleeping to try again"
207+
)
208+
logging.debug(e)
209+
time.sleep(11)
210+
print("[Ray on Vertex AI]: Finished writing", metadata.num_rows, "rows")
211+
212+
project_id = project_id or initializer.global_config.project
213+
214+
if dataset is None:
215+
raise ValueError(
216+
"[Ray on Vertex AI]: Dataset is required when writing to BigQuery."
217+
)
218+
219+
if ray_remote_args is None:
220+
ray_remote_args = {}
221+
222+
_write_single_block = cached_remote_fn(_write_single_block).options(
223+
**ray_remote_args
224+
)
225+
write_tasks = []
226+
227+
# Set up datasets to write
228+
client = bigquery.Client(project=project_id, client_info=bq_info)
229+
dataset_id = dataset.split(".", 1)[0]
230+
try:
231+
client.create_dataset(f"{project_id}.{dataset_id}", timeout=30)
232+
print("[Ray on Vertex AI]: Created dataset", dataset_id)
233+
except exceptions.Conflict:
234+
print(
235+
"[Ray on Vertex AI]: Dataset",
236+
dataset_id,
237+
"already exists. The table will be overwritten if it already exists.",
238+
)
239+
240+
# Delete table if it already exists
241+
client.delete_table(f"{project_id}.{dataset}", not_found_ok=True)
242+
243+
print("[Ray on Vertex AI]: Writing", len(blocks), "blocks")
244+
for i in range(len(blocks)):
245+
write_task = _write_single_block.remote(
246+
blocks[i], metadata[i], project_id, dataset
247+
)
248+
write_tasks.append(write_task)
249+
return write_tasks

0 commit comments

Comments
 (0)