Skip to content

Commit 5b1c57d

Browse files
refactor: update AlloyDBClient and ConnectionInfo classes (#335)
Refactor of AlloyDBClient and ConnectionInfo classes. Added the get_connection_info method to the AlloyDBClient. This method does an immediate refresh and calls the AlloyDB APIs to get a fresh ConnectionInfo. This allows better encapsulation. Refactored ConnectionInfo to a dataclass and added two new methods; create_ssl_context and get_preferred_ip. Moved all SSL/TLS configuration within create_ssl_context. This way ConnectionInfo() prepares all the info required to connect and then create_ssl_context will use the info to establish the SSL/TLS connection when called from Connector.connect at the time of connection.
1 parent 2adb5bd commit 5b1c57d

File tree

9 files changed

+280
-157
lines changed

9 files changed

+280
-157
lines changed

google/cloud/alloydb/connector/async_connector.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(
6060
ip_type: str | IPTypes = IPTypes.PRIVATE,
6161
user_agent: Optional[str] = None,
6262
) -> None:
63-
self._instances: Dict[str, RefreshAheadCache] = {}
63+
self._cache: Dict[str, RefreshAheadCache] = {}
6464
# initialize default params
6565
self._quota_project = quota_project
6666
self._alloydb_api_endpoint = alloydb_api_endpoint
@@ -125,11 +125,11 @@ async def connect(
125125
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
126126

127127
# use existing connection info if possible
128-
if instance_uri in self._instances:
129-
instance = self._instances[instance_uri]
128+
if instance_uri in self._cache:
129+
cache = self._cache[instance_uri]
130130
else:
131-
instance = RefreshAheadCache(instance_uri, self._client, self._keys)
132-
self._instances[instance_uri] = instance
131+
cache = RefreshAheadCache(instance_uri, self._client, self._keys)
132+
self._cache[instance_uri] = cache
133133

134134
connect_func = {
135135
"asyncpg": asyncpg.connect,
@@ -151,7 +151,8 @@ async def connect(
151151
# if ip_type is str, convert to IPTypes enum
152152
if isinstance(ip_type, str):
153153
ip_type = IPTypes(ip_type.upper())
154-
ip_address, context = await instance.connection_info(ip_type)
154+
conn_info = await cache.connect_info()
155+
ip_address = conn_info.get_preferred_ip(ip_type)
155156

156157
# callable to be used for auto IAM authn
157158
def get_authentication_token() -> str:
@@ -166,10 +167,10 @@ def get_authentication_token() -> str:
166167
if enable_iam_auth:
167168
kwargs["password"] = get_authentication_token
168169
try:
169-
return await connector(ip_address, context, **kwargs)
170+
return await connector(ip_address, conn_info.create_ssl_context(), **kwargs)
170171
except Exception:
171172
# we attempt a force refresh, then throw the error
172-
await instance.force_refresh()
173+
await cache.force_refresh()
173174
raise
174175

175176
async def __aenter__(self) -> Any:
@@ -188,8 +189,6 @@ async def __aexit__(
188189
async def close(self) -> None:
189190
"""Helper function to cancel RefreshAheadCaches' tasks
190191
and close client."""
191-
await asyncio.gather(
192-
*[instance.close() for instance in self._instances.values()]
193-
)
192+
await asyncio.gather(*[cache.close() for cache in self._cache.values()])
194193
if self._client:
195194
await self._client.close()

google/cloud/alloydb/connector/client.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,16 @@
1414

1515
from __future__ import annotations
1616

17+
import asyncio
1718
import logging
1819
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
1920

2021
import aiohttp
22+
from cryptography import x509
23+
from google.auth.credentials import TokenState
24+
from google.auth.transport import requests
2125

26+
from google.cloud.alloydb.connector.connection_info import ConnectionInfo
2227
from google.cloud.alloydb.connector.version import __version__ as version
2328

2429
if TYPE_CHECKING:
@@ -181,6 +186,70 @@ async def _get_client_certificate(
181186

182187
return (resp_dict["caCert"], resp_dict["pemCertificateChain"])
183188

189+
async def get_connection_info(
190+
self,
191+
project: str,
192+
region: str,
193+
cluster: str,
194+
name: str,
195+
keys: asyncio.Future,
196+
) -> ConnectionInfo:
197+
"""Immediately performs a full refresh operation using the AlloyDB API.
198+
199+
Args:
200+
project (str): The name of the project the AlloyDB instance is
201+
located in.
202+
region (str): The region the AlloyDB instance is located in.
203+
cluster (str): The cluster the AlloyDB instance is located in.
204+
name (str): Name of the AlloyDB instance.
205+
keys (asyncio.Future): A future to the client's public-private key
206+
pair.
207+
208+
Returns:
209+
ConnectionInfo: All the information required to connect securely to
210+
the AlloyDB instance.
211+
"""
212+
priv_key, pub_key = await keys
213+
214+
# before making AlloyDB API calls, refresh creds if required
215+
if not self._credentials.token_state == TokenState.FRESH:
216+
self._credentials.refresh(requests.Request())
217+
218+
# fetch metadata
219+
metadata_task = asyncio.create_task(
220+
self._get_metadata(
221+
project,
222+
region,
223+
cluster,
224+
name,
225+
)
226+
)
227+
# generate client and CA certs
228+
certs_task = asyncio.create_task(
229+
self._get_client_certificate(
230+
project,
231+
region,
232+
cluster,
233+
pub_key,
234+
)
235+
)
236+
237+
ip_addrs, certs = await asyncio.gather(metadata_task, certs_task)
238+
239+
# unpack certs
240+
ca_cert, cert_chain = certs
241+
# get expiration from client certificate
242+
cert_obj = x509.load_pem_x509_certificate(cert_chain[0].encode("UTF-8"))
243+
expiration = cert_obj.not_valid_after_utc
244+
245+
return ConnectionInfo(
246+
cert_chain,
247+
ca_cert,
248+
priv_key,
249+
ip_addrs,
250+
expiration,
251+
)
252+
184253
async def close(self) -> None:
185254
"""Close AlloyDBClient gracefully."""
186255
await self._client.close()

google/cloud/alloydb/connector/connection_info.py

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,61 +14,76 @@
1414

1515
from __future__ import annotations
1616

17+
from dataclasses import dataclass
1718
import logging
1819
import ssl
1920
from tempfile import TemporaryDirectory
20-
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
21-
22-
from cryptography import x509
21+
from typing import Dict, List, Optional, TYPE_CHECKING
2322

23+
from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError
2424
from google.cloud.alloydb.connector.utils import _write_to_file
2525

2626
if TYPE_CHECKING:
27+
import datetime
28+
2729
from cryptography.hazmat.primitives.asymmetric import rsa
2830

31+
from google.cloud.alloydb.connector.instance import IPTypes
32+
2933
logger = logging.getLogger(name=__name__)
3034

3135

36+
@dataclass
3237
class ConnectionInfo:
33-
"""
34-
Manages the result of a refresh operation.
38+
"""Contains all necessary information to connect securely to the
39+
server-side Proxy running on an AlloyDB instance."""
3540

36-
Holds the certificates and IP address of an AlloyDB instance.
37-
Builds the TLS context required to connect to AlloyDB database.
41+
cert_chain: List[str]
42+
ca_cert: str
43+
key: rsa.RSAPrivateKey
44+
ip_addrs: Dict[str, Optional[str]]
45+
expiration: datetime.datetime
46+
context: Optional[ssl.SSLContext] = None
3847

39-
Args:
40-
ip_addrs (Dict[str, str]): The IP addresses of the AlloyDB instance.
41-
key (rsa.RSAPrivateKey): Private key for the client connection.
42-
certs (Tuple[str, List(str)]): Client cert and CA certs for establishing
43-
the chain of trust used in building the TLS context.
44-
"""
48+
def create_ssl_context(self) -> ssl.SSLContext:
49+
"""Constructs a SSL/TLS context for the given connection info.
50+
51+
Cache the SSL context to ensure we don't read from disk repeatedly when
52+
configuring a secure connection.
53+
"""
54+
# if SSL context is cached, use it
55+
if self.context is not None:
56+
return self.context
4557

46-
def __init__(
47-
self,
48-
ip_addrs: Dict[str, Optional[str]],
49-
key: rsa.RSAPrivateKey,
50-
certs: Tuple[str, List[str]],
51-
) -> None:
52-
self.ip_addrs = ip_addrs
5358
# create TLS context
54-
self.context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
59+
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
5560
# TODO: Set check_hostname to True to verify the identity in the
5661
# certificate once PSC DNS is populated in all existing clusters.
57-
self.context.check_hostname = False
62+
context.check_hostname = False
5863
# force TLSv1.3
59-
self.context.minimum_version = ssl.TLSVersion.TLSv1_3
60-
# unpack certs
61-
ca_cert, cert_chain = certs
62-
# get expiration from client certificate
63-
cert_obj = x509.load_pem_x509_certificate(cert_chain[0].encode("UTF-8"))
64-
self.expiration = cert_obj.not_valid_after_utc
64+
context.minimum_version = ssl.TLSVersion.TLSv1_3
6565

6666
# tmpdir and its contents are automatically deleted after the CA cert
6767
# and cert chain are loaded into the SSLcontext. The values
6868
# need to be written to files in order to be loaded by the SSLContext
6969
with TemporaryDirectory() as tmpdir:
7070
ca_filename, cert_chain_filename, key_filename = _write_to_file(
71-
tmpdir, ca_cert, cert_chain, key
71+
tmpdir, self.ca_cert, self.cert_chain, self.key
72+
)
73+
context.load_cert_chain(cert_chain_filename, keyfile=key_filename)
74+
context.load_verify_locations(cafile=ca_filename)
75+
# set class attribute to cache context for subsequent calls
76+
self.context = context
77+
return context
78+
79+
def get_preferred_ip(self, ip_type: IPTypes) -> str:
80+
"""Returns the first IP address for the instance, according to the preference
81+
supplied by ip_type. If no IP addressess with the given preference are found,
82+
an error is raised."""
83+
ip_address = self.ip_addrs.get(ip_type.value)
84+
if ip_address is None:
85+
raise IPTypeNotFoundError(
86+
"AlloyDB instance does not have an IP addresses matching "
87+
f"type: '{ip_type.value}'"
7288
)
73-
self.context.load_cert_chain(cert_chain_filename, keyfile=key_filename)
74-
self.context.load_verify_locations(cafile=ca_filename)
89+
return ip_address

google/cloud/alloydb/connector/connector.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,12 +178,17 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) ->
178178
# if ip_type is str, convert to IPTypes enum
179179
if isinstance(ip_type, str):
180180
ip_type = IPTypes(ip_type.upper())
181-
ip_address, context = await cache.connection_info(ip_type)
181+
conn_info = await cache.connect_info()
182+
ip_address = conn_info.get_preferred_ip(ip_type)
182183

183184
# synchronous drivers are blocking and run using executor
184185
try:
185186
metadata_partial = partial(
186-
self.metadata_exchange, ip_address, context, enable_iam_auth, driver
187+
self.metadata_exchange,
188+
ip_address,
189+
conn_info.create_ssl_context(),
190+
enable_iam_auth,
191+
driver,
187192
)
188193
sock = await self._loop.run_in_executor(None, metadata_partial)
189194
connect_partial = partial(connector, sock, **kwargs)

google/cloud/alloydb/connector/instance.py

Lines changed: 11 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,13 @@
2020
import re
2121
from typing import Tuple, TYPE_CHECKING
2222

23-
from google.auth.credentials import TokenState
24-
from google.auth.transport import requests
25-
2623
from google.cloud.alloydb.connector.connection_info import ConnectionInfo
27-
from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError
2824
from google.cloud.alloydb.connector.exceptions import RefreshError
2925
from google.cloud.alloydb.connector.rate_limiter import AsyncRateLimiter
3026
from google.cloud.alloydb.connector.refresh_utils import _is_valid
3127
from google.cloud.alloydb.connector.refresh_utils import _seconds_until_refresh
3228

3329
if TYPE_CHECKING:
34-
import ssl
35-
3630
from cryptography.hazmat.primitives.asymmetric import rsa
3731

3832
from google.cloud.alloydb.connector.client import AlloyDBClient
@@ -132,32 +126,13 @@ async def _perform_refresh(self) -> ConnectionInfo:
132126

133127
try:
134128
await self._refresh_rate_limiter.acquire()
135-
priv_key, pub_key = await self._keys
136-
137-
# before making AlloyDB API calls, refresh creds if required
138-
if not self._client._credentials.token_state == TokenState.FRESH:
139-
self._client._credentials.refresh(requests.Request())
140-
141-
# fetch metadata
142-
metadata_task = asyncio.create_task(
143-
self._client._get_metadata(
144-
self._project,
145-
self._region,
146-
self._cluster,
147-
self._name,
148-
)
129+
connection_info = await self._client.get_connection_info(
130+
self._project,
131+
self._region,
132+
self._cluster,
133+
self._name,
134+
self._keys,
149135
)
150-
# generate client and CA certs
151-
certs_task = asyncio.create_task(
152-
self._client._get_client_certificate(
153-
self._project,
154-
self._region,
155-
self._cluster,
156-
pub_key,
157-
)
158-
)
159-
160-
ip_addr, certs = await asyncio.gather(metadata_task, certs_task)
161136

162137
except Exception:
163138
logger.debug(
@@ -167,8 +142,7 @@ async def _perform_refresh(self) -> ConnectionInfo:
167142

168143
finally:
169144
self._refresh_in_progress.clear()
170-
171-
return ConnectionInfo(ip_addr, priv_key, certs)
145+
return connection_info
172146

173147
def _schedule_refresh(self, delay: int) -> asyncio.Task:
174148
"""
@@ -241,24 +215,11 @@ async def force_refresh(self) -> None:
241215
if not await _is_valid(self._current):
242216
self._current = self._next
243217

244-
async def connection_info(self, ip_type: IPTypes) -> Tuple[str, ssl.SSLContext]:
245-
"""
246-
Return connection info for current refresh result.
247-
248-
Args:
249-
ip_type (IpTypes): Type of AlloyDB instance IP to connect over.
250-
Returns:
251-
Tuple[str, ssl.SSLContext]: AlloyDB instance IP address
252-
and configured TLS connection.
218+
async def connect_info(self) -> ConnectionInfo:
219+
"""Retrieves ConnectionInfo instance for establishing a secure
220+
connection to the AlloyDB instance.
253221
"""
254-
refresh: ConnectionInfo = await self._current
255-
ip_address = refresh.ip_addrs.get(ip_type.value)
256-
if ip_address is None:
257-
raise IPTypeNotFoundError(
258-
"AlloyDB instance does not have an IP addresses matching "
259-
f"type: '{ip_type.value}'"
260-
)
261-
return ip_address, refresh.context
222+
return await self._current
262223

263224
async def close(self) -> None:
264225
"""

0 commit comments

Comments
 (0)