14
14
# See the License for the specific language governing permissions and
15
15
# limitations under the License.
16
16
#
17
+ import grpc
17
18
import logging
18
19
from typing import Dict
19
20
from typing import Optional
20
21
from google .cloud import aiplatform
22
+ from google .cloud .aiplatform import initializer
21
23
from ray import client_builder
22
24
from .render import VertexRayTemplate
23
25
from .util import _validation_utils
@@ -80,7 +82,8 @@ class VertexRayClientBuilder(client_builder.ClientBuilder):
80
82
def __init__ (self , address : Optional [str ]) -> None :
81
83
address = _validation_utils .maybe_reconstruct_resource_name (address )
82
84
_validation_utils .valid_resource_name (address )
83
-
85
+ self ._credentials = None
86
+ self ._metadata = None
84
87
self .vertex_address = address
85
88
logging .info (
86
89
"[Ray on Vertex AI]: Using cluster resource name to access head address with GAPIC API"
@@ -89,9 +92,17 @@ def __init__(self, address: Optional[str]) -> None:
89
92
self .resource_name = address
90
93
91
94
self .response = _gapic_utils .get_persistent_resource (self .resource_name )
92
- address = self .response .resource_runtime .access_uris .get (
95
+ private_address = self .response .resource_runtime .access_uris .get (
93
96
"RAY_HEAD_NODE_INTERNAL_IP"
94
97
)
98
+ public_address = self .response .resource_runtime .access_uris .get (
99
+ "RAY_CLIENT_ENDPOINT"
100
+ )
101
+ if public_address is None :
102
+ address = private_address
103
+ else :
104
+ address = public_address
105
+
95
106
if address is None :
96
107
persistent_resource_id = self .resource_name .split ("/" )[5 ]
97
108
raise ValueError (
@@ -143,6 +154,17 @@ def __init__(self, address: Optional[str]) -> None:
143
154
def connect (self ) -> _VertexRayClientContext :
144
155
# Can send any other params to ray cluster here
145
156
logging .info ("[Ray on Vertex AI]: Connecting..." )
157
+
158
+ public_address = self .response .resource_runtime .access_uris .get (
159
+ "RAY_CLIENT_ENDPOINT"
160
+ )
161
+ if public_address :
162
+ self ._credentials = grpc .ssl_channel_credentials ()
163
+ bearer_token = _validation_utils .get_bearer_token ()
164
+ self ._metadata = [
165
+ ("authorization" , "Bearer {}" .format (bearer_token )),
166
+ ("x-goog-user-project" , "{}" .format (initializer .global_config .project )),
167
+ ]
146
168
ray_client_context = super ().connect ()
147
169
ray_head_uris = self .response .resource_runtime .access_uris
148
170
0 commit comments