Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 85bfd47

Browse files
authored
Return an immutable value from get_latest_event_ids_in_room. (#16326)
1 parent 63d28a8 commit 85bfd47

File tree

12 files changed

+48
-40
lines changed

12 files changed

+48
-40
lines changed

changelog.d/16326.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Improve type hints.

synapse/events/builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def is_state(self) -> bool:
103103

104104
async def build(
105105
self,
106-
prev_event_ids: StrCollection,
106+
prev_event_ids: List[str],
107107
auth_event_ids: Optional[List[str]],
108108
depth: Optional[int] = None,
109109
) -> EventBase:

synapse/handlers/federation_event.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -723,12 +723,11 @@ async def _get_missing_events_for_pdu(
723723
if not prevs - seen:
724724
return
725725

726-
latest_list = await self._store.get_latest_event_ids_in_room(room_id)
726+
latest_frozen = await self._store.get_latest_event_ids_in_room(room_id)
727727

728728
# We add the prev events that we have seen to the latest
729729
# list to ensure the remote server doesn't give them to us
730-
latest = set(latest_list)
731-
latest |= seen
730+
latest = seen | latest_frozen
732731

733732
logger.info(
734733
"Requesting missing events between %s and %s",
@@ -1976,8 +1975,7 @@ async def _check_for_soft_fail(
19761975
# partial and full state and may not be accurate.
19771976
return
19781977

1979-
extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id)
1980-
extrem_ids = set(extrem_ids_list)
1978+
extrem_ids = await self._store.get_latest_event_ids_in_room(event.room_id)
19811979
prev_event_ids = set(event.prev_event_ids())
19821980

19831981
if extrem_ids == prev_event_ids:

synapse/storage/controllers/persist_events.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from collections import deque
2020
from typing import (
2121
TYPE_CHECKING,
22+
AbstractSet,
2223
Any,
2324
Awaitable,
2425
Callable,
@@ -618,7 +619,7 @@ async def _persist_event_batch(
618619
)
619620

620621
for room_id, ev_ctx_rm in events_by_room.items():
621-
latest_event_ids = set(
622+
latest_event_ids = (
622623
await self.main_store.get_latest_event_ids_in_room(room_id)
623624
)
624625
new_latest_event_ids = await self._calculate_new_extremities(
@@ -740,7 +741,7 @@ async def _calculate_new_extremities(
740741
self,
741742
room_id: str,
742743
event_contexts: List[Tuple[EventBase, EventContext]],
743-
latest_event_ids: Collection[str],
744+
latest_event_ids: AbstractSet[str],
744745
) -> Set[str]:
745746
"""Calculates the new forward extremities for a room given events to
746747
persist.
@@ -758,8 +759,6 @@ async def _calculate_new_extremities(
758759
and not event.internal_metadata.is_soft_failed()
759760
]
760761

761-
latest_event_ids = set(latest_event_ids)
762-
763762
# start with the existing forward extremities
764763
result = set(latest_event_ids)
765764

@@ -798,7 +797,7 @@ async def _get_new_state_after_events(
798797
self,
799798
room_id: str,
800799
events_context: List[Tuple[EventBase, EventContext]],
801-
old_latest_event_ids: Set[str],
800+
old_latest_event_ids: AbstractSet[str],
802801
new_latest_event_ids: Set[str],
803802
) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]], Set[str]]:
804803
"""Calculate the current state dict after adding some new events to

synapse/storage/databases/main/event_federation.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
TYPE_CHECKING,
2020
Collection,
2121
Dict,
22+
FrozenSet,
2223
Iterable,
2324
List,
2425
Optional,
@@ -47,7 +48,7 @@
4748
from synapse.storage.databases.main.events_worker import EventsWorkerStore
4849
from synapse.storage.databases.main.signatures import SignatureWorkerStore
4950
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
50-
from synapse.types import JsonDict, StrCollection, StrSequence
51+
from synapse.types import JsonDict, StrCollection
5152
from synapse.util import json_encoder
5253
from synapse.util.caches.descriptors import cached
5354
from synapse.util.caches.lrucache import LruCache
@@ -1179,13 +1180,14 @@ def _get_rooms_with_many_extremities_txn(txn: LoggingTransaction) -> List[str]:
11791180
)
11801181

11811182
@cached(max_entries=5000, iterable=True)
1182-
async def get_latest_event_ids_in_room(self, room_id: str) -> StrSequence:
1183-
return await self.db_pool.simple_select_onecol(
1183+
async def get_latest_event_ids_in_room(self, room_id: str) -> FrozenSet[str]:
1184+
event_ids = await self.db_pool.simple_select_onecol(
11841185
table="event_forward_extremities",
11851186
keyvalues={"room_id": room_id},
11861187
retcol="event_id",
11871188
desc="get_latest_event_ids_in_room",
11881189
)
1190+
return frozenset(event_ids)
11891191

11901192
async def get_min_depth(self, room_id: str) -> Optional[int]:
11911193
"""For the given room, get the minimum depth we have seen for it."""

synapse/storage/databases/main/events.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ async def _persist_events_and_state_updates(
222222

223223
for room_id, latest_event_ids in new_forward_extremities.items():
224224
self.store.get_latest_event_ids_in_room.prefill(
225-
(room_id,), list(latest_event_ids)
225+
(room_id,), frozenset(latest_event_ids)
226226
)
227227

228228
async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:

tests/handlers/test_presence.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1858,7 +1858,7 @@ def _add_new_user(self, room_id: str, user_id: str) -> None:
18581858
)
18591859

18601860
event = self.get_success(
1861-
builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None)
1861+
builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None)
18621862
)
18631863

18641864
self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event))

tests/replication/storage/test_events.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def tearDown(self) -> None:
9090
def test_get_latest_event_ids_in_room(self) -> None:
9191
create = self.persist(type="m.room.create", key="", creator=USER_ID)
9292
self.replicate()
93-
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
93+
self.check("get_latest_event_ids_in_room", (ROOM_ID,), {create.event_id})
9494

9595
join = self.persist(
9696
type="m.room.member",
@@ -99,7 +99,7 @@ def test_get_latest_event_ids_in_room(self) -> None:
9999
prev_events=[(create.event_id, {})],
100100
)
101101
self.replicate()
102-
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
102+
self.check("get_latest_event_ids_in_room", (ROOM_ID,), {join.event_id})
103103

104104
def test_redactions(self) -> None:
105105
self.persist(type="m.room.create", key="", creator=USER_ID)

tests/replication/tcp/streams/test_events.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, List, Optional, Sequence
15+
from typing import Any, List, Optional
1616

1717
from twisted.test.proto_helpers import MemoryReactor
1818

@@ -139,7 +139,7 @@ def test_update_function_huge_state_change(self) -> None:
139139
)
140140

141141
# this is the point in the DAG where we make a fork
142-
fork_point: Sequence[str] = self.get_success(
142+
fork_point = self.get_success(
143143
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
144144
)
145145

@@ -294,7 +294,7 @@ def test_update_function_state_row_limit(self) -> None:
294294
)
295295

296296
# this is the point in the DAG where we make a fork
297-
fork_point: Sequence[str] = self.get_success(
297+
fork_point = self.get_success(
298298
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
299299
)
300300

@@ -316,14 +316,14 @@ def test_update_function_state_row_limit(self) -> None:
316316
self.test_handler.received_rdata_rows.clear()
317317

318318
# now roll back all that state by de-modding the users
319-
prev_events = fork_point
319+
prev_events = list(fork_point)
320320
pl_events = []
321321
for u in user_ids:
322322
pls["users"][u] = 0
323323
e = self.get_success(
324324
inject_event(
325325
self.hs,
326-
prev_event_ids=list(prev_events),
326+
prev_event_ids=prev_events,
327327
type=EventTypes.PowerLevels,
328328
state_key="",
329329
sender=self.user_id,

tests/replication/test_federation_sender_shard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def create_room_with_remote_server(
261261

262262
builder = factory.for_room_version(room_version, event_dict)
263263
join_event = self.get_success(
264-
builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None)
264+
builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None)
265265
)
266266

267267
self.get_success(federation.on_send_membership_event(remote_server, join_event))

tests/storage/test_cleanup_extrems.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_soft_failed_extremities_handled_correctly(self) -> None:
120120
self.store.get_latest_event_ids_in_room(self.room_id)
121121
)
122122

123-
self.assertEqual(latest_event_ids, [event_id_4])
123+
self.assertEqual(latest_event_ids, {event_id_4})
124124

125125
def test_basic_cleanup(self) -> None:
126126
"""Test that extremities are correctly calculated in the presence of
@@ -147,15 +147,15 @@ def test_basic_cleanup(self) -> None:
147147
latest_event_ids = self.get_success(
148148
self.store.get_latest_event_ids_in_room(self.room_id)
149149
)
150-
self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b})
150+
self.assertEqual(latest_event_ids, {event_id_a, event_id_b})
151151

152152
# Run the background update and check it did the right thing
153153
self.run_background_update()
154154

155155
latest_event_ids = self.get_success(
156156
self.store.get_latest_event_ids_in_room(self.room_id)
157157
)
158-
self.assertEqual(latest_event_ids, [event_id_b])
158+
self.assertEqual(latest_event_ids, {event_id_b})
159159

160160
def test_chain_of_fail_cleanup(self) -> None:
161161
"""Test that extremities are correctly calculated in the presence of
@@ -185,15 +185,15 @@ def test_chain_of_fail_cleanup(self) -> None:
185185
latest_event_ids = self.get_success(
186186
self.store.get_latest_event_ids_in_room(self.room_id)
187187
)
188-
self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b})
188+
self.assertEqual(latest_event_ids, {event_id_a, event_id_b})
189189

190190
# Run the background update and check it did the right thing
191191
self.run_background_update()
192192

193193
latest_event_ids = self.get_success(
194194
self.store.get_latest_event_ids_in_room(self.room_id)
195195
)
196-
self.assertEqual(latest_event_ids, [event_id_b])
196+
self.assertEqual(latest_event_ids, {event_id_b})
197197

198198
def test_forked_graph_cleanup(self) -> None:
199199
r"""Test that extremities are correctly calculated in the presence of
@@ -240,15 +240,15 @@ def test_forked_graph_cleanup(self) -> None:
240240
latest_event_ids = self.get_success(
241241
self.store.get_latest_event_ids_in_room(self.room_id)
242242
)
243-
self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b, event_id_c})
243+
self.assertEqual(latest_event_ids, {event_id_a, event_id_b, event_id_c})
244244

245245
# Run the background update and check it did the right thing
246246
self.run_background_update()
247247

248248
latest_event_ids = self.get_success(
249249
self.store.get_latest_event_ids_in_room(self.room_id)
250250
)
251-
self.assertEqual(set(latest_event_ids), {event_id_b, event_id_c})
251+
self.assertEqual(latest_event_ids, {event_id_b, event_id_c})
252252

253253

254254
class CleanupExtremDummyEventsTestCase(HomeserverTestCase):

tests/test_federation.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,15 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
5151
self.store = self.hs.get_datastores().main
5252

5353
# Figure out what the most recent event is
54-
most_recent = self.get_success(
55-
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
56-
)[0]
54+
most_recent = next(
55+
iter(
56+
self.get_success(
57+
self.hs.get_datastores().main.get_latest_event_ids_in_room(
58+
self.room_id
59+
)
60+
)
61+
)
62+
)
5763

5864
join_event = make_event_from_dict(
5965
{
@@ -100,8 +106,8 @@ async def _check_sigs_and_hash_for_pulled_events_and_fetch(
100106

101107
# Make sure we actually joined the room
102108
self.assertEqual(
103-
self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0],
104-
"$join:test.serv",
109+
self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)),
110+
{"$join:test.serv"},
105111
)
106112

107113
def test_cant_hide_direct_ancestors(self) -> None:
@@ -127,9 +133,11 @@ async def post_json(
127133
self.http_client.post_json = post_json
128134

129135
# Figure out what the most recent event is
130-
most_recent = self.get_success(
131-
self.store.get_latest_event_ids_in_room(self.room_id)
132-
)[0]
136+
most_recent = next(
137+
iter(
138+
self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
139+
)
140+
)
133141

134142
# Now lie about an event
135143
lying_event = make_event_from_dict(
@@ -165,7 +173,7 @@ async def post_json(
165173

166174
# Make sure the invalid event isn't there
167175
extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
168-
self.assertEqual(extrem[0], "$join:test.serv")
176+
self.assertEqual(extrem, {"$join:test.serv"})
169177

170178
def test_retry_device_list_resync(self) -> None:
171179
"""Tests that device lists are marked as stale if they couldn't be synced, and

0 commit comments

Comments
 (0)