Skip to content

Commit d4cae46

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: Support global endpoint natively
PiperOrigin-RevId: 723767182
1 parent d205601 commit d4cae46

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

google/cloud/aiplatform/initializer.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -231,17 +231,26 @@ def init(
231231
f"{api_transport} is not a valid transport type. "
232232
+ f"Valid transport types: {VALID_TRANSPORT_TYPES}"
233233
)
234-
else:
235234
# Raise error if api_transport other than rest is specified for usage with API key.
235+
elif api_key and api_transport != "rest":
236+
raise ValueError(f"{api_transport} is not supported with API keys. ")
237+
else:
236238
if not project and not api_transport:
237239
api_transport = "rest"
238-
elif not project and api_transport != "rest":
239-
raise ValueError(f"{api_transport} is not supported with API keys. ")
240+
240241
if location:
241242
utils.validate_region(location)
243+
# Set api_transport as "rest" if location is "global".
244+
if location == "global" and not api_transport:
245+
self._api_transport = "rest"
246+
elif location == "global" and api_transport == "grpc":
247+
raise ValueError(
248+
"api_transport cannot be 'grpc' when location is 'global'."
249+
)
242250
if experiment_description and experiment is None:
243251
raise ValueError(
244-
"Experiment needs to be set in `init` in order to add experiment descriptions."
252+
"Experiment needs to be set in `init` in order to add experiment"
253+
" descriptions."
245254
)
246255

247256
# reset metadata_service config if project or location is updated.
@@ -464,8 +473,9 @@ def get_client_options(
464473
and not self._project
465474
and not self._location
466475
and not location_override
467-
):
468-
# Default endpoint is location invariant if using API key
476+
) or (self._location == "global"):
477+
# Default endpoint is location invariant if using API key or global
478+
# location.
469479
api_endpoint = "aiplatform.googleapis.com"
470480

471481
# If both project and API key are passed in, project takes precedence.

tests/unit/aiplatform/test_initializer.py

+20
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,26 @@ def test_create_client_with_invalid_api_transport_override(self, api_transport):
295295
api_transport=api_transport,
296296
)
297297

298+
def test_create_client_with_global_location(self):
299+
initializer.global_config.init(project=_TEST_PROJECT, location="global")
300+
client = initializer.global_config.create_client(
301+
client_class=utils.PredictionClientWithOverride
302+
)
303+
assert initializer.global_config.location == "global"
304+
assert initializer.global_config._api_transport == "rest"
305+
assert isinstance(client, utils.PredictionClientWithOverride)
306+
assert client._transport._host == f"https://{constants.API_BASE_PATH}"
307+
308+
def test_create_client_with_global_location_and_grpc_transport(self):
309+
with pytest.raises(ValueError):
310+
initializer.global_config.init(
311+
project=_TEST_PROJECT, location="global", api_transport="grpc"
312+
)
313+
314+
def test_create_client_with_api_key_and_grpc_transport(self):
315+
with pytest.raises(ValueError):
316+
initializer.global_config.init(api_key="test_api_key", api_transport="grpc")
317+
298318
def test_create_client_overrides(self):
299319
initializer.global_config.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
300320
creds = credentials.AnonymousCredentials()

0 commit comments

Comments
 (0)