Skip to content

Commit 8b166cc

Browse files
committed
feat: limitless plugin implementation
1 parent c8bfb0d commit 8b166cc

File tree

9 files changed

+670
-36
lines changed

9 files changed

+670
-36
lines changed

aws_advanced_python_wrapper/database_dialect.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,15 @@ def is_reader_query(self) -> str:
9898
return self._IS_READER_QUERY
9999

100100

101+
@runtime_checkable
102+
class AuroraLimitlessDialect(Protocol):
103+
_LIMITLESS_ROUTER_ENDPOINT_QUERY: str
104+
105+
@property
106+
def limitless_router_endpoint_query(self) -> str:
107+
return self._LIMITLESS_ROUTER_ENDPOINT_QUERY
108+
109+
101110
class DatabaseDialect(Protocol):
102111
"""
103112
Database dialects help the AWS Advanced Python Driver determine what kind of underlying database is being used,
@@ -342,7 +351,7 @@ def get_host_list_provider_supplier(self) -> Callable:
342351
return lambda provider_service, props: RdsHostListProvider(provider_service, props)
343352

344353

345-
class AuroraPgDialect(PgDatabaseDialect, TopologyAwareDatabaseDialect):
354+
class AuroraPgDialect(PgDatabaseDialect, TopologyAwareDatabaseDialect, AuroraLimitlessDialect):
346355
_DIALECT_UPDATE_CANDIDATES: Tuple[DialectCode, ...] = (DialectCode.MULTI_AZ_PG,)
347356

348357
_EXTENSIONS_QUERY = "SELECT (setting LIKE '%aurora_stat_utils%') AS aurora_stat_utils " \
@@ -359,6 +368,7 @@ class AuroraPgDialect(PgDatabaseDialect, TopologyAwareDatabaseDialect):
359368

360369
_HOST_ID_QUERY = "SELECT aurora_db_instance_identifier()"
361370
_IS_READER_QUERY = "SELECT pg_is_in_recovery()"
371+
_LIMITLESS_ROUTER_ENDPOINT_QUERY = "SELECT router_endpoint, load FROM aurora_limitless_router_endpoints()"
362372

363373
@property
364374
def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]:

aws_advanced_python_wrapper/default_plugin.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import TYPE_CHECKING
17+
from typing import TYPE_CHECKING, List
1818

1919
if TYPE_CHECKING:
2020
from aws_advanced_python_wrapper.connection_provider import (ConnectionProvider,
@@ -129,6 +129,15 @@ def get_host_info_by_strategy(self, role: HostRole, strategy: str) -> HostInfo:
129129

130130
return self._connection_provider_manager.get_host_info_by_strategy(hosts, role, strategy, self._plugin_service.props)
131131

132+
def get_host_info_from_input_by_strategy(self, host_list: List[HostInfo], role: HostRole, strategy: str) -> HostInfo:
133+
if HostRole.UNKNOWN == role:
134+
raise AwsWrapperError(Messages.get("DefaultPlugin.UnknownHosts"))
135+
136+
if len(host_list) < 1:
137+
raise AwsWrapperError(Messages.get("DefaultPlugin.EmptyHosts"))
138+
139+
return self._connection_provider_manager.get_host_info_by_strategy(tuple(host_list), role, strategy, self._plugin_service.props)
140+
132141
@property
133142
def subscribed_methods(self) -> Set[str]:
134143
return DefaultPlugin._SUBSCRIBED_METHODS

aws_advanced_python_wrapper/host_list_provider.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ def get_host_role(self, connection: Connection) -> HostRole:
6969
def identify_connection(self, connection: Optional[Connection]) -> Optional[HostInfo]:
7070
...
7171

72+
def get_cluster_id(self) -> str:
73+
...
74+
7275

7376
@runtime_checkable
7477
class DynamicHostListProvider(HostListProvider, Protocol):
@@ -519,6 +522,10 @@ def _identify_connection(self, conn: Connection):
519522
cursor.execute(self._dialect.host_id_query)
520523
return cursor.fetchone()
521524

525+
def get_cluster_id(self):
526+
self._initialize()
527+
return self._cluster_id
528+
522529
@dataclass()
523530
class ClusterIdSuggestion:
524531
cluster_id: str
@@ -646,3 +653,6 @@ def get_host_role(self, connection: Connection) -> HostRole:
646653
def identify_connection(self, connection: Optional[Connection]) -> Optional[HostInfo]:
647654
raise UnsupportedOperationError(
648655
Messages.get_formatted("ConnectionStringHostListProvider.UnsupportedMethod", "identify_connection"))
656+
657+
def get_cluster_id(self):
658+
return "<none>"

0 commit comments

Comments
 (0)