diff --git a/rust/src/events/internal_metadata.rs b/rust/src/events/internal_metadata.rs index eeb6074c10c8..da5665de791f 100644 --- a/rust/src/events/internal_metadata.rs +++ b/rust/src/events/internal_metadata.rs @@ -58,6 +58,7 @@ enum EventInternalMetadataData { TxnId(Box), TokenId(i64), DeviceId(Box), + SendAdditionalContext(bool), } impl EventInternalMetadataData { @@ -115,6 +116,13 @@ impl EventInternalMetadataData { pyo3::intern!(py, "device_id"), o.into_pyobject(py).unwrap_infallible().into_any(), ), + EventInternalMetadataData::SendAdditionalContext(o) => ( + pyo3::intern!(py, "send_additional_context"), + o.into_pyobject(py) + .unwrap_infallible() + .to_owned() + .into_any(), + ), } } @@ -177,6 +185,11 @@ impl EventInternalMetadataData { .map(String::into_boxed_str) .with_context(|| format!("'{key_str}' has invalid type"))?, ), + "send_additional_context" => EventInternalMetadataData::SendAdditionalContext( + value + .extract() + .with_context(|| format!("'{key_str}' has invalid type"))?, + ), _ => return Ok(None), }; @@ -370,6 +383,16 @@ impl EventInternalMetadata { get_property_opt!(self, Redacted).copied().unwrap_or(false) } + /// Whether this is a join that occurred at the same depth as another event. + /// + /// This is used to see if the joining server should proactively be sent other + /// events that occurred between the make_join and send_join. + fn should_send_additional_context(&self) -> bool { + get_property_opt!(self, SendAdditionalContext) + .copied() + .unwrap_or(false) + } + /// Whether this event can trigger a push notification fn is_notifiable(&self) -> bool { !self.outlier || self.is_out_of_band_membership() @@ -437,6 +460,16 @@ impl EventInternalMetadata { set_property!(self, Redacted, obj); } + #[getter] + fn get_send_additional_context(&self) -> PyResult { + let bool = get_property!(self, SendAdditionalContext)?; + Ok(*bool) + } + #[setter] + fn set_send_additional_context(&mut self, obj: bool) { + set_property!(self, SendAdditionalContext, obj); + } + /// The transaction ID, if it was set when the event was created. #[getter] fn get_txn_id(&self) -> PyResult<&str> { diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index b95b3c629d44..a6e4a09bbc50 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -132,6 +132,7 @@ import abc import logging from collections import OrderedDict +from itertools import chain from typing import ( TYPE_CHECKING, Collection, @@ -150,6 +151,7 @@ from twisted.internet import defer import synapse.metrics +from synapse.api.errors import StoreError from synapse.api.presence import UserPresenceState from synapse.events import EventBase from synapse.federation.sender.per_destination_queue import ( @@ -169,6 +171,7 @@ run_as_background_process, wrap_as_background_process, ) +from synapse.storage.databases.main.events_worker import EventCacheEntry from synapse.types import ( JsonDict, ReadReceipt, @@ -480,10 +483,63 @@ async def _process_event_queue_loop(self) -> None: ) event_ids = event_to_received_ts.keys() - event_entries = await self.store.get_unredacted_events_from_cache_or_db( - event_ids - ) + event_entries: Dict[ + str, EventCacheEntry + ] = await self.store.get_unredacted_events_from_cache_or_db(event_ids) + + # Do a quick check inside the events metadata, see if we need to gather + # more events to send(such as additional forward extremities during a join) + _extra_event_entries: Dict[str, EventCacheEntry] = {} + + # Pre-trigger the destinations set below, because if we need to send a + # forward extremity, it likely it only needs to go to the server that + # the join came from. + # Mapping of event_id -> set[remote server] + _special_destinations: Dict[str, set[str]] = {} + for event_cache_entry in event_entries.values(): + event_metadata = event_cache_entry.event.internal_metadata + if ( + event_metadata.stream_ordering + and event_metadata.should_send_additional_context() + ): + try: + maybe_forward_extremities = set( + await self.store._get_forward_extremeties_for_room( + event_cache_entry.event.room_id, + event_metadata.stream_ordering, + ) + ) + # Strike out any of the event_ids we were already going to + # send. TODO: maybe parse them anyway for the special destinations? + maybe_forward_extremities.difference_update(event_ids) + if maybe_forward_extremities: + logger.debug( + "Found additional events that need to be sent: %r", + maybe_forward_extremities, + ) + # Good, add first to the cached EventBases being sent out + new_entries = await self.store.get_unredacted_events_from_cache_or_db( + maybe_forward_extremities + ) + _extra_event_entries.update(new_entries) + # Then update the event_id_to_received_ts dict so... + # ...the timestamp can be updated? wtf? + # for event_id, event_cache_entry in new_entries: + # event_to_received_ts.update({event_id: event_cache_entry.event.received_ts}) + for _event_entry in _extra_event_entries: + _special_destinations.setdefault( + _event_entry, set() + ).add( + get_domain_from_id( + event_cache_entry.event.sender + ) + ) + except StoreError: + logger.debug( + "Skipping additional event context to send related to a join" + ) + event_entries.update(_extra_event_entries) logger.debug( "Handling %i -> %i: %i events to send (current id %i)", last_token, @@ -492,6 +548,8 @@ async def _process_event_queue_loop(self) -> None: self._last_poked_id, ) + # This presents a problem. The _last_poked_id will be higher than the + # next_token here, so the loop may stop. Maybe not, this may be fine if not event_entries and next_token >= self._last_poked_id: logger.debug("All events processed") break @@ -550,6 +608,9 @@ async def handle_event(event: EventBase) -> None: ) return + # destinations is going to be difficult too. State at the event will + # not reflect that the newly-joined host is in the room. Thanks + # stream ordering destinations: Optional[Collection[str]] = None if not event.prev_event_ids(): # If there are no prev event IDs then the state is empty @@ -571,6 +632,14 @@ async def handle_event(event: EventBase) -> None: if partial_state_destinations is not None: destinations = partial_state_destinations + if destinations is None: + destinations = _special_destinations.get(event.event_id, None) + logger.debug( + "Grabbing special destinations of %r for %s", + destinations, + event.event_id, + ) + if destinations is None: # We check the external cache for the destinations, which is # stored per state group. @@ -634,11 +703,11 @@ async def handle_event(event: EventBase) -> None: await self._send_pdu(event, sharded_destinations) now = self.clock.time_msec() - ts = event_to_received_ts[event.event_id] - assert ts is not None - synapse.metrics.event_processing_lag_by_event.labels( - "federation_sender" - ).observe((now - ts) / 1000) + ts = event_to_received_ts.get(event.event_id) + if ts is not None: + synapse.metrics.event_processing_lag_by_event.labels( + "federation_sender" + ).observe((now - ts) / 1000) async def handle_room_events(events: List[EventBase]) -> None: logger.debug( @@ -650,7 +719,7 @@ async def handle_room_events(events: List[EventBase]) -> None: events_by_room: Dict[str, List[EventBase]] = {} - for event_id in event_ids: + for event_id in chain(event_ids, _extra_event_entries.keys()): # `event_entries` is unsorted, so we have to iterate over `event_ids` # to ensure the events are in the right order event_cache = event_entries.get(event_id) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index d8d8c8a0fe80..eeb4977dcd5a 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -229,7 +229,10 @@ async def maybe_backfill( # linearizer lock queue in the timing processing_start_time = self.clock.time_msec() if record_time else 0 - async with self._room_backfill.queue(room_id): + async with ( + self._room_backfill.queue(room_id), + self._federation_event_handler._room_pdu_linearizer.queue(room_id), + ): async with self._worker_locks.acquire_read_write_lock( PURGE_PAGINATION_LOCK_NAME, room_id, write=False ): @@ -1836,11 +1839,14 @@ async def _sync_partial_state_room_wrapper() -> None: self._active_partial_state_syncs.add(room_id) try: - await self._sync_partial_state_room( - initial_destination=initial_destination, - other_destinations=other_destinations, - room_id=room_id, - ) + async with self._federation_event_handler._room_pdu_linearizer.queue( + room_id + ): + await self._sync_partial_state_room( + initial_destination=initial_destination, + other_destinations=other_destinations, + room_id=room_id, + ) finally: # Read the room's partial state flag while we still hold the claim to # being the active partial state sync (so that another partial state diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 1e738f484f95..bc47be512e9f 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -268,7 +268,13 @@ async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None: # Try to fetch any missing prev events to fill in gaps in the graph prevs = set(pdu.prev_event_ids()) + logger.info( + "JASON: on_receive_pdu: initial try to collect prev_events: %r", prevs + ) seen = await self._store.have_events_in_timeline(prevs) + logger.info( + "JASON: on_receive_pdu: initial view on events that are seen: %r", seen + ) missing_prevs = prevs - seen if missing_prevs: @@ -405,6 +411,18 @@ async def on_send_membership_event( # the room, so we send it on their behalf. event.internal_metadata.send_on_behalf_of = origin + # If the join handshake started and another event was persisted before the join + # finished, there will be another forward extremity at that depth that is not + # referenced by the join's prev_events. We also do not know the stream ordering + # yet for this join, to do the lookup for it. Set the flag now, to avoid having + # to check while persisting if this is: + # 1. A new join(and not a join->join transition like what happens with display + # names) + # 2. Not a backfilled join, as we don't care about those for proactively sending + # related events + if event.membership == Membership.JOIN: + event.internal_metadata.send_additional_context = True + context = await self._state_handler.compute_event_context(event) await self._check_event_auth(origin, event, context) if context.rejected: @@ -728,9 +746,13 @@ async def _get_missing_events_for_pdu( room_id = pdu.room_id event_id = pdu.event_id + logger.info("JASON: _get_missing_events_for_pdu: prevs declared: %r", prevs) seen = await self._store.have_events_in_timeline(prevs) - + logger.info( + "JASON: _get_missing_events_for_pdu: inner view on events that are seen: %r", + seen, + ) if not prevs - seen: return diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 26fbc1a483f2..0a70e772c824 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -336,6 +336,20 @@ async def _persist_events_and_state_updates( # `stream_ordering` from the first time it was persisted). event.internal_metadata.stream_ordering = stream event.internal_metadata.instance_name = self._instance_name + if ( + new_forward_extremities + and event.internal_metadata.should_send_additional_context() + ): + # It's a legit new join(because otherwise this may be a backwards + # extremity and we don't care)... + if event.event_id in new_forward_extremities: + # ...if there are other forwards... + forward_extremities_that_remain = ( + new_forward_extremities.difference(event.event_id) + ) + if not forward_extremities_that_remain: + # ...whoops, false positive. Don't actually need to send anything else + event.internal_metadata.send_additional_context = False sliding_sync_table_changes = None if state_delta_for_room is not None: diff --git a/synapse/synapse_rust/events.pyi b/synapse/synapse_rust/events.pyi index 7d3422572ddb..aea83b40d1b7 100644 --- a/synapse/synapse_rust/events.pyi +++ b/synapse/synapse_rust/events.pyi @@ -39,6 +39,8 @@ class EventInternalMetadata: """The access token ID of the user who sent this event, if any.""" device_id: str """The device ID of the user who sent this event, if any.""" + send_additional_context: bool + """Check for additional forward extremities adjacent to a join that should be sent to the joining server""" def get_dict(self) -> JsonDict: ... def is_outlier(self) -> bool: ... @@ -103,6 +105,11 @@ class EventInternalMetadata: marked as redacted without needing to make another database call. """ + def should_send_additional_context(self) -> bool: + """Whether this is a join and a look up should be done for additional + forward extremities. + """ + def is_notifiable(self) -> bool: """Whether this event can trigger a push notification"""