Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit eca7cff

Browse files
authored
Keep fallback key marked as used if it's re-uploaded (#11382)
1 parent e2e9bea commit eca7cff

File tree

3 files changed

+72
-12
lines changed

3 files changed

+72
-12
lines changed

changelog.d/11382.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Keep fallback key marked as used if it's re-uploaded.

synapse/storage/databases/main/end_to_end_keys.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -408,29 +408,58 @@ async def set_e2e_fallback_keys(
408408
fallback_keys: the keys to set. This is a map from key ID (which is
409409
of the form "algorithm:id") to key data.
410410
"""
411+
await self.db_pool.runInteraction(
412+
"set_e2e_fallback_keys_txn",
413+
self._set_e2e_fallback_keys_txn,
414+
user_id,
415+
device_id,
416+
fallback_keys,
417+
)
418+
419+
await self.invalidate_cache_and_stream(
420+
"get_e2e_unused_fallback_key_types", (user_id, device_id)
421+
)
422+
423+
def _set_e2e_fallback_keys_txn(
424+
self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict
425+
) -> None:
411426
# fallback_keys will usually only have one item in it, so using a for
412427
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
413428
# FIXME: make sure that only one key per algorithm is uploaded
414429
for key_id, fallback_key in fallback_keys.items():
415430
algorithm, key_id = key_id.split(":", 1)
416-
await self.db_pool.simple_upsert(
417-
"e2e_fallback_keys_json",
431+
old_key_json = self.db_pool.simple_select_one_onecol_txn(
432+
txn,
433+
table="e2e_fallback_keys_json",
418434
keyvalues={
419435
"user_id": user_id,
420436
"device_id": device_id,
421437
"algorithm": algorithm,
422438
},
423-
values={
424-
"key_id": key_id,
425-
"key_json": json_encoder.encode(fallback_key),
426-
"used": False,
427-
},
428-
desc="set_e2e_fallback_key",
439+
retcol="key_json",
440+
allow_none=True,
429441
)
430442

431-
await self.invalidate_cache_and_stream(
432-
"get_e2e_unused_fallback_key_types", (user_id, device_id)
433-
)
443+
new_key_json = encode_canonical_json(fallback_key).decode("utf-8")
444+
445+
# If the uploaded key is the same as the current fallback key,
446+
# don't do anything. This prevents marking the key as unused if it
447+
# was already used.
448+
if old_key_json != new_key_json:
449+
self.db_pool.simple_upsert_txn(
450+
txn,
451+
table="e2e_fallback_keys_json",
452+
keyvalues={
453+
"user_id": user_id,
454+
"device_id": device_id,
455+
"algorithm": algorithm,
456+
},
457+
values={
458+
"key_id": key_id,
459+
"key_json": json_encoder.encode(fallback_key),
460+
"used": False,
461+
},
462+
)
434463

435464
@cached(max_entries=10000)
436465
async def get_e2e_unused_fallback_key_types(

tests/handlers/test_e2e_keys.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def test_fallback_key(self):
162162
local_user = "@boris:" + self.hs.hostname
163163
device_id = "xyz"
164164
fallback_key = {"alg1:k1": "key1"}
165+
fallback_key2 = {"alg1:k2": "key2"}
165166
otk = {"alg1:k2": "key2"}
166167

167168
# we shouldn't have any unused fallback keys yet
@@ -213,6 +214,35 @@ def test_fallback_key(self):
213214
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
214215
)
215216

217+
# re-uploading the same fallback key should still result in no unused fallback
218+
# keys
219+
self.get_success(
220+
self.handler.upload_keys_for_user(
221+
local_user,
222+
device_id,
223+
{"org.matrix.msc2732.fallback_keys": fallback_key},
224+
)
225+
)
226+
227+
res = self.get_success(
228+
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
229+
)
230+
self.assertEqual(res, [])
231+
232+
# uploading a new fallback key should result in an unused fallback key
233+
self.get_success(
234+
self.handler.upload_keys_for_user(
235+
local_user,
236+
device_id,
237+
{"org.matrix.msc2732.fallback_keys": fallback_key2},
238+
)
239+
)
240+
241+
res = self.get_success(
242+
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
243+
)
244+
self.assertEqual(res, ["alg1"])
245+
216246
# if the user uploads a one-time key, the next claim should fetch the
217247
# one-time key, and then go back to the fallback
218248
self.get_success(
@@ -238,7 +268,7 @@ def test_fallback_key(self):
238268
)
239269
self.assertEqual(
240270
res,
241-
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
271+
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}},
242272
)
243273

244274
def test_replace_master_key(self):

0 commit comments

Comments
 (0)