15
15
# See the License for the specific language governing permissions and
16
16
# limitations under the License.
17
17
import logging
18
- import sys
19
18
import time
20
- from typing import Iterable , Tuple
19
+ from time import monotonic as monotonic_time
20
+ from typing import Any , Callable , Dict , Iterable , Iterator , List , Optional , Tuple
21
21
22
22
from six import iteritems , iterkeys , itervalues
23
23
from six .moves import intern , range
32
32
from synapse .logging .context import LoggingContext , make_deferred_yieldable
33
33
from synapse .metrics .background_process_metrics import run_as_background_process
34
34
from synapse .storage .background_updates import BackgroundUpdater
35
- from synapse .storage .engines import PostgresEngine , Sqlite3Engine
35
+ from synapse .storage .engines import BaseDatabaseEngine , PostgresEngine , Sqlite3Engine
36
+ from synapse .storage .types import Connection , Cursor
36
37
from synapse .util .stringutils import exception_to_unicode
37
38
38
- # import a function which will return a monotonic time, in seconds
39
- try :
40
- # on python 3, use time.monotonic, since time.clock can go backwards
41
- from time import monotonic as monotonic_time
42
- except ImportError :
43
- # ... but python 2 doesn't have it
44
- from time import clock as monotonic_time
45
-
46
39
logger = logging .getLogger (__name__ )
47
40
48
- try :
49
- MAX_TXN_ID = sys .maxint - 1
50
- except AttributeError :
51
- # python 3 does not have a maximum int value
52
- MAX_TXN_ID = 2 ** 63 - 1
41
+ # python 3 does not have a maximum int value
42
+ MAX_TXN_ID = 2 ** 63 - 1
53
43
54
44
sql_logger = logging .getLogger ("synapse.storage.SQL" )
55
45
transaction_logger = logging .getLogger ("synapse.storage.txn" )
77
67
78
68
79
69
def make_pool (
80
- reactor , db_config : DatabaseConnectionConfig , engine
70
+ reactor , db_config : DatabaseConnectionConfig , engine : BaseDatabaseEngine
81
71
) -> adbapi .ConnectionPool :
82
72
"""Get the connection pool for the database.
83
73
"""
@@ -90,7 +80,9 @@ def make_pool(
90
80
)
91
81
92
82
93
- def make_conn (db_config : DatabaseConnectionConfig , engine ):
83
+ def make_conn (
84
+ db_config : DatabaseConnectionConfig , engine : BaseDatabaseEngine
85
+ ) -> Connection :
94
86
"""Make a new connection to the database and return it.
95
87
96
88
Returns:
@@ -107,20 +99,27 @@ def make_conn(db_config: DatabaseConnectionConfig, engine):
107
99
return db_conn
108
100
109
101
110
- class LoggingTransaction (object ):
102
+ # The type of entry which goes on our after_callbacks and exception_callbacks lists.
103
+ #
104
+ # Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
105
+ # that mypy sees the type but the runtime python doesn't.
106
+ _CallbackListEntry = Tuple ["Callable[..., None]" , Iterable [Any ], Dict [str , Any ]]
107
+
108
+
109
+ class LoggingTransaction :
111
110
"""An object that almost-transparently proxies for the 'txn' object
112
111
passed to the constructor. Adds logging and metrics to the .execute()
113
112
method.
114
113
115
114
Args:
116
115
txn: The database transcation object to wrap.
117
- name (str) : The name of this transactions for logging.
118
- database_engine (Sqlite3Engine|PostgresEngine)
119
- after_callbacks(list|None) : A list that callbacks will be appended to
116
+ name: The name of this transactions for logging.
117
+ database_engine
118
+ after_callbacks: A list that callbacks will be appended to
120
119
that have been added by `call_after` which should be run on
121
120
successful completion of the transaction. None indicates that no
122
121
callbacks should be allowed to be scheduled to run.
123
- exception_callbacks(list|None) : A list that callbacks will be appended
122
+ exception_callbacks: A list that callbacks will be appended
124
123
to that have been added by `call_on_exception` which should be run
125
124
if transaction ends with an error. None indicates that no callbacks
126
125
should be allowed to be scheduled to run.
@@ -135,46 +134,67 @@ class LoggingTransaction(object):
135
134
]
136
135
137
136
def __init__ (
138
- self , txn , name , database_engine , after_callbacks = None , exception_callbacks = None
137
+ self ,
138
+ txn : Cursor ,
139
+ name : str ,
140
+ database_engine : BaseDatabaseEngine ,
141
+ after_callbacks : Optional [List [_CallbackListEntry ]] = None ,
142
+ exception_callbacks : Optional [List [_CallbackListEntry ]] = None ,
139
143
):
140
- object . __setattr__ ( self , " txn" , txn )
141
- object . __setattr__ ( self , " name" , name )
142
- object . __setattr__ ( self , " database_engine" , database_engine )
143
- object . __setattr__ ( self , " after_callbacks" , after_callbacks )
144
- object . __setattr__ ( self , " exception_callbacks" , exception_callbacks )
144
+ self . txn = txn
145
+ self . name = name
146
+ self . database_engine = database_engine
147
+ self . after_callbacks = after_callbacks
148
+ self . exception_callbacks = exception_callbacks
145
149
146
- def call_after (self , callback , * args , ** kwargs ):
150
+ def call_after (self , callback : "Callable[..., None]" , * args , ** kwargs ):
147
151
"""Call the given callback on the main twisted thread after the
148
152
transaction has finished. Used to invalidate the caches on the
149
153
correct thread.
150
154
"""
155
+ # if self.after_callbacks is None, that means that whatever constructed the
156
+ # LoggingTransaction isn't expecting there to be any callbacks; assert that
157
+ # is not the case.
158
+ assert self .after_callbacks is not None
151
159
self .after_callbacks .append ((callback , args , kwargs ))
152
160
153
- def call_on_exception (self , callback , * args , ** kwargs ):
161
+ def call_on_exception (self , callback : "Callable[..., None]" , * args , ** kwargs ):
162
+ # if self.exception_callbacks is None, that means that whatever constructed the
163
+ # LoggingTransaction isn't expecting there to be any callbacks; assert that
164
+ # is not the case.
165
+ assert self .exception_callbacks is not None
154
166
self .exception_callbacks .append ((callback , args , kwargs ))
155
167
156
- def __getattr__ (self , name ) :
157
- return getattr ( self .txn , name )
168
+ def fetchall (self ) -> List [ Tuple ] :
169
+ return self .txn . fetchall ( )
158
170
159
- def __setattr__ (self , name , value ) :
160
- setattr ( self .txn , name , value )
171
+ def fetchone (self ) -> Tuple :
172
+ return self .txn . fetchone ( )
161
173
162
- def __iter__ (self ):
174
+ def __iter__ (self ) -> Iterator [ Tuple ] :
163
175
return self .txn .__iter__ ()
164
176
177
+ @property
178
+ def rowcount (self ) -> int :
179
+ return self .txn .rowcount
180
+
181
+ @property
182
+ def description (self ) -> Any :
183
+ return self .txn .description
184
+
165
185
def execute_batch (self , sql , args ):
166
186
if isinstance (self .database_engine , PostgresEngine ):
167
- from psycopg2 .extras import execute_batch
187
+ from psycopg2 .extras import execute_batch # type: ignore
168
188
169
189
self ._do_execute (lambda * x : execute_batch (self .txn , * x ), sql , args )
170
190
else :
171
191
for val in args :
172
192
self .execute (sql , val )
173
193
174
- def execute (self , sql , * args ):
194
+ def execute (self , sql : str , * args : Any ):
175
195
self ._do_execute (self .txn .execute , sql , * args )
176
196
177
- def executemany (self , sql , * args ):
197
+ def executemany (self , sql : str , * args : Any ):
178
198
self ._do_execute (self .txn .executemany , sql , * args )
179
199
180
200
def _make_sql_one_line (self , sql ):
@@ -207,6 +227,9 @@ def _do_execute(self, func, sql, *args):
207
227
sql_logger .debug ("[SQL time] {%s} %f sec" , self .name , secs )
208
228
sql_query_timer .labels (sql .split ()[0 ]).observe (secs )
209
229
230
+ def close (self ):
231
+ self .txn .close ()
232
+
210
233
211
234
class PerformanceCounters (object ):
212
235
def __init__ (self ):
@@ -251,17 +274,19 @@ class Database(object):
251
274
252
275
_TXN_ID = 0
253
276
254
- def __init__ (self , hs , database_config : DatabaseConnectionConfig , engine ):
277
+ def __init__ (
278
+ self , hs , database_config : DatabaseConnectionConfig , engine : BaseDatabaseEngine
279
+ ):
255
280
self .hs = hs
256
281
self ._clock = hs .get_clock ()
257
282
self ._database_config = database_config
258
283
self ._db_pool = make_pool (hs .get_reactor (), database_config , engine )
259
284
260
285
self .updates = BackgroundUpdater (hs , self )
261
286
262
- self ._previous_txn_total_time = 0
263
- self ._current_txn_total_time = 0
264
- self ._previous_loop_ts = 0
287
+ self ._previous_txn_total_time = 0.0
288
+ self ._current_txn_total_time = 0.0
289
+ self ._previous_loop_ts = 0.0
265
290
266
291
# TODO(paul): These can eventually be removed once the metrics code
267
292
# is running in mainline, and we have some nice monitoring frontends
@@ -463,23 +488,23 @@ def new_transaction(
463
488
sql_txn_timer .labels (desc ).observe (duration )
464
489
465
490
@defer .inlineCallbacks
466
- def runInteraction (self , desc , func , * args , ** kwargs ):
491
+ def runInteraction (self , desc : str , func : Callable , * args : Any , ** kwargs : Any ):
467
492
"""Starts a transaction on the database and runs a given function
468
493
469
494
Arguments:
470
- desc (str) : description of the transaction, for logging and metrics
471
- func (func) : callback function, which will be called with a
495
+ desc: description of the transaction, for logging and metrics
496
+ func: callback function, which will be called with a
472
497
database transaction (twisted.enterprise.adbapi.Transaction) as
473
498
its first argument, followed by `args` and `kwargs`.
474
499
475
- args (list) : positional args to pass to `func`
476
- kwargs (dict) : named args to pass to `func`
500
+ args: positional args to pass to `func`
501
+ kwargs: named args to pass to `func`
477
502
478
503
Returns:
479
504
Deferred: The result of func
480
505
"""
481
- after_callbacks = []
482
- exception_callbacks = []
506
+ after_callbacks = [] # type: List[_CallbackListEntry]
507
+ exception_callbacks = [] # type: List[_CallbackListEntry]
483
508
484
509
if LoggingContext .current_context () == LoggingContext .sentinel :
485
510
logger .warning ("Starting db txn '%s' from sentinel context" , desc )
@@ -505,15 +530,15 @@ def runInteraction(self, desc, func, *args, **kwargs):
505
530
return result
506
531
507
532
@defer .inlineCallbacks
508
- def runWithConnection (self , func , * args , ** kwargs ):
533
+ def runWithConnection (self , func : Callable , * args : Any , ** kwargs : Any ):
509
534
"""Wraps the .runWithConnection() method on the underlying db_pool.
510
535
511
536
Arguments:
512
- func (func) : callback function, which will be called with a
537
+ func: callback function, which will be called with a
513
538
database connection (twisted.enterprise.adbapi.Connection) as
514
539
its first argument, followed by `args` and `kwargs`.
515
- args (list) : positional args to pass to `func`
516
- kwargs (dict) : named args to pass to `func`
540
+ args: positional args to pass to `func`
541
+ kwargs: named args to pass to `func`
517
542
518
543
Returns:
519
544
Deferred: The result of func
@@ -800,7 +825,7 @@ def _getwhere(key):
800
825
return False
801
826
802
827
# We didn't find any existing rows, so insert a new one
803
- allvalues = {}
828
+ allvalues = {} # type: Dict[str, Any]
804
829
allvalues .update (keyvalues )
805
830
allvalues .update (values )
806
831
allvalues .update (insertion_values )
@@ -829,7 +854,7 @@ def simple_upsert_txn_native_upsert(
829
854
Returns:
830
855
None
831
856
"""
832
- allvalues = {}
857
+ allvalues = {} # type: Dict[str, Any]
833
858
allvalues .update (keyvalues )
834
859
allvalues .update (insertion_values )
835
860
@@ -916,7 +941,7 @@ def simple_upsert_many_txn_native_upsert(
916
941
Returns:
917
942
None
918
943
"""
919
- allnames = []
944
+ allnames = [] # type: List[str]
920
945
allnames .extend (key_names )
921
946
allnames .extend (value_names )
922
947
@@ -1100,7 +1125,7 @@ def simple_select_many_batch(
1100
1125
keyvalues : dict of column names and values to select the rows with
1101
1126
retcols : list of strings giving the names of the columns to return
1102
1127
"""
1103
- results = []
1128
+ results = [] # type: List[Dict[str, Any]]
1104
1129
1105
1130
if not iterable :
1106
1131
return results
@@ -1439,7 +1464,7 @@ def simple_select_list_paginate_txn(
1439
1464
raise ValueError ("order_direction must be one of 'ASC' or 'DESC'." )
1440
1465
1441
1466
where_clause = "WHERE " if filters or keyvalues else ""
1442
- arg_list = []
1467
+ arg_list = [] # type: List[Any]
1443
1468
if filters :
1444
1469
where_clause += " AND " .join ("%s LIKE ?" % (k ,) for k in filters )
1445
1470
arg_list += list (filters .values ())
0 commit comments