23
23
from typing import (
24
24
TYPE_CHECKING ,
25
25
Any ,
26
+ Awaitable ,
26
27
Callable ,
27
28
Collection ,
28
29
Dict ,
57
58
from synapse .storage .background_updates import BackgroundUpdater
58
59
from synapse .storage .engines import BaseDatabaseEngine , PostgresEngine , Sqlite3Engine
59
60
from synapse .storage .types import Connection , Cursor
60
- from synapse .util .async_helpers import delay_cancellation , maybe_awaitable
61
+ from synapse .util .async_helpers import delay_cancellation
61
62
from synapse .util .iterutils import batch_iter
62
63
63
64
if TYPE_CHECKING :
@@ -168,6 +169,7 @@ def cursor(
168
169
* ,
169
170
txn_name : Optional [str ] = None ,
170
171
after_callbacks : Optional [List ["_CallbackListEntry" ]] = None ,
172
+ async_after_callbacks : Optional [List ["_AsyncCallbackListEntry" ]] = None ,
171
173
exception_callbacks : Optional [List ["_CallbackListEntry" ]] = None ,
172
174
) -> "LoggingTransaction" :
173
175
if not txn_name :
@@ -178,6 +180,7 @@ def cursor(
178
180
name = txn_name ,
179
181
database_engine = self .engine ,
180
182
after_callbacks = after_callbacks ,
183
+ async_after_callbacks = async_after_callbacks ,
181
184
exception_callbacks = exception_callbacks ,
182
185
)
183
186
@@ -209,6 +212,9 @@ def __getattr__(self, name: str) -> Any:
209
212
210
213
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
211
214
_CallbackListEntry = Tuple [Callable [..., object ], Tuple [object , ...], Dict [str , object ]]
215
+ _AsyncCallbackListEntry = Tuple [
216
+ Callable [..., Awaitable ], Tuple [object , ...], Dict [str , object ]
217
+ ]
212
218
213
219
P = ParamSpec ("P" )
214
220
R = TypeVar ("R" )
@@ -227,6 +233,10 @@ class LoggingTransaction:
227
233
that have been added by `call_after` which should be run on
228
234
successful completion of the transaction. None indicates that no
229
235
callbacks should be allowed to be scheduled to run.
236
+ async_after_callbacks: A list that asynchronous callbacks will be appended
237
+ to by `async_call_after` which should run, before after_callbacks, on
238
+ successful completion of the transaction. None indicates that no
239
+ callbacks should be allowed to be scheduled to run.
230
240
exception_callbacks: A list that callbacks will be appended
231
241
to that have been added by `call_on_exception` which should be run
232
242
if transaction ends with an error. None indicates that no callbacks
@@ -238,6 +248,7 @@ class LoggingTransaction:
238
248
"name" ,
239
249
"database_engine" ,
240
250
"after_callbacks" ,
251
+ "async_after_callbacks" ,
241
252
"exception_callbacks" ,
242
253
]
243
254
@@ -247,12 +258,14 @@ def __init__(
247
258
name : str ,
248
259
database_engine : BaseDatabaseEngine ,
249
260
after_callbacks : Optional [List [_CallbackListEntry ]] = None ,
261
+ async_after_callbacks : Optional [List [_AsyncCallbackListEntry ]] = None ,
250
262
exception_callbacks : Optional [List [_CallbackListEntry ]] = None ,
251
263
):
252
264
self .txn = txn
253
265
self .name = name
254
266
self .database_engine = database_engine
255
267
self .after_callbacks = after_callbacks
268
+ self .async_after_callbacks = async_after_callbacks
256
269
self .exception_callbacks = exception_callbacks
257
270
258
271
def call_after (
@@ -277,6 +290,28 @@ def call_after(
277
290
# type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
278
291
self .after_callbacks .append ((callback , args , kwargs )) # type: ignore[arg-type]
279
292
293
+ def async_call_after (
294
+ self , callback : Callable [P , Awaitable ], * args : P .args , ** kwargs : P .kwargs
295
+ ) -> None :
296
+ """Call the given asynchronous callback on the main twisted thread after
297
+ the transaction has finished (but before those added in `call_after`).
298
+
299
+ Mostly used to invalidate remote caches after transactions.
300
+
301
+ Note that transactions may be retried a few times if they encounter database
302
+ errors such as serialization failures. Callbacks given to `async_call_after`
303
+ will accumulate across transaction attempts and will _all_ be called once a
304
+ transaction attempt succeeds, regardless of whether previous transaction
305
+ attempts failed. Otherwise, if all transaction attempts fail, all
306
+ `call_on_exception` callbacks will be run instead.
307
+ """
308
+ # if self.async_after_callbacks is None, that means that whatever constructed the
309
+ # LoggingTransaction isn't expecting there to be any callbacks; assert that
310
+ # is not the case.
311
+ assert self .async_after_callbacks is not None
312
+ # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
313
+ self .async_after_callbacks .append ((callback , args , kwargs )) # type: ignore[arg-type]
314
+
280
315
def call_on_exception (
281
316
self , callback : Callable [P , object ], * args : P .args , ** kwargs : P .kwargs
282
317
) -> None :
@@ -574,6 +609,7 @@ def new_transaction(
574
609
conn : LoggingDatabaseConnection ,
575
610
desc : str ,
576
611
after_callbacks : List [_CallbackListEntry ],
612
+ async_after_callbacks : List [_AsyncCallbackListEntry ],
577
613
exception_callbacks : List [_CallbackListEntry ],
578
614
func : Callable [Concatenate [LoggingTransaction , P ], R ],
579
615
* args : P .args ,
@@ -597,6 +633,7 @@ def new_transaction(
597
633
conn
598
634
desc
599
635
after_callbacks
636
+ async_after_callbacks
600
637
exception_callbacks
601
638
func
602
639
*args
@@ -659,6 +696,7 @@ def new_transaction(
659
696
cursor = conn .cursor (
660
697
txn_name = name ,
661
698
after_callbacks = after_callbacks ,
699
+ async_after_callbacks = async_after_callbacks ,
662
700
exception_callbacks = exception_callbacks ,
663
701
)
664
702
try :
@@ -798,6 +836,7 @@ async def runInteraction(
798
836
799
837
async def _runInteraction () -> R :
800
838
after_callbacks : List [_CallbackListEntry ] = []
839
+ async_after_callbacks : List [_AsyncCallbackListEntry ] = []
801
840
exception_callbacks : List [_CallbackListEntry ] = []
802
841
803
842
if not current_context ():
@@ -809,6 +848,7 @@ async def _runInteraction() -> R:
809
848
self .new_transaction ,
810
849
desc ,
811
850
after_callbacks ,
851
+ async_after_callbacks ,
812
852
exception_callbacks ,
813
853
func ,
814
854
* args ,
@@ -817,15 +857,17 @@ async def _runInteraction() -> R:
817
857
** kwargs ,
818
858
)
819
859
860
+ # We order these assuming that async functions call out to external
861
+ # systems (e.g. to invalidate a cache) and the sync functions make these
862
+ # changes on any local in-memory caches/similar, and thus must be second.
863
+ for async_callback , async_args , async_kwargs in async_after_callbacks :
864
+ await async_callback (* async_args , ** async_kwargs )
820
865
for after_callback , after_args , after_kwargs in after_callbacks :
821
- await maybe_awaitable (after_callback (* after_args , ** after_kwargs ))
822
-
866
+ after_callback (* after_args , ** after_kwargs )
823
867
return cast (R , result )
824
868
except Exception :
825
869
for exception_callback , after_args , after_kwargs in exception_callbacks :
826
- await maybe_awaitable (
827
- exception_callback (* after_args , ** after_kwargs )
828
- )
870
+ exception_callback (* after_args , ** after_kwargs )
829
871
raise
830
872
831
873
# To handle cancellation, we ensure that `after_callback`s and
0 commit comments