Skip to content

Commit 36a56b9

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: Support reserved_ip_ranges for VPC network in Ray on Vertex cluster
chore: Update ray prediction tests for forward compatibility PiperOrigin-RevId: 670628417
1 parent 4a528c6 commit 36a56b9

File tree

7 files changed

+45
-5
lines changed

7 files changed

+45
-5
lines changed

google/cloud/aiplatform/vertex_ray/cluster_init.py

+7
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def create_ray_cluster(
6161
enable_metrics_collection: Optional[bool] = True,
6262
enable_logging: Optional[bool] = True,
6363
psc_interface_config: Optional[resources.PscIConfig] = None,
64+
reserved_ip_ranges: Optional[List[str]] = None,
6465
labels: Optional[Dict[str, str]] = None,
6566
) -> str:
6667
"""Create a ray cluster on the Vertex AI.
@@ -126,6 +127,11 @@ def create_ray_cluster(
126127
enable_metrics_collection: Enable Ray metrics collection for visualization.
127128
enable_logging: Enable exporting Ray logs to Cloud Logging.
128129
psc_interface_config: PSC-I config.
130+
reserved_ip_ranges: A list of names for the reserved IP ranges under
131+
the VPC network that can be used for this cluster. If set, we will
132+
deploy the cluster within the provided IP ranges. Otherwise, the
133+
cluster is deployed to any IP ranges under the provided VPC network.
134+
Example: ["vertex-ai-ip-range"].
129135
labels:
130136
The labels with user-defined metadata to organize Ray cluster.
131137
@@ -325,6 +331,7 @@ def create_ray_cluster(
325331
labels=labels,
326332
resource_runtime_spec=resource_runtime_spec,
327333
psc_interface_config=gapic_psc_interface_config,
334+
reserved_ip_ranges=reserved_ip_ranges,
328335
)
329336

330337
location = initializer.global_config.location

google/cloud/aiplatform/vertex_ray/predict/xgboost/register.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@
4343
import xgboost
4444

4545
except ModuleNotFoundError as mnfe:
46-
raise ModuleNotFoundError("XGBoost isn't installed.") from mnfe
46+
if ray.__version__ == "2.9.3":
47+
raise ModuleNotFoundError("XGBoost isn't installed.") from mnfe
48+
else:
49+
xgboost = None
4750

4851

4952
def register_xgboost(

google/cloud/aiplatform/vertex_ray/util/_gapic_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def persistent_resource_to_cluster(
150150
cluster = Cluster(
151151
cluster_resource_name=persistent_resource.name,
152152
network=persistent_resource.network,
153+
reserved_ip_ranges=persistent_resource.reserved_ip_ranges,
153154
state=persistent_resource.state.name,
154155
labels=persistent_resource.labels,
155156
dashboard_address=dashboard_address,

google/cloud/aiplatform/vertex_ray/util/resources.py

+6
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,11 @@ class Cluster:
117117
managed in the Vertex API service. For Ray Job API, VPC network is
118118
not required because cluster connection can be accessed through
119119
dashboard address.
120+
reserved_ip_ranges: A list of names for the reserved IP ranges under
121+
the VPC network that can be used for this cluster. If set, we will
122+
deploy the cluster within the provided IP ranges. Otherwise, the
123+
cluster is deployed to any IP ranges under the provided VPC network.
124+
Example: ["vertex-ai-ip-range"].
120125
service_account: Service account to be used for running Ray programs on
121126
the cluster.
122127
state: Describes the cluster state (defined in PersistentResource.State).
@@ -140,6 +145,7 @@ class Cluster:
140145

141146
cluster_resource_name: str = None
142147
network: str = None
148+
reserved_ip_ranges: List[str] = None
143149
service_account: str = None
144150
state: PersistentResource.State = None
145151
python_version: str = None

tests/unit/vertex_ray/test_cluster_init.py

+1
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ def test_create_ray_cluster_2_pools_custom_images_success(
384384
head_node_type=tc.ClusterConstants.TEST_HEAD_NODE_TYPE_2_POOLS_CUSTOM_IMAGE,
385385
worker_node_types=tc.ClusterConstants.TEST_WORKER_NODE_TYPES_2_POOLS_CUSTOM_IMAGE,
386386
network=tc.ProjectConstants.TEST_VPC_NETWORK,
387+
reserved_ip_ranges=["vertex-dedicated-range"],
387388
cluster_name=tc.ClusterConstants.TEST_VERTEX_RAY_PR_ID,
388389
)
389390

tests/unit/vertex_ray/test_constants.py

+12
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,17 @@
5151
from google.cloud.aiplatform_v1beta1.types.service_networking import (
5252
PscInterfaceConfig,
5353
)
54+
import ray
5455
import pytest
5556

5657

5758
rovminversion = pytest.mark.skipif(
5859
sys.version_info > (3, 10), reason="Requires python3.10 or lower"
5960
)
61+
# TODO(b/363340317)
62+
xgbversion = pytest.mark.skipif(
63+
ray.__version__ != "2.9.3", reason="Requires xgboost 1.7 or higher"
64+
)
6065

6166

6267
@dataclasses.dataclass(frozen=True)
@@ -347,6 +352,7 @@ class ClusterConstants:
347352
),
348353
psc_interface_config=None,
349354
network=ProjectConstants.TEST_VPC_NETWORK,
355+
reserved_ip_ranges=["vertex-dedicated-range"],
350356
)
351357
# Responses
352358
TEST_RESOURCE_POOL_2.replica_count = 1
@@ -366,6 +372,7 @@ class ClusterConstants:
366372
network_attachment=TEST_PSC_NETWORK_ATTACHMENT
367373
),
368374
network=None,
375+
reserved_ip_ranges=None,
369376
resource_runtime=ResourceRuntime(
370377
access_uris={
371378
"RAY_DASHBOARD_URI": TEST_VERTEX_RAY_DASHBOARD_ADDRESS,
@@ -386,6 +393,7 @@ class ClusterConstants:
386393
),
387394
),
388395
network=ProjectConstants.TEST_VPC_NETWORK,
396+
reserved_ip_ranges=["vertex-dedicated-range"],
389397
resource_runtime=ResourceRuntime(
390398
access_uris={
391399
"RAY_DASHBOARD_URI": TEST_VERTEX_RAY_DASHBOARD_ADDRESS,
@@ -399,6 +407,7 @@ class ClusterConstants:
399407
python_version="3.10",
400408
ray_version="2.9",
401409
network=ProjectConstants.TEST_VPC_NETWORK,
410+
reserved_ip_ranges=None,
402411
service_account=None,
403412
state="RUNNING",
404413
head_node_type=TEST_HEAD_NODE_TYPE_1_POOL,
@@ -412,6 +421,7 @@ class ClusterConstants:
412421
python_version="3.10",
413422
ray_version="2.9",
414423
network="",
424+
reserved_ip_ranges="",
415425
service_account=None,
416426
state="RUNNING",
417427
head_node_type=TEST_HEAD_NODE_TYPE_2_POOLS,
@@ -424,6 +434,7 @@ class ClusterConstants:
424434
TEST_CLUSTER_CUSTOM_IMAGE = Cluster(
425435
cluster_resource_name=TEST_VERTEX_RAY_PR_ADDRESS,
426436
network=ProjectConstants.TEST_VPC_NETWORK,
437+
reserved_ip_ranges=["vertex-dedicated-range"],
427438
service_account=None,
428439
state="RUNNING",
429440
head_node_type=TEST_HEAD_NODE_TYPE_2_POOLS_CUSTOM_IMAGE,
@@ -438,6 +449,7 @@ class ClusterConstants:
438449
python_version="3.10",
439450
ray_version="2.9",
440451
network="",
452+
reserved_ip_ranges="",
441453
service_account=ProjectConstants.TEST_SERVICE_ACCOUNT,
442454
state="RUNNING",
443455
head_node_type=TEST_HEAD_NODE_TYPE_1_POOL,

tests/unit/vertex_ray/test_ray_prediction.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
import numpy as np
4242
import pytest
4343
import ray
44-
from ray.train import xgboost as ray_xgboost
4544
import tensorflow as tf
4645
import torch
4746
import xgboost
@@ -90,9 +89,14 @@ def ray_sklearn_checkpoint():
9089

9190
@pytest.fixture()
9291
def ray_xgboost_checkpoint():
93-
model = test_prediction_utils.get_xgboost_model()
94-
checkpoint = ray_xgboost.XGBoostCheckpoint.from_model(model.get_booster())
95-
return checkpoint
92+
if ray.__version__ == "2.9.3":
93+
from ray.train import xgboost as ray_xgboost
94+
95+
model = test_prediction_utils.get_xgboost_model()
96+
checkpoint = ray_xgboost.XGBoostCheckpoint.from_model(model.get_booster())
97+
return checkpoint
98+
else:
99+
return None
96100

97101

98102
@pytest.fixture()
@@ -374,6 +378,7 @@ def test_register_sklearnartifact_uri_not_gcs_uri_raise_error(
374378
assert ve.match(regexp=r".*'artifact_uri' should " "start with 'gs://'.*")
375379

376380
# XGBoost Tests
381+
@tc.xgbversion
377382
@tc.rovminversion
378383
def test_convert_checkpoint_to_xgboost_raise_exception(
379384
self, ray_checkpoint_from_dict
@@ -392,6 +397,7 @@ def test_convert_checkpoint_to_xgboost_raise_exception(
392397
"ray.train.xgboost.XGBoostCheckpoint .*"
393398
)
394399

400+
@tc.xgbversion
395401
def test_convert_checkpoint_to_xgboost_model_succeed(
396402
self, ray_xgboost_checkpoint
397403
) -> None:
@@ -406,6 +412,7 @@ def test_convert_checkpoint_to_xgboost_model_succeed(
406412
y_pred = model.predict(xgboost.DMatrix(np.array([[1, 2]])))
407413
assert y_pred[0] is not None
408414

415+
@tc.xgbversion
409416
def test_register_xgboost_succeed(
410417
self,
411418
ray_xgboost_checkpoint,
@@ -429,6 +436,7 @@ def test_register_xgboost_succeed(
429436
pickle_dump.assert_called_once()
430437
gcs_utils_upload_to_gcs.assert_called_once()
431438

439+
@tc.xgbversion
432440
def test_register_xgboost_initialized_succeed(
433441
self,
434442
ray_xgboost_checkpoint,
@@ -455,6 +463,7 @@ def test_register_xgboost_initialized_succeed(
455463
pickle_dump.assert_called_once()
456464
gcs_utils_upload_to_gcs.assert_called_once()
457465

466+
@tc.xgbversion
458467
def test_register_xgboostartifact_uri_is_none_raise_error(
459468
self, ray_xgboost_checkpoint
460469
) -> None:
@@ -467,6 +476,7 @@ def test_register_xgboostartifact_uri_is_none_raise_error(
467476
)
468477
assert ve.match(regexp=r".*'artifact_uri' should " "start with 'gs://'.*")
469478

479+
@tc.xgbversion
470480
def test_register_xgboostartifact_uri_not_gcs_uri_raise_error(
471481
self, ray_xgboost_checkpoint
472482
) -> None:

0 commit comments

Comments
 (0)