@@ -883,10 +883,24 @@ async def _get_state_after_missing_prev_event(
883
883
logger .debug ("We are also missing %i auth events" , len (missing_auth_events ))
884
884
885
885
missing_events = missing_desired_events | missing_auth_events
886
- logger .debug ("Fetching %i events from remote" , len (missing_events ))
887
- await self ._get_events_and_persist (
888
- destination = destination , room_id = room_id , event_ids = missing_events
889
- )
886
+
887
+ # Making an individual request for each of 1000s of events has a lot of
888
+ # overhead. On the other hand, we don't really want to fetch all of the events
889
+ # if we already have most of them.
890
+ #
891
+ # As an arbitrary heuristic, if we are missing more than 10% of the events, then
892
+ # we fetch the whole state.
893
+ #
894
+ # TODO: might it be better to have an API which lets us do an aggregate event
895
+ # request
896
+ if (len (missing_events ) * 10 ) >= len (auth_event_ids ) + len (state_event_ids ):
897
+ logger .debug ("Requesting complete state from remote" )
898
+ await self ._get_state_and_persist (destination , room_id , event_id )
899
+ else :
900
+ logger .debug ("Fetching %i events from remote" , len (missing_events ))
901
+ await self ._get_events_and_persist (
902
+ destination = destination , room_id = room_id , event_ids = missing_events
903
+ )
890
904
891
905
# we need to make sure we re-load from the database to get the rejected
892
906
# state correct.
@@ -945,6 +959,27 @@ async def _get_state_after_missing_prev_event(
945
959
946
960
return remote_state
947
961
962
+ async def _get_state_and_persist (
963
+ self , destination : str , room_id : str , event_id : str
964
+ ) -> None :
965
+ """Get the complete room state at a given event, and persist any new events
966
+ as outliers"""
967
+ room_version = await self ._store .get_room_version (room_id )
968
+ auth_events , state_events = await self ._federation_client .get_room_state (
969
+ destination , room_id , event_id = event_id , room_version = room_version
970
+ )
971
+ logger .info ("/state returned %i events" , len (auth_events ) + len (state_events ))
972
+
973
+ await self ._auth_and_persist_outliers (
974
+ room_id , itertools .chain (auth_events , state_events )
975
+ )
976
+
977
+ # we also need the event itself.
978
+ if not await self ._store .have_seen_events (room_id , event_id ):
979
+ await self ._get_events_and_persist (
980
+ destination = destination , room_id = room_id , event_ids = (event_id ,)
981
+ )
982
+
948
983
async def _process_received_pdu (
949
984
self ,
950
985
origin : str ,
0 commit comments