Skip to content

Commit 10a4d7f

Browse files
feat: Replace aiohttp.ClientSession with AlloyDBAdminAsyncClient (#416)
1 parent 168c018 commit 10a4d7f

File tree

10 files changed

+214
-321
lines changed

10 files changed

+214
-321
lines changed

google/cloud/alloydb/connector/async_connector.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from google.cloud.alloydb.connector.lazy import LazyRefreshCache
3232
from google.cloud.alloydb.connector.types import CacheTypes
3333
from google.cloud.alloydb.connector.utils import generate_keys
34+
from google.cloud.alloydb.connector.utils import strip_http_prefix
3435

3536
if TYPE_CHECKING:
3637
from google.auth.credentials import Credentials
@@ -51,7 +52,7 @@ class AsyncConnector:
5152
billing purposes.
5253
Defaults to None, picking up project from environment.
5354
alloydb_api_endpoint (str): Base URL to use when calling
54-
the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com".
55+
the AlloyDB API endpoint. Defaults to "alloydb.googleapis.com".
5556
enable_iam_auth (bool): Enables automatic IAM database authentication.
5657
ip_type (str | IPTypes): Default IP type for all AlloyDB connections.
5758
Defaults to IPTypes.PRIVATE ("PRIVATE") for private IP connections.
@@ -66,7 +67,7 @@ def __init__(
6667
self,
6768
credentials: Optional[Credentials] = None,
6869
quota_project: Optional[str] = None,
69-
alloydb_api_endpoint: str = "https://alloydb.googleapis.com",
70+
alloydb_api_endpoint: str = "alloydb.googleapis.com",
7071
enable_iam_auth: bool = False,
7172
ip_type: str | IPTypes = IPTypes.PRIVATE,
7273
user_agent: Optional[str] = None,
@@ -75,7 +76,7 @@ def __init__(
7576
self._cache: dict[str, CacheTypes] = {}
7677
# initialize default params
7778
self._quota_project = quota_project
78-
self._alloydb_api_endpoint = alloydb_api_endpoint
79+
self._alloydb_api_endpoint = strip_http_prefix(alloydb_api_endpoint)
7980
self._enable_iam_auth = enable_iam_auth
8081
# if ip_type is str, convert to IPTypes enum
8182
if isinstance(ip_type, str):
@@ -235,5 +236,3 @@ async def close(self) -> None:
235236
"""Helper function to cancel RefreshAheadCaches' tasks
236237
and close client."""
237238
await asyncio.gather(*[cache.close() for cache in self._cache.values()])
238-
if self._client:
239-
await self._client.close()

google/cloud/alloydb/connector/client.py

Lines changed: 40 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@
1818
import logging
1919
from typing import Optional, TYPE_CHECKING
2020

21-
import aiohttp
2221
from cryptography import x509
22+
from google.api_core.client_options import ClientOptions
23+
from google.api_core.gapic_v1.client_info import ClientInfo
2324
from google.auth.credentials import TokenState
2425
from google.auth.transport import requests
26+
import google.cloud.alloydb_v1beta as v1beta
27+
from google.protobuf import duration_pb2
2528

2629
from google.cloud.alloydb.connector.connection_info import ConnectionInfo
2730
from google.cloud.alloydb.connector.version import __version__ as version
@@ -55,7 +58,7 @@ def __init__(
5558
alloydb_api_endpoint: str,
5659
quota_project: Optional[str],
5760
credentials: Credentials,
58-
client: Optional[aiohttp.ClientSession] = None,
61+
client: Optional[v1beta.AlloyDBAdminAsyncClient] = None,
5962
driver: Optional[str] = None,
6063
user_agent: Optional[str] = None,
6164
) -> None:
@@ -72,23 +75,28 @@ def __init__(
7275
A credentials object created from the google-auth Python library.
7376
Must have the AlloyDB Admin scopes. For more info check out
7477
https://google-auth.readthedocs.io/en/latest/.
75-
client (aiohttp.ClientSession): Async client used to make requests to
76-
AlloyDB APIs.
78+
client (v1beta.AlloyDBAdminAsyncClient): Async client used to make
79+
requests to AlloyDB APIs.
7780
Optional, defaults to None and creates new client.
7881
driver (str): Database driver to be used by the client.
7982
"""
8083
user_agent = _format_user_agent(driver, user_agent)
81-
headers = {
82-
"x-goog-api-client": user_agent,
83-
"User-Agent": user_agent,
84-
"Content-Type": "application/json",
85-
}
86-
if quota_project:
87-
headers["x-goog-user-project"] = quota_project
8884

89-
self._client = client if client else aiohttp.ClientSession(headers=headers)
85+
self._client = (
86+
client
87+
if client
88+
else v1beta.AlloyDBAdminAsyncClient(
89+
credentials=credentials,
90+
client_options=ClientOptions(
91+
api_endpoint=alloydb_api_endpoint,
92+
quota_project_id=quota_project,
93+
),
94+
client_info=ClientInfo(
95+
user_agent=user_agent,
96+
),
97+
)
98+
)
9099
self._credentials = credentials
91-
self._alloydb_api_endpoint = alloydb_api_endpoint
92100
# asyncpg does not currently support using metadata exchange
93101
# only use metadata exchange for pg8000 driver
94102
self._use_metadata = True if driver == "pg8000" else False
@@ -118,35 +126,21 @@ async def _get_metadata(
118126
Returns:
119127
dict: IP addresses of the AlloyDB instance.
120128
"""
121-
headers = {
122-
"Authorization": f"Bearer {self._credentials.token}",
123-
}
129+
parent = (
130+
f"projects/{project}/locations/{region}/clusters/{cluster}/instances/{name}"
131+
)
124132

125-
url = f"{self._alloydb_api_endpoint}/{API_VERSION}/projects/{project}/locations/{region}/clusters/{cluster}/instances/{name}/connectionInfo"
126-
127-
resp = await self._client.get(url, headers=headers)
128-
# try to get response json for better error message
129-
try:
130-
resp_dict = await resp.json()
131-
if resp.status >= 400:
132-
# if detailed error message is in json response, use as error message
133-
message = resp_dict.get("error", {}).get("message")
134-
if message:
135-
resp.reason = message
136-
# skip, raise_for_status will catch all errors in finally block
137-
except Exception:
138-
pass
139-
finally:
140-
resp.raise_for_status()
133+
req = v1beta.GetConnectionInfoRequest(parent=parent)
134+
resp = await self._client.get_connection_info(request=req)
141135

142136
# Remove trailing period from PSC DNS name.
143-
psc_dns = resp_dict.get("pscDnsName")
137+
psc_dns = resp.psc_dns_name
144138
if psc_dns:
145139
psc_dns = psc_dns.rstrip(".")
146140

147141
return {
148-
"PRIVATE": resp_dict.get("ipAddress"),
149-
"PUBLIC": resp_dict.get("publicIpAddress"),
142+
"PRIVATE": resp.ip_address,
143+
"PUBLIC": resp.public_ip_address,
150144
"PSC": psc_dns,
151145
}
152146

@@ -175,34 +169,17 @@ async def _get_client_certificate(
175169
tuple[str, list[str]]: tuple containing the CA certificate
176170
and certificate chain for the AlloyDB instance.
177171
"""
178-
headers = {
179-
"Authorization": f"Bearer {self._credentials.token}",
180-
}
181-
182-
url = f"{self._alloydb_api_endpoint}/{API_VERSION}/projects/{project}/locations/{region}/clusters/{cluster}:generateClientCertificate"
183-
184-
data = {
185-
"publicKey": pub_key,
186-
"certDuration": "3600s",
187-
"useMetadataExchange": self._use_metadata,
188-
}
189-
190-
resp = await self._client.post(url, headers=headers, json=data)
191-
# try to get response json for better error message
192-
try:
193-
resp_dict = await resp.json()
194-
if resp.status >= 400:
195-
# if detailed error message is in json response, use as error message
196-
message = resp_dict.get("error", {}).get("message")
197-
if message:
198-
resp.reason = message
199-
# skip, raise_for_status will catch all errors in finally block
200-
except Exception:
201-
pass
202-
finally:
203-
resp.raise_for_status()
204-
205-
return (resp_dict["caCert"], resp_dict["pemCertificateChain"])
172+
parent = f"projects/{project}/locations/{region}/clusters/{cluster}"
173+
dur = duration_pb2.Duration()
174+
dur.seconds = 3600
175+
req = v1beta.GenerateClientCertificateRequest(
176+
parent=parent,
177+
cert_duration=dur,
178+
public_key=pub_key,
179+
use_metadata_exchange=self._use_metadata,
180+
)
181+
resp = await self._client.generate_client_certificate(request=req)
182+
return (resp.ca_cert, resp.pem_certificate_chain)
206183

207184
async def get_connection_info(
208185
self,
@@ -267,9 +244,3 @@ async def get_connection_info(
267244
ip_addrs,
268245
expiration,
269246
)
270-
271-
async def close(self) -> None:
272-
"""Close AlloyDBClient gracefully."""
273-
logger.debug("Waiting for connector's http client to close")
274-
await self._client.close()
275-
logger.debug("Closed connector's http client")

google/cloud/alloydb/connector/connector.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import google.cloud.alloydb.connector.pg8000 as pg8000
3838
from google.cloud.alloydb.connector.types import CacheTypes
3939
from google.cloud.alloydb.connector.utils import generate_keys
40+
from google.cloud.alloydb.connector.utils import strip_http_prefix
4041
import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb
4142

4243
if TYPE_CHECKING:
@@ -64,7 +65,7 @@ class Connector:
6465
billing purposes.
6566
Defaults to None, picking up project from environment.
6667
alloydb_api_endpoint (str): Base URL to use when calling
67-
the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com".
68+
the AlloyDB API endpoint. Defaults to "alloydb.googleapis.com".
6869
enable_iam_auth (bool): Enables automatic IAM database authentication.
6970
ip_type (str | IPTypes): Default IP type for all AlloyDB connections.
7071
Defaults to IPTypes.PRIVATE ("PRIVATE") for private IP connections.
@@ -85,7 +86,7 @@ def __init__(
8586
self,
8687
credentials: Optional[Credentials] = None,
8788
quota_project: Optional[str] = None,
88-
alloydb_api_endpoint: str = "https://alloydb.googleapis.com",
89+
alloydb_api_endpoint: str = "alloydb.googleapis.com",
8990
enable_iam_auth: bool = False,
9091
ip_type: str | IPTypes = IPTypes.PRIVATE,
9192
user_agent: Optional[str] = None,
@@ -99,7 +100,7 @@ def __init__(
99100
self._cache: dict[str, CacheTypes] = {}
100101
# initialize default params
101102
self._quota_project = quota_project
102-
self._alloydb_api_endpoint = alloydb_api_endpoint
103+
self._alloydb_api_endpoint = strip_http_prefix(alloydb_api_endpoint)
103104
self._enable_iam_auth = enable_iam_auth
104105
# if ip_type is str, convert to IPTypes enum
105106
if isinstance(ip_type, str):
@@ -389,5 +390,3 @@ async def close_async(self) -> None:
389390
"""Helper function to cancel RefreshAheadCaches' tasks
390391
and close client."""
391392
await asyncio.gather(*[cache.close() for cache in self._cache.values()])
392-
if self._client:
393-
await self._client.close()

google/cloud/alloydb/connector/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from __future__ import annotations
1616

17+
import re
18+
1719
import aiofiles
1820
from cryptography.hazmat.primitives import serialization
1921
from cryptography.hazmat.primitives.asymmetric import rsa
@@ -58,3 +60,13 @@ async def generate_keys() -> tuple[rsa.RSAPrivateKey, str]:
5860
.decode("UTF-8")
5961
)
6062
return (priv_key, pub_key)
63+
64+
65+
def strip_http_prefix(url: str) -> str:
66+
"""
67+
Returns a new URL with 'http://' or 'https://' prefix removed.
68+
"""
69+
m = re.search(r"^(https?://)?(.+)", url)
70+
if m is None:
71+
return ""
72+
return m.group(2)

mypy.ini

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,9 @@ ignore_missing_imports = True
1111

1212
[mypy-asyncpg]
1313
ignore_missing_imports = True
14+
15+
[mypy-google.cloud.alloydb_v1beta]
16+
ignore_missing_imports = True
17+
18+
[mypy-google.api_core.*]
19+
ignore_missing_imports = True

tests/unit/mocks.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from google.auth.credentials import TokenState
3333
from google.auth.transport import requests
3434

35+
from google.cloud import alloydb_v1beta
3536
from google.cloud.alloydb.connector.connection_info import ConnectionInfo
3637
import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb
3738

@@ -232,7 +233,6 @@ def __init__(
232233
self, instance: Optional[FakeInstance] = None, driver: str = "pg8000"
233234
) -> None:
234235
self.instance = FakeInstance() if instance is None else instance
235-
self.closed = False
236236
self._user_agent = f"test-user-agent+{driver}"
237237
self._credentials = FakeCredentials()
238238

@@ -317,9 +317,6 @@ async def get_connection_info(
317317
expiration,
318318
)
319319

320-
async def close(self) -> None:
321-
self.closed = True
322-
323320

324321
def metadata_exchange(sock: ssl.SSLSocket) -> None:
325322
"""
@@ -448,3 +445,36 @@ def write_static_info(i: FakeInstance) -> io.StringIO:
448445
"pscInstanceConfig": {"pscDnsName": i.ip_addrs["PSC"]},
449446
}
450447
return io.StringIO(json.dumps(static))
448+
449+
450+
class FakeAlloyDBAdminAsyncClient:
451+
async def get_connection_info(
452+
self, request: alloydb_v1beta.GetConnectionInfoRequest
453+
) -> alloydb_v1beta.types.resources.ConnectionInfo:
454+
ci = alloydb_v1beta.types.resources.ConnectionInfo()
455+
ci.ip_address = "10.0.0.1"
456+
ci.public_ip_address = "127.0.0.1"
457+
ci.instance_uid = "123456789"
458+
ci.psc_dns_name = "x.y.alloydb.goog"
459+
460+
parent = request.parent
461+
instance = parent.split("/")[-1]
462+
if instance == "test-instance":
463+
ci.public_ip_address = ""
464+
ci.psc_dns_name = ""
465+
elif instance == "public-instance":
466+
ci.psc_dns_name = ""
467+
else:
468+
ci.ip_address = ""
469+
ci.public_ip_address = ""
470+
return ci
471+
472+
async def generate_client_certificate(
473+
self, request: alloydb_v1beta.GenerateClientCertificateRequest
474+
) -> alloydb_v1beta.types.service.GenerateClientCertificateResponse:
475+
ccr = alloydb_v1beta.types.service.GenerateClientCertificateResponse()
476+
ccr.ca_cert = "This is the CA cert"
477+
ccr.pem_certificate_chain.append("This is the client cert")
478+
ccr.pem_certificate_chain.append("This is the intermediate cert")
479+
ccr.pem_certificate_chain.append("This is the root cert")
480+
return ccr

0 commit comments

Comments
 (0)