Skip to content

Commit a2bd4c3

Browse files
richvdhphil-flex
authored andcommitted
Add some type annotations in synapse.storage (matrix-org#6987)
I cracked, and added some type definitions in synapse.storage.
1 parent 971ab4b commit a2bd4c3

File tree

8 files changed

+270
-84
lines changed

8 files changed

+270
-84
lines changed

changelog.d/6987.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add some type annotations to the database storage classes.

synapse/storage/database.py

Lines changed: 84 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717
import logging
18-
import sys
1918
import time
20-
from typing import Iterable, Tuple
19+
from time import monotonic as monotonic_time
20+
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
2121

2222
from six import iteritems, iterkeys, itervalues
2323
from six.moves import intern, range
@@ -32,24 +32,14 @@
3232
from synapse.logging.context import LoggingContext, make_deferred_yieldable
3333
from synapse.metrics.background_process_metrics import run_as_background_process
3434
from synapse.storage.background_updates import BackgroundUpdater
35-
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
35+
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
36+
from synapse.storage.types import Connection, Cursor
3637
from synapse.util.stringutils import exception_to_unicode
3738

38-
# import a function which will return a monotonic time, in seconds
39-
try:
40-
# on python 3, use time.monotonic, since time.clock can go backwards
41-
from time import monotonic as monotonic_time
42-
except ImportError:
43-
# ... but python 2 doesn't have it
44-
from time import clock as monotonic_time
45-
4639
logger = logging.getLogger(__name__)
4740

48-
try:
49-
MAX_TXN_ID = sys.maxint - 1
50-
except AttributeError:
51-
# python 3 does not have a maximum int value
52-
MAX_TXN_ID = 2 ** 63 - 1
41+
# python 3 does not have a maximum int value
42+
MAX_TXN_ID = 2 ** 63 - 1
5343

5444
sql_logger = logging.getLogger("synapse.storage.SQL")
5545
transaction_logger = logging.getLogger("synapse.storage.txn")
@@ -77,7 +67,7 @@
7767

7868

7969
def make_pool(
80-
reactor, db_config: DatabaseConnectionConfig, engine
70+
reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
8171
) -> adbapi.ConnectionPool:
8272
"""Get the connection pool for the database.
8373
"""
@@ -90,7 +80,9 @@ def make_pool(
9080
)
9181

9282

93-
def make_conn(db_config: DatabaseConnectionConfig, engine):
83+
def make_conn(
84+
db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
85+
) -> Connection:
9486
"""Make a new connection to the database and return it.
9587
9688
Returns:
@@ -107,20 +99,27 @@ def make_conn(db_config: DatabaseConnectionConfig, engine):
10799
return db_conn
108100

109101

110-
class LoggingTransaction(object):
102+
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
103+
#
104+
# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
105+
# that mypy sees the type but the runtime python doesn't.
106+
_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
107+
108+
109+
class LoggingTransaction:
111110
"""An object that almost-transparently proxies for the 'txn' object
112111
passed to the constructor. Adds logging and metrics to the .execute()
113112
method.
114113
115114
Args:
116115
txn: The database transcation object to wrap.
117-
name (str): The name of this transactions for logging.
118-
database_engine (Sqlite3Engine|PostgresEngine)
119-
after_callbacks(list|None): A list that callbacks will be appended to
116+
name: The name of this transactions for logging.
117+
database_engine
118+
after_callbacks: A list that callbacks will be appended to
120119
that have been added by `call_after` which should be run on
121120
successful completion of the transaction. None indicates that no
122121
callbacks should be allowed to be scheduled to run.
123-
exception_callbacks(list|None): A list that callbacks will be appended
122+
exception_callbacks: A list that callbacks will be appended
124123
to that have been added by `call_on_exception` which should be run
125124
if transaction ends with an error. None indicates that no callbacks
126125
should be allowed to be scheduled to run.
@@ -135,46 +134,67 @@ class LoggingTransaction(object):
135134
]
136135

137136
def __init__(
138-
self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None
137+
self,
138+
txn: Cursor,
139+
name: str,
140+
database_engine: BaseDatabaseEngine,
141+
after_callbacks: Optional[List[_CallbackListEntry]] = None,
142+
exception_callbacks: Optional[List[_CallbackListEntry]] = None,
139143
):
140-
object.__setattr__(self, "txn", txn)
141-
object.__setattr__(self, "name", name)
142-
object.__setattr__(self, "database_engine", database_engine)
143-
object.__setattr__(self, "after_callbacks", after_callbacks)
144-
object.__setattr__(self, "exception_callbacks", exception_callbacks)
144+
self.txn = txn
145+
self.name = name
146+
self.database_engine = database_engine
147+
self.after_callbacks = after_callbacks
148+
self.exception_callbacks = exception_callbacks
145149

146-
def call_after(self, callback, *args, **kwargs):
150+
def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
147151
"""Call the given callback on the main twisted thread after the
148152
transaction has finished. Used to invalidate the caches on the
149153
correct thread.
150154
"""
155+
# if self.after_callbacks is None, that means that whatever constructed the
156+
# LoggingTransaction isn't expecting there to be any callbacks; assert that
157+
# is not the case.
158+
assert self.after_callbacks is not None
151159
self.after_callbacks.append((callback, args, kwargs))
152160

153-
def call_on_exception(self, callback, *args, **kwargs):
161+
def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
162+
# if self.exception_callbacks is None, that means that whatever constructed the
163+
# LoggingTransaction isn't expecting there to be any callbacks; assert that
164+
# is not the case.
165+
assert self.exception_callbacks is not None
154166
self.exception_callbacks.append((callback, args, kwargs))
155167

156-
def __getattr__(self, name):
157-
return getattr(self.txn, name)
168+
def fetchall(self) -> List[Tuple]:
169+
return self.txn.fetchall()
158170

159-
def __setattr__(self, name, value):
160-
setattr(self.txn, name, value)
171+
def fetchone(self) -> Tuple:
172+
return self.txn.fetchone()
161173

162-
def __iter__(self):
174+
def __iter__(self) -> Iterator[Tuple]:
163175
return self.txn.__iter__()
164176

177+
@property
178+
def rowcount(self) -> int:
179+
return self.txn.rowcount
180+
181+
@property
182+
def description(self) -> Any:
183+
return self.txn.description
184+
165185
def execute_batch(self, sql, args):
166186
if isinstance(self.database_engine, PostgresEngine):
167-
from psycopg2.extras import execute_batch
187+
from psycopg2.extras import execute_batch # type: ignore
168188

169189
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
170190
else:
171191
for val in args:
172192
self.execute(sql, val)
173193

174-
def execute(self, sql, *args):
194+
def execute(self, sql: str, *args: Any):
175195
self._do_execute(self.txn.execute, sql, *args)
176196

177-
def executemany(self, sql, *args):
197+
def executemany(self, sql: str, *args: Any):
178198
self._do_execute(self.txn.executemany, sql, *args)
179199

180200
def _make_sql_one_line(self, sql):
@@ -207,6 +227,9 @@ def _do_execute(self, func, sql, *args):
207227
sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
208228
sql_query_timer.labels(sql.split()[0]).observe(secs)
209229

230+
def close(self):
231+
self.txn.close()
232+
210233

211234
class PerformanceCounters(object):
212235
def __init__(self):
@@ -251,17 +274,19 @@ class Database(object):
251274

252275
_TXN_ID = 0
253276

254-
def __init__(self, hs, database_config: DatabaseConnectionConfig, engine):
277+
def __init__(
278+
self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
279+
):
255280
self.hs = hs
256281
self._clock = hs.get_clock()
257282
self._database_config = database_config
258283
self._db_pool = make_pool(hs.get_reactor(), database_config, engine)
259284

260285
self.updates = BackgroundUpdater(hs, self)
261286

262-
self._previous_txn_total_time = 0
263-
self._current_txn_total_time = 0
264-
self._previous_loop_ts = 0
287+
self._previous_txn_total_time = 0.0
288+
self._current_txn_total_time = 0.0
289+
self._previous_loop_ts = 0.0
265290

266291
# TODO(paul): These can eventually be removed once the metrics code
267292
# is running in mainline, and we have some nice monitoring frontends
@@ -463,23 +488,23 @@ def new_transaction(
463488
sql_txn_timer.labels(desc).observe(duration)
464489

465490
@defer.inlineCallbacks
466-
def runInteraction(self, desc, func, *args, **kwargs):
491+
def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
467492
"""Starts a transaction on the database and runs a given function
468493
469494
Arguments:
470-
desc (str): description of the transaction, for logging and metrics
471-
func (func): callback function, which will be called with a
495+
desc: description of the transaction, for logging and metrics
496+
func: callback function, which will be called with a
472497
database transaction (twisted.enterprise.adbapi.Transaction) as
473498
its first argument, followed by `args` and `kwargs`.
474499
475-
args (list): positional args to pass to `func`
476-
kwargs (dict): named args to pass to `func`
500+
args: positional args to pass to `func`
501+
kwargs: named args to pass to `func`
477502
478503
Returns:
479504
Deferred: The result of func
480505
"""
481-
after_callbacks = []
482-
exception_callbacks = []
506+
after_callbacks = [] # type: List[_CallbackListEntry]
507+
exception_callbacks = [] # type: List[_CallbackListEntry]
483508

484509
if LoggingContext.current_context() == LoggingContext.sentinel:
485510
logger.warning("Starting db txn '%s' from sentinel context", desc)
@@ -505,15 +530,15 @@ def runInteraction(self, desc, func, *args, **kwargs):
505530
return result
506531

507532
@defer.inlineCallbacks
508-
def runWithConnection(self, func, *args, **kwargs):
533+
def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
509534
"""Wraps the .runWithConnection() method on the underlying db_pool.
510535
511536
Arguments:
512-
func (func): callback function, which will be called with a
537+
func: callback function, which will be called with a
513538
database connection (twisted.enterprise.adbapi.Connection) as
514539
its first argument, followed by `args` and `kwargs`.
515-
args (list): positional args to pass to `func`
516-
kwargs (dict): named args to pass to `func`
540+
args: positional args to pass to `func`
541+
kwargs: named args to pass to `func`
517542
518543
Returns:
519544
Deferred: The result of func
@@ -800,7 +825,7 @@ def _getwhere(key):
800825
return False
801826

802827
# We didn't find any existing rows, so insert a new one
803-
allvalues = {}
828+
allvalues = {} # type: Dict[str, Any]
804829
allvalues.update(keyvalues)
805830
allvalues.update(values)
806831
allvalues.update(insertion_values)
@@ -829,7 +854,7 @@ def simple_upsert_txn_native_upsert(
829854
Returns:
830855
None
831856
"""
832-
allvalues = {}
857+
allvalues = {} # type: Dict[str, Any]
833858
allvalues.update(keyvalues)
834859
allvalues.update(insertion_values)
835860

@@ -916,7 +941,7 @@ def simple_upsert_many_txn_native_upsert(
916941
Returns:
917942
None
918943
"""
919-
allnames = []
944+
allnames = [] # type: List[str]
920945
allnames.extend(key_names)
921946
allnames.extend(value_names)
922947

@@ -1100,7 +1125,7 @@ def simple_select_many_batch(
11001125
keyvalues : dict of column names and values to select the rows with
11011126
retcols : list of strings giving the names of the columns to return
11021127
"""
1103-
results = []
1128+
results = [] # type: List[Dict[str, Any]]
11041129

11051130
if not iterable:
11061131
return results
@@ -1439,7 +1464,7 @@ def simple_select_list_paginate_txn(
14391464
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
14401465

14411466
where_clause = "WHERE " if filters or keyvalues else ""
1442-
arg_list = []
1467+
arg_list = [] # type: List[Any]
14431468
if filters:
14441469
where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
14451470
arg_list += list(filters.values())

synapse/storage/engines/__init__.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,31 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
16-
import importlib
1715
import platform
1816

19-
from ._base import IncorrectDatabaseSetup
17+
from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
2018
from .postgres import PostgresEngine
2119
from .sqlite import Sqlite3Engine
2220

23-
SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine}
24-
2521

26-
def create_engine(database_config):
22+
def create_engine(database_config) -> BaseDatabaseEngine:
2723
name = database_config["name"]
28-
engine_class = SUPPORTED_MODULE.get(name, None)
2924

30-
if engine_class:
25+
if name == "sqlite3":
26+
import sqlite3
27+
28+
return Sqlite3Engine(sqlite3, database_config)
29+
30+
if name == "psycopg2":
3131
# pypy requires psycopg2cffi rather than psycopg2
32-
if name == "psycopg2" and platform.python_implementation() == "PyPy":
33-
name = "psycopg2cffi"
34-
module = importlib.import_module(name)
35-
return engine_class(module, database_config)
32+
if platform.python_implementation() == "PyPy":
33+
import psycopg2cffi as psycopg2 # type: ignore
34+
else:
35+
import psycopg2 # type: ignore
36+
37+
return PostgresEngine(psycopg2, database_config)
3638

3739
raise RuntimeError("Unsupported database engine '%s'" % (name,))
3840

3941

40-
__all__ = ["create_engine", "IncorrectDatabaseSetup"]
42+
__all__ = ["create_engine", "BaseDatabaseEngine", "IncorrectDatabaseSetup"]

0 commit comments

Comments
 (0)