Skip to content

Commit accaa97

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: support PSC-Interface in Ray on Vertex
feat: support disable Cloud logging in Ray on Vertex PiperOrigin-RevId: 661019434
1 parent a521ba6 commit accaa97

File tree

7 files changed

+129
-30
lines changed

7 files changed

+129
-30
lines changed

google/cloud/aiplatform/vertex_ray/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from google.cloud.aiplatform.vertex_ray.util.resources import (
3939
Resources,
4040
NodeImages,
41+
PscIConfig,
4142
)
4243

4344
from google.cloud.aiplatform.vertex_ray.dashboard_sdk import (
@@ -61,4 +62,5 @@
6162
"update_ray_cluster",
6263
"Resources",
6364
"NodeImages",
65+
"PscIConfig",
6466
)

google/cloud/aiplatform/vertex_ray/cluster_init.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,20 @@
2323
from google.cloud.aiplatform import initializer
2424
from google.cloud.aiplatform import utils
2525
from google.cloud.aiplatform.utils import resource_manager_utils
26-
from google.cloud.aiplatform_v1.types import persistent_resource_service
26+
from google.cloud.aiplatform_v1beta1.types import persistent_resource_service
2727

28-
from google.cloud.aiplatform_v1.types.persistent_resource import (
28+
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
2929
PersistentResource,
30+
RayLogsSpec,
3031
RaySpec,
3132
RayMetricSpec,
3233
ResourcePool,
3334
ResourceRuntimeSpec,
3435
ServiceAccountSpec,
3536
)
36-
37+
from google.cloud.aiplatform_v1beta1.types.service_networking import (
38+
PscInterfaceConfig,
39+
)
3740
from google.cloud.aiplatform.vertex_ray.util import (
3841
_gapic_utils,
3942
_validation_utils,
@@ -56,6 +59,8 @@ def create_ray_cluster(
5659
worker_node_types: Optional[List[resources.Resources]] = [resources.Resources()],
5760
custom_images: Optional[resources.NodeImages] = None,
5861
enable_metrics_collection: Optional[bool] = True,
62+
enable_logging: Optional[bool] = True,
63+
psc_interface_config: Optional[resources.PscIConfig] = None,
5964
labels: Optional[Dict[str, str]] = None,
6065
) -> str:
6166
"""Create a ray cluster on the Vertex AI.
@@ -119,6 +124,8 @@ def create_ray_cluster(
119124
head/worker_node_type(s). Note that configuring `Resources.custom_image`
120125
will override `custom_images` here. Allowlist only.
121126
enable_metrics_collection: Enable Ray metrics collection for visualization.
127+
enable_logging: Enable exporting Ray logs to Cloud Logging.
128+
psc_interface_config: PSC-I config.
122129
labels:
123130
The labels with user-defined metadata to organize Ray cluster.
124131
@@ -258,10 +265,17 @@ def create_ray_cluster(
258265
i += 1
259266

260267
resource_pools = [resource_pool_0] + worker_pools
261-
disabled = not enable_metrics_collection
262-
ray_metric_spec = RayMetricSpec(disabled=disabled)
268+
269+
metrics_collection_disabled = not enable_metrics_collection
270+
ray_metric_spec = RayMetricSpec(disabled=metrics_collection_disabled)
271+
272+
logging_disabled = not enable_logging
273+
ray_logs_spec = RayLogsSpec(disabled=logging_disabled)
274+
263275
ray_spec = RaySpec(
264-
resource_pool_images=resource_pool_images, ray_metric_spec=ray_metric_spec
276+
resource_pool_images=resource_pool_images,
277+
ray_metric_spec=ray_metric_spec,
278+
ray_logs_spec=ray_logs_spec,
265279
)
266280
if service_account:
267281
service_account_spec = ServiceAccountSpec(
@@ -274,11 +288,18 @@ def create_ray_cluster(
274288
)
275289
else:
276290
resource_runtime_spec = ResourceRuntimeSpec(ray_spec=ray_spec)
291+
if psc_interface_config:
292+
gapic_psc_interface_config = PscInterfaceConfig(
293+
network_attachment=psc_interface_config.network_attachment,
294+
)
295+
else:
296+
gapic_psc_interface_config = None
277297
persistent_resource = PersistentResource(
278298
resource_pools=resource_pools,
279299
network=network,
280300
labels=labels,
281301
resource_runtime_spec=resource_runtime_spec,
302+
psc_interface_config=gapic_psc_interface_config,
282303
)
283304

284305
location = initializer.global_config.location

google/cloud/aiplatform/vertex_ray/util/_gapic_utils.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,13 @@
2828
from google.cloud.aiplatform.vertex_ray.util import _validation_utils
2929
from google.cloud.aiplatform.vertex_ray.util.resources import (
3030
Cluster,
31+
PscIConfig,
3132
Resources,
3233
)
33-
from google.cloud.aiplatform_v1.types.persistent_resource import (
34+
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
3435
PersistentResource,
3536
)
36-
from google.cloud.aiplatform_v1.types.persistent_resource_service import (
37+
from google.cloud.aiplatform_v1beta1.types.persistent_resource_service import (
3738
GetPersistentResourceRequest,
3839
)
3940

@@ -47,7 +48,7 @@ def create_persistent_resource_client():
4748
return initializer.global_config.create_client(
4849
client_class=PersistentResourceClientWithOverride,
4950
appended_gapic_version="vertex_ray",
50-
)
51+
).select_version("v1beta1")
5152

5253

5354
def polling_delay(num_attempts: int, time_scale: float) -> datetime.timedelta:
@@ -159,6 +160,10 @@ def persistent_resource_to_cluster(
159160
% persistent_resource.name,
160161
)
161162
return
163+
if persistent_resource.psc_interface_config:
164+
cluster.psc_interface_config = PscIConfig(
165+
network_attachment=persistent_resource.psc_interface_config.network_attachment
166+
)
162167
resource_pools = persistent_resource.resource_pools
163168

164169
head_resource_pool = resource_pools[0]
@@ -192,6 +197,12 @@ def persistent_resource_to_cluster(
192197
ray_version = None
193198
cluster.python_version = python_version
194199
cluster.ray_version = ray_version
200+
cluster.ray_metric_enabled = not (
201+
persistent_resource.resource_runtime_spec.ray_spec.ray_metric_spec.disabled
202+
)
203+
cluster.ray_logs_enabled = not (
204+
persistent_resource.resource_runtime_spec.ray_spec.ray_logs_spec.disabled
205+
)
195206

196207
accelerator_type = head_resource_pool.machine_spec.accelerator_type
197208
if accelerator_type.value != 0:

google/cloud/aiplatform/vertex_ray/util/resources.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#
1717
import dataclasses
1818
from typing import Dict, List, Optional
19-
from google.cloud.aiplatform_v1.types import PersistentResource
19+
from google.cloud.aiplatform_v1beta1.types import PersistentResource
2020

2121

2222
@dataclasses.dataclass
@@ -68,6 +68,27 @@ class NodeImages:
6868
worker: str = None
6969

7070

71+
@dataclasses.dataclass
72+
class PscIConfig:
73+
"""PSC-I config.
74+
75+
Attributes:
76+
network_attachment: Optional. The name or full name of the Compute Engine
77+
`network attachment <https://cloud.google.com/vpc/docs/about-network-attachments>`
78+
to attach to the resource. It has a format:
79+
``projects/{project}/regions/{region}/networkAttachments/{networkAttachment}``.
80+
Where {project} is a project number, as in ``12345``, and
81+
{networkAttachment} is a network attachment name. To specify
82+
this field, you must have already [created a network
83+
attachment]
84+
(https://cloud.google.com/vpc/docs/create-manage-network-attachments#create-network-attachments).
85+
This field is only used for resources using PSC-I. Make sure you do not
86+
specify the network here for VPC peering.
87+
"""
88+
89+
network_attachment: str = None
90+
91+
7192
@dataclasses.dataclass
7293
class Cluster:
7394
"""Ray cluster (output only).
@@ -111,6 +132,9 @@ class Cluster:
111132
head_node_type: Resources = None
112133
worker_node_types: List[Resources] = None
113134
dashboard_address: str = None
135+
ray_metric_enabled: bool = True
136+
ray_logs_enabled: bool = True
137+
psc_interface_config: PscIConfig = None
114138
labels: Dict[str, str] = None
115139

116140

tests/unit/vertex_ray/conftest.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@
1919
from google.auth import credentials as auth_credentials
2020
from google.cloud import resourcemanager
2121
from google.cloud.aiplatform import vertex_ray
22-
from google.cloud.aiplatform_v1.services.persistent_resource_service import (
22+
from google.cloud.aiplatform_v1beta1.services.persistent_resource_service import (
2323
PersistentResourceServiceClient,
2424
)
25-
from google.cloud.aiplatform_v1.types.persistent_resource import (
25+
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
2626
PersistentResource,
2727
)
28-
from google.cloud.aiplatform_v1.types.persistent_resource import (
28+
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
2929
ResourceRuntime,
3030
)
31-
from google.cloud.aiplatform_v1.types.persistent_resource_service import (
31+
from google.cloud.aiplatform_v1beta1.types.persistent_resource_service import (
3232
DeletePersistentResourceRequest,
3333
)
3434
import test_constants as tc

tests/unit/vertex_ray/test_cluster_init.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
Resources,
2323
NodeImages,
2424
)
25-
from google.cloud.aiplatform_v1.services.persistent_resource_service import (
25+
from google.cloud.aiplatform_v1beta1.services.persistent_resource_service import (
2626
PersistentResourceServiceClient,
2727
)
28-
from google.cloud.aiplatform_v1.types import persistent_resource_service
28+
from google.cloud.aiplatform_v1beta1.types import persistent_resource_service
2929
import test_constants as tc
3030
import mock
3131
import pytest
@@ -352,13 +352,15 @@ def test_create_ray_cluster_1_pool_gpu_with_labels_success(
352352
self, create_persistent_resource_1_pool_mock
353353
):
354354
"""If head and worker nodes are duplicate, merge to head pool."""
355+
# Also test disable logging and metrics collection.
355356
cluster_name = vertex_ray.create_ray_cluster(
356357
head_node_type=tc.ClusterConstants.TEST_HEAD_NODE_TYPE_1_POOL,
357358
worker_node_types=tc.ClusterConstants.TEST_WORKER_NODE_TYPES_1_POOL,
358359
network=tc.ProjectConstants.TEST_VPC_NETWORK,
359360
cluster_name=tc.ClusterConstants.TEST_VERTEX_RAY_PR_ID,
360361
labels=tc.ClusterConstants.TEST_LABELS,
361362
enable_metrics_collection=False,
363+
enable_logging=False,
362364
)
363365

364366
assert tc.ClusterConstants.TEST_VERTEX_RAY_PR_ADDRESS == cluster_name
@@ -401,11 +403,15 @@ def test_create_ray_cluster_2_pools_success(
401403
self, create_persistent_resource_2_pools_mock
402404
):
403405
"""If head and worker nodes are not duplicate, create separate resource_pools."""
406+
# Also test PSC-I.
407+
psc_interface_config = vertex_ray.PscIConfig(
408+
network_attachment=tc.ClusterConstants.TEST_PSC_NETWORK_ATTACHMENT
409+
)
404410
cluster_name = vertex_ray.create_ray_cluster(
405411
head_node_type=tc.ClusterConstants.TEST_HEAD_NODE_TYPE_2_POOLS,
406412
worker_node_types=tc.ClusterConstants.TEST_WORKER_NODE_TYPES_2_POOLS,
407-
network=tc.ProjectConstants.TEST_VPC_NETWORK,
408413
cluster_name=tc.ClusterConstants.TEST_VERTEX_RAY_PR_ID,
414+
psc_interface_config=psc_interface_config,
409415
)
410416

411417
assert tc.ClusterConstants.TEST_VERTEX_RAY_PR_ADDRESS == cluster_name

0 commit comments

Comments
 (0)