|
16 | 16 | import itertools
|
17 | 17 | import json
|
18 | 18 | import logging
|
19 |
| -from typing import Dict, Iterable, Mapping, Optional, Tuple |
| 19 | +from typing import Dict, Iterable, Optional, Tuple |
20 | 20 |
|
| 21 | +from canonicaljson import encode_canonical_json |
21 | 22 | from signedjson.key import decode_verify_key_bytes
|
22 | 23 | from unpaddedbase64 import decode_base64
|
23 | 24 |
|
| 25 | +from synapse.storage.database import LoggingTransaction |
24 | 26 | from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
25 | 27 | from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote
|
26 | 28 | from synapse.storage.types import Cursor
|
| 29 | +from synapse.types import JsonDict |
27 | 30 | from synapse.util.caches.descriptors import cached, cachedList
|
28 | 31 | from synapse.util.iterutils import batch_iter
|
29 | 32 |
|
|
36 | 39 | class KeyStore(CacheInvalidationWorkerStore):
|
37 | 40 | """Persistence for signature verification keys"""
|
38 | 41 |
|
39 |
| - @cached() |
40 |
| - def _get_server_signature_key( |
41 |
| - self, server_name_and_key_id: Tuple[str, str] |
42 |
| - ) -> FetchKeyResult: |
43 |
| - raise NotImplementedError() |
44 |
| - |
45 |
| - @cachedList( |
46 |
| - cached_method_name="_get_server_signature_key", |
47 |
| - list_name="server_name_and_key_ids", |
48 |
| - ) |
49 |
| - async def get_server_signature_keys( |
50 |
| - self, server_name_and_key_ids: Iterable[Tuple[str, str]] |
51 |
| - ) -> Dict[Tuple[str, str], FetchKeyResult]: |
52 |
| - """ |
53 |
| - Args: |
54 |
| - server_name_and_key_ids: |
55 |
| - iterable of (server_name, key-id) tuples to fetch keys for |
56 |
| -
|
57 |
| - Returns: |
58 |
| - A map from (server_name, key_id) -> FetchKeyResult, or None if the |
59 |
| - key is unknown |
60 |
| - """ |
61 |
| - keys = {} |
62 |
| - |
63 |
| - def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None: |
64 |
| - """Processes a batch of keys to fetch, and adds the result to `keys`.""" |
65 |
| - |
66 |
| - # batch_iter always returns tuples so it's safe to do len(batch) |
67 |
| - sql = """ |
68 |
| - SELECT server_name, key_id, verify_key, ts_valid_until_ms |
69 |
| - FROM server_signature_keys WHERE 1=0 |
70 |
| - """ + " OR (server_name=? AND key_id=?)" * len( |
71 |
| - batch |
72 |
| - ) |
73 |
| - |
74 |
| - txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) |
75 |
| - |
76 |
| - for row in txn: |
77 |
| - server_name, key_id, key_bytes, ts_valid_until_ms = row |
78 |
| - |
79 |
| - if ts_valid_until_ms is None: |
80 |
| - # Old keys may be stored with a ts_valid_until_ms of null, |
81 |
| - # in which case we treat this as if it was set to `0`, i.e. |
82 |
| - # it won't match key requests that define a minimum |
83 |
| - # `ts_valid_until_ms`. |
84 |
| - ts_valid_until_ms = 0 |
85 |
| - |
86 |
| - keys[(server_name, key_id)] = FetchKeyResult( |
87 |
| - verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), |
88 |
| - valid_until_ts=ts_valid_until_ms, |
89 |
| - ) |
90 |
| - |
91 |
| - def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]: |
92 |
| - for batch in batch_iter(server_name_and_key_ids, 50): |
93 |
| - _get_keys(txn, batch) |
94 |
| - return keys |
95 |
| - |
96 |
| - return await self.db_pool.runInteraction("get_server_signature_keys", _txn) |
97 |
| - |
98 |
| - async def store_server_signature_keys( |
| 42 | + async def store_server_keys_response( |
99 | 43 | self,
|
| 44 | + server_name: str, |
100 | 45 | from_server: str,
|
101 | 46 | ts_added_ms: int,
|
102 |
| - verify_keys: Mapping[Tuple[str, str], FetchKeyResult], |
| 47 | + verify_keys: Dict[str, FetchKeyResult], |
| 48 | + response_json: JsonDict, |
103 | 49 | ) -> None:
|
104 |
| - """Stores NACL verification keys for remote servers. |
| 50 | + """Stores the keys for the given server that we got from `from_server`. |
| 51 | +
|
105 | 52 | Args:
|
106 |
| - from_server: Where the verification keys were looked up |
107 |
| - ts_added_ms: The time to record that the key was added |
108 |
| - verify_keys: |
109 |
| - keys to be stored. Each entry is a triplet of |
110 |
| - (server_name, key_id, key). |
| 53 | + server_name: The owner of the keys |
| 54 | + from_server: Which server we got the keys from |
| 55 | + ts_added_ms: When we're adding the keys |
| 56 | + verify_keys: The decoded keys |
| 57 | + response_json: The full *signed* response JSON that contains the keys. |
111 | 58 | """
|
112 |
| - key_values = [] |
113 |
| - value_values = [] |
114 |
| - invalidations = [] |
115 |
| - for (server_name, key_id), fetch_result in verify_keys.items(): |
116 |
| - key_values.append((server_name, key_id)) |
117 |
| - value_values.append( |
118 |
| - ( |
119 |
| - from_server, |
120 |
| - ts_added_ms, |
121 |
| - fetch_result.valid_until_ts, |
122 |
| - db_binary_type(fetch_result.verify_key.encode()), |
123 |
| - ) |
124 |
| - ) |
125 |
| - # invalidate takes a tuple corresponding to the params of |
126 |
| - # _get_server_signature_key. _get_server_signature_key only takes one |
127 |
| - # param, which is itself the 2-tuple (server_name, key_id). |
128 |
| - invalidations.append((server_name, key_id)) |
129 | 59 |
|
130 |
| - await self.db_pool.simple_upsert_many( |
131 |
| - table="server_signature_keys", |
132 |
| - key_names=("server_name", "key_id"), |
133 |
| - key_values=key_values, |
134 |
| - value_names=( |
135 |
| - "from_server", |
136 |
| - "ts_added_ms", |
137 |
| - "ts_valid_until_ms", |
138 |
| - "verify_key", |
139 |
| - ), |
140 |
| - value_values=value_values, |
141 |
| - desc="store_server_signature_keys", |
142 |
| - ) |
| 60 | + key_json_bytes = encode_canonical_json(response_json) |
| 61 | + |
| 62 | + def store_server_keys_response_txn(txn: LoggingTransaction) -> None: |
| 63 | + self.db_pool.simple_upsert_many_txn( |
| 64 | + txn, |
| 65 | + table="server_signature_keys", |
| 66 | + key_names=("server_name", "key_id"), |
| 67 | + key_values=[(server_name, key_id) for key_id in verify_keys], |
| 68 | + value_names=( |
| 69 | + "from_server", |
| 70 | + "ts_added_ms", |
| 71 | + "ts_valid_until_ms", |
| 72 | + "verify_key", |
| 73 | + ), |
| 74 | + value_values=[ |
| 75 | + ( |
| 76 | + from_server, |
| 77 | + ts_added_ms, |
| 78 | + fetch_result.valid_until_ts, |
| 79 | + db_binary_type(fetch_result.verify_key.encode()), |
| 80 | + ) |
| 81 | + for fetch_result in verify_keys.values() |
| 82 | + ], |
| 83 | + ) |
143 | 84 |
|
144 |
| - invalidate = self._get_server_signature_key.invalidate |
145 |
| - for i in invalidations: |
146 |
| - invalidate((i,)) |
| 85 | + self.db_pool.simple_upsert_many_txn( |
| 86 | + txn, |
| 87 | + table="server_keys_json", |
| 88 | + key_names=("server_name", "key_id", "from_server"), |
| 89 | + key_values=[ |
| 90 | + (server_name, key_id, from_server) for key_id in verify_keys |
| 91 | + ], |
| 92 | + value_names=( |
| 93 | + "ts_added_ms", |
| 94 | + "ts_valid_until_ms", |
| 95 | + "key_json", |
| 96 | + ), |
| 97 | + value_values=[ |
| 98 | + ( |
| 99 | + ts_added_ms, |
| 100 | + fetch_result.valid_until_ts, |
| 101 | + db_binary_type(key_json_bytes), |
| 102 | + ) |
| 103 | + for fetch_result in verify_keys.values() |
| 104 | + ], |
| 105 | + ) |
147 | 106 |
|
148 |
| - async def store_server_keys_json( |
149 |
| - self, |
150 |
| - server_name: str, |
151 |
| - key_id: str, |
152 |
| - from_server: str, |
153 |
| - ts_now_ms: int, |
154 |
| - ts_expires_ms: int, |
155 |
| - key_json_bytes: bytes, |
156 |
| - ) -> None: |
157 |
| - """Stores the JSON bytes for a set of keys from a server |
158 |
| - The JSON should be signed by the originating server, the intermediate |
159 |
| - server, and by this server. Updates the value for the |
160 |
| - (server_name, key_id, from_server) triplet if one already existed. |
161 |
| - Args: |
162 |
| - server_name: The name of the server. |
163 |
| - key_id: The identifier of the key this JSON is for. |
164 |
| - from_server: The server this JSON was fetched from. |
165 |
| - ts_now_ms: The time now in milliseconds. |
166 |
| - ts_valid_until_ms: The time when this json stops being valid. |
167 |
| - key_json_bytes: The encoded JSON. |
168 |
| - """ |
169 |
| - await self.db_pool.simple_upsert( |
170 |
| - table="server_keys_json", |
171 |
| - keyvalues={ |
172 |
| - "server_name": server_name, |
173 |
| - "key_id": key_id, |
174 |
| - "from_server": from_server, |
175 |
| - }, |
176 |
| - values={ |
177 |
| - "server_name": server_name, |
178 |
| - "key_id": key_id, |
179 |
| - "from_server": from_server, |
180 |
| - "ts_added_ms": ts_now_ms, |
181 |
| - "ts_valid_until_ms": ts_expires_ms, |
182 |
| - "key_json": db_binary_type(key_json_bytes), |
183 |
| - }, |
184 |
| - desc="store_server_keys_json", |
185 |
| - ) |
| 107 | + # invalidate takes a tuple corresponding to the params of |
| 108 | + # _get_server_keys_json. _get_server_keys_json only takes one |
| 109 | + # param, which is itself the 2-tuple (server_name, key_id). |
| 110 | + for key_id in verify_keys: |
| 111 | + self._invalidate_cache_and_stream( |
| 112 | + txn, self._get_server_keys_json, ((server_name, key_id),) |
| 113 | + ) |
| 114 | + self._invalidate_cache_and_stream( |
| 115 | + txn, self.get_server_key_json_for_remote, (server_name, key_id) |
| 116 | + ) |
186 | 117 |
|
187 |
| - # invalidate takes a tuple corresponding to the params of |
188 |
| - # _get_server_keys_json. _get_server_keys_json only takes one |
189 |
| - # param, which is itself the 2-tuple (server_name, key_id). |
190 |
| - await self.invalidate_cache_and_stream( |
191 |
| - "_get_server_keys_json", ((server_name, key_id),) |
192 |
| - ) |
193 |
| - await self.invalidate_cache_and_stream( |
194 |
| - "get_server_key_json_for_remote", (server_name, key_id) |
| 118 | + await self.db_pool.runInteraction( |
| 119 | + "store_server_keys_response", store_server_keys_response_txn |
195 | 120 | )
|
196 | 121 |
|
197 | 122 | @cached()
|
|
0 commit comments