34
34
35
35
36
36
# -*- coding: utf-8 -*-
37
- # TODO(b/328684671)
38
37
_EXPECTED_MASK = field_mask_pb2 .FieldMask (paths = ["resource_pools.replica_count" ])
39
38
40
39
# for manual scaling
@@ -241,6 +240,22 @@ def update_persistent_resource_2_pools_mock():
241
240
yield update_persistent_resource_2_pools_mock
242
241
243
242
243
+ def cluster_eq (returned_cluster , expected_cluster ):
244
+ assert vars (returned_cluster .head_node_type ) == vars (
245
+ expected_cluster .head_node_type
246
+ )
247
+ assert vars (returned_cluster .worker_node_types [0 ]) == vars (
248
+ expected_cluster .worker_node_types [0 ]
249
+ )
250
+ assert (
251
+ returned_cluster .cluster_resource_name == expected_cluster .cluster_resource_name
252
+ )
253
+ assert returned_cluster .python_version == expected_cluster .python_version
254
+ assert returned_cluster .ray_version == expected_cluster .ray_version
255
+ assert returned_cluster .network == expected_cluster .network
256
+ assert returned_cluster .state == expected_cluster .state
257
+
258
+
244
259
@pytest .mark .usefixtures ("google_auth_mock" , "get_project_number_mock" )
245
260
class TestClusterManagement :
246
261
def setup_method (self ):
@@ -315,6 +330,7 @@ def test_create_ray_cluster_1_pool_gpu_with_labels_success(
315
330
network = tc .ProjectConstants .TEST_VPC_NETWORK ,
316
331
cluster_name = tc .ClusterConstants .TEST_VERTEX_RAY_PR_ID ,
317
332
labels = tc .ClusterConstants .TEST_LABELS ,
333
+ enable_metrics_collection = False ,
318
334
)
319
335
320
336
assert tc .ClusterConstants .TEST_VERTEX_RAY_PR_ADDRESS == cluster_name
@@ -465,21 +481,7 @@ def test_get_ray_cluster_success(self, get_persistent_resource_1_pool_mock):
465
481
)
466
482
467
483
get_persistent_resource_1_pool_mock .assert_called_once ()
468
-
469
- assert vars (cluster .head_node_type ) == vars (
470
- tc .ClusterConstants .TEST_CLUSTER .head_node_type
471
- )
472
- assert vars (cluster .worker_node_types [0 ]) == vars (
473
- tc .ClusterConstants .TEST_CLUSTER .worker_node_types [0 ]
474
- )
475
- assert (
476
- cluster .cluster_resource_name
477
- == tc .ClusterConstants .TEST_CLUSTER .cluster_resource_name
478
- )
479
- assert cluster .python_version == tc .ClusterConstants .TEST_CLUSTER .python_version
480
- assert cluster .ray_version == tc .ClusterConstants .TEST_CLUSTER .ray_version
481
- assert cluster .network == tc .ClusterConstants .TEST_CLUSTER .network
482
- assert cluster .state == tc .ClusterConstants .TEST_CLUSTER .state
484
+ cluster_eq (cluster , tc .ClusterConstants .TEST_CLUSTER )
483
485
484
486
def test_get_ray_cluster_with_custom_image_success (
485
487
self , get_persistent_resource_2_pools_custom_image_mock
@@ -489,27 +491,7 @@ def test_get_ray_cluster_with_custom_image_success(
489
491
)
490
492
491
493
get_persistent_resource_2_pools_custom_image_mock .assert_called_once ()
492
-
493
- assert vars (cluster .head_node_type ) == vars (
494
- tc .ClusterConstants .TEST_CLUSTER_CUSTOM_IMAGE .head_node_type
495
- )
496
- assert vars (cluster .worker_node_types [0 ]) == vars (
497
- tc .ClusterConstants .TEST_CLUSTER_CUSTOM_IMAGE .worker_node_types [0 ]
498
- )
499
- assert (
500
- cluster .cluster_resource_name
501
- == tc .ClusterConstants .TEST_CLUSTER_CUSTOM_IMAGE .cluster_resource_name
502
- )
503
- assert (
504
- cluster .python_version
505
- == tc .ClusterConstants .TEST_CLUSTER_CUSTOM_IMAGE .python_version
506
- )
507
- assert (
508
- cluster .ray_version
509
- == tc .ClusterConstants .TEST_CLUSTER_CUSTOM_IMAGE .ray_version
510
- )
511
- assert cluster .network == tc .ClusterConstants .TEST_CLUSTER_CUSTOM_IMAGE .network
512
- assert cluster .state == tc .ClusterConstants .TEST_CLUSTER_CUSTOM_IMAGE .state
494
+ cluster_eq (cluster , tc .ClusterConstants .TEST_CLUSTER_CUSTOM_IMAGE )
513
495
514
496
@pytest .mark .usefixtures ("get_persistent_resource_exception_mock" )
515
497
def test_get_ray_cluster_error (self ):
@@ -526,42 +508,9 @@ def test_list_ray_clusters_success(self, list_persistent_resources_mock):
526
508
list_persistent_resources_mock .assert_called_once ()
527
509
528
510
# first ray cluster
529
- assert vars (clusters [0 ].head_node_type ) == vars (
530
- tc .ClusterConstants .TEST_CLUSTER .head_node_type
531
- )
532
- assert vars (clusters [0 ].worker_node_types [0 ]) == vars (
533
- tc .ClusterConstants .TEST_CLUSTER .worker_node_types [0 ]
534
- )
535
- assert (
536
- clusters [0 ].cluster_resource_name
537
- == tc .ClusterConstants .TEST_CLUSTER .cluster_resource_name
538
- )
539
- assert (
540
- clusters [0 ].python_version
541
- == tc .ClusterConstants .TEST_CLUSTER .python_version
542
- )
543
- assert clusters [0 ].ray_version == tc .ClusterConstants .TEST_CLUSTER .ray_version
544
- assert clusters [0 ].network == tc .ClusterConstants .TEST_CLUSTER .network
545
- assert clusters [0 ].state == tc .ClusterConstants .TEST_CLUSTER .state
546
-
511
+ cluster_eq (clusters [0 ], tc .ClusterConstants .TEST_CLUSTER )
547
512
# second ray cluster
548
- assert vars (clusters [1 ].head_node_type ) == vars (
549
- tc .ClusterConstants .TEST_CLUSTER_2 .head_node_type
550
- )
551
- assert vars (clusters [1 ].worker_node_types [0 ]) == vars (
552
- tc .ClusterConstants .TEST_CLUSTER_2 .worker_node_types [0 ]
553
- )
554
- assert (
555
- clusters [1 ].cluster_resource_name
556
- == tc .ClusterConstants .TEST_CLUSTER_2 .cluster_resource_name
557
- )
558
- assert (
559
- clusters [1 ].python_version
560
- == tc .ClusterConstants .TEST_CLUSTER_2 .python_version
561
- )
562
- assert clusters [1 ].ray_version == tc .ClusterConstants .TEST_CLUSTER_2 .ray_version
563
- assert clusters [1 ].network == tc .ClusterConstants .TEST_CLUSTER_2 .network
564
- assert clusters [1 ].state == tc .ClusterConstants .TEST_CLUSTER_2 .state
513
+ cluster_eq (clusters [1 ], tc .ClusterConstants .TEST_CLUSTER_2 )
565
514
566
515
def test_list_ray_clusters_initialized_success (
567
516
self , get_project_number_mock , list_persistent_resources_mock
0 commit comments