Skip to content

Commit ff148cd

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: Enable Ray cluster creation and registering TensorFlow checkpoint to Vertex with Ray version 2.9
PiperOrigin-RevId: 611334973
1 parent d947304 commit ff148cd

File tree

7 files changed

+104
-62
lines changed

7 files changed

+104
-62
lines changed

google/cloud/aiplatform/preview/vertex_ray/cluster_init.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,13 @@ def create_ray_cluster(
121121

122122
local_ray_verion = _validation_utils.get_local_ray_version()
123123
if ray_version != local_ray_verion:
124+
install_ray_version = ".".join(ray_version.split("_"))
124125
logging.info(
125126
f"[Ray on Vertex]: Local runtime has Ray version {local_ray_verion}"
126127
+ f", but the requested cluster runtime has {ray_version}. Please "
127-
+ "ensure that the Ray versions match for client connectivity."
128+
+ "ensure that the Ray versions match for client connectivity. You may "
129+
+ f'"pip install --user --force-reinstall ray[default]=={install_ray_version}"'
130+
+ " and restart runtime before cluster connection."
128131
)
129132

130133
if cluster_name is None:
@@ -162,8 +165,12 @@ def create_ray_cluster(
162165
ray_version, python_version, enable_cuda
163166
)
164167
if custom_images is not None:
165-
if not (custom_images.head is None or custom_images.worker is None):
166-
image_uri = custom_images.head
168+
if custom_images.head is None or custom_images.worker is None:
169+
raise ValueError(
170+
"[Ray on Vertex AI]: custom_images.head and custom_images.worker must be specified when custom_images is set."
171+
)
172+
image_uri = custom_images.head
173+
167174
resource_pool_images[resource_pool_0.id] = image_uri
168175

169176
worker_pools = []
@@ -207,8 +214,7 @@ def create_ray_cluster(
207214
ray_version, python_version, enable_cuda
208215
)
209216
if custom_images is not None:
210-
if not (custom_images.head is None or custom_images.worker is None):
211-
image_uri = custom_images.worker
217+
image_uri = custom_images.worker
212218
resource_pool_images[resource_pool.id] = image_uri
213219

214220
i += 1

google/cloud/aiplatform/preview/vertex_ray/predict/sklearn/register.py

+21-12
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import logging
2121
import os
2222
import pickle
23+
import ray
2324
import tempfile
2425
from typing import Optional, TYPE_CHECKING
2526

@@ -117,15 +118,23 @@ def _get_estimator_from(
117118
Raises:
118119
ValueError: Invalid Argument.
119120
"""
120-
if not isinstance(checkpoint, ray_sklearn.SklearnCheckpoint):
121-
raise ValueError(
122-
"[Ray on Vertex AI]: arg checkpoint should be a"
123-
" ray.train.sklearn.SklearnCheckpoint instance"
124-
)
125-
if checkpoint.get_preprocessor() is not None:
126-
logging.warning(
127-
"Checkpoint contains preprocessor. However, converting from a Ray"
128-
" Checkpoint to framework specific model does NOT support"
129-
" preprocessing. The model will be exported without preprocessors."
130-
)
131-
return checkpoint.get_estimator()
121+
ray_version = ray.__version__
122+
if ray_version == "2.4.0":
123+
if not isinstance(checkpoint, ray_sklearn.SklearnCheckpoint):
124+
raise ValueError(
125+
"[Ray on Vertex AI]: arg checkpoint should be a"
126+
" ray.train.sklearn.SklearnCheckpoint instance"
127+
)
128+
if checkpoint.get_preprocessor() is not None:
129+
logging.warning(
130+
"Checkpoint contains preprocessor. However, converting from a Ray"
131+
" Checkpoint to framework specific model does NOT support"
132+
" preprocessing. The model will be exported without preprocessors."
133+
)
134+
return checkpoint.get_estimator()
135+
136+
# get_model() signature changed in future versions
137+
try:
138+
return checkpoint.get_estimator()
139+
except AttributeError:
140+
raise RuntimeError("Unsupported Ray version.")

google/cloud/aiplatform/preview/vertex_ray/predict/tensorflow/register.py

+22-15
Original file line numberDiff line numberDiff line change
@@ -123,22 +123,29 @@ def _get_tensorflow_model_from(
123123
Raises:
124124
ValueError: Invalid Argument.
125125
"""
126-
if not isinstance(checkpoint, ray_tensorflow.TensorflowCheckpoint):
127-
raise ValueError(
128-
"[Ray on Vertex AI]: arg checkpoint should be a"
129-
" ray.train.tensorflow.TensorflowCheckpoint instance"
130-
)
131-
if checkpoint.get_preprocessor() is not None:
132-
logging.warning(
133-
"Checkpoint contains preprocessor. However, converting from a Ray"
134-
" Checkpoint to framework specific model does NOT support"
135-
" preprocessing. The model will be exported without preprocessors."
136-
)
137-
if ray.__version__ == "2.4.0":
126+
ray_version = ray.__version__
127+
if ray_version == "2.4.0":
128+
if not isinstance(checkpoint, ray_tensorflow.TensorflowCheckpoint):
129+
raise ValueError(
130+
"[Ray on Vertex AI]: arg checkpoint should be a"
131+
" ray.train.tensorflow.TensorflowCheckpoint instance"
132+
)
133+
if checkpoint.get_preprocessor() is not None:
134+
logging.warning(
135+
"Checkpoint contains preprocessor. However, converting from a Ray"
136+
" Checkpoint to framework specific model does NOT support"
137+
" preprocessing. The model will be exported without preprocessors."
138+
)
139+
138140
return checkpoint.get_model(model)
139141

140142
# get_model() signature changed in future versions
141143
try:
142-
return checkpoint.get_model()
143-
except AttributeError:
144-
raise RuntimeError("Unsupported Ray version.")
144+
from tensorflow import keras
145+
146+
try:
147+
return keras.models.load_model(checkpoint.path)
148+
except OSError:
149+
return keras.models.load_model("gs://" + checkpoint.path)
150+
except ImportError:
151+
logging.warning("TensorFlow must be installed to load the trained model.")

google/cloud/aiplatform/preview/vertex_ray/predict/torch/register.py

+21-12
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import logging
1919
from typing import Optional
20+
import ray
2021

2122
try:
2223
from ray.train import torch as ray_torch
@@ -51,15 +52,23 @@ def get_pytorch_model_from(
5152
Raises:
5253
ValueError: Invalid Argument.
5354
"""
54-
if not isinstance(checkpoint, ray_torch.TorchCheckpoint):
55-
raise ValueError(
56-
"[Ray on Vertex AI]: arg checkpoint should be a"
57-
" ray.train.torch.TorchCheckpoint instance"
58-
)
59-
if checkpoint.get_preprocessor() is not None:
60-
logging.warning(
61-
"Checkpoint contains preprocessor. However, converting from a Ray"
62-
" Checkpoint to framework specific model does NOT support"
63-
" preprocessing. The model will be exported without preprocessors."
64-
)
65-
return checkpoint.get_model(model=model)
55+
ray_version = ray.__version__
56+
if ray_version == "2.4.0":
57+
if not isinstance(checkpoint, ray_torch.TorchCheckpoint):
58+
raise ValueError(
59+
"[Ray on Vertex AI]: arg checkpoint should be a"
60+
" ray.train.torch.TorchCheckpoint instance"
61+
)
62+
if checkpoint.get_preprocessor() is not None:
63+
logging.warning(
64+
"Checkpoint contains preprocessor. However, converting from a Ray"
65+
" Checkpoint to framework specific model does NOT support"
66+
" preprocessing. The model will be exported without preprocessors."
67+
)
68+
return checkpoint.get_model(model=model)
69+
70+
# get_model() signature changed in future versions
71+
try:
72+
return checkpoint.get_model()
73+
except AttributeError:
74+
raise RuntimeError("Unsupported Ray version.")

google/cloud/aiplatform/preview/vertex_ray/predict/xgboost/register.py

+21-12
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import logging
2121
import os
2222
import pickle
23+
import ray
2324
import tempfile
2425
from typing import Optional, TYPE_CHECKING
2526

@@ -121,15 +122,23 @@ def _get_xgboost_model_from(
121122
Raises:
122123
ValueError: Invalid Argument.
123124
"""
124-
if not isinstance(checkpoint, ray_xgboost.XGBoostCheckpoint):
125-
raise ValueError(
126-
"[Ray on Vertex AI]: arg checkpoint should be a"
127-
" ray.train.xgboost.XGBoostCheckpoint instance"
128-
)
129-
if checkpoint.get_preprocessor() is not None:
130-
logging.warning(
131-
"Checkpoint contains preprocessor. However, converting from a Ray"
132-
" Checkpoint to framework specific model does NOT support"
133-
" preprocessing. The model will be exported without preprocessors."
134-
)
135-
return checkpoint.get_model()
125+
ray_version = ray.__version__
126+
if ray_version == "2.4.0":
127+
if not isinstance(checkpoint, ray_xgboost.XGBoostCheckpoint):
128+
raise ValueError(
129+
"[Ray on Vertex AI]: arg checkpoint should be a"
130+
" ray.train.xgboost.XGBoostCheckpoint instance"
131+
)
132+
if checkpoint.get_preprocessor() is not None:
133+
logging.warning(
134+
"Checkpoint contains preprocessor. However, converting from a Ray"
135+
" Checkpoint to framework specific model does NOT support"
136+
" preprocessing. The model will be exported without preprocessors."
137+
)
138+
return checkpoint.get_model()
139+
140+
# get_model() signature changed in future versions
141+
try:
142+
return checkpoint.get_model()
143+
except AttributeError:
144+
raise RuntimeError("Unsupported Ray version.")

google/cloud/aiplatform/preview/vertex_ray/util/_validation_utils.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -78,21 +78,23 @@ def get_local_ray_version():
7878

7979
def get_image_uri(ray_version, python_version, enable_cuda):
8080
"""Image uri for a given ray version and python version."""
81-
if ray_version not in ["2_4"]:
82-
raise ValueError("[Ray on Vertex AI]: The supported Ray version is 2_4.")
81+
if ray_version not in ["2_4", "2_9"]:
82+
raise ValueError(
83+
"[Ray on Vertex AI]: The supported Ray versions are 2_4 (2.4.0) and 2_9 (2.9.3)."
84+
)
8385
if python_version not in ["3_10"]:
8486
raise ValueError("[Ray on Vertex AI]: The supported Python version is 3_10.")
8587

8688
location = initializer.global_config.location
8789
region = location.split("-")[0]
8890
if region not in _AVAILABLE_REGIONS:
8991
region = _DEFAULT_REGION
90-
92+
ray_version = ray_version.replace("_", "-")
9193
if enable_cuda:
9294
# TODO(b/292003337) update eligible image uris
93-
return f"{region}-docker.pkg.dev/vertex-ai/training/ray-gpu.2-4.py310:latest"
95+
return f"{region}-docker.pkg.dev/vertex-ai/training/ray-gpu.{ray_version}.py310:latest"
9496
else:
95-
return f"{region}-docker.pkg.dev/vertex-ai/training/ray-cpu.2-4.py310:latest"
97+
return f"{region}-docker.pkg.dev/vertex-ai/training/ray-cpu.{ray_version}.py310:latest"
9698

9799

98100
def get_versions_from_image_uri(image_uri):

tests/unit/vertex_ray/test_cluster_init.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def test_create_ray_cluster_ray_version_error(self):
381381
network=tc.ProjectConstants._TEST_VPC_NETWORK,
382382
ray_version="2_1",
383383
)
384-
e.match(regexp=r"The supported Ray version is 2_4.")
384+
e.match(regexp=r"The supported Ray versions are 2_4 ")
385385

386386
@pytest.mark.usefixtures("create_persistent_resource_exception_mock")
387387
def test_create_ray_cluster_state_error(self):

0 commit comments

Comments
 (0)