Skip to content

Commit e33d11f

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: Add an arg to turn off Ray metrics collection during cluster creation
PiperOrigin-RevId: 617612703
1 parent e51c977 commit e33d11f

File tree

3 files changed

+60
-83
lines changed

3 files changed

+60
-83
lines changed

google/cloud/aiplatform/preview/vertex_ray/cluster_init.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
2929
PersistentResource,
3030
RaySpec,
31+
RayMetricSpec,
3132
ResourcePool,
3233
ResourceRuntimeSpec,
3334
)
@@ -49,6 +50,7 @@ def create_ray_cluster(
4950
cluster_name: Optional[str] = None,
5051
worker_node_types: Optional[List[resources.Resources]] = None,
5152
custom_images: Optional[resources.NodeImages] = None,
53+
enable_metrics_collection: Optional[bool] = True,
5254
labels: Optional[Dict[str, str]] = None,
5355
) -> str:
5456
"""Create a ray cluster on the Vertex AI.
@@ -107,6 +109,7 @@ def create_ray_cluster(
107109
has a specific custom image, use `Resources.custom_image` for
108110
head/worker_node_type(s). Note that configuring `Resources.custom_image`
109111
will override `custom_images` here. Allowlist only.
112+
enable_metrics_collection: Enable Ray metrics collection for visualization.
110113
labels:
111114
The labels with user-defined metadata to organize Ray cluster.
112115
@@ -244,8 +247,11 @@ def create_ray_cluster(
244247
i += 1
245248

246249
resource_pools = [resource_pool_0] + worker_pools
247-
248-
ray_spec = RaySpec(resource_pool_images=resource_pool_images)
250+
disabled = not enable_metrics_collection
251+
ray_metric_spec = RayMetricSpec(disabled=disabled)
252+
ray_spec = RaySpec(
253+
resource_pool_images=resource_pool_images, ray_metric_spec=ray_metric_spec
254+
)
249255
resource_runtime_spec = ResourceRuntimeSpec(ray_spec=ray_spec)
250256
persistent_resource = PersistentResource(
251257
resource_pools=resource_pools,

tests/unit/vertex_ray/test_cluster_init.py

+21-72
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434

3535

3636
# -*- coding: utf-8 -*-
37-
# TODO(b/328684671)
3837
_EXPECTED_MASK = field_mask_pb2.FieldMask(paths=["resource_pools.replica_count"])
3938

4039
# for manual scaling
@@ -241,6 +240,22 @@ def update_persistent_resource_2_pools_mock():
241240
yield update_persistent_resource_2_pools_mock
242241

243242

243+
def cluster_eq(returned_cluster, expected_cluster):
244+
assert vars(returned_cluster.head_node_type) == vars(
245+
expected_cluster.head_node_type
246+
)
247+
assert vars(returned_cluster.worker_node_types[0]) == vars(
248+
expected_cluster.worker_node_types[0]
249+
)
250+
assert (
251+
returned_cluster.cluster_resource_name == expected_cluster.cluster_resource_name
252+
)
253+
assert returned_cluster.python_version == expected_cluster.python_version
254+
assert returned_cluster.ray_version == expected_cluster.ray_version
255+
assert returned_cluster.network == expected_cluster.network
256+
assert returned_cluster.state == expected_cluster.state
257+
258+
244259
@pytest.mark.usefixtures("google_auth_mock", "get_project_number_mock")
245260
class TestClusterManagement:
246261
def setup_method(self):
@@ -315,6 +330,7 @@ def test_create_ray_cluster_1_pool_gpu_with_labels_success(
315330
network=tc.ProjectConstants.TEST_VPC_NETWORK,
316331
cluster_name=tc.ClusterConstants.TEST_VERTEX_RAY_PR_ID,
317332
labels=tc.ClusterConstants.TEST_LABELS,
333+
enable_metrics_collection=False,
318334
)
319335

320336
assert tc.ClusterConstants.TEST_VERTEX_RAY_PR_ADDRESS == cluster_name
@@ -465,21 +481,7 @@ def test_get_ray_cluster_success(self, get_persistent_resource_1_pool_mock):
465481
)
466482

467483
get_persistent_resource_1_pool_mock.assert_called_once()
468-
469-
assert vars(cluster.head_node_type) == vars(
470-
tc.ClusterConstants.TEST_CLUSTER.head_node_type
471-
)
472-
assert vars(cluster.worker_node_types[0]) == vars(
473-
tc.ClusterConstants.TEST_CLUSTER.worker_node_types[0]
474-
)
475-
assert (
476-
cluster.cluster_resource_name
477-
== tc.ClusterConstants.TEST_CLUSTER.cluster_resource_name
478-
)
479-
assert cluster.python_version == tc.ClusterConstants.TEST_CLUSTER.python_version
480-
assert cluster.ray_version == tc.ClusterConstants.TEST_CLUSTER.ray_version
481-
assert cluster.network == tc.ClusterConstants.TEST_CLUSTER.network
482-
assert cluster.state == tc.ClusterConstants.TEST_CLUSTER.state
484+
cluster_eq(cluster, tc.ClusterConstants.TEST_CLUSTER)
483485

484486
def test_get_ray_cluster_with_custom_image_success(
485487
self, get_persistent_resource_2_pools_custom_image_mock
@@ -489,27 +491,7 @@ def test_get_ray_cluster_with_custom_image_success(
489491
)
490492

491493
get_persistent_resource_2_pools_custom_image_mock.assert_called_once()
492-
493-
assert vars(cluster.head_node_type) == vars(
494-
tc.ClusterConstants.TEST_CLUSTER_CUSTOM_IMAGE.head_node_type
495-
)
496-
assert vars(cluster.worker_node_types[0]) == vars(
497-
tc.ClusterConstants.TEST_CLUSTER_CUSTOM_IMAGE.worker_node_types[0]
498-
)
499-
assert (
500-
cluster.cluster_resource_name
501-
== tc.ClusterConstants.TEST_CLUSTER_CUSTOM_IMAGE.cluster_resource_name
502-
)
503-
assert (
504-
cluster.python_version
505-
== tc.ClusterConstants.TEST_CLUSTER_CUSTOM_IMAGE.python_version
506-
)
507-
assert (
508-
cluster.ray_version
509-
== tc.ClusterConstants.TEST_CLUSTER_CUSTOM_IMAGE.ray_version
510-
)
511-
assert cluster.network == tc.ClusterConstants.TEST_CLUSTER_CUSTOM_IMAGE.network
512-
assert cluster.state == tc.ClusterConstants.TEST_CLUSTER_CUSTOM_IMAGE.state
494+
cluster_eq(cluster, tc.ClusterConstants.TEST_CLUSTER_CUSTOM_IMAGE)
513495

514496
@pytest.mark.usefixtures("get_persistent_resource_exception_mock")
515497
def test_get_ray_cluster_error(self):
@@ -526,42 +508,9 @@ def test_list_ray_clusters_success(self, list_persistent_resources_mock):
526508
list_persistent_resources_mock.assert_called_once()
527509

528510
# first ray cluster
529-
assert vars(clusters[0].head_node_type) == vars(
530-
tc.ClusterConstants.TEST_CLUSTER.head_node_type
531-
)
532-
assert vars(clusters[0].worker_node_types[0]) == vars(
533-
tc.ClusterConstants.TEST_CLUSTER.worker_node_types[0]
534-
)
535-
assert (
536-
clusters[0].cluster_resource_name
537-
== tc.ClusterConstants.TEST_CLUSTER.cluster_resource_name
538-
)
539-
assert (
540-
clusters[0].python_version
541-
== tc.ClusterConstants.TEST_CLUSTER.python_version
542-
)
543-
assert clusters[0].ray_version == tc.ClusterConstants.TEST_CLUSTER.ray_version
544-
assert clusters[0].network == tc.ClusterConstants.TEST_CLUSTER.network
545-
assert clusters[0].state == tc.ClusterConstants.TEST_CLUSTER.state
546-
511+
cluster_eq(clusters[0], tc.ClusterConstants.TEST_CLUSTER)
547512
# second ray cluster
548-
assert vars(clusters[1].head_node_type) == vars(
549-
tc.ClusterConstants.TEST_CLUSTER_2.head_node_type
550-
)
551-
assert vars(clusters[1].worker_node_types[0]) == vars(
552-
tc.ClusterConstants.TEST_CLUSTER_2.worker_node_types[0]
553-
)
554-
assert (
555-
clusters[1].cluster_resource_name
556-
== tc.ClusterConstants.TEST_CLUSTER_2.cluster_resource_name
557-
)
558-
assert (
559-
clusters[1].python_version
560-
== tc.ClusterConstants.TEST_CLUSTER_2.python_version
561-
)
562-
assert clusters[1].ray_version == tc.ClusterConstants.TEST_CLUSTER_2.ray_version
563-
assert clusters[1].network == tc.ClusterConstants.TEST_CLUSTER_2.network
564-
assert clusters[1].state == tc.ClusterConstants.TEST_CLUSTER_2.state
513+
cluster_eq(clusters[1], tc.ClusterConstants.TEST_CLUSTER_2)
565514

566515
def test_list_ray_clusters_initialized_success(
567516
self, get_project_number_mock, list_persistent_resources_mock

tests/unit/vertex_ray/test_constants.py

+31-9
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
PersistentResource,
3030
)
3131
from google.cloud.aiplatform_v1beta1.types.persistent_resource import RaySpec
32+
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
33+
RayMetricSpec,
34+
)
3235
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
3336
ResourcePool,
3437
)
@@ -116,22 +119,31 @@ class ClusterConstants:
116119
TEST_REQUEST_RUNNING_1_POOL = PersistentResource(
117120
resource_pools=[TEST_RESOURCE_POOL_0],
118121
resource_runtime_spec=ResourceRuntimeSpec(
119-
ray_spec=RaySpec(resource_pool_images={"head-node": TEST_GPU_IMAGE}),
122+
ray_spec=RaySpec(
123+
resource_pool_images={"head-node": TEST_GPU_IMAGE},
124+
ray_metric_spec=RayMetricSpec(disabled=False),
125+
),
120126
),
121127
network=ProjectConstants.TEST_VPC_NETWORK,
122128
)
123129
TEST_REQUEST_RUNNING_1_POOL_WITH_LABELS = PersistentResource(
124130
resource_pools=[TEST_RESOURCE_POOL_0],
125131
resource_runtime_spec=ResourceRuntimeSpec(
126-
ray_spec=RaySpec(resource_pool_images={"head-node": TEST_GPU_IMAGE}),
132+
ray_spec=RaySpec(
133+
resource_pool_images={"head-node": TEST_GPU_IMAGE},
134+
ray_metric_spec=RayMetricSpec(disabled=True),
135+
),
127136
),
128137
network=ProjectConstants.TEST_VPC_NETWORK,
129138
labels=TEST_LABELS,
130139
)
131140
TEST_REQUEST_RUNNING_1_POOL_CUSTOM_IMAGES = PersistentResource(
132141
resource_pools=[TEST_RESOURCE_POOL_0],
133142
resource_runtime_spec=ResourceRuntimeSpec(
134-
ray_spec=RaySpec(resource_pool_images={"head-node": TEST_CUSTOM_IMAGE}),
143+
ray_spec=RaySpec(
144+
resource_pool_images={"head-node": TEST_CUSTOM_IMAGE},
145+
ray_metric_spec=RayMetricSpec(disabled=False),
146+
),
135147
),
136148
network=ProjectConstants.TEST_VPC_NETWORK,
137149
)
@@ -140,7 +152,10 @@ class ClusterConstants:
140152
name=TEST_VERTEX_RAY_PR_ADDRESS,
141153
resource_pools=[TEST_RESOURCE_POOL_0],
142154
resource_runtime_spec=ResourceRuntimeSpec(
143-
ray_spec=RaySpec(resource_pool_images={"head-node": TEST_GPU_IMAGE}),
155+
ray_spec=RaySpec(
156+
resource_pool_images={"head-node": TEST_GPU_IMAGE},
157+
ray_metric_spec=RayMetricSpec(disabled=False),
158+
),
144159
),
145160
network=ProjectConstants.TEST_VPC_NETWORK,
146161
resource_runtime=ResourceRuntime(
@@ -156,7 +171,10 @@ class ClusterConstants:
156171
name=TEST_VERTEX_RAY_PR_ADDRESS,
157172
resource_pools=[TEST_RESOURCE_POOL_0],
158173
resource_runtime_spec=ResourceRuntimeSpec(
159-
ray_spec=RaySpec(resource_pool_images={"head-node": TEST_CUSTOM_IMAGE}),
174+
ray_spec=RaySpec(
175+
resource_pool_images={"head-node": TEST_CUSTOM_IMAGE},
176+
ray_metric_spec=RayMetricSpec(disabled=False),
177+
),
160178
),
161179
network=ProjectConstants.TEST_VPC_NETWORK,
162180
resource_runtime=ResourceRuntime(
@@ -218,7 +236,8 @@ class ClusterConstants:
218236
resource_pool_images={
219237
"head-node": TEST_CPU_IMAGE,
220238
"worker-pool1": TEST_GPU_IMAGE,
221-
}
239+
},
240+
ray_metric_spec=RayMetricSpec(disabled=False),
222241
),
223242
),
224243
network=ProjectConstants.TEST_VPC_NETWORK,
@@ -230,7 +249,8 @@ class ClusterConstants:
230249
resource_pool_images={
231250
"head-node": TEST_CUSTOM_IMAGE,
232251
"worker-pool1": TEST_CUSTOM_IMAGE,
233-
}
252+
},
253+
ray_metric_spec=RayMetricSpec(disabled=False),
234254
),
235255
),
236256
network=ProjectConstants.TEST_VPC_NETWORK,
@@ -243,7 +263,8 @@ class ClusterConstants:
243263
resource_pool_images={
244264
"head-node": TEST_CPU_IMAGE,
245265
"worker-pool1": TEST_GPU_IMAGE,
246-
}
266+
},
267+
ray_metric_spec=RayMetricSpec(disabled=False),
247268
),
248269
),
249270
network=ProjectConstants.TEST_VPC_NETWORK,
@@ -263,7 +284,8 @@ class ClusterConstants:
263284
resource_pool_images={
264285
"head-node": TEST_CUSTOM_IMAGE,
265286
"worker-pool1": TEST_CUSTOM_IMAGE,
266-
}
287+
},
288+
ray_metric_spec=RayMetricSpec(disabled=False),
267289
),
268290
),
269291
network=ProjectConstants.TEST_VPC_NETWORK,

0 commit comments

Comments
 (0)