Skip to content

Commit 57a5f78

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: Support public endpoint for Ray Client
PiperOrigin-RevId: 630181847
1 parent 4ce2f60 commit 57a5f78

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

google/cloud/aiplatform/preview/vertex_ray/client_builder.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
import grpc
1718
import logging
1819
from typing import Dict
1920
from typing import Optional
2021
from google.cloud import aiplatform
22+
from google.cloud.aiplatform import initializer
2123
from ray import client_builder
2224
from .render import VertexRayTemplate
2325
from .util import _validation_utils
@@ -80,7 +82,8 @@ class VertexRayClientBuilder(client_builder.ClientBuilder):
8082
def __init__(self, address: Optional[str]) -> None:
8183
address = _validation_utils.maybe_reconstruct_resource_name(address)
8284
_validation_utils.valid_resource_name(address)
83-
85+
self._credentials = None
86+
self._metadata = None
8487
self.vertex_address = address
8588
logging.info(
8689
"[Ray on Vertex AI]: Using cluster resource name to access head address with GAPIC API"
@@ -89,9 +92,17 @@ def __init__(self, address: Optional[str]) -> None:
8992
self.resource_name = address
9093

9194
self.response = _gapic_utils.get_persistent_resource(self.resource_name)
92-
address = self.response.resource_runtime.access_uris.get(
95+
private_address = self.response.resource_runtime.access_uris.get(
9396
"RAY_HEAD_NODE_INTERNAL_IP"
9497
)
98+
public_address = self.response.resource_runtime.access_uris.get(
99+
"RAY_CLIENT_ENDPOINT"
100+
)
101+
if public_address is None:
102+
address = private_address
103+
else:
104+
address = public_address
105+
95106
if address is None:
96107
persistent_resource_id = self.resource_name.split("/")[5]
97108
raise ValueError(
@@ -143,6 +154,17 @@ def __init__(self, address: Optional[str]) -> None:
143154
def connect(self) -> _VertexRayClientContext:
144155
# Can send any other params to ray cluster here
145156
logging.info("[Ray on Vertex AI]: Connecting...")
157+
158+
public_address = self.response.resource_runtime.access_uris.get(
159+
"RAY_CLIENT_ENDPOINT"
160+
)
161+
if public_address:
162+
self._credentials = grpc.ssl_channel_credentials()
163+
bearer_token = _validation_utils.get_bearer_token()
164+
self._metadata = [
165+
("authorization", "Bearer {}".format(bearer_token)),
166+
("x-goog-user-project", "{}".format(initializer.global_config.project)),
167+
]
146168
ray_client_context = super().connect()
147169
ray_head_uris = self.response.resource_runtime.access_uris
148170

0 commit comments

Comments
 (0)