Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 11124ed

Browse files
author
David Robertson
committed
Use StateFilter
1 parent 1bc0f13 commit 11124ed

File tree

3 files changed

+50
-108
lines changed

3 files changed

+50
-108
lines changed

synapse/config/api.py

Lines changed: 6 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -13,71 +13,30 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import Any, Container, Dict, Iterable, Mapping, Optional, Set, Tuple, Type
17-
18-
import attr
16+
from typing import Any, Iterable, Optional, Tuple
1917

2018
from synapse.api.constants import EventTypes
2119
from synapse.config._base import Config, ConfigError
2220
from synapse.config._util import validate_config
2321
from synapse.types import JsonDict
22+
from synapse.types.state import StateFilter
2423

2524
logger = logging.getLogger(__name__)
2625

2726

28-
@attr.s(auto_attribs=True)
29-
class StateKeyFilter(Container[str]):
30-
"""A simpler version of StateFilter which ignores event types.
31-
32-
Represents an optional constraint that state_keys must belong to a given set of
33-
strings called `options`. An empty set of `options` means that there are no
34-
restrictions.
35-
"""
36-
37-
options: Set[str]
38-
39-
@classmethod
40-
def any(cls: Type["StateKeyFilter"]) -> "StateKeyFilter":
41-
return cls(set())
42-
43-
@classmethod
44-
def only(cls: Type["StateKeyFilter"], state_key: str) -> "StateKeyFilter":
45-
return cls({state_key})
46-
47-
def __contains__(self, state_key: object) -> bool:
48-
return not self.options or state_key in self.options
49-
50-
def add(self, state_key: Optional[str]) -> None:
51-
if state_key is None:
52-
self.options = set()
53-
elif self.options:
54-
self.options.add(state_key)
55-
56-
5727
class ApiConfig(Config):
5828
section = "api"
5929

60-
room_prejoin_state: Mapping[str, StateKeyFilter]
30+
room_prejoin_state: StateFilter
6131
track_puppetted_users_ips: bool
6232

6333
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
6434
validate_config(_MAIN_SCHEMA, config, ())
65-
self.room_prejoin_state = self._build_prejoin_state(config)
35+
self.room_prejoin_state = StateFilter.from_types(
36+
self._get_prejoin_state_entries(config)
37+
)
6638
self.track_puppeted_user_ips = config.get("track_puppeted_user_ips", False)
6739

68-
def _build_prejoin_state(self, config: JsonDict) -> Dict[str, StateKeyFilter]:
69-
prejoin_events = {}
70-
for event_type, state_key in self._get_prejoin_state_entries(config):
71-
if event_type not in prejoin_events:
72-
if state_key is None:
73-
filter = StateKeyFilter.any()
74-
else:
75-
filter = StateKeyFilter.only(state_key)
76-
prejoin_events[event_type] = filter
77-
else:
78-
prejoin_events[event_type].add(state_key)
79-
return prejoin_events
80-
8140
def _get_prejoin_state_entries(
8241
self, config: JsonDict
8342
) -> Iterable[Tuple[str, Optional[str]]]:

synapse/storage/databases/main/events_worker.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616
import threading
1717
import weakref
1818
from enum import Enum, auto
19+
from itertools import chain
1920
from typing import (
2021
TYPE_CHECKING,
2122
Any,
2223
Collection,
2324
Dict,
2425
Iterable,
2526
List,
26-
Mapping,
2727
MutableMapping,
2828
Optional,
2929
Set,
@@ -46,7 +46,6 @@
4646
RoomVersion,
4747
RoomVersions,
4848
)
49-
from synapse.config.api import StateKeyFilter
5049
from synapse.events import EventBase, make_event_from_dict
5150
from synapse.events.snapshot import EventContext
5251
from synapse.events.utils import prune_event
@@ -77,6 +76,7 @@
7776
)
7877
from synapse.storage.util.sequence import build_sequence_generator
7978
from synapse.types import JsonDict, get_domain_from_id
79+
from synapse.types.state import StateFilter
8080
from synapse.util import unwrapFirstError
8181
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
8282
from synapse.util.caches.descriptors import cached, cachedList
@@ -880,7 +880,7 @@ def _get_events_from_local_cache(
880880
async def get_stripped_room_state_from_event_context(
881881
self,
882882
context: EventContext,
883-
state_keys_to_include: Mapping[str, StateKeyFilter],
883+
state_keys_to_include: StateFilter,
884884
membership_user_id: Optional[str] = None,
885885
) -> List[JsonDict]:
886886
"""
@@ -902,31 +902,21 @@ async def get_stripped_room_state_from_event_context(
902902
Returns:
903903
A list of dictionaries, each representing a stripped state event from the room.
904904
"""
905-
current_state_ids = await context.get_current_state_ids()
905+
if membership_user_id:
906+
types = chain(
907+
state_keys_to_include.to_types(),
908+
[(EventTypes.Member, membership_user_id)],
909+
)
910+
filter = StateFilter.from_types(types)
911+
else:
912+
filter = state_keys_to_include
913+
selected_state_ids = await context.get_current_state_ids(filter)
906914

907915
# We know this event is not an outlier, so this must be
908916
# non-None.
909-
assert current_state_ids is not None
910-
911-
def should_include(t: str, s: str) -> bool:
912-
if t in state_keys_to_include and s in state_keys_to_include[t]:
913-
return True
914-
if (
915-
membership_user_id
916-
and t == EventTypes.Member
917-
and s == membership_user_id
918-
):
919-
return True
920-
return False
921-
922-
# The state to include
923-
state_to_include_ids = [
924-
e_id
925-
for (event_type, state_key), e_id in current_state_ids.items()
926-
if should_include(event_type, state_key)
927-
]
917+
assert selected_state_ids is not None
928918

929-
state_to_include = await self.get_events(state_to_include_ids)
919+
state_to_include = await self.get_events(selected_state_ids.values())
930920

931921
return [
932922
{

tests/config/test_api.py

Lines changed: 30 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,48 +3,32 @@
33
import yaml
44

55
from synapse.config import ConfigError
6-
from synapse.config.api import ApiConfig, StateKeyFilter
7-
8-
DEFAULT_PREJOIN_STATE = {
9-
"m.room.join_rules": StateKeyFilter.only(""),
10-
"m.room.canonical_alias": StateKeyFilter.only(""),
11-
"m.room.avatar": StateKeyFilter.only(""),
12-
"m.room.encryption": StateKeyFilter.only(""),
13-
"m.room.name": StateKeyFilter.only(""),
14-
"m.room.create": StateKeyFilter.only(""),
15-
"m.room.topic": StateKeyFilter.only(""),
6+
from synapse.config.api import ApiConfig
7+
from synapse.types.state import StateFilter
8+
9+
DEFAULT_PREJOIN_STATE_PAIRS = {
10+
("m.room.join_rules", ""),
11+
("m.room.canonical_alias", ""),
12+
("m.room.avatar", ""),
13+
("m.room.encryption", ""),
14+
("m.room.name", ""),
15+
("m.room.create", ""),
16+
("m.room.topic", ""),
1617
}
1718

1819

1920
class TestRoomPrejoinState(StdlibTestCase):
20-
def test_state_key_filter(self) -> None:
21-
"""Sanity check the StateKeyFilter class."""
22-
s = StateKeyFilter.only("foo")
23-
self.assertIn("foo", s)
24-
self.assertNotIn("bar", s)
25-
self.assertNotIn("baz", s)
26-
s.add("bar")
27-
self.assertIn("foo", s)
28-
self.assertIn("bar", s)
29-
self.assertNotIn("baz", s)
30-
31-
s = StateKeyFilter.any()
32-
self.assertIn("foo", s)
33-
self.assertIn("bar", s)
34-
self.assertIn("baz", s)
35-
s.add("bar")
36-
self.assertIn("foo", s)
37-
self.assertIn("bar", s)
38-
self.assertIn("baz", s)
39-
4021
def read_config(self, source: str) -> ApiConfig:
4122
config = ApiConfig()
4223
config.read_config(yaml.safe_load(source))
4324
return config
4425

4526
def test_no_prejoin_state(self) -> None:
4627
config = self.read_config("foo: bar")
47-
self.assertEqual(config.room_prejoin_state, DEFAULT_PREJOIN_STATE)
28+
self.assertFalse(config.room_prejoin_state.has_wildcards())
29+
self.assertEqual(
30+
set(config.room_prejoin_state.concrete_types()), DEFAULT_PREJOIN_STATE_PAIRS
31+
)
4832

4933
def test_disable_default_event_types(self) -> None:
5034
config = self.read_config(
@@ -53,7 +37,7 @@ def test_disable_default_event_types(self) -> None:
5337
disable_default_event_types: true
5438
"""
5539
)
56-
self.assertEqual(config.room_prejoin_state, {})
40+
self.assertEqual(config.room_prejoin_state, StateFilter.none())
5741

5842
def test_event_without_state_key(self) -> None:
5943
config = self.read_config(
@@ -64,7 +48,8 @@ def test_event_without_state_key(self) -> None:
6448
- foo
6549
"""
6650
)
67-
self.assertEqual(config.room_prejoin_state, {"foo": StateKeyFilter.any()})
51+
self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
52+
self.assertEqual(config.room_prejoin_state.concrete_types(), [])
6853

6954
def test_event_with_specific_state_key(self) -> None:
7055
config = self.read_config(
@@ -75,7 +60,11 @@ def test_event_with_specific_state_key(self) -> None:
7560
- [foo, bar]
7661
"""
7762
)
78-
self.assertEqual(config.room_prejoin_state, {"foo": StateKeyFilter.only("bar")})
63+
self.assertFalse(config.room_prejoin_state.has_wildcards())
64+
self.assertEqual(
65+
set(config.room_prejoin_state.concrete_types()),
66+
{("foo", "bar")},
67+
)
7968

8069
def test_repeated_event_with_specific_state_key(self) -> None:
8170
config = self.read_config(
@@ -87,8 +76,10 @@ def test_repeated_event_with_specific_state_key(self) -> None:
8776
- [foo, baz]
8877
"""
8978
)
79+
self.assertFalse(config.room_prejoin_state.has_wildcards())
9080
self.assertEqual(
91-
config.room_prejoin_state, {"foo": StateKeyFilter({"bar", "baz"})}
81+
set(config.room_prejoin_state.concrete_types()),
82+
{("foo", "bar"), ("foo", "baz")},
9283
)
9384

9485
def test_no_specific_state_key_overrides_specific_state_key(self) -> None:
@@ -101,7 +92,8 @@ def test_no_specific_state_key_overrides_specific_state_key(self) -> None:
10192
- foo
10293
"""
10394
)
104-
self.assertEqual(config.room_prejoin_state, {"foo": StateKeyFilter.any()})
95+
self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
96+
self.assertEqual(config.room_prejoin_state.concrete_types(), [])
10597

10698
config = self.read_config(
10799
"""
@@ -112,7 +104,8 @@ def test_no_specific_state_key_overrides_specific_state_key(self) -> None:
112104
- [foo, bar]
113105
"""
114106
)
115-
self.assertEqual(config.room_prejoin_state, {"foo": StateKeyFilter.any()})
107+
self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
108+
self.assertEqual(config.room_prejoin_state.concrete_types(), [])
116109

117110
def test_bad_event_type_entry_raises(self) -> None:
118111
with self.assertRaises(ConfigError):

0 commit comments

Comments
 (0)