13
13
# limitations under the License.
14
14
15
15
import logging
16
- from typing import TYPE_CHECKING , Any , Mapping , Optional , cast
16
+ from typing import TYPE_CHECKING , Any , Mapping , NoReturn , Optional , Tuple , cast
17
17
18
18
from synapse .storage .engines ._base import (
19
19
BaseDatabaseEngine ,
20
20
IncorrectDatabaseSetup ,
21
21
IsolationLevel ,
22
22
)
23
- from synapse .storage .types import Connection
23
+ from synapse .storage .types import Cursor
24
24
25
25
if TYPE_CHECKING :
26
26
import psycopg2 # noqa: F401
27
27
28
+ from synapse .storage .database import LoggingDatabaseConnection
29
+
30
+
28
31
logger = logging .getLogger (__name__ )
29
32
30
33
@@ -37,11 +40,11 @@ def __init__(self, database_config: Mapping[str, Any]):
37
40
38
41
# Disables passing `bytes` to txn.execute, c.f. #6186. If you do
39
42
# actually want to use bytes than wrap it in `bytearray`.
40
- def _disable_bytes_adapter (_ ) :
43
+ def _disable_bytes_adapter (_ : bytes ) -> NoReturn :
41
44
raise Exception ("Passing bytes to DB is disabled." )
42
45
43
46
psycopg2 .extensions .register_adapter (bytes , _disable_bytes_adapter )
44
- self .synchronous_commit = database_config .get ("synchronous_commit" , True )
47
+ self .synchronous_commit : bool = database_config .get ("synchronous_commit" , True )
45
48
self ._version : Optional [int ] = None # unknown as yet
46
49
47
50
self .isolation_level_map : Mapping [int , int ] = {
@@ -58,14 +61,16 @@ def _disable_bytes_adapter(_):
58
61
def single_threaded (self ) -> bool :
59
62
return False
60
63
61
- def get_db_locale (self , txn ) :
64
+ def get_db_locale (self , txn : Cursor ) -> Tuple [ str , str ] :
62
65
txn .execute (
63
66
"SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
64
67
)
65
- collation , ctype = txn .fetchone ()
68
+ collation , ctype = cast ( Tuple [ str , str ], txn .fetchone () )
66
69
return collation , ctype
67
70
68
- def check_database (self , db_conn , allow_outdated_version : bool = False ):
71
+ def check_database (
72
+ self , db_conn : "psycopg2.connection" , allow_outdated_version : bool = False
73
+ ) -> None :
69
74
# Get the version of PostgreSQL that we're using. As per the psycopg2
70
75
# docs: The number is formed by converting the major, minor, and
71
76
# revision numbers into two-decimal-digit numbers and appending them
@@ -113,7 +118,7 @@ def check_database(self, db_conn, allow_outdated_version: bool = False):
113
118
ctype ,
114
119
)
115
120
116
- def check_new_database (self , txn ) :
121
+ def check_new_database (self , txn : Cursor ) -> None :
117
122
"""Gets called when setting up a brand new database. This allows us to
118
123
apply stricter checks on new databases versus existing database.
119
124
"""
@@ -134,10 +139,10 @@ def check_new_database(self, txn):
134
139
"See docs/postgres.md for more information." % ("\n " .join (errors ))
135
140
)
136
141
137
- def convert_param_style (self , sql ) :
142
+ def convert_param_style (self , sql : str ) -> str :
138
143
return sql .replace ("?" , "%s" )
139
144
140
- def on_new_connection (self , db_conn ) :
145
+ def on_new_connection (self , db_conn : "LoggingDatabaseConnection" ) -> None :
141
146
db_conn .set_isolation_level (self .default_isolation_level )
142
147
143
148
# Set the bytea output to escape, vs the default of hex
@@ -154,14 +159,14 @@ def on_new_connection(self, db_conn):
154
159
db_conn .commit ()
155
160
156
161
@property
157
- def can_native_upsert (self ):
162
+ def can_native_upsert (self ) -> bool :
158
163
"""
159
164
Can we use native UPSERTs?
160
165
"""
161
166
return True
162
167
163
168
@property
164
- def supports_using_any_list (self ):
169
+ def supports_using_any_list (self ) -> bool :
165
170
"""Do we support using `a = ANY(?)` and passing a list"""
166
171
return True
167
172
@@ -170,7 +175,7 @@ def supports_returning(self) -> bool:
170
175
"""Do we support the `RETURNING` clause in insert/update/delete?"""
171
176
return True
172
177
173
- def is_deadlock (self , error ) :
178
+ def is_deadlock (self , error : Exception ) -> bool :
174
179
import psycopg2 .extensions
175
180
176
181
if isinstance (error , psycopg2 .DatabaseError ):
@@ -180,19 +185,15 @@ def is_deadlock(self, error):
180
185
return error .pgcode in ["40001" , "40P01" ]
181
186
return False
182
187
183
- def is_connection_closed (self , conn ) :
188
+ def is_connection_closed (self , conn : "psycopg2.connection" ) -> bool :
184
189
return bool (conn .closed )
185
190
186
- def lock_table (self , txn , table ) :
191
+ def lock_table (self , txn : Cursor , table : str ) -> None :
187
192
txn .execute ("LOCK TABLE %s in EXCLUSIVE MODE" % (table ,))
188
193
189
194
@property
190
- def server_version (self ):
191
- """Returns a string giving the server version. For example: '8.1.5'
192
-
193
- Returns:
194
- string
195
- """
195
+ def server_version (self ) -> str :
196
+ """Returns a string giving the server version. For example: '8.1.5'."""
196
197
# note that this is a bit of a hack because it relies on check_database
197
198
# having been called. Still, that should be a safe bet here.
198
199
numver = self ._version
@@ -204,19 +205,21 @@ def server_version(self):
204
205
else :
205
206
return "%i.%i.%i" % (numver / 10000 , (numver % 10000 ) / 100 , numver % 100 )
206
207
207
- def in_transaction (self , conn : Connection ) -> bool :
208
+ def in_transaction (self , conn : "psycopg2.connection" ) -> bool :
208
209
import psycopg2 .extensions
209
210
210
- return conn .status != psycopg2 .extensions .STATUS_READY # type: ignore
211
+ return conn .status != psycopg2 .extensions .STATUS_READY
211
212
212
- def attempt_to_set_autocommit (self , conn : Connection , autocommit : bool ):
213
- return conn .set_session (autocommit = autocommit ) # type: ignore
213
+ def attempt_to_set_autocommit (
214
+ self , conn : "psycopg2.connection" , autocommit : bool
215
+ ) -> None :
216
+ return conn .set_session (autocommit = autocommit )
214
217
215
218
def attempt_to_set_isolation_level (
216
- self , conn : Connection , isolation_level : Optional [int ]
217
- ):
219
+ self , conn : "psycopg2.connection" , isolation_level : Optional [int ]
220
+ ) -> None :
218
221
if isolation_level is None :
219
222
isolation_level = self .default_isolation_level
220
223
else :
221
224
isolation_level = self .isolation_level_map [isolation_level ]
222
- return conn .set_isolation_level (isolation_level ) # type: ignore
225
+ return conn .set_isolation_level (isolation_level )
0 commit comments