Skip to content

Commit 10c6ad2

Browse files
matthew29tangcopybara-github
authored andcommitted
feat: Verify client and cluster Ray versions match
PiperOrigin-RevId: 588901140
1 parent 7c64672 commit 10c6ad2

File tree

4 files changed

+32
-0
lines changed

4 files changed

+32
-0
lines changed

google/cloud/aiplatform/preview/vertex_ray/client_builder.py

+12
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,18 @@ def __init__(self, address: Optional[str]) -> None:
111111
" failed to start Head node properly because custom service account isn't supported.",
112112
)
113113
logging.debug("[Ray on Vertex AI]: Resolved head node ip: %s", address)
114+
cluster = _gapic_utils.persistent_resource_to_cluster(
115+
persistent_resource=self.response
116+
)
117+
if cluster is None:
118+
raise ValueError(
119+
"[Ray on Vertex AI]: Please delete and recreate the cluster (The cluster is not a Ray cluster or the cluster image is outdated)."
120+
)
121+
local_ray_verion = _validation_utils.get_local_ray_version()
122+
if cluster.ray_version != local_ray_verion:
123+
raise ValueError(
124+
f"[Ray on Vertex AI]: Local runtime has Ray version {local_ray_verion}, but the cluster runtime has {cluster.ray_version}. Please ensure that the Ray versions match."
125+
)
114126
super().__init__(address)
115127

116128
def connect(self) -> _VertexRayClientContext:

google/cloud/aiplatform/preview/vertex_ray/util/_validation_utils.py

+8
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import google.auth
1919
import google.auth.transport.requests
2020
import logging
21+
import ray
2122
import re
2223

2324
from google.cloud.aiplatform import initializer
@@ -68,6 +69,13 @@ def maybe_reconstruct_resource_name(address) -> str:
6869
return address
6970

7071

72+
def get_local_ray_version():
73+
ray_version = ray.__version__.split(".")
74+
if len(ray_version) == 3:
75+
ray_version = ray_version[:2]
76+
return "_".join(ray_version)
77+
78+
7179
def get_image_uri(ray_version, python_version, enable_cuda):
7280
"""Image uri for a given ray version and python version."""
7381
if ray_version not in ["2_4"]:

tests/unit/vertex_ray/test_constants.py

+7
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@
3939
ResourceRuntimeSpec,
4040
)
4141

42+
import pytest
43+
import sys
44+
45+
rovminversion = pytest.mark.skipif(
46+
sys.version_info > (3, 10), reason="Requires python3.10 or lower"
47+
)
48+
4249

4350
@dataclasses.dataclass(frozen=True)
4451
class ProjectConstants:

tests/unit/vertex_ray/test_vertex_ray_client.py

+5
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def setup_method(self):
8484
def teardown_method(self):
8585
aiplatform.initializer.global_pool.shutdown(wait=True)
8686

87+
@tc.rovminversion
8788
@pytest.mark.usefixtures("get_persistent_resource_status_running_mock")
8889
def test_init_with_full_resource_name(
8990
self,
@@ -94,6 +95,7 @@ def test_init_with_full_resource_name(
9495
tc.ClusterConstants._TEST_VERTEX_RAY_HEAD_NODE_IP,
9596
)
9697

98+
@tc.rovminversion
9799
@pytest.mark.usefixtures(
98100
"get_persistent_resource_status_running_mock", "google_auth_mock"
99101
)
@@ -112,6 +114,7 @@ def test_init_with_cluster_name(
112114
tc.ClusterConstants._TEST_VERTEX_RAY_HEAD_NODE_IP,
113115
)
114116

117+
@tc.rovminversion
115118
@pytest.mark.usefixtures("get_persistent_resource_status_running_mock")
116119
def test_connect_running(self, ray_client_connect_mock):
117120
connect_result = vertex_ray.ClientBuilder(
@@ -124,6 +127,7 @@ def test_connect_running(self, ray_client_connect_mock):
124127
== tc.ClusterConstants._TEST_VERTEX_RAY_PR_ID
125128
)
126129

130+
@tc.rovminversion
127131
@pytest.mark.usefixtures("get_persistent_resource_status_running_no_ray_mock")
128132
def test_connect_running_no_ray(self, ray_client_connect_mock):
129133
expected_message = (
@@ -139,6 +143,7 @@ def test_connect_running_no_ray(self, ray_client_connect_mock):
139143
ray_client_connect_mock.assert_called_once_with()
140144
assert str(exception.value) == expected_message
141145

146+
@tc.rovminversion
142147
@pytest.mark.parametrize(
143148
"address",
144149
[

0 commit comments

Comments
 (0)