Skip to content

Commit c0b31e2

Browse files
matthew29tangcopybara-github
authored andcommitted
feat: Add private async REST support for transport override
PiperOrigin-RevId: 691063218
1 parent 82bb938 commit c0b31e2

13 files changed

+236
-44
lines changed

google/cloud/aiplatform/initializer.py

+50-8
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import logging
2424
import os
2525
import types
26-
from typing import Iterator, List, Optional, Sequence, Tuple, Type, TypeVar, Union
26+
from typing import Any, Iterator, List, Optional, Sequence, Tuple, Type, TypeVar, Union
2727

2828
from google.api_core import client_options
2929
from google.api_core import gapic_v1
@@ -46,6 +46,15 @@
4646
encryption_spec_v1beta1 as gca_encryption_spec_v1beta1,
4747
)
4848

49+
try:
50+
import google.auth.aio
51+
52+
AsyncCredentials = google.auth.aio.credentials.Credentials
53+
_HAS_ASYNC_CRED_DEPS = True
54+
except ImportError:
55+
AsyncCredentials = Any
56+
_HAS_ASYNC_CRED_DEPS = False
57+
4958
_TVertexAiServiceClientWithOverride = TypeVar(
5059
"_TVertexAiServiceClientWithOverride",
5160
bound=utils.VertexAiServiceClientWithOverride,
@@ -121,6 +130,7 @@ def __init__(self):
121130
self._api_transport = None
122131
self._request_metadata = None
123132
self._resource_type = None
133+
self._async_rest_credentials = None
124134

125135
def init(
126136
self,
@@ -590,15 +600,24 @@ def create_client(
590600
}
591601

592602
# Do not pass "grpc", rely on gapic defaults unless "rest" is specified
593-
if self._api_transport == "rest":
594-
if "Async" in client_class.__name__:
595-
# Warn user that "rest" is not supported and use grpc instead
603+
if self._api_transport == "rest" and "Async" in client_class.__name__:
604+
# User requests async rest
605+
if self._async_rest_credentials:
606+
# Rest async recieves credentials from _async_rest_credentials
607+
kwargs["credentials"] = self._async_rest_credentials
608+
kwargs["transport"] = "rest_asyncio"
609+
else:
610+
# Rest async was specified, but no async credentials were set.
611+
# Fallback to gRPC instead.
596612
logging.warning(
597-
"REST is not supported for async clients, "
598-
+ "falling back to grpc."
613+
"REST async clients requires async credentials set using "
614+
+ "aiplatform.initializer._set_async_rest_credentials().\n"
615+
+ "Falling back to grpc since no async rest credentials "
616+
+ "were detected."
599617
)
600-
else:
601-
kwargs["transport"] = self._api_transport
618+
elif self._api_transport == "rest":
619+
# User requests sync REST
620+
kwargs["transport"] = self._api_transport
602621

603622
client = client_class(**kwargs)
604623
# We only wrap the client if the request_metadata is set at the creation time.
@@ -672,6 +691,29 @@ def __call__(self, *args, **kwargs):
672691
)
673692

674693

694+
def _set_async_rest_credentials(credentials: AsyncCredentials):
695+
"""Private method to set async REST credentials."""
696+
if global_config._api_transport != "rest":
697+
raise ValueError(
698+
"Async REST credentials can only be set when using REST transport."
699+
)
700+
elif not _HAS_ASYNC_CRED_DEPS or not isinstance(credentials, AsyncCredentials):
701+
raise ValueError(
702+
"Async REST transport requires async credentials of type"
703+
+ f"{AsyncCredentials} which is only supported in "
704+
+ "google-auth >= 2.35.0.\n\n"
705+
+ "Install the following dependencies:\n"
706+
+ "pip install google-api-core[grpc, async_rest] >= 2.21.0\n"
707+
+ "pip install google-auth[aiohttp] >= 2.35.0\n\n"
708+
+ "Example usage:\n"
709+
+ "from google.auth.aio.credentials import StaticCredentials\n"
710+
+ "async_credentials = StaticCredentials(token=YOUR_TOKEN_HERE)\n"
711+
+ "aiplatform.initializer._set_async_rest_credentials("
712+
+ "credentials=async_credentials)"
713+
)
714+
global_config._async_rest_credentials = credentials
715+
716+
675717
def _get_function_name_from_stack_frame(frame) -> str:
676718
"""Gates fully qualified function or method name.
677719

setup.py

+3
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@
195195
+ profiler_extra_require
196196
+ tokenization_testing_extra_require
197197
+ [
198+
# aiohttp is required for async rest tests (need google-auth[aiohttp],
199+
# but can't specify extras in constraints files)
200+
"aiohttp",
198201
"bigframes; python_version>='3.10'",
199202
# google-api-core 2.x is required since kfp requires protobuf > 4
200203
"google-api-core >= 2.11, < 3.0.0",

testing/constraints-3.10.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# -*- coding: utf-8 -*-
22
# This constraints file is required for unit tests.
33
# List all library dependencies and extras in this file.
4-
google-api-core
4+
google-api-core==2.21.0 # Tests google-api-core with rest async support
5+
google-auth==2.35.0 # Tests google-auth with rest async support
56
proto-plus==1.22.3
67
protobuf
78
mock==4.0.2

testing/constraints-3.11.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# -*- coding: utf-8 -*-
22
# This constraints file is required for unit tests.
33
# List all library dependencies and extras in this file.
4-
google-api-core
4+
google-api-core==2.21.0 # Tests google-api-core with rest async support
5+
google-auth==2.35.0 # Tests google-auth with rest async support
56
proto-plus
67
protobuf
78
mock==4.0.2

testing/constraints-3.12.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# -*- coding: utf-8 -*-
22
# This constraints file is required for unit tests.
33
# List all library dependencies and extras in this file.
4-
google-api-core
4+
google-api-core==2.21.0 # Tests google-api-core with rest async support
5+
google-auth==2.35.0 # Tests google-auth with rest async support
56
proto-plus
67
protobuf
78
mock==4.0.2

testing/constraints-3.8.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# are correct in setup.py
44
# List *all* library dependencies and extras in this file.
55
google-api-core==2.17.1 # Increased for gapic owlbot presubmit tests
6-
google-auth==2.14.1
6+
google-auth==2.14.1 # Tests google-auth without rest async support
77
proto-plus==1.22.3
88
protobuf
99
mock==4.0.2

testing/constraints-3.9.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# -*- coding: utf-8 -*-
22
# This constraints file is required for unit tests.
33
# List all library dependencies and extras in this file.
4-
google-api-core
4+
google-api-core==2.21.0 # Tests google-api-core with rest async support
5+
google-auth==2.35.0 # Tests google-auth with rest async support
56
proto-plus==1.22.3
67
protobuf
78
mock==4.0.2

tests/system/aiplatform/test_initializer.py

+18
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
# limitations under the License.
1616
#
1717

18+
import pytest
19+
1820
from google.auth import credentials as auth_credentials
1921

2022
from google.cloud import aiplatform
23+
from google.cloud.aiplatform import initializer as aiplatform_initializer
2124
from tests.system.aiplatform import e2e_base
2225

2326

@@ -39,3 +42,18 @@ def test_init_calls_set_google_auth_default(self):
3942
# init() with only project shouldn't overwrite creds
4043
aiplatform.init(project=e2e_base._PROJECT)
4144
assert aiplatform.initializer.global_config.credentials == creds
45+
46+
def test_init_rest_async_incorrect_credentials(self):
47+
# Async REST credentials must be explicitly set using
48+
# _set_async_rest_credentials() for async REST transport.
49+
creds = auth_credentials.AnonymousCredentials()
50+
aiplatform.init(
51+
project=e2e_base._PROJECT,
52+
location=e2e_base._LOCATION,
53+
api_transport="rest",
54+
)
55+
56+
# System tests are run on Python 3.10 which has async deps.
57+
with pytest.raises(ValueError):
58+
# Expect a ValueError for passing in sync credentials.
59+
aiplatform_initializer._set_async_rest_credentials(credentials=creds)

tests/system/aiplatform/test_language_models.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,10 @@ def test_text_generation_preview_count_tokens(self, api_transport):
8686
assert response.total_billable_characters
8787

8888
@pytest.mark.asyncio
89-
@pytest.mark.parametrize("api_transport", ["grpc", "rest"])
90-
async def test_text_generation_model_predict_async(self, api_transport):
89+
async def test_text_generation_model_predict_async(self):
9190
aiplatform.init(
9291
project=e2e_base._PROJECT,
9392
location=e2e_base._LOCATION,
94-
api_transport=api_transport,
9593
)
9694

9795
model = TextGenerationModel.from_pretrained("google/text-bison@001")
@@ -227,12 +225,10 @@ def test_chat_model_preview_count_tokens(self, api_transport):
227225
)
228226

229227
@pytest.mark.asyncio
230-
@pytest.mark.parametrize("api_transport", ["grpc", "rest"])
231-
async def test_chat_model_async(self, api_transport):
228+
async def test_chat_model_async(self):
232229
aiplatform.init(
233230
project=e2e_base._PROJECT,
234231
location=e2e_base._LOCATION,
235-
api_transport=api_transport,
236232
)
237233

238234
chat_model = ChatModel.from_pretrained("google/chat-bison@001")
@@ -343,12 +339,10 @@ def test_text_embedding(self, api_transport):
343339
assert embeddings[1].statistics.truncated
344340

345341
@pytest.mark.asyncio
346-
@pytest.mark.parametrize("api_transport", ["grpc", "rest"])
347-
async def test_text_embedding_async(self, api_transport):
342+
async def test_text_embedding_async(self):
348343
aiplatform.init(
349344
project=e2e_base._PROJECT,
350345
location=e2e_base._LOCATION,
351-
api_transport=api_transport,
352346
)
353347

354348
model = TextEmbeddingModel.from_pretrained("google/textembedding-gecko@001")

0 commit comments

Comments
 (0)