Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 2b35626

Browse files
authored
Refactor storing of server keys (#16261)
1 parent 9400dc0 commit 2b35626

File tree

6 files changed

+106
-365
lines changed

6 files changed

+106
-365
lines changed

changelog.d/16261.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Simplify server key storage.

synapse/crypto/keyring.py

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,7 @@
2323
get_verify_key,
2424
is_signing_algorithm_supported,
2525
)
26-
from signedjson.sign import (
27-
SignatureVerifyException,
28-
encode_canonical_json,
29-
signature_ids,
30-
verify_signed_json,
31-
)
26+
from signedjson.sign import SignatureVerifyException, signature_ids, verify_signed_json
3227
from signedjson.types import VerifyKey
3328
from unpaddedbase64 import decode_base64
3429

@@ -596,24 +591,12 @@ async def process_v2_response(
596591
verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
597592
)
598593

599-
key_json_bytes = encode_canonical_json(response_json)
600-
601-
await make_deferred_yieldable(
602-
defer.gatherResults(
603-
[
604-
run_in_background(
605-
self.store.store_server_keys_json,
606-
server_name=server_name,
607-
key_id=key_id,
608-
from_server=from_server,
609-
ts_now_ms=time_added_ms,
610-
ts_expires_ms=ts_valid_until_ms,
611-
key_json_bytes=key_json_bytes,
612-
)
613-
for key_id in verify_keys
614-
],
615-
consumeErrors=True,
616-
).addErrback(unwrapFirstError)
594+
await self.store.store_server_keys_response(
595+
server_name=server_name,
596+
from_server=from_server,
597+
ts_added_ms=time_added_ms,
598+
verify_keys=verify_keys,
599+
response_json=response_json,
617600
)
618601

619602
return verify_keys
@@ -775,10 +758,6 @@ async def get_server_verify_key_v2_indirect(
775758

776759
keys.setdefault(server_name, {}).update(processed_response)
777760

778-
await self.store.store_server_signature_keys(
779-
perspective_name, time_now_ms, added_keys
780-
)
781-
782761
return keys
783762

784763
def _validate_perspectives_response(

synapse/storage/databases/main/keys.py

Lines changed: 72 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,17 @@
1616
import itertools
1717
import json
1818
import logging
19-
from typing import Dict, Iterable, Mapping, Optional, Tuple
19+
from typing import Dict, Iterable, Optional, Tuple
2020

21+
from canonicaljson import encode_canonical_json
2122
from signedjson.key import decode_verify_key_bytes
2223
from unpaddedbase64 import decode_base64
2324

25+
from synapse.storage.database import LoggingTransaction
2426
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
2527
from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote
2628
from synapse.storage.types import Cursor
29+
from synapse.types import JsonDict
2730
from synapse.util.caches.descriptors import cached, cachedList
2831
from synapse.util.iterutils import batch_iter
2932

@@ -36,162 +39,84 @@
3639
class KeyStore(CacheInvalidationWorkerStore):
3740
"""Persistence for signature verification keys"""
3841

39-
@cached()
40-
def _get_server_signature_key(
41-
self, server_name_and_key_id: Tuple[str, str]
42-
) -> FetchKeyResult:
43-
raise NotImplementedError()
44-
45-
@cachedList(
46-
cached_method_name="_get_server_signature_key",
47-
list_name="server_name_and_key_ids",
48-
)
49-
async def get_server_signature_keys(
50-
self, server_name_and_key_ids: Iterable[Tuple[str, str]]
51-
) -> Dict[Tuple[str, str], FetchKeyResult]:
52-
"""
53-
Args:
54-
server_name_and_key_ids:
55-
iterable of (server_name, key-id) tuples to fetch keys for
56-
57-
Returns:
58-
A map from (server_name, key_id) -> FetchKeyResult, or None if the
59-
key is unknown
60-
"""
61-
keys = {}
62-
63-
def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None:
64-
"""Processes a batch of keys to fetch, and adds the result to `keys`."""
65-
66-
# batch_iter always returns tuples so it's safe to do len(batch)
67-
sql = """
68-
SELECT server_name, key_id, verify_key, ts_valid_until_ms
69-
FROM server_signature_keys WHERE 1=0
70-
""" + " OR (server_name=? AND key_id=?)" * len(
71-
batch
72-
)
73-
74-
txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
75-
76-
for row in txn:
77-
server_name, key_id, key_bytes, ts_valid_until_ms = row
78-
79-
if ts_valid_until_ms is None:
80-
# Old keys may be stored with a ts_valid_until_ms of null,
81-
# in which case we treat this as if it was set to `0`, i.e.
82-
# it won't match key requests that define a minimum
83-
# `ts_valid_until_ms`.
84-
ts_valid_until_ms = 0
85-
86-
keys[(server_name, key_id)] = FetchKeyResult(
87-
verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
88-
valid_until_ts=ts_valid_until_ms,
89-
)
90-
91-
def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:
92-
for batch in batch_iter(server_name_and_key_ids, 50):
93-
_get_keys(txn, batch)
94-
return keys
95-
96-
return await self.db_pool.runInteraction("get_server_signature_keys", _txn)
97-
98-
async def store_server_signature_keys(
42+
async def store_server_keys_response(
9943
self,
44+
server_name: str,
10045
from_server: str,
10146
ts_added_ms: int,
102-
verify_keys: Mapping[Tuple[str, str], FetchKeyResult],
47+
verify_keys: Dict[str, FetchKeyResult],
48+
response_json: JsonDict,
10349
) -> None:
104-
"""Stores NACL verification keys for remote servers.
50+
"""Stores the keys for the given server that we got from `from_server`.
51+
10552
Args:
106-
from_server: Where the verification keys were looked up
107-
ts_added_ms: The time to record that the key was added
108-
verify_keys:
109-
keys to be stored. Each entry is a triplet of
110-
(server_name, key_id, key).
53+
server_name: The owner of the keys
54+
from_server: Which server we got the keys from
55+
ts_added_ms: When we're adding the keys
56+
verify_keys: The decoded keys
57+
response_json: The full *signed* response JSON that contains the keys.
11158
"""
112-
key_values = []
113-
value_values = []
114-
invalidations = []
115-
for (server_name, key_id), fetch_result in verify_keys.items():
116-
key_values.append((server_name, key_id))
117-
value_values.append(
118-
(
119-
from_server,
120-
ts_added_ms,
121-
fetch_result.valid_until_ts,
122-
db_binary_type(fetch_result.verify_key.encode()),
123-
)
124-
)
125-
# invalidate takes a tuple corresponding to the params of
126-
# _get_server_signature_key. _get_server_signature_key only takes one
127-
# param, which is itself the 2-tuple (server_name, key_id).
128-
invalidations.append((server_name, key_id))
12959

130-
await self.db_pool.simple_upsert_many(
131-
table="server_signature_keys",
132-
key_names=("server_name", "key_id"),
133-
key_values=key_values,
134-
value_names=(
135-
"from_server",
136-
"ts_added_ms",
137-
"ts_valid_until_ms",
138-
"verify_key",
139-
),
140-
value_values=value_values,
141-
desc="store_server_signature_keys",
142-
)
60+
key_json_bytes = encode_canonical_json(response_json)
61+
62+
def store_server_keys_response_txn(txn: LoggingTransaction) -> None:
63+
self.db_pool.simple_upsert_many_txn(
64+
txn,
65+
table="server_signature_keys",
66+
key_names=("server_name", "key_id"),
67+
key_values=[(server_name, key_id) for key_id in verify_keys],
68+
value_names=(
69+
"from_server",
70+
"ts_added_ms",
71+
"ts_valid_until_ms",
72+
"verify_key",
73+
),
74+
value_values=[
75+
(
76+
from_server,
77+
ts_added_ms,
78+
fetch_result.valid_until_ts,
79+
db_binary_type(fetch_result.verify_key.encode()),
80+
)
81+
for fetch_result in verify_keys.values()
82+
],
83+
)
14384

144-
invalidate = self._get_server_signature_key.invalidate
145-
for i in invalidations:
146-
invalidate((i,))
85+
self.db_pool.simple_upsert_many_txn(
86+
txn,
87+
table="server_keys_json",
88+
key_names=("server_name", "key_id", "from_server"),
89+
key_values=[
90+
(server_name, key_id, from_server) for key_id in verify_keys
91+
],
92+
value_names=(
93+
"ts_added_ms",
94+
"ts_valid_until_ms",
95+
"key_json",
96+
),
97+
value_values=[
98+
(
99+
ts_added_ms,
100+
fetch_result.valid_until_ts,
101+
db_binary_type(key_json_bytes),
102+
)
103+
for fetch_result in verify_keys.values()
104+
],
105+
)
147106

148-
async def store_server_keys_json(
149-
self,
150-
server_name: str,
151-
key_id: str,
152-
from_server: str,
153-
ts_now_ms: int,
154-
ts_expires_ms: int,
155-
key_json_bytes: bytes,
156-
) -> None:
157-
"""Stores the JSON bytes for a set of keys from a server
158-
The JSON should be signed by the originating server, the intermediate
159-
server, and by this server. Updates the value for the
160-
(server_name, key_id, from_server) triplet if one already existed.
161-
Args:
162-
server_name: The name of the server.
163-
key_id: The identifier of the key this JSON is for.
164-
from_server: The server this JSON was fetched from.
165-
ts_now_ms: The time now in milliseconds.
166-
ts_valid_until_ms: The time when this json stops being valid.
167-
key_json_bytes: The encoded JSON.
168-
"""
169-
await self.db_pool.simple_upsert(
170-
table="server_keys_json",
171-
keyvalues={
172-
"server_name": server_name,
173-
"key_id": key_id,
174-
"from_server": from_server,
175-
},
176-
values={
177-
"server_name": server_name,
178-
"key_id": key_id,
179-
"from_server": from_server,
180-
"ts_added_ms": ts_now_ms,
181-
"ts_valid_until_ms": ts_expires_ms,
182-
"key_json": db_binary_type(key_json_bytes),
183-
},
184-
desc="store_server_keys_json",
185-
)
107+
# invalidate takes a tuple corresponding to the params of
108+
# _get_server_keys_json. _get_server_keys_json only takes one
109+
# param, which is itself the 2-tuple (server_name, key_id).
110+
for key_id in verify_keys:
111+
self._invalidate_cache_and_stream(
112+
txn, self._get_server_keys_json, ((server_name, key_id),)
113+
)
114+
self._invalidate_cache_and_stream(
115+
txn, self.get_server_key_json_for_remote, (server_name, key_id)
116+
)
186117

187-
# invalidate takes a tuple corresponding to the params of
188-
# _get_server_keys_json. _get_server_keys_json only takes one
189-
# param, which is itself the 2-tuple (server_name, key_id).
190-
await self.invalidate_cache_and_stream(
191-
"_get_server_keys_json", ((server_name, key_id),)
192-
)
193-
await self.invalidate_cache_and_stream(
194-
"get_server_key_json_for_remote", (server_name, key_id)
118+
await self.db_pool.runInteraction(
119+
"store_server_keys_response", store_server_keys_response_txn
195120
)
196121

197122
@cached()

tests/crypto/test_keyring.py

Lines changed: 13 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import time
1515
from typing import Any, Dict, List, Optional, cast
16-
from unittest.mock import AsyncMock, Mock
16+
from unittest.mock import Mock
1717

1818
import attr
1919
import canonicaljson
@@ -189,23 +189,24 @@ def test_verify_json_for_server(self) -> None:
189189
kr = keyring.Keyring(self.hs)
190190

191191
key1 = signedjson.key.generate_signing_key("1")
192-
r = self.hs.get_datastores().main.store_server_keys_json(
192+
r = self.hs.get_datastores().main.store_server_keys_response(
193193
"server9",
194-
get_key_id(key1),
195194
from_server="test",
196-
ts_now_ms=int(time.time() * 1000),
197-
ts_expires_ms=1000,
195+
ts_added_ms=int(time.time() * 1000),
196+
verify_keys={
197+
get_key_id(key1): FetchKeyResult(
198+
verify_key=get_verify_key(key1), valid_until_ts=1000
199+
)
200+
},
198201
# The entire response gets signed & stored, just include the bits we
199202
# care about.
200-
key_json_bytes=canonicaljson.encode_canonical_json(
201-
{
202-
"verify_keys": {
203-
get_key_id(key1): {
204-
"key": encode_verify_key_base64(get_verify_key(key1))
205-
}
203+
response_json={
204+
"verify_keys": {
205+
get_key_id(key1): {
206+
"key": encode_verify_key_base64(get_verify_key(key1))
206207
}
207208
}
208-
),
209+
},
209210
)
210211
self.get_success(r)
211212

@@ -285,34 +286,6 @@ async def get_keys(
285286
d = kr.verify_json_for_server(self.hs.hostname, json1, 0)
286287
self.get_success(d)
287288

288-
def test_verify_json_for_server_with_null_valid_until_ms(self) -> None:
289-
"""Tests that we correctly handle key requests for keys we've stored
290-
with a null `ts_valid_until_ms`
291-
"""
292-
mock_fetcher = Mock()
293-
mock_fetcher.get_keys = AsyncMock(return_value={})
294-
295-
key1 = signedjson.key.generate_signing_key("1")
296-
r = self.hs.get_datastores().main.store_server_signature_keys(
297-
"server9",
298-
int(time.time() * 1000),
299-
# None is not a valid value in FetchKeyResult, but we're abusing this
300-
# API to insert null values into the database. The nulls get converted
301-
# to 0 when fetched in KeyStore.get_server_signature_keys.
302-
{("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), None)}, # type: ignore[arg-type]
303-
)
304-
self.get_success(r)
305-
306-
json1: JsonDict = {}
307-
signedjson.sign.sign_json(json1, "server9", key1)
308-
309-
# should succeed on a signed object with a 0 minimum_valid_until_ms
310-
d = self.hs.get_datastores().main.get_server_signature_keys(
311-
[("server9", get_key_id(key1))]
312-
)
313-
result = self.get_success(d)
314-
self.assertEqual(result[("server9", get_key_id(key1))].valid_until_ts, 0)
315-
316289
def test_verify_json_dedupes_key_requests(self) -> None:
317290
"""Two requests for the same key should be deduped."""
318291
key1 = signedjson.key.generate_signing_key("1")

0 commit comments

Comments
 (0)