Skip to content

Commit 2adb5bd

Browse files
refactor: rename RefreshResult to ConnectionInfo (#334)
Renaming RefreshResult to ConnectionInfo to be better aligned with other Connector libraries. Aligning file names with that of Cloud SQL Python Connector by splitting refresh.py into refresh_utils.py and connection_info.py
1 parent 9f0b213 commit 2adb5bd

File tree

6 files changed

+132
-98
lines changed

6 files changed

+132
-98
lines changed

google/cloud/alloydb/connector/refresh.py renamed to google/cloud/alloydb/connector/connection_info.py

Lines changed: 1 addition & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414

1515
from __future__ import annotations
1616

17-
import asyncio
18-
from datetime import datetime
19-
from datetime import timezone
2017
import logging
2118
import ssl
2219
from tempfile import TemporaryDirectory
@@ -31,39 +28,8 @@
3128

3229
logger = logging.getLogger(name=__name__)
3330

34-
# _refresh_buffer is the amount of time before a refresh's result expires
35-
# that a new refresh operation begins.
36-
_refresh_buffer: int = 4 * 60 # 4 minutes
3731

38-
39-
def _seconds_until_refresh(
40-
expiration: datetime, now: datetime = datetime.now(timezone.utc)
41-
) -> int:
42-
"""
43-
Calculates the duration to wait before starting the next refresh.
44-
Usually the duration will be half of the time until certificate
45-
expiration.
46-
47-
Args:
48-
expiration (datetime.datetime): Time of certificate expiration.
49-
now (datetime.datetime): Current time (UTC)
50-
Returns:
51-
int: Time in seconds to wait before performing next refresh.
52-
"""
53-
54-
duration = int((expiration - now).total_seconds())
55-
56-
# if certificate duration is less than 1 hour
57-
if duration < 3600:
58-
# something is wrong with certificate, refresh now
59-
if duration < _refresh_buffer:
60-
return 0
61-
# otherwise wait until 4 minutes before expiration for next refresh
62-
return duration - _refresh_buffer
63-
return duration // 2
64-
65-
66-
class RefreshResult:
32+
class ConnectionInfo:
6733
"""
6834
Manages the result of a refresh operation.
6935
@@ -91,8 +57,6 @@ def __init__(
9157
self.context.check_hostname = False
9258
# force TLSv1.3
9359
self.context.minimum_version = ssl.TLSVersion.TLSv1_3
94-
# add request_ssl attribute to ssl.SSLContext, required for pg8000 driver
95-
self.context.request_ssl = False # type: ignore
9660
# unpack certs
9761
ca_cert, cert_chain = certs
9862
# get expiration from client certificate
@@ -108,15 +72,3 @@ def __init__(
10872
)
10973
self.context.load_cert_chain(cert_chain_filename, keyfile=key_filename)
11074
self.context.load_verify_locations(cafile=ca_filename)
111-
112-
113-
async def _is_valid(task: asyncio.Task) -> bool:
114-
try:
115-
result = await task
116-
# valid if current time is before cert expiration
117-
if datetime.now(timezone.utc) < result.expiration:
118-
return True
119-
except Exception:
120-
# suppress any errors from task
121-
logger.debug("Current refresh result is invalid.")
122-
return False

google/cloud/alloydb/connector/instance.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@
2323
from google.auth.credentials import TokenState
2424
from google.auth.transport import requests
2525

26+
from google.cloud.alloydb.connector.connection_info import ConnectionInfo
2627
from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError
2728
from google.cloud.alloydb.connector.exceptions import RefreshError
2829
from google.cloud.alloydb.connector.rate_limiter import AsyncRateLimiter
29-
from google.cloud.alloydb.connector.refresh import _is_valid
30-
from google.cloud.alloydb.connector.refresh import _seconds_until_refresh
31-
from google.cloud.alloydb.connector.refresh import RefreshResult
30+
from google.cloud.alloydb.connector.refresh_utils import _is_valid
31+
from google.cloud.alloydb.connector.refresh_utils import _seconds_until_refresh
3232

3333
if TYPE_CHECKING:
3434
import ssl
@@ -117,15 +117,15 @@ def __init__(
117117
self._current: asyncio.Task = self._schedule_refresh(0)
118118
self._next: asyncio.Task = self._current
119119

120-
async def _perform_refresh(self) -> RefreshResult:
120+
async def _perform_refresh(self) -> ConnectionInfo:
121121
"""
122122
Perform a refresh operation on an AlloyDB instance.
123123
124124
Retrieves metadata and generates new client certificate
125125
required to connect securely to the AlloyDB instance.
126126
127127
Returns:
128-
RefreshResult: Result of the refresh operation.
128+
ConnectionInfo: Result of the refresh operation.
129129
"""
130130
self._refresh_in_progress.set()
131131
logger.debug(f"['{self._instance_uri}']: Entered _perform_refresh")
@@ -168,7 +168,7 @@ async def _perform_refresh(self) -> RefreshResult:
168168
finally:
169169
self._refresh_in_progress.clear()
170170

171-
return RefreshResult(ip_addr, priv_key, certs)
171+
return ConnectionInfo(ip_addr, priv_key, certs)
172172

173173
def _schedule_refresh(self, delay: int) -> asyncio.Task:
174174
"""
@@ -178,12 +178,12 @@ def _schedule_refresh(self, delay: int) -> asyncio.Task:
178178
delay (int): Time in seconds to sleep before performing refresh.
179179
180180
Returns:
181-
asyncio.Task[RefreshResult]: A task representing the scheduled
181+
asyncio.Task[ConnectionInfo]: A task representing the scheduled
182182
refresh operation.
183183
"""
184184
return asyncio.create_task(self._refresh_operation(delay))
185185

186-
async def _refresh_operation(self, delay: int) -> RefreshResult:
186+
async def _refresh_operation(self, delay: int) -> ConnectionInfo:
187187
"""
188188
A coroutine that sleeps for the specified amount of time before
189189
running _perform_refresh.
@@ -192,7 +192,7 @@ async def _refresh_operation(self, delay: int) -> RefreshResult:
192192
delay (int): Time in seconds to sleep before performing refresh.
193193
194194
Returns:
195-
RefreshResult: Refresh result for an AlloyDB instance.
195+
ConnectionInfo: Refresh result for an AlloyDB instance.
196196
"""
197197
refresh_task: asyncio.Task
198198
try:
@@ -251,7 +251,7 @@ async def connection_info(self, ip_type: IPTypes) -> Tuple[str, ssl.SSLContext]:
251251
Tuple[str, ssl.SSLContext]: AlloyDB instance IP address
252252
and configured TLS connection.
253253
"""
254-
refresh: RefreshResult = await self._current
254+
refresh: ConnectionInfo = await self._current
255255
ip_address = refresh.ip_addrs.get(ip_type.value)
256256
if ip_address is None:
257257
raise IPTypeNotFoundError(
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import asyncio
18+
from datetime import datetime
19+
from datetime import timezone
20+
import logging
21+
22+
logger = logging.getLogger(name=__name__)
23+
24+
# _refresh_buffer is the amount of time before a refresh's result expires
25+
# that a new refresh operation begins.
26+
_refresh_buffer: int = 4 * 60 # 4 minutes
27+
28+
29+
def _seconds_until_refresh(
30+
expiration: datetime, now: datetime = datetime.now(timezone.utc)
31+
) -> int:
32+
"""
33+
Calculates the duration to wait before starting the next refresh.
34+
Usually the duration will be half of the time until certificate
35+
expiration.
36+
37+
Args:
38+
expiration (datetime.datetime): Time of certificate expiration.
39+
now (datetime.datetime): Current time (UTC)
40+
Returns:
41+
int: Time in seconds to wait before performing next refresh.
42+
"""
43+
44+
duration = int((expiration - now).total_seconds())
45+
46+
# if certificate duration is less than 1 hour
47+
if duration < 3600:
48+
# something is wrong with certificate, refresh now
49+
if duration < _refresh_buffer:
50+
return 0
51+
# otherwise wait until 4 minutes before expiration for next refresh
52+
return duration - _refresh_buffer
53+
return duration // 2
54+
55+
56+
async def _is_valid(task: asyncio.Task) -> bool:
57+
try:
58+
result = await task
59+
# valid if current time is before cert expiration
60+
if datetime.now(timezone.utc) < result.expiration:
61+
return True
62+
except Exception:
63+
# suppress any errors from task
64+
logger.debug("Current refresh result is invalid.")
65+
return False

tests/unit/test_refresh.py renamed to tests/unit/test_connection_info.py

Lines changed: 4 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -23,42 +23,12 @@
2323
from cryptography.hazmat.primitives.asymmetric import rsa
2424
from mocks import FakeInstance
2525

26-
from google.cloud.alloydb.connector.refresh import _seconds_until_refresh
27-
from google.cloud.alloydb.connector.refresh import RefreshResult
26+
from google.cloud.alloydb.connector.connection_info import ConnectionInfo
2827

2928

30-
def test_seconds_until_refresh_over_1_hour() -> None:
29+
def test_ConnectionInfo_init_(fake_instance: FakeInstance) -> None:
3130
"""
32-
Test _seconds_until_refresh returns proper time in seconds.
33-
If expiration is over 1 hour, should return duration/2.
34-
"""
35-
now = datetime.now()
36-
assert _seconds_until_refresh(now + timedelta(minutes=62), now) == 31 * 60
37-
38-
39-
def test_seconds_until_refresh_under_1_hour_over_4_mins() -> None:
40-
"""
41-
Test _seconds_until_refresh returns proper time in seconds.
42-
If expiration is under 1 hour and over 4 minutes,
43-
should return duration-refresh_buffer (refresh_buffer = 4 minutes).
44-
"""
45-
now = datetime.now(timezone.utc)
46-
assert _seconds_until_refresh(now + timedelta(minutes=5), now) == 60
47-
48-
49-
def test_seconds_until_refresh_under_4_mins() -> None:
50-
"""
51-
Test _seconds_until_refresh returns proper time in seconds.
52-
If expiration is under 4 minutes, should return 0.
53-
"""
54-
assert (
55-
_seconds_until_refresh(datetime.now(timezone.utc) + timedelta(minutes=3)) == 0
56-
)
57-
58-
59-
def test_RefreshResult_init_(fake_instance: FakeInstance) -> None:
60-
"""
61-
Test to check whether the __init__ method of RefreshResult
31+
Test to check whether the __init__ method of ConnectionInfo
6232
can correctly initialize TLS context.
6333
"""
6434
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
@@ -79,7 +49,6 @@ def test_RefreshResult_init_(fake_instance: FakeInstance) -> None:
7949
"UTF-8"
8050
)
8151
certs = (ca_cert, [client_cert, intermediate_cert, root_cert])
82-
refresh = RefreshResult(fake_instance.ip_addrs, key, certs)
52+
refresh = ConnectionInfo(fake_instance.ip_addrs, key, certs)
8353
# verify TLS requirements
8454
assert refresh.context.minimum_version == ssl.TLSVersion.TLSv1_3
85-
assert refresh.context.request_ssl is False

tests/unit/test_instance.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@
2121
from mocks import FakeAlloyDBClient
2222
import pytest
2323

24+
from google.cloud.alloydb.connector.connection_info import ConnectionInfo
2425
from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError
2526
from google.cloud.alloydb.connector.exceptions import RefreshError
2627
from google.cloud.alloydb.connector.instance import _parse_instance_uri
2728
from google.cloud.alloydb.connector.instance import IPTypes
2829
from google.cloud.alloydb.connector.instance import RefreshAheadCache
29-
from google.cloud.alloydb.connector.refresh import _is_valid
30-
from google.cloud.alloydb.connector.refresh import RefreshResult
30+
from google.cloud.alloydb.connector.refresh_utils import _is_valid
3131
from google.cloud.alloydb.connector.utils import generate_keys
3232

3333

@@ -125,7 +125,7 @@ async def test_RefreshAheadCache_close() -> None:
125125

126126
@pytest.mark.asyncio
127127
async def test_perform_refresh() -> None:
128-
"""Test that _perform refresh returns valid RefreshResult"""
128+
"""Test that _perform refresh returns valid ConnectionInfo"""
129129
keys = asyncio.create_task(generate_keys())
130130
client = FakeAlloyDBClient()
131131
cache = RefreshAheadCache(
@@ -299,6 +299,6 @@ async def test_force_refresh_cancels_pending_refresh() -> None:
299299
assert await pending_refresh
300300
# verify pending_refresh has now been cancelled
301301
assert pending_refresh.cancelled() is True
302-
assert isinstance(await cache._current, RefreshResult)
302+
assert isinstance(await cache._current, ConnectionInfo)
303303
# close instance
304304
await cache.close()

tests/unit/test_refresh_utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from datetime import datetime
16+
from datetime import timedelta
17+
from datetime import timezone
18+
19+
from google.cloud.alloydb.connector.refresh_utils import _seconds_until_refresh
20+
21+
22+
def test_seconds_until_refresh_over_1_hour() -> None:
23+
"""
24+
Test _seconds_until_refresh returns proper time in seconds.
25+
If expiration is over 1 hour, should return duration/2.
26+
"""
27+
now = datetime.now()
28+
assert _seconds_until_refresh(now + timedelta(minutes=62), now) == 31 * 60
29+
30+
31+
def test_seconds_until_refresh_under_1_hour_over_4_mins() -> None:
32+
"""
33+
Test _seconds_until_refresh returns proper time in seconds.
34+
If expiration is under 1 hour and over 4 minutes,
35+
should return duration-refresh_buffer (refresh_buffer = 4 minutes).
36+
"""
37+
now = datetime.now(timezone.utc)
38+
assert _seconds_until_refresh(now + timedelta(minutes=5), now) == 60
39+
40+
41+
def test_seconds_until_refresh_under_4_mins() -> None:
42+
"""
43+
Test _seconds_until_refresh returns proper time in seconds.
44+
If expiration is under 4 minutes, should return 0.
45+
"""
46+
assert (
47+
_seconds_until_refresh(datetime.now(timezone.utc) + timedelta(minutes=3)) == 0
48+
)

0 commit comments

Comments
 (0)