46
46
)
47
47
_TEST_RESPONSE_RUNNING_2_POOLS_RESIZE .resource_pools [1 ].replica_count = 1
48
48
49
+ _TEST_RESPONSE_RUNNING_1_POOL_RESIZE_0_WORKER = copy .deepcopy (
50
+ tc .ClusterConstants ._TEST_RESPONSE_RUNNING_1_POOL
51
+ )
52
+ _TEST_RESPONSE_RUNNING_1_POOL_RESIZE_0_WORKER .resource_pools [0 ].replica_count = 1
53
+
49
54
50
55
@pytest .fixture
51
56
def create_persistent_resource_1_pool_mock ():
@@ -163,6 +168,22 @@ def update_persistent_resource_1_pool_mock():
163
168
yield update_persistent_resource_1_pool_mock
164
169
165
170
171
+ @pytest .fixture
172
+ def update_persistent_resource_1_pool_0_worker_mock ():
173
+ with mock .patch .object (
174
+ PersistentResourceServiceClient ,
175
+ "update_persistent_resource" ,
176
+ ) as update_persistent_resource_1_pool_0_worker_mock :
177
+ update_persistent_resource_lro_mock = mock .Mock (ga_operation .Operation )
178
+ update_persistent_resource_lro_mock .result .return_value = (
179
+ _TEST_RESPONSE_RUNNING_1_POOL_RESIZE_0_WORKER
180
+ )
181
+ update_persistent_resource_1_pool_0_worker_mock .return_value = (
182
+ update_persistent_resource_lro_mock
183
+ )
184
+ yield update_persistent_resource_1_pool_0_worker_mock
185
+
186
+
166
187
@pytest .fixture
167
188
def update_persistent_resource_2_pools_mock ():
168
189
with mock .patch .object (
@@ -472,6 +493,30 @@ def test_update_ray_cluster_1_pool(self, update_persistent_resource_1_pool_mock)
472
493
473
494
assert returned_name == tc .ClusterConstants ._TEST_VERTEX_RAY_PR_ADDRESS
474
495
496
+ @pytest .mark .usefixtures ("get_persistent_resource_1_pool_mock" )
497
+ def test_update_ray_cluster_1_pool_to_0_worker (
498
+ self , update_persistent_resource_1_pool_mock
499
+ ):
500
+
501
+ new_worker_node_types = []
502
+ for worker_node_type in tc .ClusterConstants ._TEST_CLUSTER .worker_node_types :
503
+ # resize worker node to node_count = 0
504
+ worker_node_type .node_count = 0
505
+ new_worker_node_types .append (worker_node_type )
506
+
507
+ returned_name = vertex_ray .update_ray_cluster (
508
+ cluster_resource_name = tc .ClusterConstants ._TEST_VERTEX_RAY_PR_ADDRESS ,
509
+ worker_node_types = new_worker_node_types ,
510
+ )
511
+
512
+ request = persistent_resource_service .UpdatePersistentResourceRequest (
513
+ persistent_resource = _TEST_RESPONSE_RUNNING_1_POOL_RESIZE_0_WORKER ,
514
+ update_mask = _EXPECTED_MASK ,
515
+ )
516
+ update_persistent_resource_1_pool_mock .assert_called_once_with (request )
517
+
518
+ assert returned_name == tc .ClusterConstants ._TEST_VERTEX_RAY_PR_ADDRESS
519
+
475
520
@pytest .mark .usefixtures ("get_persistent_resource_2_pools_mock" )
476
521
def test_update_ray_cluster_2_pools (self , update_persistent_resource_2_pools_mock ):
477
522
@@ -493,3 +538,49 @@ def test_update_ray_cluster_2_pools(self, update_persistent_resource_2_pools_moc
493
538
update_persistent_resource_2_pools_mock .assert_called_once_with (request )
494
539
495
540
assert returned_name == tc .ClusterConstants ._TEST_VERTEX_RAY_PR_ADDRESS
541
+
542
+ @pytest .mark .usefixtures ("get_persistent_resource_2_pools_mock" )
543
+ def test_update_ray_cluster_2_pools_0_worker_fail (self ):
544
+
545
+ new_worker_node_types = []
546
+ for worker_node_type in tc .ClusterConstants ._TEST_CLUSTER_2 .worker_node_types :
547
+ # resize worker node to node_count = 0
548
+ worker_node_type .node_count = 0
549
+ new_worker_node_types .append (worker_node_type )
550
+
551
+ with pytest .raises (ValueError ) as e :
552
+ vertex_ray .update_ray_cluster (
553
+ cluster_resource_name = tc .ClusterConstants ._TEST_VERTEX_RAY_PR_ADDRESS ,
554
+ worker_node_types = new_worker_node_types ,
555
+ )
556
+
557
+ e .match (regexp = r"must update to >= 1 nodes." )
558
+
559
+ @pytest .mark .usefixtures ("get_persistent_resource_1_pool_mock" )
560
+ def test_update_ray_cluster_duplicate_worker_node_types_error (self ):
561
+ new_worker_node_types = (
562
+ tc .ClusterConstants ._TEST_CLUSTER_2 .worker_node_types
563
+ + tc .ClusterConstants ._TEST_CLUSTER_2 .worker_node_types
564
+ )
565
+ with pytest .raises (ValueError ) as e :
566
+ vertex_ray .update_ray_cluster (
567
+ cluster_resource_name = tc .ClusterConstants ._TEST_VERTEX_RAY_PR_ADDRESS ,
568
+ worker_node_types = new_worker_node_types ,
569
+ )
570
+
571
+ e .match (regexp = r"Worker_node_types have duplicate machine specs" )
572
+
573
+ @pytest .mark .usefixtures ("get_persistent_resource_1_pool_mock" )
574
+ def test_update_ray_cluster_mismatch_worker_node_types_count_error (self ):
575
+ with pytest .raises (ValueError ) as e :
576
+ new_worker_node_types = (
577
+ tc .ClusterConstants ._TEST_CLUSTER_2 .worker_node_types
578
+ )
579
+ vertex_ray .update_ray_cluster (
580
+ cluster_resource_name = tc .ClusterConstants ._TEST_VERTEX_RAY_PR_ADDRESS ,
581
+ worker_node_types = new_worker_node_types ,
582
+ )
583
+
584
+ e .match (
585
+ regexp = r"does not match the number of the existing worker_node_type"
586
+ )
0 commit comments