Skip to content

Commit 979a4f3

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: Add explicit constraints for update_ray_cluster
PiperOrigin-RevId: 589973886
1 parent cd233ef commit 979a4f3

File tree

3 files changed

+152
-27
lines changed

3 files changed

+152
-27
lines changed

google/cloud/aiplatform/preview/vertex_ray/cluster_init.py

+54-5
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,14 @@ def create_ray_cluster(
124124
"[Ray on Vertex AI]: For head_node_type, "
125125
+ "Resources.node_count must be 1."
126126
)
127+
if (
128+
head_node_type.accelerator_type is None
129+
and head_node_type.accelerator_count > 0
130+
):
131+
raise ValueError(
132+
"[Ray on Vertex]: accelerator_type must be specified when"
133+
+ " accelerator_count is set to a value other than 0."
134+
)
127135

128136
resource_pool_images = {}
129137

@@ -147,6 +155,14 @@ def create_ray_cluster(
147155
i = 0
148156
if worker_node_types:
149157
for worker_node_type in worker_node_types:
158+
if (
159+
worker_node_type.accelerator_type is None
160+
and worker_node_type.accelerator_count > 0
161+
):
162+
raise ValueError(
163+
"[Ray on Vertex]: accelerator_type must be specified when"
164+
+ " accelerator_count is set to a value other than 0."
165+
)
150166
# Worker and head share the same MachineSpec, merge them into the
151167
# same ResourcePool
152168
additional_replica_count = resources._check_machine_spec_identical(
@@ -327,31 +343,64 @@ def update_ray_cluster(
327343
Returns:
328344
The cluster_resource_name of the Ray cluster on Vertex.
329345
"""
346+
# worker_node_types should not be duplicated.
347+
for i in range(len(worker_node_types)):
348+
for j in range(len(worker_node_types)):
349+
additional_replica_count = resources._check_machine_spec_identical(
350+
worker_node_types[i], worker_node_types[j]
351+
)
352+
if additional_replica_count > 0 and i != j:
353+
raise ValueError(
354+
"[Ray on Vertex AI]: Worker_node_types have duplicate machine specs: ",
355+
worker_node_types[i],
356+
"and ",
357+
worker_node_types[j],
358+
)
359+
330360
persistent_resource = _gapic_utils.get_persistent_resource(
331361
persistent_resource_name=cluster_resource_name
332362
)
333363

334364
current_persistent_resource = copy.deepcopy(persistent_resource)
335-
head_node_type = get_ray_cluster(cluster_resource_name).head_node_type
336365
current_persistent_resource.resource_pools[0].replica_count = 1
337-
# TODO(b/300146407): Raise ValueError for duplicate resource pools
366+
367+
previous_ray_cluster = get_ray_cluster(cluster_resource_name)
368+
head_node_type = previous_ray_cluster.head_node_type
369+
previous_worker_node_types = previous_ray_cluster.worker_node_types
370+
371+
# new worker_node_types and previous_worker_node_types should be the same length.
372+
if len(worker_node_types) != len(previous_worker_node_types):
373+
raise ValueError(
374+
f"[Ray on Vertex AI]: Desired number of worker_node_types ({len(worker_node_types)}) does not match the number of the existing worker_node_type({len(previous_worker_node_types)}).",
375+
)
376+
377+
# Merge worker_node_type and head_node_type if the share
378+
# the same machine spec.
338379
not_merged = 1
339380
for i in range(len(worker_node_types)):
340381
additional_replica_count = resources._check_machine_spec_identical(
341382
head_node_type, worker_node_types[i]
342383
)
343-
if additional_replica_count != 0:
344-
# merge the 1st duplicated worker with head
384+
if additional_replica_count != 0 or (
385+
additional_replica_count == 0 and worker_node_types[i].node_count == 0
386+
):
387+
# Merge the 1st duplicated worker with head, allow scale down to 0 worker
345388
current_persistent_resource.resource_pools[0].replica_count = (
346389
1 + additional_replica_count
347390
)
348-
# reset not_merged
391+
# Reset not_merged
349392
not_merged = 0
350393
else:
351394
# No duplication w/ head node, write the 2nd worker node to the 2nd resource pool.
352395
current_persistent_resource.resource_pools[
353396
i + not_merged
354397
].replica_count = worker_node_types[i].node_count
398+
# New worker_node_type.node_count should be >=1 unless the worker_node_type
399+
# and head_node_type are merged due to the same machine specs.
400+
if worker_node_types[i].node_count == 0:
401+
raise ValueError(
402+
f"[Ray on Vertex AI]: Worker_node_type ({worker_node_types[i]}) must update to >= 1 nodes",
403+
)
355404

356405
request = persistent_resource_service.UpdatePersistentResourceRequest(
357406
persistent_resource=current_persistent_resource,

google/cloud/aiplatform/preview/vertex_ray/util/resources.py

+7-22
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from google.cloud.aiplatform_v1beta1.types import PersistentResource
2020

2121

22+
@dataclasses.dataclass
2223
class Resources:
2324
"""Resources for a ray cluster node.
2425
@@ -38,28 +39,12 @@ class Resources:
3839
be either unspecified or within the range of [100, 64000].
3940
"""
4041

41-
def __init__(
42-
self,
43-
machine_type: Optional[str] = "n1-standard-4",
44-
node_count: Optional[int] = 1,
45-
accelerator_type: Optional[str] = None,
46-
accelerator_count: Optional[int] = 0,
47-
boot_disk_type: Optional[str] = "pd-ssd",
48-
boot_disk_size_gb: Optional[int] = 100,
49-
):
50-
51-
self.machine_type = machine_type
52-
self.node_count = node_count
53-
self.accelerator_type = accelerator_type
54-
self.accelerator_count = accelerator_count
55-
self.boot_disk_type = boot_disk_type
56-
self.boot_disk_size_gb = boot_disk_size_gb
57-
58-
if accelerator_type is None and accelerator_count > 0:
59-
raise ValueError(
60-
"[Ray on Vertex]: accelerator_type must be specified when"
61-
+ " accelerator_count is set to a value other than 0."
62-
)
42+
machine_type: Optional[str] = "n1-standard-4"
43+
node_count: Optional[int] = 1
44+
accelerator_type: Optional[str] = None
45+
accelerator_count: Optional[int] = 0
46+
boot_disk_type: Optional[str] = "pd-ssd"
47+
boot_disk_size_gb: Optional[int] = 100
6348

6449

6550
@dataclasses.dataclass

tests/unit/vertex_ray/test_cluster_init.py

+91
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@
4646
)
4747
_TEST_RESPONSE_RUNNING_2_POOLS_RESIZE.resource_pools[1].replica_count = 1
4848

49+
_TEST_RESPONSE_RUNNING_1_POOL_RESIZE_0_WORKER = copy.deepcopy(
50+
tc.ClusterConstants._TEST_RESPONSE_RUNNING_1_POOL
51+
)
52+
_TEST_RESPONSE_RUNNING_1_POOL_RESIZE_0_WORKER.resource_pools[0].replica_count = 1
53+
4954

5055
@pytest.fixture
5156
def create_persistent_resource_1_pool_mock():
@@ -163,6 +168,22 @@ def update_persistent_resource_1_pool_mock():
163168
yield update_persistent_resource_1_pool_mock
164169

165170

171+
@pytest.fixture
172+
def update_persistent_resource_1_pool_0_worker_mock():
173+
with mock.patch.object(
174+
PersistentResourceServiceClient,
175+
"update_persistent_resource",
176+
) as update_persistent_resource_1_pool_0_worker_mock:
177+
update_persistent_resource_lro_mock = mock.Mock(ga_operation.Operation)
178+
update_persistent_resource_lro_mock.result.return_value = (
179+
_TEST_RESPONSE_RUNNING_1_POOL_RESIZE_0_WORKER
180+
)
181+
update_persistent_resource_1_pool_0_worker_mock.return_value = (
182+
update_persistent_resource_lro_mock
183+
)
184+
yield update_persistent_resource_1_pool_0_worker_mock
185+
186+
166187
@pytest.fixture
167188
def update_persistent_resource_2_pools_mock():
168189
with mock.patch.object(
@@ -472,6 +493,30 @@ def test_update_ray_cluster_1_pool(self, update_persistent_resource_1_pool_mock)
472493

473494
assert returned_name == tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS
474495

496+
@pytest.mark.usefixtures("get_persistent_resource_1_pool_mock")
497+
def test_update_ray_cluster_1_pool_to_0_worker(
498+
self, update_persistent_resource_1_pool_mock
499+
):
500+
501+
new_worker_node_types = []
502+
for worker_node_type in tc.ClusterConstants._TEST_CLUSTER.worker_node_types:
503+
# resize worker node to node_count = 0
504+
worker_node_type.node_count = 0
505+
new_worker_node_types.append(worker_node_type)
506+
507+
returned_name = vertex_ray.update_ray_cluster(
508+
cluster_resource_name=tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS,
509+
worker_node_types=new_worker_node_types,
510+
)
511+
512+
request = persistent_resource_service.UpdatePersistentResourceRequest(
513+
persistent_resource=_TEST_RESPONSE_RUNNING_1_POOL_RESIZE_0_WORKER,
514+
update_mask=_EXPECTED_MASK,
515+
)
516+
update_persistent_resource_1_pool_mock.assert_called_once_with(request)
517+
518+
assert returned_name == tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS
519+
475520
@pytest.mark.usefixtures("get_persistent_resource_2_pools_mock")
476521
def test_update_ray_cluster_2_pools(self, update_persistent_resource_2_pools_mock):
477522

@@ -493,3 +538,49 @@ def test_update_ray_cluster_2_pools(self, update_persistent_resource_2_pools_moc
493538
update_persistent_resource_2_pools_mock.assert_called_once_with(request)
494539

495540
assert returned_name == tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS
541+
542+
@pytest.mark.usefixtures("get_persistent_resource_2_pools_mock")
543+
def test_update_ray_cluster_2_pools_0_worker_fail(self):
544+
545+
new_worker_node_types = []
546+
for worker_node_type in tc.ClusterConstants._TEST_CLUSTER_2.worker_node_types:
547+
# resize worker node to node_count = 0
548+
worker_node_type.node_count = 0
549+
new_worker_node_types.append(worker_node_type)
550+
551+
with pytest.raises(ValueError) as e:
552+
vertex_ray.update_ray_cluster(
553+
cluster_resource_name=tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS,
554+
worker_node_types=new_worker_node_types,
555+
)
556+
557+
e.match(regexp=r"must update to >= 1 nodes.")
558+
559+
@pytest.mark.usefixtures("get_persistent_resource_1_pool_mock")
560+
def test_update_ray_cluster_duplicate_worker_node_types_error(self):
561+
new_worker_node_types = (
562+
tc.ClusterConstants._TEST_CLUSTER_2.worker_node_types
563+
+ tc.ClusterConstants._TEST_CLUSTER_2.worker_node_types
564+
)
565+
with pytest.raises(ValueError) as e:
566+
vertex_ray.update_ray_cluster(
567+
cluster_resource_name=tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS,
568+
worker_node_types=new_worker_node_types,
569+
)
570+
571+
e.match(regexp=r"Worker_node_types have duplicate machine specs")
572+
573+
@pytest.mark.usefixtures("get_persistent_resource_1_pool_mock")
574+
def test_update_ray_cluster_mismatch_worker_node_types_count_error(self):
575+
with pytest.raises(ValueError) as e:
576+
new_worker_node_types = (
577+
tc.ClusterConstants._TEST_CLUSTER_2.worker_node_types
578+
)
579+
vertex_ray.update_ray_cluster(
580+
cluster_resource_name=tc.ClusterConstants._TEST_VERTEX_RAY_PR_ADDRESS,
581+
worker_node_types=new_worker_node_types,
582+
)
583+
584+
e.match(
585+
regexp=r"does not match the number of the existing worker_node_type"
586+
)

0 commit comments

Comments
 (0)