Skip to content

Commit 6ab4084

Browse files
matthew29tangcopybara-github
authored andcommitted
feat: Add transport override to enable the use of REST instead of GRPC
PiperOrigin-RevId: 611159115
1 parent 02829f1 commit 6ab4084

File tree

5 files changed

+202
-55
lines changed

5 files changed

+202
-55
lines changed

google/cloud/aiplatform/initializer.py

+24
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def __init__(self):
104104
self._network = None
105105
self._service_account = None
106106
self._api_endpoint = None
107+
self._api_transport = None
107108

108109
def init(
109110
self,
@@ -121,6 +122,7 @@ def init(
121122
network: Optional[str] = None,
122123
service_account: Optional[str] = None,
123124
api_endpoint: Optional[str] = None,
125+
api_transport: Optional[str] = None,
124126
):
125127
"""Updates common initialization parameters with provided options.
126128
@@ -179,6 +181,8 @@ def init(
179181
api_endpoint (str):
180182
Optional. The desired API endpoint,
181183
e.g., us-central1-aiplatform.googleapis.com
184+
api_transport (str):
185+
Optional. The transport method which is either 'grpc' or 'rest'
182186
Raises:
183187
ValueError:
184188
If experiment_description is provided but experiment is not.
@@ -231,6 +235,15 @@ def init(
231235
backing_tensorboard=experiment_tensorboard,
232236
)
233237

238+
if api_transport:
239+
VALID_TRANSPORT_TYPES = ["grpc", "rest"]
240+
if api_transport not in VALID_TRANSPORT_TYPES:
241+
raise ValueError(
242+
f"{api_transport} is not a valid transport type. "
243+
+ f"Valid transport types: {VALID_TRANSPORT_TYPES}"
244+
)
245+
self._api_transport = api_transport
246+
234247
def get_encryption_spec(
235248
self,
236249
encryption_spec_key_name: Optional[str],
@@ -481,6 +494,17 @@ def create_client(
481494
"client_info": client_info,
482495
}
483496

497+
# Do not pass "grpc", rely on gapic defaults unless "rest" is specified
498+
if self._api_transport == "rest":
499+
if "Async" in client_class.__name__:
500+
# Warn user that "rest" is not supported and use grpc instead
501+
logging.warning(
502+
"REST is not supported for async clients, "
503+
+ "falling back to grpc."
504+
)
505+
else:
506+
kwargs["transport"] = self._api_transport
507+
484508
return client_class(**kwargs)
485509

486510

google/cloud/aiplatform/utils/__init__.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ def __init__(
388388
client_options: client_options.ClientOptions,
389389
client_info: gapic_v1.client_info.ClientInfo,
390390
credentials: Optional[auth_credentials.Credentials] = None,
391+
transport: Optional[str] = None,
391392
):
392393
"""Stores parameters needed to instantiate client.
393394
@@ -400,20 +401,30 @@ def __init__(
400401
Required. Client info to pass to client.
401402
credentials (auth_credentials.credentials):
402403
Optional. Client credentials to pass to client.
404+
transport (str):
405+
Optional. Transport type to pass to client.
403406
"""
404407

405408
self._client_class = client_class
406409
self._credentials = credentials
407410
self._client_options = client_options
408411
self._client_info = client_info
412+
self._api_transport = transport
409413

410414
def __getattr__(self, name: str) -> Any:
411415
"""Instantiates client and returns attribute of the client."""
412-
temporary_client = self._client_class(
416+
417+
kwargs = dict(
413418
credentials=self._credentials,
414419
client_options=self._client_options,
415420
client_info=self._client_info,
416421
)
422+
423+
if self._api_transport is not None:
424+
kwargs["transport"] = self._api_transport
425+
426+
temporary_client = self._client_class(**kwargs)
427+
417428
return getattr(temporary_client, name)
418429

419430
@property
@@ -448,6 +459,7 @@ def __init__(
448459
client_options: client_options.ClientOptions,
449460
client_info: gapic_v1.client_info.ClientInfo,
450461
credentials: Optional[auth_credentials.Credentials] = None,
462+
transport: Optional[str] = None,
451463
):
452464
"""Stores parameters needed to instantiate client.
453465
@@ -458,21 +470,28 @@ def __init__(
458470
Required. Client info to pass to client.
459471
credentials (auth_credentials.credentials):
460472
Optional. Client credentials to pass to client.
473+
transport (str):
474+
Optional. Transport type to pass to client.
461475
"""
476+
kwargs = dict(
477+
credentials=credentials,
478+
client_options=client_options,
479+
client_info=client_info,
480+
)
481+
482+
if transport is not None:
483+
kwargs["transport"] = transport
462484

463485
self._clients = {
464486
version: self.WrappedClient(
465487
client_class=client_class,
466488
client_options=client_options,
467489
client_info=client_info,
468490
credentials=credentials,
491+
transport=transport,
469492
)
470493
if self._is_temporary
471-
else client_class(
472-
client_options=client_options,
473-
client_info=client_info,
474-
credentials=credentials,
475-
)
494+
else client_class(**kwargs)
476495
for version, client_class in self._version_map
477496
}
478497

0 commit comments

Comments
 (0)