Skip to content

Commit e0c6227

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: Support custom service account for Ray cluster creation and Ray Client connection
PiperOrigin-RevId: 631998839
1 parent cc8bc96 commit e0c6227

File tree

7 files changed

+248
-19
lines changed

7 files changed

+248
-19
lines changed

google/cloud/aiplatform/preview/vertex_ray/client_builder.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,21 @@ def __init__(self, address: Optional[str]) -> None:
9898
public_address = self.response.resource_runtime.access_uris.get(
9999
"RAY_CLIENT_ENDPOINT"
100100
)
101+
service_account = (
102+
self.response.resource_runtime_spec.service_account_spec.service_account
103+
)
104+
101105
if public_address is None:
102106
address = private_address
107+
if service_account:
108+
raise ValueError(
109+
"[Ray on Vertex AI]: Ray Cluster ",
110+
address,
111+
" failed to start Head node properly because custom service"
112+
" account isn't supported in peered VPC network. Use public"
113+
" endpoint instead (createa a cluster withought specifying"
114+
" VPC network).",
115+
)
103116
else:
104117
address = public_address
105118

@@ -110,17 +123,7 @@ def __init__(self, address: Optional[str]) -> None:
110123
persistent_resource_id,
111124
" Head node is not reachable. Please ensure that a valid VPC network has been specified.",
112125
)
113-
# Handling service_account
114-
service_account = (
115-
self.response.resource_runtime_spec.service_account_spec.service_account
116-
)
117126

118-
if service_account:
119-
raise ValueError(
120-
"[Ray on Vertex AI]: Ray Cluster ",
121-
address,
122-
" failed to start Head node properly because custom service account isn't supported.",
123-
)
124127
logging.debug("[Ray on Vertex AI]: Resolved head node ip: %s", address)
125128
cluster = _gapic_utils.persistent_resource_to_cluster(
126129
persistent_resource=self.response

google/cloud/aiplatform/preview/vertex_ray/cluster_init.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
RayMetricSpec,
3333
ResourcePool,
3434
ResourceRuntimeSpec,
35+
ServiceAccountSpec,
3536
)
3637

3738
from google.cloud.aiplatform.preview.vertex_ray.util import (
@@ -48,6 +49,7 @@ def create_ray_cluster(
4849
python_version: Optional[str] = "3.10",
4950
ray_version: Optional[str] = "2.9",
5051
network: Optional[str] = None,
52+
service_account: Optional[str] = None,
5153
cluster_name: Optional[str] = None,
5254
worker_node_types: Optional[List[resources.Resources]] = None,
5355
custom_images: Optional[resources.NodeImages] = None,
@@ -78,7 +80,9 @@ def create_ray_cluster(
7880
7981
cluster_resource_name = vertex_ray.create_ray_cluster(
8082
head_node_type=head_node_type,
81-
network="projects/my-project-number/global/networks/my-vpc-name",
83+
network="projects/my-project-number/global/networks/my-vpc-name", # Optional
84+
service_account="[email protected]", # Optional
85+
cluster_name="my-cluster-name", # Optional
8286
worker_node_types=worker_node_types,
8387
ray_version="2.9",
8488
)
@@ -100,6 +104,8 @@ def create_ray_cluster(
100104
Vertex API service. For Ray Job API, VPC network is not required
101105
because Ray Cluster connection can be accessed through dashboard
102106
address.
107+
service_account: Service account to be used for running Ray programs on
108+
the cluster.
103109
cluster_name: This value may be up to 63 characters, and valid
104110
characters are `[a-z0-9_-]`. The first character cannot be a number
105111
or hyphen.
@@ -254,7 +260,17 @@ def create_ray_cluster(
254260
ray_spec = RaySpec(
255261
resource_pool_images=resource_pool_images, ray_metric_spec=ray_metric_spec
256262
)
257-
resource_runtime_spec = ResourceRuntimeSpec(ray_spec=ray_spec)
263+
if service_account:
264+
service_account_spec = ServiceAccountSpec(
265+
enable_custom_service_account=True,
266+
service_account=service_account,
267+
)
268+
resource_runtime_spec = ResourceRuntimeSpec(
269+
ray_spec=ray_spec,
270+
service_account_spec=service_account_spec,
271+
)
272+
else:
273+
resource_runtime_spec = ResourceRuntimeSpec(ray_spec=ray_spec)
258274
persistent_resource = PersistentResource(
259275
resource_pools=resource_pools,
260276
network=network,

google/cloud/aiplatform/preview/vertex_ray/util/_gapic_utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,10 @@ def persistent_resource_to_cluster(
166166
head_image_uri = (
167167
persistent_resource.resource_runtime_spec.ray_spec.resource_pool_images[head_id]
168168
)
169-
169+
if persistent_resource.resource_runtime_spec.service_account_spec.service_account:
170+
cluster.service_account = (
171+
persistent_resource.resource_runtime_spec.service_account_spec.service_account
172+
)
170173
if not head_image_uri:
171174
head_image_uri = persistent_resource.resource_runtime_spec.ray_spec.image_uri
172175

google/cloud/aiplatform/preview/vertex_ray/util/resources.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class Resources:
4141
us-docker.pkg.dev/my-project/ray-gpu.2-9.py310-tf:latest).
4242
"""
4343

44-
machine_type: Optional[str] = "n1-standard-8"
44+
machine_type: Optional[str] = "n1-standard-16"
4545
node_count: Optional[int] = 1
4646
accelerator_type: Optional[str] = None
4747
accelerator_count: Optional[int] = 0
@@ -81,6 +81,8 @@ class Cluster:
8181
managed in the Vertex API service. For Ray Job API, VPC network is
8282
not required because cluster connection can be accessed through
8383
dashboard address.
84+
service_account: Service account to be used for running Ray programs on
85+
the cluster.
8486
state: Describes the cluster state (defined in PersistentResource.State).
8587
python_version: Python version for the ray cluster (e.g. "3.10").
8688
ray_version: Ray version for the ray cluster (e.g. "2.4").
@@ -102,6 +104,7 @@ class Cluster:
102104

103105
cluster_resource_name: str = None
104106
network: str = None
107+
service_account: str = None
105108
state: PersistentResource.State = None
106109
python_version: str = None
107110
ray_version: str = None

tests/unit/vertex_ray/test_cluster_init.py

+62
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,34 @@ def get_persistent_resource_1_pool_custom_image_mock():
9292
yield get_persistent_resource_1_pool_custom_image_mock
9393

9494

95+
@pytest.fixture
96+
def create_persistent_resource_1_pool_byosa_mock():
97+
with mock.patch.object(
98+
PersistentResourceServiceClient,
99+
"create_persistent_resource",
100+
) as create_persistent_resource_1_pool_byosa_mock:
101+
create_persistent_resource_lro_mock = mock.Mock(ga_operation.Operation)
102+
create_persistent_resource_lro_mock.result.return_value = (
103+
tc.ClusterConstants.TEST_RESPONSE_RUNNING_1_POOL_BYOSA
104+
)
105+
create_persistent_resource_1_pool_byosa_mock.return_value = (
106+
create_persistent_resource_lro_mock
107+
)
108+
yield create_persistent_resource_1_pool_byosa_mock
109+
110+
111+
@pytest.fixture
112+
def get_persistent_resource_1_pool_byosa_mock():
113+
with mock.patch.object(
114+
PersistentResourceServiceClient,
115+
"get_persistent_resource",
116+
) as get_persistent_resource_1_pool_byosa_mock:
117+
get_persistent_resource_1_pool_byosa_mock.return_value = (
118+
tc.ClusterConstants.TEST_RESPONSE_RUNNING_1_POOL_BYOSA
119+
)
120+
yield get_persistent_resource_1_pool_byosa_mock
121+
122+
95123
@pytest.fixture
96124
def create_persistent_resource_2_pools_mock():
97125
with mock.patch.object(
@@ -426,6 +454,30 @@ def test_create_ray_cluster_initialized_success(
426454
]
427455
)
428456

457+
@pytest.mark.usefixtures("get_persistent_resource_1_pool_byosa_mock")
458+
def test_create_ray_cluster_byosa_success(
459+
self, create_persistent_resource_1_pool_byosa_mock
460+
):
461+
"""If head and worker nodes are duplicate, merge to head pool."""
462+
cluster_name = vertex_ray.create_ray_cluster(
463+
head_node_type=tc.ClusterConstants.TEST_HEAD_NODE_TYPE_1_POOL,
464+
worker_node_types=tc.ClusterConstants.TEST_WORKER_NODE_TYPES_1_POOL,
465+
service_account=tc.ProjectConstants.TEST_SERVICE_ACCOUNT,
466+
cluster_name=tc.ClusterConstants.TEST_VERTEX_RAY_PR_ID,
467+
)
468+
469+
assert tc.ClusterConstants.TEST_VERTEX_RAY_PR_ADDRESS == cluster_name
470+
471+
request = persistent_resource_service.CreatePersistentResourceRequest(
472+
parent=tc.ProjectConstants.TEST_PARENT,
473+
persistent_resource=tc.ClusterConstants.TEST_REQUEST_RUNNING_1_POOL_BYOSA,
474+
persistent_resource_id=tc.ClusterConstants.TEST_VERTEX_RAY_PR_ID,
475+
)
476+
477+
create_persistent_resource_1_pool_byosa_mock.assert_called_with(
478+
request,
479+
)
480+
429481
def test_create_ray_cluster_head_multinode_error(self):
430482
with pytest.raises(ValueError) as e:
431483
vertex_ray.create_ray_cluster(
@@ -508,6 +560,16 @@ def test_get_ray_cluster_with_custom_image_success(
508560
get_persistent_resource_2_pools_custom_image_mock.assert_called_once()
509561
cluster_eq(cluster, tc.ClusterConstants.TEST_CLUSTER_CUSTOM_IMAGE)
510562

563+
def test_get_ray_cluster_byosa_success(
564+
self, get_persistent_resource_1_pool_byosa_mock
565+
):
566+
cluster = vertex_ray.get_ray_cluster(
567+
cluster_resource_name=tc.ClusterConstants.TEST_VERTEX_RAY_PR_ADDRESS
568+
)
569+
570+
get_persistent_resource_1_pool_byosa_mock.assert_called_once()
571+
cluster_eq(cluster, tc.ClusterConstants.TEST_CLUSTER_BYOSA)
572+
511573
@pytest.mark.usefixtures("get_persistent_resource_exception_mock")
512574
def test_get_ray_cluster_error(self):
513575
with pytest.raises(ValueError) as e:

tests/unit/vertex_ray/test_constants.py

+84-5
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#
1717

1818
import dataclasses
19+
import sys
1920

2021
from google.cloud.aiplatform.preview.vertex_ray.util.resources import Cluster
2122
from google.cloud.aiplatform.preview.vertex_ray.util.resources import (
@@ -28,10 +29,10 @@
2829
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
2930
PersistentResource,
3031
)
31-
from google.cloud.aiplatform_v1beta1.types.persistent_resource import RaySpec
3232
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
3333
RayMetricSpec,
3434
)
35+
from google.cloud.aiplatform_v1beta1.types.persistent_resource import RaySpec
3536
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
3637
ResourcePool,
3738
)
@@ -41,9 +42,11 @@
4142
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
4243
ResourceRuntimeSpec,
4344
)
44-
45+
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
46+
ServiceAccountSpec,
47+
)
4548
import pytest
46-
import sys
49+
4750

4851
rovminversion = pytest.mark.skipif(
4952
sys.version_info > (3, 10), reason="Requires python3.10 or lower"
@@ -67,6 +70,7 @@ class ProjectConstants:
6770
TEST_MODEL_ID = (
6871
f"projects/{TEST_GCP_PROJECT_NUMBER}/locations/{TEST_GCP_REGION}/models/456"
6972
)
73+
TEST_SERVICE_ACCOUNT = "[email protected]"
7074

7175

7276
@dataclasses.dataclass(frozen=True)
@@ -79,6 +83,9 @@ class ClusterConstants:
7983
TEST_VERTEX_RAY_DASHBOARD_ADDRESS = (
8084
"48b400ad90b8dd3c-dot-us-central1.aiplatform-training.googleusercontent.com"
8185
)
86+
TEST_VERTEX_RAY_CLIENT_ENDPOINT = (
87+
"88888.us-central1-1234567.staging-ray.vertexai.goog:443"
88+
)
8289
TEST_VERTEX_RAY_PR_ID = "user-persistent-resource-1234567890"
8390
TEST_VERTEX_RAY_PR_ADDRESS = (
8491
f"{ProjectConstants.TEST_PARENT}/persistentResources/" + TEST_VERTEX_RAY_PR_ID
@@ -106,7 +113,7 @@ class ClusterConstants:
106113
TEST_RESOURCE_POOL_0 = ResourcePool(
107114
id="head-node",
108115
machine_spec=MachineSpec(
109-
machine_type="n1-standard-8",
116+
machine_type="n1-standard-16",
110117
accelerator_type="NVIDIA_TESLA_P100",
111118
accelerator_count=1,
112119
),
@@ -147,6 +154,20 @@ class ClusterConstants:
147154
),
148155
network=ProjectConstants.TEST_VPC_NETWORK,
149156
)
157+
TEST_REQUEST_RUNNING_1_POOL_BYOSA = PersistentResource(
158+
resource_pools=[TEST_RESOURCE_POOL_0],
159+
resource_runtime_spec=ResourceRuntimeSpec(
160+
ray_spec=RaySpec(
161+
resource_pool_images={"head-node": TEST_GPU_IMAGE},
162+
ray_metric_spec=RayMetricSpec(disabled=False),
163+
),
164+
service_account_spec=ServiceAccountSpec(
165+
enable_custom_service_account=True,
166+
service_account=ProjectConstants.TEST_SERVICE_ACCOUNT,
167+
),
168+
),
169+
network=None,
170+
)
150171
# Get response has generated name, and URIs
151172
TEST_RESPONSE_RUNNING_1_POOL = PersistentResource(
152173
name=TEST_VERTEX_RAY_PR_ADDRESS,
@@ -185,6 +206,50 @@ class ClusterConstants:
185206
),
186207
state="RUNNING",
187208
)
209+
TEST_RESPONSE_RUNNING_1_POOL_BYOSA = PersistentResource(
210+
name=TEST_VERTEX_RAY_PR_ADDRESS,
211+
resource_pools=[TEST_RESOURCE_POOL_0],
212+
resource_runtime_spec=ResourceRuntimeSpec(
213+
ray_spec=RaySpec(
214+
resource_pool_images={"head-node": TEST_GPU_IMAGE},
215+
ray_metric_spec=RayMetricSpec(disabled=False),
216+
),
217+
service_account_spec=ServiceAccountSpec(
218+
enable_custom_service_account=True,
219+
service_account=ProjectConstants.TEST_SERVICE_ACCOUNT,
220+
),
221+
),
222+
network=None,
223+
resource_runtime=ResourceRuntime(
224+
access_uris={
225+
"RAY_DASHBOARD_URI": TEST_VERTEX_RAY_DASHBOARD_ADDRESS,
226+
"RAY_CLIENT_ENDPOINT": TEST_VERTEX_RAY_CLIENT_ENDPOINT,
227+
}
228+
),
229+
state="RUNNING",
230+
)
231+
TEST_RESPONSE_1_POOL_BYOSA_PRIVATE = PersistentResource(
232+
name=TEST_VERTEX_RAY_PR_ADDRESS,
233+
resource_pools=[TEST_RESOURCE_POOL_0],
234+
resource_runtime_spec=ResourceRuntimeSpec(
235+
ray_spec=RaySpec(
236+
resource_pool_images={"head-node": TEST_GPU_IMAGE},
237+
ray_metric_spec=RayMetricSpec(disabled=False),
238+
),
239+
service_account_spec=ServiceAccountSpec(
240+
enable_custom_service_account=True,
241+
service_account=ProjectConstants.TEST_SERVICE_ACCOUNT,
242+
),
243+
),
244+
network=ProjectConstants.TEST_VPC_NETWORK,
245+
resource_runtime=ResourceRuntime(
246+
access_uris={
247+
"RAY_DASHBOARD_URI": TEST_VERTEX_RAY_DASHBOARD_ADDRESS,
248+
"RAY_CLIENT_ENDPOINT": TEST_VERTEX_RAY_CLIENT_ENDPOINT,
249+
}
250+
),
251+
state="RUNNING",
252+
)
188253
# 2_POOL: worker_node_types and head_node_type have different MachineSpecs
189254
TEST_HEAD_NODE_TYPE_2_POOLS = Resources()
190255
TEST_WORKER_NODE_TYPES_2_POOLS = [
@@ -208,7 +273,7 @@ class ClusterConstants:
208273
TEST_RESOURCE_POOL_1 = ResourcePool(
209274
id="head-node",
210275
machine_spec=MachineSpec(
211-
machine_type="n1-standard-8",
276+
machine_type="n1-standard-16",
212277
),
213278
disk_spec=DiskSpec(
214279
boot_disk_type="pd-ssd",
@@ -302,6 +367,7 @@ class ClusterConstants:
302367
python_version="3.10",
303368
ray_version="2.9",
304369
network=ProjectConstants.TEST_VPC_NETWORK,
370+
service_account=None,
305371
state="RUNNING",
306372
head_node_type=TEST_HEAD_NODE_TYPE_1_POOL,
307373
worker_node_types=TEST_WORKER_NODE_TYPES_1_POOL,
@@ -312,6 +378,7 @@ class ClusterConstants:
312378
python_version="3.10",
313379
ray_version="2.9",
314380
network=ProjectConstants.TEST_VPC_NETWORK,
381+
service_account=None,
315382
state="RUNNING",
316383
head_node_type=TEST_HEAD_NODE_TYPE_2_POOLS,
317384
worker_node_types=TEST_WORKER_NODE_TYPES_2_POOLS,
@@ -320,11 +387,23 @@ class ClusterConstants:
320387
TEST_CLUSTER_CUSTOM_IMAGE = Cluster(
321388
cluster_resource_name=TEST_VERTEX_RAY_PR_ADDRESS,
322389
network=ProjectConstants.TEST_VPC_NETWORK,
390+
service_account=None,
323391
state="RUNNING",
324392
head_node_type=TEST_HEAD_NODE_TYPE_2_POOLS_CUSTOM_IMAGE,
325393
worker_node_types=TEST_WORKER_NODE_TYPES_2_POOLS_CUSTOM_IMAGE,
326394
dashboard_address=TEST_VERTEX_RAY_DASHBOARD_ADDRESS,
327395
)
396+
TEST_CLUSTER_BYOSA = Cluster(
397+
cluster_resource_name=TEST_VERTEX_RAY_PR_ADDRESS,
398+
python_version="3.10",
399+
ray_version="2.9",
400+
network="",
401+
service_account=ProjectConstants.TEST_SERVICE_ACCOUNT,
402+
state="RUNNING",
403+
head_node_type=TEST_HEAD_NODE_TYPE_1_POOL,
404+
worker_node_types=TEST_WORKER_NODE_TYPES_1_POOL,
405+
dashboard_address=TEST_VERTEX_RAY_DASHBOARD_ADDRESS,
406+
)
328407
TEST_BEARER_TOKEN = "test-bearer-token"
329408
TEST_HEADERS = {
330409
"Content-Type": "application/json",

0 commit comments

Comments
 (0)