|
23 | 23 | import logging
|
24 | 24 | import os
|
25 | 25 | import types
|
26 |
| -from typing import Iterator, List, Optional, Sequence, Tuple, Type, TypeVar, Union |
| 26 | +from typing import Any, Iterator, List, Optional, Sequence, Tuple, Type, TypeVar, Union |
27 | 27 |
|
28 | 28 | from google.api_core import client_options
|
29 | 29 | from google.api_core import gapic_v1
|
|
46 | 46 | encryption_spec_v1beta1 as gca_encryption_spec_v1beta1,
|
47 | 47 | )
|
48 | 48 |
|
| 49 | +try: |
| 50 | + import google.auth.aio |
| 51 | + |
| 52 | + AsyncCredentials = google.auth.aio.credentials.Credentials |
| 53 | + _HAS_ASYNC_CRED_DEPS = True |
| 54 | +except ImportError: |
| 55 | + AsyncCredentials = Any |
| 56 | + _HAS_ASYNC_CRED_DEPS = False |
| 57 | + |
49 | 58 | _TVertexAiServiceClientWithOverride = TypeVar(
|
50 | 59 | "_TVertexAiServiceClientWithOverride",
|
51 | 60 | bound=utils.VertexAiServiceClientWithOverride,
|
@@ -121,6 +130,7 @@ def __init__(self):
|
121 | 130 | self._api_transport = None
|
122 | 131 | self._request_metadata = None
|
123 | 132 | self._resource_type = None
|
| 133 | + self._async_rest_credentials = None |
124 | 134 |
|
125 | 135 | def init(
|
126 | 136 | self,
|
@@ -590,15 +600,24 @@ def create_client(
|
590 | 600 | }
|
591 | 601 |
|
592 | 602 | # Do not pass "grpc", rely on gapic defaults unless "rest" is specified
|
593 |
| - if self._api_transport == "rest": |
594 |
| - if "Async" in client_class.__name__: |
595 |
| - # Warn user that "rest" is not supported and use grpc instead |
| 603 | + if self._api_transport == "rest" and "Async" in client_class.__name__: |
| 604 | + # User requests async rest |
| 605 | + if self._async_rest_credentials: |
| 606 | + # Rest async recieves credentials from _async_rest_credentials |
| 607 | + kwargs["credentials"] = self._async_rest_credentials |
| 608 | + kwargs["transport"] = "rest_asyncio" |
| 609 | + else: |
| 610 | + # Rest async was specified, but no async credentials were set. |
| 611 | + # Fallback to gRPC instead. |
596 | 612 | logging.warning(
|
597 |
| - "REST is not supported for async clients, " |
598 |
| - + "falling back to grpc." |
| 613 | + "REST async clients requires async credentials set using " |
| 614 | + + "aiplatform.initializer._set_async_rest_credentials().\n" |
| 615 | + + "Falling back to grpc since no async rest credentials " |
| 616 | + + "were detected." |
599 | 617 | )
|
600 |
| - else: |
601 |
| - kwargs["transport"] = self._api_transport |
| 618 | + elif self._api_transport == "rest": |
| 619 | + # User requests sync REST |
| 620 | + kwargs["transport"] = self._api_transport |
602 | 621 |
|
603 | 622 | client = client_class(**kwargs)
|
604 | 623 | # We only wrap the client if the request_metadata is set at the creation time.
|
@@ -672,6 +691,29 @@ def __call__(self, *args, **kwargs):
|
672 | 691 | )
|
673 | 692 |
|
674 | 693 |
|
| 694 | +def _set_async_rest_credentials(credentials: AsyncCredentials): |
| 695 | + """Private method to set async REST credentials.""" |
| 696 | + if global_config._api_transport != "rest": |
| 697 | + raise ValueError( |
| 698 | + "Async REST credentials can only be set when using REST transport." |
| 699 | + ) |
| 700 | + elif not _HAS_ASYNC_CRED_DEPS or not isinstance(credentials, AsyncCredentials): |
| 701 | + raise ValueError( |
| 702 | + "Async REST transport requires async credentials of type" |
| 703 | + + f"{AsyncCredentials} which is only supported in " |
| 704 | + + "google-auth >= 2.35.0.\n\n" |
| 705 | + + "Install the following dependencies:\n" |
| 706 | + + "pip install google-api-core[grpc, async_rest] >= 2.21.0\n" |
| 707 | + + "pip install google-auth[aiohttp] >= 2.35.0\n\n" |
| 708 | + + "Example usage:\n" |
| 709 | + + "from google.auth.aio.credentials import StaticCredentials\n" |
| 710 | + + "async_credentials = StaticCredentials(token=YOUR_TOKEN_HERE)\n" |
| 711 | + + "aiplatform.initializer._set_async_rest_credentials(" |
| 712 | + + "credentials=async_credentials)" |
| 713 | + ) |
| 714 | + global_config._async_rest_credentials = credentials |
| 715 | + |
| 716 | + |
675 | 717 | def _get_function_name_from_stack_frame(frame) -> str:
|
676 | 718 | """Gates fully qualified function or method name.
|
677 | 719 |
|
|
0 commit comments