Skip to content

Commit d727189

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: support custom image for Ray cluster creation
PiperOrigin-RevId: 606689613
1 parent 310ee49 commit d727189

File tree

6 files changed

+126
-7
lines changed

6 files changed

+126
-7
lines changed

google/cloud/aiplatform/preview/vertex_ray/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
)
3535
from google.cloud.aiplatform.preview.vertex_ray.util.resources import (
3636
Resources,
37+
NodeImages,
3738
)
3839

3940
from google.cloud.aiplatform.preview.vertex_ray.dashboard_sdk import (
@@ -55,4 +56,5 @@
5556
"list_ray_clusters",
5657
"update_ray_cluster",
5758
"Resources",
59+
"NodeImages",
5860
)

google/cloud/aiplatform/preview/vertex_ray/cluster_init.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import copy
1919
import logging
20+
import time
2021
from typing import Dict, List, Optional
2122

2223
from google.cloud.aiplatform import initializer
@@ -47,6 +48,7 @@ def create_ray_cluster(
4748
network: Optional[str] = None,
4849
cluster_name: Optional[str] = None,
4950
worker_node_types: Optional[List[resources.Resources]] = None,
51+
custom_images: Optional[resources.NodeImages] = None,
5052
labels: Optional[Dict[str, str]] = None,
5153
) -> str:
5254
"""Create a ray cluster on the Vertex AI.
@@ -97,6 +99,8 @@ def create_ray_cluster(
9799
or hyphen.
98100
worker_node_types: The list of Resources of the worker nodes. The same
99101
Resources object should not appear multiple times in the list.
102+
custom_images: The NodeImages which specifies head node and worker nodes
103+
images. Allowlist only.
100104
labels:
101105
The labels with user-defined metadata to organize Ray cluster.
102106
@@ -157,6 +161,9 @@ def create_ray_cluster(
157161
image_uri = _validation_utils.get_image_uri(
158162
ray_version, python_version, enable_cuda
159163
)
164+
if custom_images is not None:
165+
if not (custom_images.head is None or custom_images.worker is None):
166+
image_uri = custom_images.head
160167
resource_pool_images[resource_pool_0.id] = image_uri
161168

162169
worker_pools = []
@@ -199,6 +206,9 @@ def create_ray_cluster(
199206
image_uri = _validation_utils.get_image_uri(
200207
ray_version, python_version, enable_cuda
201208
)
209+
if custom_images is not None:
210+
if not (custom_images.head is None or custom_images.worker is None):
211+
image_uri = custom_images.worker
202212
resource_pool_images[resource_pool.id] = image_uri
203213

204214
i += 1
@@ -425,6 +435,12 @@ def update_ray_cluster(
425435
) from e
426436

427437
# block before returning
438+
start_time = time.time()
428439
response = operation_future.result()
429-
print("[Ray on Vertex AI]: Successfully updated the cluster.")
440+
duration = (time.time() - start_time) // 60
441+
print(
442+
"[Ray on Vertex AI]: Successfully updated the cluster ({} mininutes elapsed).".format(
443+
duration
444+
)
445+
)
430446
return response.name

google/cloud/aiplatform/preview/vertex_ray/util/_gapic_utils.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from google.cloud.aiplatform.preview.vertex_ray.util import _validation_utils
2929
from google.cloud.aiplatform.preview.vertex_ray.util.resources import (
3030
Cluster,
31+
NodeImages,
3132
Resources,
3233
)
3334
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
@@ -156,14 +157,24 @@ def persistent_resource_to_cluster(
156157
)
157158
return
158159

159-
image_uri = persistent_resource.resource_runtime_spec.ray_spec.resource_pool_images[
160-
"head-node"
161-
]
162-
if not image_uri:
163-
image_uri = persistent_resource.resource_runtime_spec.ray_spec.image_uri
160+
head_image_uri = (
161+
persistent_resource.resource_runtime_spec.ray_spec.resource_pool_images[
162+
"head-node"
163+
]
164+
)
165+
worker_image_uri = (
166+
persistent_resource.resource_runtime_spec.ray_spec.resource_pool_images.get(
167+
"worker-pool1", None
168+
)
169+
)
170+
if worker_image_uri is None:
171+
worker_image_uri = head_image_uri
172+
173+
if not head_image_uri:
174+
head_image_uri = persistent_resource.resource_runtime_spec.ray_spec.image_uri
164175
try:
165176
python_version, ray_version = _validation_utils.get_versions_from_image_uri(
166-
image_uri
177+
head_image_uri
167178
)
168179
except IndexError:
169180
logging.info(
@@ -173,6 +184,7 @@ def persistent_resource_to_cluster(
173184
return
174185
cluster.python_version = python_version
175186
cluster.ray_version = ray_version
187+
cluster.node_images = NodeImages(head=head_image_uri, worker=worker_image_uri)
176188

177189
resource_pools = persistent_resource.resource_pools
178190

google/cloud/aiplatform/preview/vertex_ray/util/resources.py

+20
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,24 @@ class Resources:
4747
boot_disk_size_gb: Optional[int] = 100
4848

4949

50+
@dataclasses.dataclass
51+
class NodeImages:
52+
"""
53+
Custom images for a ray cluster. We currently support Ray v2.4 and python v3.10.
54+
The custom images must be extended from the following base images:
55+
"{region}-docker.pkg.dev/vertex-ai/training/ray-cpu.2-4.py310:latest" or
56+
"{region}-docker.pkg.dev/vertex-ai/training/ray-gpu.2-4.py310:latest". In
57+
order to use custom images, need to specify both head and worker images.
58+
59+
Attributes:
60+
head: head node image (eg. us-docker.pkg.dev/my-project/ray-cpu.2-4.py310-tf:latest).
61+
worker: worker node image (eg. us-docker.pkg.dev/my-project/ray-gpu.2-4.py310-tf:latest).
62+
"""
63+
64+
head: str = None
65+
worker: str = None
66+
67+
5068
@dataclasses.dataclass
5169
class Cluster:
5270
"""Ray cluster (output only).
@@ -69,6 +87,7 @@ class Cluster:
6987
duplicate the elements in the list.
7088
dashboard_address: For Ray Job API (JobSubmissionClient), with this
7189
cluster connection doesn't require VPC peering.
90+
node_images: The NodeImages for a ray cluster.
7291
labels:
7392
The labels with user-defined metadata to organize Ray cluster.
7493
@@ -87,6 +106,7 @@ class Cluster:
87106
head_node_type: Resources = None
88107
worker_node_types: List[Resources] = None
89108
dashboard_address: str = None
109+
node_images: NodeImages = None
90110
labels: Dict[str, str] = None
91111

92112

tests/unit/vertex_ray/test_cluster_init.py

+42
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from google.cloud.aiplatform.preview import vertex_ray
2121
from google.cloud.aiplatform.preview.vertex_ray.util.resources import (
2222
Resources,
23+
NodeImages,
2324
)
2425
from google.cloud.aiplatform_v1beta1.services.persistent_resource_service import (
2526
PersistentResourceServiceClient,
@@ -80,6 +81,18 @@ def get_persistent_resource_1_pool_mock():
8081
yield get_persistent_resource_1_pool_mock
8182

8283

84+
@pytest.fixture
85+
def get_persistent_resource_1_pool_custom_image_mock():
86+
with mock.patch.object(
87+
PersistentResourceServiceClient,
88+
"get_persistent_resource",
89+
) as get_persistent_resource_1_pool_custom_image_mock:
90+
get_persistent_resource_1_pool_custom_image_mock.return_value = (
91+
tc.ClusterConstants._TEST_RESPONSE_RUNNING_1_POOL_CUSTOM_IMAGES
92+
)
93+
yield get_persistent_resource_1_pool_custom_image_mock
94+
95+
8396
@pytest.fixture
8497
def create_persistent_resource_2_pools_mock():
8598
with mock.patch.object(
@@ -234,6 +247,35 @@ def test_create_ray_cluster_1_pool_gpu_success(
234247
request,
235248
)
236249

250+
@pytest.mark.usefixtures("get_persistent_resource_1_pool_custom_image_mock")
251+
def test_create_ray_cluster_1_pool_custom_image_success(
252+
self, create_persistent_resource_1_pool_mock
253+
):
254+
"""If head and worker nodes are duplicate, merge to head pool."""
255+
custom_images = NodeImages(
256+
head=tc.ClusterConstants._TEST_CUSTOM_IMAGE,
257+
worker=tc.ClusterConstants._TEST_CUSTOM_IMAGE,
258+
)
259+
cluster_name = vertex_ray.create_ray_cluster(
260+
head_node_type=tc.ClusterConstants._TEST_HEAD_NODE_TYPE_1_POOL,
261+
worker_node_types=tc.ClusterConstants._TEST_WORKER_NODE_TYPES_1_POOL,
262+
network=tc.ProjectConstants._TEST_VPC_NETWORK,
263+
cluster_name=tc.ClusterConstants._TEST_VERTEX_RAY_PR_ID,
264+
custom_images=custom_images,
265+
)
266+
267+
assert tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS == cluster_name
268+
269+
request = persistent_resource_service.CreatePersistentResourceRequest(
270+
parent=tc.ProjectConstants._TEST_PARENT,
271+
persistent_resource=tc.ClusterConstants._TEST_REQUEST_RUNNING_1_POOL_CUSTOM_IMAGES,
272+
persistent_resource_id=tc.ClusterConstants._TEST_VERTEX_RAY_PR_ID,
273+
)
274+
275+
create_persistent_resource_1_pool_mock.assert_called_with(
276+
request,
277+
)
278+
237279
@pytest.mark.usefixtures("get_persistent_resource_1_pool_mock")
238280
def test_create_ray_cluster_1_pool_gpu_with_labels_success(
239281
self, create_persistent_resource_1_pool_mock

tests/unit/vertex_ray/test_constants.py

+27
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from google.cloud.aiplatform.preview.vertex_ray.util.resources import Cluster
2121
from google.cloud.aiplatform.preview.vertex_ray.util.resources import (
2222
Resources,
23+
NodeImages,
2324
)
2425
from google.cloud.aiplatform_v1beta1.types.machine_resources import DiskSpec
2526
from google.cloud.aiplatform_v1beta1.types.machine_resources import (
@@ -82,6 +83,7 @@ class ClusterConstants:
8283
)
8384
_TEST_CPU_IMAGE = "us-docker.pkg.dev/vertex-ai/training/ray-cpu.2-4.py310:latest"
8485
_TEST_GPU_IMAGE = "us-docker.pkg.dev/vertex-ai/training/ray-gpu.2-4.py310:latest"
86+
_TEST_CUSTOM_IMAGE = "us-docker.pkg.dev/my-project/ray-custom.2-4.py310:latest"
8587
# RUNNING Persistent Cluster w/o Ray
8688
_TEST_RESPONSE_NO_RAY_RUNNING = PersistentResource(
8789
name=_TEST_VERTEX_RAY_PR_ADDRESS,
@@ -127,6 +129,13 @@ class ClusterConstants:
127129
network=ProjectConstants._TEST_VPC_NETWORK,
128130
labels=_TEST_LABELS,
129131
)
132+
_TEST_REQUEST_RUNNING_1_POOL_CUSTOM_IMAGES = PersistentResource(
133+
resource_pools=[_TEST_RESOURCE_POOL_0],
134+
resource_runtime_spec=ResourceRuntimeSpec(
135+
ray_spec=RaySpec(resource_pool_images={"head-node": _TEST_CUSTOM_IMAGE}),
136+
),
137+
network=ProjectConstants._TEST_VPC_NETWORK,
138+
)
130139
# Get response has generated name, and URIs
131140
_TEST_RESPONSE_RUNNING_1_POOL = PersistentResource(
132141
name=_TEST_VERTEX_RAY_PR_ADDRESS,
@@ -143,6 +152,22 @@ class ClusterConstants:
143152
),
144153
state="RUNNING",
145154
)
155+
# Get response has generated name, and URIs
156+
_TEST_RESPONSE_RUNNING_1_POOL_CUSTOM_IMAGES = PersistentResource(
157+
name=_TEST_VERTEX_RAY_PR_ADDRESS,
158+
resource_pools=[_TEST_RESOURCE_POOL_0],
159+
resource_runtime_spec=ResourceRuntimeSpec(
160+
ray_spec=RaySpec(resource_pool_images={"head-node": _TEST_CUSTOM_IMAGE}),
161+
),
162+
network=ProjectConstants._TEST_VPC_NETWORK,
163+
resource_runtime=ResourceRuntime(
164+
access_uris={
165+
"RAY_DASHBOARD_URI": _TEST_VERTEX_RAY_DASHBOARD_ADDRESS,
166+
"RAY_HEAD_NODE_INTERNAL_IP": _TEST_VERTEX_RAY_HEAD_NODE_IP,
167+
}
168+
),
169+
state="RUNNING",
170+
)
146171
# 2_POOL: worker_node_types and head_node_type have different MachineSpecs
147172
_TEST_HEAD_NODE_TYPE_2_POOLS = Resources()
148173
_TEST_WORKER_NODE_TYPES_2_POOLS = [
@@ -213,6 +238,7 @@ class ClusterConstants:
213238
head_node_type=_TEST_HEAD_NODE_TYPE_1_POOL,
214239
worker_node_types=_TEST_WORKER_NODE_TYPES_1_POOL,
215240
dashboard_address=_TEST_VERTEX_RAY_DASHBOARD_ADDRESS,
241+
node_images=NodeImages(head=_TEST_CPU_IMAGE, worker=_TEST_CPU_IMAGE),
216242
)
217243
_TEST_CLUSTER_2 = Cluster(
218244
cluster_resource_name=_TEST_VERTEX_RAY_PR_ADDRESS,
@@ -223,6 +249,7 @@ class ClusterConstants:
223249
head_node_type=_TEST_HEAD_NODE_TYPE_2_POOLS,
224250
worker_node_types=_TEST_WORKER_NODE_TYPES_2_POOLS,
225251
dashboard_address=_TEST_VERTEX_RAY_DASHBOARD_ADDRESS,
252+
node_images=NodeImages(head=_TEST_CPU_IMAGE, worker=_TEST_GPU_IMAGE),
226253
)
227254
_TEST_BEARER_TOKEN = "test-bearer-token"
228255
_TEST_HEADERS = {

0 commit comments

Comments
 (0)