Skip to content

Commit 984dd80

Browse files
fix: Modify AlloyDBClient to use sync transport for sync connector (#442)
1 parent e4e356b commit 984dd80

File tree

3 files changed

+185
-7
lines changed

3 files changed

+185
-7
lines changed

google/cloud/alloydb/connector/client.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,28 @@ def __init__(
7979
requests to AlloyDB APIs.
8080
Optional, defaults to None and creates new client.
8181
driver (str): Database driver to be used by the client.
82+
user_agent (str): The custom user-agent string to use in the HTTP
83+
header when making requests to AlloyDB APIs.
84+
Optional, defaults to None and uses a pre-defined one.
8285
"""
8386
user_agent = _format_user_agent(driver, user_agent)
8487

85-
self._client = (
86-
client
87-
if client
88-
else v1beta.AlloyDBAdminAsyncClient(
88+
# TODO(rhatgadkar-goog): Rollback the PR of deciding between creating
89+
# AlloyDBAdminClient or AlloyDBAdminAsyncClient when either
90+
# https://github.com/grpc/grpc/issues/25364 is resolved or an async REST
91+
# transport for AlloyDBAdminAsyncClient gets introduced.
92+
# The issue is that the async gRPC transport does not work with multiple
93+
# event loops in the same process. So all calls to the AlloyDB Admin
94+
# API, even from multiple threads, need to be made to a single-event
95+
# loop. See https://github.com/GoogleCloudPlatform/alloydb-python-connector/issues/435
96+
# for more details.
97+
self._is_sync = False
98+
if client:
99+
self._client = client
100+
elif driver == "pg8000":
101+
self._client = v1beta.AlloyDBAdminClient(
89102
credentials=credentials,
103+
transport="grpc",
90104
client_options=ClientOptions(
91105
api_endpoint=alloydb_api_endpoint,
92106
quota_project_id=quota_project,
@@ -95,7 +109,20 @@ def __init__(
95109
user_agent=user_agent,
96110
),
97111
)
98-
)
112+
self._is_sync = True
113+
else:
114+
self._client = v1beta.AlloyDBAdminAsyncClient(
115+
credentials=credentials,
116+
transport="grpc_asyncio",
117+
client_options=ClientOptions(
118+
api_endpoint=alloydb_api_endpoint,
119+
quota_project_id=quota_project,
120+
),
121+
client_info=ClientInfo(
122+
user_agent=user_agent,
123+
),
124+
)
125+
99126
self._credentials = credentials
100127
# asyncpg does not currently support using metadata exchange
101128
# only use metadata exchange for pg8000 driver
@@ -131,7 +158,10 @@ async def _get_metadata(
131158
)
132159

133160
req = v1beta.GetConnectionInfoRequest(parent=parent)
134-
resp = await self._client.get_connection_info(request=req)
161+
if self._is_sync:
162+
resp = self._client.get_connection_info(request=req)
163+
else:
164+
resp = await self._client.get_connection_info(request=req)
135165

136166
# Remove trailing period from PSC DNS name.
137167
psc_dns = resp.psc_dns_name
@@ -178,7 +208,10 @@ async def _get_client_certificate(
178208
public_key=pub_key,
179209
use_metadata_exchange=self._use_metadata,
180210
)
181-
resp = await self._client.generate_client_certificate(request=req)
211+
if self._is_sync:
212+
resp = self._client.generate_client_certificate(request=req)
213+
else:
214+
resp = await self._client.generate_client_certificate(request=req)
182215
return (resp.ca_cert, resp.pem_certificate_chain)
183216

184217
async def get_connection_info(

tests/unit/mocks.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,3 +478,36 @@ async def generate_client_certificate(
478478
ccr.pem_certificate_chain.append("This is the intermediate cert")
479479
ccr.pem_certificate_chain.append("This is the root cert")
480480
return ccr
481+
482+
483+
class FakeAlloyDBAdminClient:
484+
def get_connection_info(
485+
self, request: alloydb_v1beta.GetConnectionInfoRequest
486+
) -> alloydb_v1beta.types.resources.ConnectionInfo:
487+
ci = alloydb_v1beta.types.resources.ConnectionInfo()
488+
ci.ip_address = "10.0.0.1"
489+
ci.public_ip_address = "127.0.0.1"
490+
ci.instance_uid = "123456789"
491+
ci.psc_dns_name = "x.y.alloydb.goog"
492+
493+
parent = request.parent
494+
instance = parent.split("/")[-1]
495+
if instance == "test-instance":
496+
ci.public_ip_address = ""
497+
ci.psc_dns_name = ""
498+
elif instance == "public-instance":
499+
ci.psc_dns_name = ""
500+
else:
501+
ci.ip_address = ""
502+
ci.public_ip_address = ""
503+
return ci
504+
505+
def generate_client_certificate(
506+
self, request: alloydb_v1beta.GenerateClientCertificateRequest
507+
) -> alloydb_v1beta.types.service.GenerateClientCertificateResponse:
508+
ccr = alloydb_v1beta.types.service.GenerateClientCertificateResponse()
509+
ccr.ca_cert = "This is the CA cert"
510+
ccr.pem_certificate_chain.append("This is the client cert")
511+
ccr.pem_certificate_chain.append("This is the intermediate cert")
512+
ccr.pem_certificate_chain.append("This is the root cert")
513+
return ccr

tests/unit/test_client.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
from typing import Optional
1616

17+
import google.cloud.alloydb_v1beta as v1beta
1718
from mocks import FakeAlloyDBAdminAsyncClient
19+
from mocks import FakeAlloyDBAdminClient
1820
from mocks import FakeCredentials
1921
import pytest
2022

@@ -80,6 +82,40 @@ async def test__get_metadata_with_psc(credentials: FakeCredentials) -> None:
8082
}
8183

8284

85+
async def test__get_metadata_with_async_client(credentials: FakeCredentials) -> None:
86+
"""
87+
Test _get_metadata returns successfully for an async client.
88+
"""
89+
test_client = AlloyDBClient("", "", credentials, FakeAlloyDBAdminAsyncClient())
90+
test_client._is_sync = False
91+
assert (
92+
await test_client._get_metadata(
93+
"test-project",
94+
"test-region",
95+
"test-cluster",
96+
"psc-instance",
97+
)
98+
is not None
99+
)
100+
101+
102+
async def test__get_metadata_with_sync_client(credentials: FakeCredentials) -> None:
103+
"""
104+
Test _get_metadata returns successfully for a sync client.
105+
"""
106+
test_client = AlloyDBClient("", "", credentials, FakeAlloyDBAdminClient())
107+
test_client._is_sync = True
108+
assert (
109+
await test_client._get_metadata(
110+
"test-project",
111+
"test-region",
112+
"test-cluster",
113+
"psc-instance",
114+
)
115+
is not None
116+
)
117+
118+
83119
@pytest.mark.asyncio
84120
async def test__get_client_certificate(credentials: FakeCredentials) -> None:
85121
"""
@@ -97,6 +133,40 @@ async def test__get_client_certificate(credentials: FakeCredentials) -> None:
97133
assert cert_chain[2] == "This is the root cert"
98134

99135

136+
async def test__get_client_certificate_with_async_client(
137+
credentials: FakeCredentials,
138+
) -> None:
139+
"""
140+
Test _get_client_certificate returns successfully for an async client.
141+
"""
142+
test_client = AlloyDBClient("", "", credentials, FakeAlloyDBAdminAsyncClient())
143+
test_client._is_sync = False
144+
keys = await generate_keys()
145+
assert (
146+
await test_client._get_client_certificate(
147+
"test-project", "test-region", "test-cluster", keys[1]
148+
)
149+
is not None
150+
)
151+
152+
153+
async def test__get_client_certificate_with_sync_client(
154+
credentials: FakeCredentials,
155+
) -> None:
156+
"""
157+
Test _get_client_certificate returns successfully for a sync client.
158+
"""
159+
test_client = AlloyDBClient("", "", credentials, FakeAlloyDBAdminClient())
160+
test_client._is_sync = True
161+
keys = await generate_keys()
162+
assert (
163+
await test_client._get_client_certificate(
164+
"test-project", "test-region", "test-cluster", keys[1]
165+
)
166+
is not None
167+
)
168+
169+
100170
@pytest.mark.asyncio
101171
async def test_AlloyDBClient_init_(credentials: FakeCredentials) -> None:
102172
"""
@@ -129,6 +199,48 @@ async def test_AlloyDBClient_init_custom_user_agent(
129199
)
130200

131201

202+
async def test_AlloyDBClient_init_specified_client(
203+
credentials: FakeCredentials,
204+
) -> None:
205+
"""
206+
Test to check that __init__ method of AlloyDBClient uses specified client.
207+
"""
208+
client = AlloyDBClient(
209+
"www.test-endpoint.com",
210+
"my-quota-project",
211+
credentials,
212+
FakeAlloyDBAdminAsyncClient(),
213+
)
214+
assert client._is_sync is False
215+
assert type(client._client) is FakeAlloyDBAdminAsyncClient
216+
217+
218+
async def test_AlloyDBClient_init_sync_client(credentials: FakeCredentials) -> None:
219+
"""
220+
Test to check that __init__ method of AlloyDBClient creates a sync client
221+
when client is not specified and driver is pg8000.
222+
"""
223+
client = AlloyDBClient(
224+
"www.test-endpoint.com", "my-quota-project", credentials, driver="pg8000"
225+
)
226+
assert client._is_sync is True
227+
assert type(client._client) is v1beta.AlloyDBAdminClient
228+
assert client._client.transport.kind == "grpc"
229+
230+
231+
async def test_AlloyDBClient_init_async_client(credentials: FakeCredentials) -> None:
232+
"""
233+
Test to check that __init__ method of AlloyDBClient creates an async client
234+
when client is not specified and driver is not pg8000.
235+
"""
236+
client = AlloyDBClient(
237+
"www.test-endpoint.com", "my-quota-project", credentials, driver=""
238+
)
239+
assert client._is_sync is False
240+
assert type(client._client) is v1beta.AlloyDBAdminAsyncClient
241+
assert client._client.transport.kind == "grpc_asyncio"
242+
243+
132244
@pytest.mark.parametrize(
133245
"driver",
134246
[None, "pg8000", "asyncpg"],

0 commit comments

Comments
 (0)