Skip to content

Commit 44d084b

Browse files
Fix unit tests
1 parent 51638a2 commit 44d084b

File tree

7 files changed

+67
-147
lines changed

7 files changed

+67
-147
lines changed

google/cloud/alloydb/connector/async_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666
self,
6767
credentials: Optional[Credentials] = None,
6868
quota_project: Optional[str] = None,
69-
alloydb_api_endpoint: str = "https://alloydb.googleapis.com",
69+
alloydb_api_endpoint: str = "alloydb.googleapis.com",
7070
enable_iam_auth: bool = False,
7171
ip_type: str | IPTypes = IPTypes.PRIVATE,
7272
user_agent: Optional[str] = None,

google/cloud/alloydb/connector/client.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __init__(
7575
A credentials object created from the google-auth Python library.
7676
Must have the AlloyDB Admin scopes. For more info check out
7777
https://google-auth.readthedocs.io/en/latest/.
78-
client (alloydb_v1.AlloyDBAdminAsyncClient): Async client used to
78+
client (alloydb_v1beta.AlloyDBAdminAsyncClient): Async client used to
7979
make requests to AlloyDB APIs.
8080
Optional, defaults to None and creates new client.
8181
driver (str): Database driver to be used by the client.
@@ -85,19 +85,17 @@ def __init__(
8585
self._client = client if client else alloydb_v1beta.AlloyDBAdminAsyncClient(
8686
credentials=credentials,
8787
client_options=ClientOptions(
88-
api_endpoint="alloydb.googleapis.com",
88+
api_endpoint=alloydb_api_endpoint,
8989
quota_project_id=quota_project,
9090
),
9191
client_info=ClientInfo(
9292
user_agent=user_agent,
9393
),
9494
)
9595
self._credentials = credentials
96-
self._alloydb_api_endpoint = alloydb_api_endpoint
9796
# asyncpg does not currently support using metadata exchange
9897
# only use metadata exchange for pg8000 driver
9998
self._use_metadata = True if driver == "pg8000" else False
100-
self._user_agent = user_agent
10199

102100
async def _get_metadata(
103101
self,
@@ -127,19 +125,6 @@ async def _get_metadata(
127125

128126
req = alloydb_v1beta.GetConnectionInfoRequest(parent=parent)
129127
resp = await self._client.get_connection_info(request=req)
130-
# # try to get response json for better error message
131-
# try:
132-
# resp_dict = await resp.json()
133-
# if resp.status >= 400:
134-
# # if detailed error message is in json response, use as error message
135-
# message = resp_dict.get("error", {}).get("message")
136-
# if message:
137-
# resp.reason = message
138-
# # skip, raise_for_status will catch all errors in finally block
139-
# except Exception:
140-
# pass
141-
# finally:
142-
# resp.raise_for_status()
143128

144129
# Remove trailing period from PSC DNS name.
145130
psc_dns = resp.psc_dns_name
@@ -187,20 +172,6 @@ async def _get_client_certificate(
187172
use_metadata_exchange=self._use_metadata,
188173
)
189174
resp = await self._client.generate_client_certificate(request=req)
190-
# # try to get response json for better error message
191-
# try:
192-
# resp_dict = await resp.json()
193-
# if resp.status >= 400:
194-
# # if detailed error message is in json response, use as error message
195-
# message = resp_dict.get("error", {}).get("message")
196-
# if message:
197-
# resp.reason = message
198-
# # skip, raise_for_status will catch all errors in finally block
199-
# except Exception:
200-
# pass
201-
# finally:
202-
# resp.raise_for_status()
203-
204175
return (resp.ca_cert, resp.pem_certificate_chain)
205176

206177
async def get_connection_info(
@@ -269,5 +240,4 @@ async def get_connection_info(
269240

270241
async def close(self) -> None:
271242
"""Close AlloyDBClient gracefully."""
272-
logger.debug("Waiting for connector's http client to close")
273-
logger.debug("Closed connector's http client")
243+
logger.debug("Closed AlloyDBClient")

google/cloud/alloydb/connector/connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class Connector:
6262
billing purposes.
6363
Defaults to None, picking up project from environment.
6464
alloydb_api_endpoint (str): Base URL to use when calling
65-
the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com".
65+
the AlloyDB API endpoint. Defaults to "alloydb.googleapis.com".
6666
enable_iam_auth (bool): Enables automatic IAM database authentication.
6767
ip_type (str | IPTypes): Default IP type for all AlloyDB connections.
6868
Defaults to IPTypes.PRIVATE ("PRIVATE") for private IP connections.
@@ -77,7 +77,7 @@ def __init__(
7777
self,
7878
credentials: Optional[Credentials] = None,
7979
quota_project: Optional[str] = None,
80-
alloydb_api_endpoint: str = "https://alloydb.googleapis.com",
80+
alloydb_api_endpoint: str = "alloydb.googleapis.com",
8181
enable_iam_auth: bool = False,
8282
ip_type: str | IPTypes = IPTypes.PRIVATE,
8383
user_agent: Optional[str] = None,

tests/unit/mocks.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from google.auth.credentials import TokenState
3131
from google.auth.transport import requests
3232

33+
from google.cloud import alloydb_v1beta
3334
from google.cloud.alloydb.connector.connection_info import ConnectionInfo
3435
import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb
3536

@@ -378,3 +379,34 @@ async def force_refresh(self) -> None:
378379

379380
async def close(self) -> None:
380381
self._close_called = True
382+
383+
384+
class FakeAlloyDBAdminAsyncClient:
385+
async def get_connection_info(self, request: alloydb_v1beta.GetConnectionInfoRequest) -> alloydb_v1beta.types.resources.ConnectionInfo:
386+
ci = alloydb_v1beta.types.resources.ConnectionInfo()
387+
ci.ip_address = "10.0.0.1"
388+
ci.public_ip_address = "127.0.0.1"
389+
ci.instance_uid = "123456789"
390+
ci.psc_dns_name = "x.y.alloydb.goog"
391+
392+
parent = request.parent
393+
instance = parent.split("/")[-1]
394+
if instance == "test-instance":
395+
ci.public_ip_address = ""
396+
ci.psc_dns_name = ""
397+
return ci
398+
elif instance == "public-instance":
399+
ci.psc_dns_name = ""
400+
return ci
401+
else:
402+
ci.ip_address = ""
403+
ci.public_ip_address = ""
404+
return ci
405+
406+
async def generate_client_certificate(self, request: alloydb_v1beta.GenerateClientCertificateRequest) -> alloydb_v1beta.types.service.GenerateClientCertificateResponse:
407+
ccr = alloydb_v1beta.types.service.GenerateClientCertificateResponse()
408+
ccr.ca_cert = "This is the CA cert"
409+
ccr.pem_certificate_chain.append("This is the client cert")
410+
ccr.pem_certificate_chain.append("This is the intermediate cert")
411+
ccr.pem_certificate_chain.append("This is the root cert")
412+
return ccr

tests/unit/test_async_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError
2828
from google.cloud.alloydb.connector.instance import RefreshAheadCache
2929

30-
ALLOYDB_API_ENDPOINT = "https://alloydb.googleapis.com"
30+
ALLOYDB_API_ENDPOINT = "alloydb.googleapis.com"
3131

3232

3333
@pytest.mark.asyncio

tests/unit/test_client.py

Lines changed: 24 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -12,73 +12,26 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import json
1615
from typing import Any, Optional
1716

1817
from aiohttp import ClientResponseError
19-
from aiohttp import web
2018
from aioresponses import aioresponses
21-
from mocks import FakeCredentials
19+
from mocks import FakeAlloyDBAdminAsyncClient, FakeCredentials
2220
import pytest
2321

22+
from google.api_core.exceptions import RetryError
2423
from google.cloud import alloydb_v1beta
2524
from google.cloud.alloydb.connector.client import AlloyDBClient
2625
from google.cloud.alloydb.connector.utils import generate_keys
2726
from google.cloud.alloydb.connector.version import __version__ as version
2827

2928

30-
async def connectionInfo(request: Any) -> alloydb_v1beta.types.resources.ConnectionInfo:
31-
ci = alloydb_v1beta.types.resources.ConnectionInfo()
32-
ci.ip_address = "10.0.0.1"
33-
ci.instance_uid = "123456789"
34-
return ci
35-
36-
37-
async def connectionInfoPublicIP(request: Any) -> alloydb_v1beta.types.resources.ConnectionInfo:
38-
ci = alloydb_v1beta.types.resources.ConnectionInfo()
39-
ci.ip_address = "10.0.0.1"
40-
ci.public_ip_address = "127.0.0.1"
41-
ci.instance_uid = "123456789"
42-
return ci
43-
44-
45-
async def connectionInfoPsc(request: Any) -> alloydb_v1beta.types.resources.ConnectionInfo:
46-
ci = alloydb_v1beta.types.resources.ConnectionInfo()
47-
ci.psc_dns_name = "x.y.alloydb.goog"
48-
ci.instance_uid = "123456789"
49-
return ci
50-
51-
52-
async def generateClientCertificate(request: Any) -> alloydb_v1beta.types.service.GenerateClientCertificateResponse:
53-
ccr = alloydb_v1beta.types.service.GenerateClientCertificateResponse()
54-
ccr.ca_cert = "This is the CA cert"
55-
ccr.pem_certificate_chain.append("This is the client cert")
56-
ccr.pem_certificate_chain.append("This is the intermediate cert")
57-
ccr.pem_certificate_chain.append("This is the root cert")
58-
return ccr
59-
60-
61-
class MockAlloyDBAdminAsyncClient:
62-
async def get_connection_info(self, request: alloydb_v1beta.GetConnectionInfoRequest) -> alloydb_v1beta.types.resources.ConnectionInfo:
63-
parent = request.parent
64-
instance = parent.split("/")[-1]
65-
if instance == "test-instance":
66-
return connectionInfo(request)
67-
elif instance == "public-instance":
68-
return connectionInfoPublicIP(request)
69-
else:
70-
return connectionInfoPsc(request)
71-
72-
async def generate_client_certificate(self, request: alloydb_v1beta.GenerateClientCertificateRequest) -> web.Response:
73-
return generateClientCertificate(request)
74-
75-
7629
@pytest.mark.asyncio
7730
async def test__get_metadata(credentials: FakeCredentials) -> None:
7831
"""
7932
Test _get_metadata returns successfully.
8033
"""
81-
test_client = AlloyDBClient("", "", credentials, MockAlloyDBAdminAsyncClient())
34+
test_client = AlloyDBClient("", "", credentials, FakeAlloyDBAdminAsyncClient())
8235
ip_addrs = await test_client._get_metadata(
8336
"test-project",
8437
"test-region",
@@ -99,7 +52,7 @@ async def test__get_metadata_with_public_ip(
9952
"""
10053
Test _get_metadata returns successfully with Public IP.
10154
"""
102-
test_client = AlloyDBClient("", "", credentials, MockAlloyDBAdminAsyncClient())
55+
test_client = AlloyDBClient("", "", credentials, FakeAlloyDBAdminAsyncClient())
10356
ip_addrs = await test_client._get_metadata(
10457
"test-project",
10558
"test-region",
@@ -120,7 +73,7 @@ async def test__get_metadata_with_psc(
12073
"""
12174
Test _get_metadata returns successfully with PSC DNS name.
12275
"""
123-
test_client = AlloyDBClient("", "", credentials, MockAlloyDBAdminAsyncClient())
76+
test_client = AlloyDBClient("", "", credentials, FakeAlloyDBAdminAsyncClient())
12477
ip_addrs = await test_client._get_metadata(
12578
"test-project",
12679
"test-region",
@@ -140,34 +93,14 @@ async def test__get_metadata_error(
14093
"""
14194
Test that AlloyDB API error messages are raised for _get_metadata.
14295
"""
143-
# mock AlloyDB API calls with exceptions
14496
client = AlloyDBClient(
145-
alloydb_api_endpoint="https://alloydb.googleapis.com",
97+
alloydb_api_endpoint="alloydb.googleapis.com",
14698
quota_project=None,
14799
credentials=credentials,
148100
)
149-
get_url = "https://alloydb.googleapis.com/v1beta/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance/connectionInfo"
150-
resp_body = {
151-
"error": {
152-
"code": 403,
153-
"message": "AlloyDB API has not been used in project 123456789 before or it is disabled",
154-
}
155-
}
156-
with aioresponses() as mocked:
157-
mocked.get(
158-
get_url,
159-
status=403,
160-
payload=resp_body,
161-
repeat=True,
162-
)
163-
with pytest.raises(ClientResponseError) as exc_info:
164-
await client._get_metadata(
165-
"my-project", "my-region", "my-cluster", "my-instance"
166-
)
167-
assert exc_info.value.status == 403
168-
assert (
169-
exc_info.value.message
170-
== "AlloyDB API has not been used in project 123456789 before or it is disabled"
101+
with pytest.raises(RetryError) as exc_info:
102+
await client._get_metadata(
103+
"my-project", "my-region", "my-cluster", "my-instance"
171104
)
172105
await client.close()
173106

@@ -179,7 +112,7 @@ async def test__get_client_certificate(
179112
"""
180113
Test _get_client_certificate returns successfully.
181114
"""
182-
test_client = AlloyDBClient("", "", credentials, MockAlloyDBAdminAsyncClient())
115+
test_client = AlloyDBClient("", "", credentials, FakeAlloyDBAdminAsyncClient())
183116
keys = await generate_keys()
184117
certs = await test_client._get_client_certificate(
185118
"test-project", "test-region", "test-cluster", keys[1]
@@ -197,32 +130,16 @@ async def test__get_client_certificate_error(
197130
"""
198131
Test that AlloyDB API error messages are raised for _get_client_certificate.
199132
"""
200-
# mock AlloyDB API calls with exceptions
201133
client = AlloyDBClient(
202-
alloydb_api_endpoint="https://alloydb.googleapis.com",
134+
alloydb_api_endpoint="alloydb.googleapis.com",
203135
quota_project=None,
204136
credentials=credentials,
205137
)
206-
post_url = "https://alloydb.googleapis.com/v1beta/projects/my-project/locations/my-region/clusters/my-cluster:generateClientCertificate"
207-
resp_body = {
208-
"error": {
209-
"code": 404,
210-
"message": "The AlloyDB instance does not exist.",
211-
}
212-
}
213-
with aioresponses() as mocked:
214-
mocked.post(
215-
post_url,
216-
status=404,
217-
payload=resp_body,
218-
repeat=True,
138+
with pytest.raises(RetryError) as exc_info:
139+
await client._get_client_certificate(
140+
"my-project", "my-region", "my-cluster", ""
219141
)
220-
with pytest.raises(ClientResponseError) as exc_info:
221-
await client._get_client_certificate(
222-
"my-project", "my-region", "my-cluster", ""
223-
)
224-
assert exc_info.value.status == 404
225-
assert exc_info.value.message == "The AlloyDB instance does not exist."
142+
print(exc_info)
226143
await client.close()
227144

228145

@@ -234,10 +151,11 @@ async def test_AlloyDBClient_init_(credentials: FakeCredentials) -> None:
234151
"""
235152
client = AlloyDBClient("www.test-endpoint.com", "my-quota-project", credentials)
236153
# verify base endpoint is set
237-
assert client._alloydb_api_endpoint == "www.test-endpoint.com"
154+
assert client._client.api_endpoint == "www.test-endpoint.com"
238155
# verify proper headers are set
239-
assert client._client.headers["User-Agent"] == f"alloydb-python-connector/{version}"
240-
assert client._client.headers["x-goog-user-project"] == "my-quota-project"
156+
got_user_agent = client._client.transport._wrapped_methods[client._client.transport.list_clusters]._metadata[0][1]
157+
assert got_user_agent.startswith(f"alloydb-python-connector/{version}")
158+
assert client._client._client._client_options.quota_project_id == "my-quota-project"
241159
# close client
242160
await client.close()
243161

@@ -255,10 +173,8 @@ async def test_AlloyDBClient_init_custom_user_agent(
255173
credentials,
256174
user_agent="custom-agent/v1.0.0 other-agent/v2.0.0",
257175
)
258-
assert (
259-
client._client.headers["User-Agent"]
260-
== f"alloydb-python-connector/{version} custom-agent/v1.0.0 other-agent/v2.0.0"
261-
)
176+
got_user_agent = client._client.transport._wrapped_methods[client._client.transport.list_clusters]._metadata[0][1]
177+
assert got_user_agent.startswith(f"alloydb-python-connector/{version} custom-agent/v1.0.0 other-agent/v2.0.0")
262178
await client.close()
263179

264180

@@ -277,10 +193,11 @@ async def test_AlloyDBClient_user_agent(
277193
client = AlloyDBClient(
278194
"www.test-endpoint.com", "my-quota-project", credentials, driver=driver
279195
)
196+
got_user_agent = client._client.transport._wrapped_methods[client._client.transport.list_clusters]._metadata[0][1]
280197
if driver is None:
281-
assert client._user_agent == f"alloydb-python-connector/{version}"
198+
assert got_user_agent.startswith(f"alloydb-python-connector/{version}")
282199
else:
283-
assert client._user_agent == f"alloydb-python-connector/{version}+{driver}"
200+
assert got_user_agent.startswith(f"alloydb-python-connector/{version}+{driver}")
284201
# close client
285202
await client.close()
286203

0 commit comments

Comments
 (0)