Skip to content

Commit 662d039

Browse files
yinghsienwucopybara-github
authored andcommitted
fix: support VPC and BYOSA case in Ray on Vertex JobSubmissionClient using cluster resource name
PiperOrigin-RevId: 642446002
1 parent 17c59c4 commit 662d039

File tree

3 files changed

+33
-45
lines changed

3 files changed

+33
-45
lines changed

google/cloud/aiplatform/vertex_ray/client_builder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def __init__(self, address: Optional[str]) -> None:
110110
address,
111111
" failed to start Head node properly because custom service"
112112
" account isn't supported in peered VPC network. Use public"
113-
" endpoint instead (createa a cluster withought specifying"
113+
" endpoint instead (createa a cluster without specifying"
114114
" VPC network).",
115115
)
116116
else:

google/cloud/aiplatform/vertex_ray/dashboard_sdk.py

+19-41
Original file line numberDiff line numberDiff line change
@@ -46,55 +46,33 @@ def get_job_submission_client_cluster_info(
4646
Raises:
4747
RuntimeError if head_address is None.
4848
"""
49-
# If passing the dashboard uri, programmatically get headers
5049
if _validation_utils.valid_dashboard_address(address):
51-
bearer_token = _validation_utils.get_bearer_token()
52-
if kwargs.get("headers", None) is None:
53-
kwargs["headers"] = {
54-
"Content-Type": "application/json",
55-
"Authorization": "Bearer {}".format(bearer_token),
56-
}
57-
return oss_dashboard_sdk.get_job_submission_client_cluster_info(
58-
address=address,
59-
_use_tls=True,
60-
*args,
61-
**kwargs,
62-
)
63-
address = _validation_utils.maybe_reconstruct_resource_name(address)
64-
_validation_utils.valid_resource_name(address)
50+
dashboard_address = address
51+
else:
52+
address = _validation_utils.maybe_reconstruct_resource_name(address)
53+
_validation_utils.valid_resource_name(address)
54+
55+
resource_name = address
56+
response = _gapic_utils.get_persistent_resource(resource_name)
6557

66-
resource_name = address
67-
response = _gapic_utils.get_persistent_resource(resource_name)
68-
head_address = response.resource_runtime.access_uris.get(
69-
"RAY_HEAD_NODE_INTERNAL_IP", None
70-
)
71-
if head_address is None:
72-
# No peering. Try to get the dashboard address.
7358
dashboard_address = response.resource_runtime.access_uris.get(
7459
"RAY_DASHBOARD_URI", None
7560
)
61+
7662
if dashboard_address is None:
7763
raise RuntimeError(
7864
"[Ray on Vertex AI]: Unable to obtain a response from the backend."
7965
)
80-
if _validation_utils.valid_dashboard_address(dashboard_address):
81-
bearer_token = _validation_utils.get_bearer_token()
82-
if kwargs.get("headers", None) is None:
83-
kwargs["headers"] = {
84-
"Content-Type": "application/json",
85-
"Authorization": "Bearer {}".format(bearer_token),
86-
}
87-
return oss_dashboard_sdk.get_job_submission_client_cluster_info(
88-
address=dashboard_address,
89-
_use_tls=True,
90-
*args,
91-
**kwargs,
92-
)
93-
# Assume that head node internal IP in a form of xxx.xxx.xxx.xxx:10001.
94-
# Ray-on-Vertex cluster serves the Dashboard at port 8888 instead of
95-
# the default 8251.
96-
head_address = ":".join([head_address.split(":")[0], "8888"])
97-
66+
# If passing the dashboard uri, programmatically get headers
67+
bearer_token = _validation_utils.get_bearer_token()
68+
if kwargs.get("headers", None) is None:
69+
kwargs["headers"] = {
70+
"Content-Type": "application/json",
71+
"Authorization": "Bearer {}".format(bearer_token),
72+
}
9873
return oss_dashboard_sdk.get_job_submission_client_cluster_info(
99-
address=head_address, *args, **kwargs
74+
address=dashboard_address,
75+
_use_tls=True,
76+
*args,
77+
**kwargs,
10078
)

tests/unit/vertex_ray/test_dashboard_sdk.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -73,16 +73,22 @@ def setup_method(self):
7373
def teardown_method(self):
7474
aiplatform.initializer.global_pool.shutdown(wait=True)
7575

76-
@pytest.mark.usefixtures("get_persistent_resource_status_running_mock")
76+
@pytest.mark.usefixtures(
77+
"get_persistent_resource_status_running_mock", "google_auth_mock"
78+
)
7779
def test_job_submission_client_cluster_info_with_full_resource_name(
7880
self,
7981
ray_get_job_submission_client_cluster_info_mock,
82+
get_bearer_token_mock,
8083
):
8184
vertex_ray.get_job_submission_client_cluster_info(
8285
tc.ClusterConstants.TEST_VERTEX_RAY_PR_ADDRESS
8386
)
87+
get_bearer_token_mock.assert_called_once_with()
8488
ray_get_job_submission_client_cluster_info_mock.assert_called_once_with(
85-
address=tc.ClusterConstants.TEST_VERTEX_RAY_JOB_CLIENT_IP
89+
address=tc.ClusterConstants.TEST_VERTEX_RAY_DASHBOARD_ADDRESS,
90+
_use_tls=True,
91+
headers=tc.ClusterConstants.TEST_HEADERS,
8692
)
8793

8894
@pytest.mark.usefixtures(
@@ -92,6 +98,7 @@ def test_job_submission_client_cluster_info_with_cluster_name(
9298
self,
9399
ray_get_job_submission_client_cluster_info_mock,
94100
get_project_number_mock,
101+
get_bearer_token_mock,
95102
):
96103
aiplatform.init(project=tc.ProjectConstants.TEST_GCP_PROJECT_ID)
97104

@@ -101,8 +108,11 @@ def test_job_submission_client_cluster_info_with_cluster_name(
101108
get_project_number_mock.assert_called_once_with(
102109
name="projects/{}".format(tc.ProjectConstants.TEST_GCP_PROJECT_ID)
103110
)
111+
get_bearer_token_mock.assert_called_once_with()
104112
ray_get_job_submission_client_cluster_info_mock.assert_called_once_with(
105-
address=tc.ClusterConstants.TEST_VERTEX_RAY_JOB_CLIENT_IP
113+
address=tc.ClusterConstants.TEST_VERTEX_RAY_DASHBOARD_ADDRESS,
114+
_use_tls=True,
115+
headers=tc.ClusterConstants.TEST_HEADERS,
106116
)
107117

108118
@pytest.mark.usefixtures(

0 commit comments

Comments
 (0)