@@ -897,10 +897,24 @@ async def _get_state_after_missing_prev_event(
897
897
logger .debug ("We are also missing %i auth events" , len (missing_auth_events ))
898
898
899
899
missing_events = missing_desired_events | missing_auth_events
900
- logger .debug ("Fetching %i events from remote" , len (missing_events ))
901
- await self ._get_events_and_persist (
902
- destination = destination , room_id = room_id , event_ids = missing_events
903
- )
900
+
901
+ # Making an individual request for each of 1000s of events has a lot of
902
+ # overhead. On the other hand, we don't really want to fetch all of the events
903
+ # if we already have most of them.
904
+ #
905
+ # As an arbitrary heuristic, if we are missing more than 10% of the events, then
906
+ # we fetch the whole state.
907
+ #
908
+ # TODO: might it be better to have an API which lets us do an aggregate event
909
+ # request
910
+ if (len (missing_events ) * 10 ) >= len (auth_event_ids ) + len (state_event_ids ):
911
+ logger .debug ("Requesting complete state from remote" )
912
+ await self ._get_state_and_persist (destination , room_id , event_id )
913
+ else :
914
+ logger .debug ("Fetching %i events from remote" , len (missing_events ))
915
+ await self ._get_events_and_persist (
916
+ destination = destination , room_id = room_id , event_ids = missing_events
917
+ )
904
918
905
919
# we need to make sure we re-load from the database to get the rejected
906
920
# state correct.
@@ -959,6 +973,27 @@ async def _get_state_after_missing_prev_event(
959
973
960
974
return remote_state
961
975
976
+ async def _get_state_and_persist (
977
+ self , destination : str , room_id : str , event_id : str
978
+ ) -> None :
979
+ """Get the complete room state at a given event, and persist any new events
980
+ as outliers"""
981
+ room_version = await self ._store .get_room_version (room_id )
982
+ auth_events , state_events = await self ._federation_client .get_room_state (
983
+ destination , room_id , event_id = event_id , room_version = room_version
984
+ )
985
+ logger .info ("/state returned %i events" , len (auth_events ) + len (state_events ))
986
+
987
+ await self ._auth_and_persist_outliers (
988
+ room_id , itertools .chain (auth_events , state_events )
989
+ )
990
+
991
+ # we also need the event itself.
992
+ if not await self ._store .have_seen_events (room_id , event_id ):
993
+ await self ._get_events_and_persist (
994
+ destination = destination , room_id = room_id , event_ids = (event_id ,)
995
+ )
996
+
962
997
async def _process_received_pdu (
963
998
self ,
964
999
origin : str ,
0 commit comments