diff --git a/google/cloud/alloydb/connector/async_connector.py b/google/cloud/alloydb/connector/async_connector.py index bd8e88d3..210e5fa6 100644 --- a/google/cloud/alloydb/connector/async_connector.py +++ b/google/cloud/alloydb/connector/async_connector.py @@ -31,6 +31,7 @@ from google.cloud.alloydb.connector.lazy import LazyRefreshCache from google.cloud.alloydb.connector.types import CacheTypes from google.cloud.alloydb.connector.utils import generate_keys +from google.cloud.alloydb.connector.utils import strip_http_prefix if TYPE_CHECKING: from google.auth.credentials import Credentials @@ -51,7 +52,7 @@ class AsyncConnector: billing purposes. Defaults to None, picking up project from environment. alloydb_api_endpoint (str): Base URL to use when calling - the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com". + the AlloyDB API endpoint. Defaults to "alloydb.googleapis.com". enable_iam_auth (bool): Enables automatic IAM database authentication. ip_type (str | IPTypes): Default IP type for all AlloyDB connections. Defaults to IPTypes.PRIVATE ("PRIVATE") for private IP connections. @@ -66,7 +67,7 @@ def __init__( self, credentials: Optional[Credentials] = None, quota_project: Optional[str] = None, - alloydb_api_endpoint: str = "https://alloydb.googleapis.com", + alloydb_api_endpoint: str = "alloydb.googleapis.com", enable_iam_auth: bool = False, ip_type: str | IPTypes = IPTypes.PRIVATE, user_agent: Optional[str] = None, @@ -75,7 +76,7 @@ def __init__( self._cache: dict[str, CacheTypes] = {} # initialize default params self._quota_project = quota_project - self._alloydb_api_endpoint = alloydb_api_endpoint + self._alloydb_api_endpoint = strip_http_prefix(alloydb_api_endpoint) self._enable_iam_auth = enable_iam_auth # if ip_type is str, convert to IPTypes enum if isinstance(ip_type, str): @@ -235,5 +236,3 @@ async def close(self) -> None: """Helper function to cancel RefreshAheadCaches' tasks and close client.""" await asyncio.gather(*[cache.close() for cache in self._cache.values()]) - if self._client: - await self._client.close() diff --git a/google/cloud/alloydb/connector/client.py b/google/cloud/alloydb/connector/client.py index 3ed95683..99d8c0bf 100644 --- a/google/cloud/alloydb/connector/client.py +++ b/google/cloud/alloydb/connector/client.py @@ -18,10 +18,13 @@ import logging from typing import Optional, TYPE_CHECKING -import aiohttp from cryptography import x509 +from google.api_core.client_options import ClientOptions +from google.api_core.gapic_v1.client_info import ClientInfo from google.auth.credentials import TokenState from google.auth.transport import requests +import google.cloud.alloydb_v1beta as v1beta +from google.protobuf import duration_pb2 from google.cloud.alloydb.connector.connection_info import ConnectionInfo from google.cloud.alloydb.connector.version import __version__ as version @@ -55,7 +58,7 @@ def __init__( alloydb_api_endpoint: str, quota_project: Optional[str], credentials: Credentials, - client: Optional[aiohttp.ClientSession] = None, + client: Optional[v1beta.AlloyDBAdminAsyncClient] = None, driver: Optional[str] = None, user_agent: Optional[str] = None, ) -> None: @@ -72,23 +75,28 @@ def __init__( A credentials object created from the google-auth Python library. Must have the AlloyDB Admin scopes. For more info check out https://google-auth.readthedocs.io/en/latest/. - client (aiohttp.ClientSession): Async client used to make requests to - AlloyDB APIs. + client (v1beta.AlloyDBAdminAsyncClient): Async client used to make + requests to AlloyDB APIs. Optional, defaults to None and creates new client. driver (str): Database driver to be used by the client. """ user_agent = _format_user_agent(driver, user_agent) - headers = { - "x-goog-api-client": user_agent, - "User-Agent": user_agent, - "Content-Type": "application/json", - } - if quota_project: - headers["x-goog-user-project"] = quota_project - self._client = client if client else aiohttp.ClientSession(headers=headers) + self._client = ( + client + if client + else v1beta.AlloyDBAdminAsyncClient( + credentials=credentials, + client_options=ClientOptions( + api_endpoint=alloydb_api_endpoint, + quota_project_id=quota_project, + ), + client_info=ClientInfo( + user_agent=user_agent, + ), + ) + ) self._credentials = credentials - self._alloydb_api_endpoint = alloydb_api_endpoint # asyncpg does not currently support using metadata exchange # only use metadata exchange for pg8000 driver self._use_metadata = True if driver == "pg8000" else False @@ -118,35 +126,21 @@ async def _get_metadata( Returns: dict: IP addresses of the AlloyDB instance. """ - headers = { - "Authorization": f"Bearer {self._credentials.token}", - } + parent = ( + f"projects/{project}/locations/{region}/clusters/{cluster}/instances/{name}" + ) - url = f"{self._alloydb_api_endpoint}/{API_VERSION}/projects/{project}/locations/{region}/clusters/{cluster}/instances/{name}/connectionInfo" - - resp = await self._client.get(url, headers=headers) - # try to get response json for better error message - try: - resp_dict = await resp.json() - if resp.status >= 400: - # if detailed error message is in json response, use as error message - message = resp_dict.get("error", {}).get("message") - if message: - resp.reason = message - # skip, raise_for_status will catch all errors in finally block - except Exception: - pass - finally: - resp.raise_for_status() + req = v1beta.GetConnectionInfoRequest(parent=parent) + resp = await self._client.get_connection_info(request=req) # Remove trailing period from PSC DNS name. - psc_dns = resp_dict.get("pscDnsName") + psc_dns = resp.psc_dns_name if psc_dns: psc_dns = psc_dns.rstrip(".") return { - "PRIVATE": resp_dict.get("ipAddress"), - "PUBLIC": resp_dict.get("publicIpAddress"), + "PRIVATE": resp.ip_address, + "PUBLIC": resp.public_ip_address, "PSC": psc_dns, } @@ -175,34 +169,17 @@ async def _get_client_certificate( tuple[str, list[str]]: tuple containing the CA certificate and certificate chain for the AlloyDB instance. """ - headers = { - "Authorization": f"Bearer {self._credentials.token}", - } - - url = f"{self._alloydb_api_endpoint}/{API_VERSION}/projects/{project}/locations/{region}/clusters/{cluster}:generateClientCertificate" - - data = { - "publicKey": pub_key, - "certDuration": "3600s", - "useMetadataExchange": self._use_metadata, - } - - resp = await self._client.post(url, headers=headers, json=data) - # try to get response json for better error message - try: - resp_dict = await resp.json() - if resp.status >= 400: - # if detailed error message is in json response, use as error message - message = resp_dict.get("error", {}).get("message") - if message: - resp.reason = message - # skip, raise_for_status will catch all errors in finally block - except Exception: - pass - finally: - resp.raise_for_status() - - return (resp_dict["caCert"], resp_dict["pemCertificateChain"]) + parent = f"projects/{project}/locations/{region}/clusters/{cluster}" + dur = duration_pb2.Duration() + dur.seconds = 3600 + req = v1beta.GenerateClientCertificateRequest( + parent=parent, + cert_duration=dur, + public_key=pub_key, + use_metadata_exchange=self._use_metadata, + ) + resp = await self._client.generate_client_certificate(request=req) + return (resp.ca_cert, resp.pem_certificate_chain) async def get_connection_info( self, @@ -267,9 +244,3 @@ async def get_connection_info( ip_addrs, expiration, ) - - async def close(self) -> None: - """Close AlloyDBClient gracefully.""" - logger.debug("Waiting for connector's http client to close") - await self._client.close() - logger.debug("Closed connector's http client") diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index 41a13fa3..3e9f4003 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -37,6 +37,7 @@ import google.cloud.alloydb.connector.pg8000 as pg8000 from google.cloud.alloydb.connector.types import CacheTypes from google.cloud.alloydb.connector.utils import generate_keys +from google.cloud.alloydb.connector.utils import strip_http_prefix import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb if TYPE_CHECKING: @@ -64,7 +65,7 @@ class Connector: billing purposes. Defaults to None, picking up project from environment. alloydb_api_endpoint (str): Base URL to use when calling - the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com". + the AlloyDB API endpoint. Defaults to "alloydb.googleapis.com". enable_iam_auth (bool): Enables automatic IAM database authentication. ip_type (str | IPTypes): Default IP type for all AlloyDB connections. Defaults to IPTypes.PRIVATE ("PRIVATE") for private IP connections. @@ -85,7 +86,7 @@ def __init__( self, credentials: Optional[Credentials] = None, quota_project: Optional[str] = None, - alloydb_api_endpoint: str = "https://alloydb.googleapis.com", + alloydb_api_endpoint: str = "alloydb.googleapis.com", enable_iam_auth: bool = False, ip_type: str | IPTypes = IPTypes.PRIVATE, user_agent: Optional[str] = None, @@ -99,7 +100,7 @@ def __init__( self._cache: dict[str, CacheTypes] = {} # initialize default params self._quota_project = quota_project - self._alloydb_api_endpoint = alloydb_api_endpoint + self._alloydb_api_endpoint = strip_http_prefix(alloydb_api_endpoint) self._enable_iam_auth = enable_iam_auth # if ip_type is str, convert to IPTypes enum if isinstance(ip_type, str): @@ -392,5 +393,3 @@ async def close_async(self) -> None: """Helper function to cancel RefreshAheadCaches' tasks and close client.""" await asyncio.gather(*[cache.close() for cache in self._cache.values()]) - if self._client: - await self._client.close() diff --git a/google/cloud/alloydb/connector/utils.py b/google/cloud/alloydb/connector/utils.py index e4c99393..10a107a7 100644 --- a/google/cloud/alloydb/connector/utils.py +++ b/google/cloud/alloydb/connector/utils.py @@ -14,6 +14,8 @@ from __future__ import annotations +import re + import aiofiles from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa @@ -58,3 +60,13 @@ async def generate_keys() -> tuple[rsa.RSAPrivateKey, str]: .decode("UTF-8") ) return (priv_key, pub_key) + + +def strip_http_prefix(url: str) -> str: + """ + Returns a new URL with 'http://' or 'https://' prefix removed. + """ + m = re.search(r"^(https?://)?(.+)", url) + if m is None: + return "" + return m.group(2) diff --git a/mypy.ini b/mypy.ini index c1eec50d..46cc2984 100644 --- a/mypy.ini +++ b/mypy.ini @@ -11,3 +11,9 @@ ignore_missing_imports = True [mypy-asyncpg] ignore_missing_imports = True + +[mypy-google.cloud.alloydb_v1beta] +ignore_missing_imports = True + +[mypy-google.api_core.*] +ignore_missing_imports = True diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index cb64c819..5cd1a0f8 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -32,6 +32,7 @@ from google.auth.credentials import TokenState from google.auth.transport import requests +from google.cloud import alloydb_v1beta from google.cloud.alloydb.connector.connection_info import ConnectionInfo import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb @@ -232,7 +233,6 @@ def __init__( self, instance: Optional[FakeInstance] = None, driver: str = "pg8000" ) -> None: self.instance = FakeInstance() if instance is None else instance - self.closed = False self._user_agent = f"test-user-agent+{driver}" self._credentials = FakeCredentials() @@ -317,9 +317,6 @@ async def get_connection_info( expiration, ) - async def close(self) -> None: - self.closed = True - def metadata_exchange(sock: ssl.SSLSocket) -> None: """ @@ -448,3 +445,36 @@ def write_static_info(i: FakeInstance) -> io.StringIO: "pscInstanceConfig": {"pscDnsName": i.ip_addrs["PSC"]}, } return io.StringIO(json.dumps(static)) + + +class FakeAlloyDBAdminAsyncClient: + async def get_connection_info( + self, request: alloydb_v1beta.GetConnectionInfoRequest + ) -> alloydb_v1beta.types.resources.ConnectionInfo: + ci = alloydb_v1beta.types.resources.ConnectionInfo() + ci.ip_address = "10.0.0.1" + ci.public_ip_address = "127.0.0.1" + ci.instance_uid = "123456789" + ci.psc_dns_name = "x.y.alloydb.goog" + + parent = request.parent + instance = parent.split("/")[-1] + if instance == "test-instance": + ci.public_ip_address = "" + ci.psc_dns_name = "" + elif instance == "public-instance": + ci.psc_dns_name = "" + else: + ci.ip_address = "" + ci.public_ip_address = "" + return ci + + async def generate_client_certificate( + self, request: alloydb_v1beta.GenerateClientCertificateRequest + ) -> alloydb_v1beta.types.service.GenerateClientCertificateResponse: + ccr = alloydb_v1beta.types.service.GenerateClientCertificateResponse() + ccr.ca_cert = "This is the CA cert" + ccr.pem_certificate_chain.append("This is the client cert") + ccr.pem_certificate_chain.append("This is the intermediate cert") + ccr.pem_certificate_chain.append("This is the root cert") + return ccr diff --git a/tests/unit/test_async_connector.py b/tests/unit/test_async_connector.py index 0f150875..07450ea1 100644 --- a/tests/unit/test_async_connector.py +++ b/tests/unit/test_async_connector.py @@ -15,7 +15,7 @@ import asyncio from typing import Union -from aiohttp import ClientResponseError +from google.api_core.exceptions import RetryError from mock import patch from mocks import FakeAlloyDBClient from mocks import FakeConnectionInfo @@ -27,7 +27,7 @@ from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError from google.cloud.alloydb.connector.instance import RefreshAheadCache -ALLOYDB_API_ENDPOINT = "https://alloydb.googleapis.com" +ALLOYDB_API_ENDPOINT = "alloydb.googleapis.com" @pytest.mark.asyncio @@ -109,6 +109,34 @@ async def test_AsyncConnector_init_bad_ip_type(credentials: FakeCredentials) -> ) +def test_AsyncConnector_init_alloydb_api_endpoint_with_http_prefix( + credentials: FakeCredentials, +) -> None: + """ + Test to check whether the __init__ method of AsyncConnector properly sets + alloydb_api_endpoint when its URL has an 'http://' prefix. + """ + connector = AsyncConnector( + alloydb_api_endpoint="http://alloydb.googleapis.com", credentials=credentials + ) + assert connector._alloydb_api_endpoint == "alloydb.googleapis.com" + connector.close() + + +def test_AsyncConnector_init_alloydb_api_endpoint_with_https_prefix( + credentials: FakeCredentials, +) -> None: + """ + Test to check whether the __init__ method of AsyncConnector properly sets + alloydb_api_endpoint when its URL has an 'https://' prefix. + """ + connector = AsyncConnector( + alloydb_api_endpoint="https://alloydb.googleapis.com", credentials=credentials + ) + assert connector._alloydb_api_endpoint == "alloydb.googleapis.com" + connector.close() + + @pytest.mark.asyncio async def test_AsyncConnector_context_manager( credentials: FakeCredentials, @@ -163,8 +191,6 @@ async def test_connect_and_close(credentials: FakeCredentials) -> None: # check connection is returned assert connection.result() is True - # outside of context manager check close cleaned up - assert connector._client.closed is True @pytest.mark.asyncio @@ -244,8 +270,6 @@ async def test_context_manager_connect_and_close( # check connection is returned assert connection.result() is True - # outside of context manager check close cleaned up - assert fake_client.closed is True @pytest.mark.asyncio @@ -309,7 +333,7 @@ async def test_Connector_remove_cached_bad_instance( """ instance_uri = "projects/test-project/locations/test-region/clusters/test-cluster/instances/bad-test-instance" async with AsyncConnector(credentials=credentials) as connector: - with pytest.raises(ClientResponseError): + with pytest.raises(RetryError): await connector.connect(instance_uri, "asyncpg") assert instance_uri not in connector._cache diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index e4b2fdbb..9035a99b 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json -from typing import Any, Optional +from typing import Optional -from aiohttp import ClientResponseError -from aiohttp import web -from aioresponses import aioresponses +from mocks import FakeAlloyDBAdminAsyncClient from mocks import FakeCredentials import pytest @@ -26,65 +23,12 @@ from google.cloud.alloydb.connector.version import __version__ as version -async def connectionInfo(request: Any) -> web.Response: - response = { - "ipAddress": "10.0.0.1", - "instanceUid": "123456789", - } - return web.Response(content_type="application/json", body=json.dumps(response)) - - -async def connectionInfoPublicIP(request: Any) -> web.Response: - response = { - "ipAddress": "10.0.0.1", - "publicIpAddress": "127.0.0.1", - "instanceUid": "123456789", - } - return web.Response(content_type="application/json", body=json.dumps(response)) - - -async def connectionInfoPsc(request: Any) -> web.Response: - response = { - "ipAddress": None, - "publicIpAddress": None, - "pscDnsName": "x.y.alloydb.goog", - "instanceUid": "123456789", - } - return web.Response(content_type="application/json", body=json.dumps(response)) - - -async def generateClientCertificate(request: Any) -> web.Response: - response = { - "caCert": "This is the CA cert", - "pemCertificateChain": [ - "This is the client cert", - "This is the intermediate cert", - "This is the root cert", - ], - } - return web.Response(content_type="application/json", body=json.dumps(response)) - - -@pytest.fixture -async def client(aiohttp_client: Any) -> Any: - app = web.Application() - metadata_uri = "/v1beta/projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance/connectionInfo" - app.router.add_get(metadata_uri, connectionInfo) - metadata_public_ip_uri = "/v1beta/projects/test-project/locations/test-region/clusters/test-cluster/instances/public-instance/connectionInfo" - app.router.add_get(metadata_public_ip_uri, connectionInfoPublicIP) - metadata_psc_uri = "/v1beta/projects/test-project/locations/test-region/clusters/test-cluster/instances/psc-instance/connectionInfo" - app.router.add_get(metadata_psc_uri, connectionInfoPsc) - client_cert_uri = "/v1beta/projects/test-project/locations/test-region/clusters/test-cluster:generateClientCertificate" - app.router.add_post(client_cert_uri, generateClientCertificate) - return await aiohttp_client(app) - - @pytest.mark.asyncio -async def test__get_metadata(client: Any, credentials: FakeCredentials) -> None: +async def test__get_metadata(credentials: FakeCredentials) -> None: """ Test _get_metadata returns successfully. """ - test_client = AlloyDBClient("", "", credentials, client) + test_client = AlloyDBClient("", "", credentials, FakeAlloyDBAdminAsyncClient()) ip_addrs = await test_client._get_metadata( "test-project", "test-region", @@ -93,19 +37,17 @@ async def test__get_metadata(client: Any, credentials: FakeCredentials) -> None: ) assert ip_addrs == { "PRIVATE": "10.0.0.1", - "PUBLIC": None, - "PSC": None, + "PUBLIC": "", + "PSC": "", } @pytest.mark.asyncio -async def test__get_metadata_with_public_ip( - client: Any, credentials: FakeCredentials -) -> None: +async def test__get_metadata_with_public_ip(credentials: FakeCredentials) -> None: """ Test _get_metadata returns successfully with Public IP. """ - test_client = AlloyDBClient("", "", credentials, client) + test_client = AlloyDBClient("", "", credentials, FakeAlloyDBAdminAsyncClient()) ip_addrs = await test_client._get_metadata( "test-project", "test-region", @@ -115,18 +57,16 @@ async def test__get_metadata_with_public_ip( assert ip_addrs == { "PRIVATE": "10.0.0.1", "PUBLIC": "127.0.0.1", - "PSC": None, + "PSC": "", } @pytest.mark.asyncio -async def test__get_metadata_with_psc( - client: Any, credentials: FakeCredentials -) -> None: +async def test__get_metadata_with_psc(credentials: FakeCredentials) -> None: """ Test _get_metadata returns successfully with PSC DNS name. """ - test_client = AlloyDBClient("", "", credentials, client) + test_client = AlloyDBClient("", "", credentials, FakeAlloyDBAdminAsyncClient()) ip_addrs = await test_client._get_metadata( "test-project", "test-region", @@ -134,89 +74,18 @@ async def test__get_metadata_with_psc( "psc-instance", ) assert ip_addrs == { - "PRIVATE": None, - "PUBLIC": None, + "PRIVATE": "", + "PUBLIC": "", "PSC": "x.y.alloydb.goog", } -async def test__get_metadata_error( - credentials: FakeCredentials, -) -> None: - """ - Test that AlloyDB API error messages are raised for _get_metadata. - """ - # mock AlloyDB API calls with exceptions - client = AlloyDBClient( - alloydb_api_endpoint="https://alloydb.googleapis.com", - quota_project=None, - credentials=credentials, - ) - get_url = "https://alloydb.googleapis.com/v1beta/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance/connectionInfo" - resp_body = { - "error": { - "code": 403, - "message": "AlloyDB API has not been used in project 123456789 before or it is disabled", - } - } - with aioresponses() as mocked: - mocked.get( - get_url, - status=403, - payload=resp_body, - repeat=True, - ) - with pytest.raises(ClientResponseError) as exc_info: - await client._get_metadata( - "my-project", "my-region", "my-cluster", "my-instance" - ) - assert exc_info.value.status == 403 - assert ( - exc_info.value.message - == "AlloyDB API has not been used in project 123456789 before or it is disabled" - ) - await client.close() - - -async def test__get_metadata_error_parsing_json( - credentials: FakeCredentials, -) -> None: - """ - Test that aiohttp default error messages are raised when _get_metadata gets - a bad JSON response. - """ - # mock AlloyDB API calls with exceptions - client = AlloyDBClient( - alloydb_api_endpoint="https://alloydb.googleapis.com", - quota_project=None, - credentials=credentials, - ) - get_url = "https://alloydb.googleapis.com/v1beta/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance/connectionInfo" - resp_body = ["error"] # invalid json - with aioresponses() as mocked: - mocked.get( - get_url, - status=403, - payload=resp_body, - repeat=True, - ) - with pytest.raises(ClientResponseError) as exc_info: - await client._get_metadata( - "my-project", "my-region", "my-cluster", "my-instance" - ) - assert exc_info.value.status == 403 - assert exc_info.value.message == "Forbidden" - await client.close() - - @pytest.mark.asyncio -async def test__get_client_certificate( - client: Any, credentials: FakeCredentials -) -> None: +async def test__get_client_certificate(credentials: FakeCredentials) -> None: """ Test _get_client_certificate returns successfully. """ - test_client = AlloyDBClient("", "", credentials, client) + test_client = AlloyDBClient("", "", credentials, FakeAlloyDBAdminAsyncClient()) keys = await generate_keys() certs = await test_client._get_client_certificate( "test-project", "test-region", "test-cluster", keys[1] @@ -228,72 +97,6 @@ async def test__get_client_certificate( assert cert_chain[2] == "This is the root cert" -async def test__get_client_certificate_error( - credentials: FakeCredentials, -) -> None: - """ - Test that AlloyDB API error messages are raised for _get_client_certificate. - """ - # mock AlloyDB API calls with exceptions - client = AlloyDBClient( - alloydb_api_endpoint="https://alloydb.googleapis.com", - quota_project=None, - credentials=credentials, - ) - post_url = "https://alloydb.googleapis.com/v1beta/projects/my-project/locations/my-region/clusters/my-cluster:generateClientCertificate" - resp_body = { - "error": { - "code": 404, - "message": "The AlloyDB instance does not exist.", - } - } - with aioresponses() as mocked: - mocked.post( - post_url, - status=404, - payload=resp_body, - repeat=True, - ) - with pytest.raises(ClientResponseError) as exc_info: - await client._get_client_certificate( - "my-project", "my-region", "my-cluster", "" - ) - assert exc_info.value.status == 404 - assert exc_info.value.message == "The AlloyDB instance does not exist." - await client.close() - - -async def test__get_client_certificate_error_parsing_json( - credentials: FakeCredentials, -) -> None: - """ - Test that aiohttp default error messages are raised when - _get_client_certificate gets a bad JSON response. - """ - # mock AlloyDB API calls with exceptions - client = AlloyDBClient( - alloydb_api_endpoint="https://alloydb.googleapis.com", - quota_project=None, - credentials=credentials, - ) - post_url = "https://alloydb.googleapis.com/v1beta/projects/my-project/locations/my-region/clusters/my-cluster:generateClientCertificate" - resp_body = ["error"] # invalid json - with aioresponses() as mocked: - mocked.post( - post_url, - status=404, - payload=resp_body, - repeat=True, - ) - with pytest.raises(ClientResponseError) as exc_info: - await client._get_client_certificate( - "my-project", "my-region", "my-cluster", "" - ) - assert exc_info.value.status == 404 - assert exc_info.value.message == "Not Found" - await client.close() - - @pytest.mark.asyncio async def test_AlloyDBClient_init_(credentials: FakeCredentials) -> None: """ @@ -302,12 +105,10 @@ async def test_AlloyDBClient_init_(credentials: FakeCredentials) -> None: """ client = AlloyDBClient("www.test-endpoint.com", "my-quota-project", credentials) # verify base endpoint is set - assert client._alloydb_api_endpoint == "www.test-endpoint.com" + assert client._client.api_endpoint == "www.test-endpoint.com" # verify proper headers are set - assert client._client.headers["User-Agent"] == f"alloydb-python-connector/{version}" - assert client._client.headers["x-goog-user-project"] == "my-quota-project" - # close client - await client.close() + assert client._user_agent.startswith(f"alloydb-python-connector/{version}") + assert client._client._client._client_options.quota_project_id == "my-quota-project" @pytest.mark.asyncio @@ -323,11 +124,9 @@ async def test_AlloyDBClient_init_custom_user_agent( credentials, user_agent="custom-agent/v1.0.0 other-agent/v2.0.0", ) - assert ( - client._client.headers["User-Agent"] - == f"alloydb-python-connector/{version} custom-agent/v1.0.0 other-agent/v2.0.0" + assert client._user_agent.startswith( + f"alloydb-python-connector/{version} custom-agent/v1.0.0 other-agent/v2.0.0" ) - await client.close() @pytest.mark.parametrize( @@ -346,11 +145,11 @@ async def test_AlloyDBClient_user_agent( "www.test-endpoint.com", "my-quota-project", credentials, driver=driver ) if driver is None: - assert client._user_agent == f"alloydb-python-connector/{version}" + assert client._user_agent.startswith(f"alloydb-python-connector/{version}") else: - assert client._user_agent == f"alloydb-python-connector/{version}+{driver}" - # close client - await client.close() + assert client._user_agent.startswith( + f"alloydb-python-connector/{version}+{driver}" + ) @pytest.mark.parametrize( @@ -369,5 +168,3 @@ async def test_AlloyDBClient_use_metadata( "www.test-endpoint.com", "my-quota-project", credentials, driver=driver ) assert client._use_metadata == expected - # close client - await client.close() diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index c7660d1c..733c303f 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -16,7 +16,7 @@ from threading import Thread from typing import Union -from aiohttp import ClientResponseError +from google.api_core.exceptions import RetryError from mock import patch from mocks import FakeAlloyDBClient from mocks import FakeCredentials @@ -37,7 +37,7 @@ def test_Connector_init(credentials: FakeCredentials) -> None: """ connector = Connector(credentials) assert connector._quota_project is None - assert connector._alloydb_api_endpoint == "https://alloydb.googleapis.com" + assert connector._alloydb_api_endpoint == "alloydb.googleapis.com" assert connector._client is None assert connector._credentials == credentials connector.close() @@ -107,6 +107,34 @@ def test_Connector_init_ip_type( connector.close() +def test_Connector_init_alloydb_api_endpoint_with_http_prefix( + credentials: FakeCredentials, +) -> None: + """ + Test to check whether the __init__ method of Connector properly sets + alloydb_api_endpoint when its URL has an 'http://' prefix. + """ + connector = Connector( + alloydb_api_endpoint="http://alloydb.googleapis.com", credentials=credentials + ) + assert connector._alloydb_api_endpoint == "alloydb.googleapis.com" + connector.close() + + +def test_Connector_init_alloydb_api_endpoint_with_https_prefix( + credentials: FakeCredentials, +) -> None: + """ + Test to check whether the __init__ method of Connector properly sets + alloydb_api_endpoint when its URL has an 'https://' prefix. + """ + connector = Connector( + alloydb_api_endpoint="https://alloydb.googleapis.com", credentials=credentials + ) + assert connector._alloydb_api_endpoint == "alloydb.googleapis.com" + connector.close() + + def test_Connector_context_manager(credentials: FakeCredentials) -> None: """ Test to check whether the __init__ method of Connector @@ -114,7 +142,7 @@ def test_Connector_context_manager(credentials: FakeCredentials) -> None: """ with Connector(credentials) as connector: assert connector._quota_project is None - assert connector._alloydb_api_endpoint == "https://alloydb.googleapis.com" + assert connector._alloydb_api_endpoint == "alloydb.googleapis.com" assert connector._client is None assert connector._credentials == credentials @@ -220,7 +248,7 @@ def test_Connector_remove_cached_bad_instance( """ instance_uri = "projects/test-project/locations/test-region/clusters/test-cluster/instances/bad-test-instance" with Connector(credentials) as connector: - with pytest.raises(ClientResponseError): + with pytest.raises(RetryError): connector.connect(instance_uri, "pg8000") assert instance_uri not in connector._cache diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py new file mode 100644 index 00000000..689786c0 --- /dev/null +++ b/tests/unit/test_utils.py @@ -0,0 +1,27 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud.alloydb.connector.utils import strip_http_prefix + + +def test_strip_http_prefix_with_empty_url() -> None: + assert strip_http_prefix("") == "" + + +def test_strip_http_prefix_with_url_having_http_prefix() -> None: + assert strip_http_prefix("http://google.com") == "google.com" + + +def test_strip_http_prefix_with_url_having_https_prefix() -> None: + assert strip_http_prefix("https://google.com") == "google.com"