3
3
import yaml
4
4
5
5
from synapse .config import ConfigError
6
- from synapse .config .api import ApiConfig , StateKeyFilter
7
-
8
- DEFAULT_PREJOIN_STATE = {
9
- "m.room.join_rules" : StateKeyFilter .only ("" ),
10
- "m.room.canonical_alias" : StateKeyFilter .only ("" ),
11
- "m.room.avatar" : StateKeyFilter .only ("" ),
12
- "m.room.encryption" : StateKeyFilter .only ("" ),
13
- "m.room.name" : StateKeyFilter .only ("" ),
14
- "m.room.create" : StateKeyFilter .only ("" ),
15
- "m.room.topic" : StateKeyFilter .only ("" ),
6
+ from synapse .config .api import ApiConfig
7
+ from synapse .types .state import StateFilter
8
+
9
+ DEFAULT_PREJOIN_STATE_PAIRS = {
10
+ ("m.room.join_rules" , "" ),
11
+ ("m.room.canonical_alias" , "" ),
12
+ ("m.room.avatar" , "" ),
13
+ ("m.room.encryption" , "" ),
14
+ ("m.room.name" , "" ),
15
+ ("m.room.create" , "" ),
16
+ ("m.room.topic" , "" ),
16
17
}
17
18
18
19
19
20
class TestRoomPrejoinState (StdlibTestCase ):
20
- def test_state_key_filter (self ) -> None :
21
- """Sanity check the StateKeyFilter class."""
22
- s = StateKeyFilter .only ("foo" )
23
- self .assertIn ("foo" , s )
24
- self .assertNotIn ("bar" , s )
25
- self .assertNotIn ("baz" , s )
26
- s .add ("bar" )
27
- self .assertIn ("foo" , s )
28
- self .assertIn ("bar" , s )
29
- self .assertNotIn ("baz" , s )
30
-
31
- s = StateKeyFilter .any ()
32
- self .assertIn ("foo" , s )
33
- self .assertIn ("bar" , s )
34
- self .assertIn ("baz" , s )
35
- s .add ("bar" )
36
- self .assertIn ("foo" , s )
37
- self .assertIn ("bar" , s )
38
- self .assertIn ("baz" , s )
39
-
40
21
def read_config (self , source : str ) -> ApiConfig :
41
22
config = ApiConfig ()
42
23
config .read_config (yaml .safe_load (source ))
43
24
return config
44
25
45
26
def test_no_prejoin_state (self ) -> None :
46
27
config = self .read_config ("foo: bar" )
47
- self .assertEqual (config .room_prejoin_state , DEFAULT_PREJOIN_STATE )
28
+ self .assertFalse (config .room_prejoin_state .has_wildcards ())
29
+ self .assertEqual (
30
+ set (config .room_prejoin_state .concrete_types ()), DEFAULT_PREJOIN_STATE_PAIRS
31
+ )
48
32
49
33
def test_disable_default_event_types (self ) -> None :
50
34
config = self .read_config (
@@ -53,7 +37,7 @@ def test_disable_default_event_types(self) -> None:
53
37
disable_default_event_types: true
54
38
"""
55
39
)
56
- self .assertEqual (config .room_prejoin_state , {} )
40
+ self .assertEqual (config .room_prejoin_state , StateFilter . none () )
57
41
58
42
def test_event_without_state_key (self ) -> None :
59
43
config = self .read_config (
@@ -64,7 +48,8 @@ def test_event_without_state_key(self) -> None:
64
48
- foo
65
49
"""
66
50
)
67
- self .assertEqual (config .room_prejoin_state , {"foo" : StateKeyFilter .any ()})
51
+ self .assertEqual (config .room_prejoin_state .wildcard_types (), ["foo" ])
52
+ self .assertEqual (config .room_prejoin_state .concrete_types (), [])
68
53
69
54
def test_event_with_specific_state_key (self ) -> None :
70
55
config = self .read_config (
@@ -75,7 +60,11 @@ def test_event_with_specific_state_key(self) -> None:
75
60
- [foo, bar]
76
61
"""
77
62
)
78
- self .assertEqual (config .room_prejoin_state , {"foo" : StateKeyFilter .only ("bar" )})
63
+ self .assertFalse (config .room_prejoin_state .has_wildcards ())
64
+ self .assertEqual (
65
+ set (config .room_prejoin_state .concrete_types ()),
66
+ {("foo" , "bar" )},
67
+ )
79
68
80
69
def test_repeated_event_with_specific_state_key (self ) -> None :
81
70
config = self .read_config (
@@ -87,8 +76,10 @@ def test_repeated_event_with_specific_state_key(self) -> None:
87
76
- [foo, baz]
88
77
"""
89
78
)
79
+ self .assertFalse (config .room_prejoin_state .has_wildcards ())
90
80
self .assertEqual (
91
- config .room_prejoin_state , {"foo" : StateKeyFilter ({"bar" , "baz" })}
81
+ set (config .room_prejoin_state .concrete_types ()),
82
+ {("foo" , "bar" ), ("foo" , "baz" )},
92
83
)
93
84
94
85
def test_no_specific_state_key_overrides_specific_state_key (self ) -> None :
@@ -101,7 +92,8 @@ def test_no_specific_state_key_overrides_specific_state_key(self) -> None:
101
92
- foo
102
93
"""
103
94
)
104
- self .assertEqual (config .room_prejoin_state , {"foo" : StateKeyFilter .any ()})
95
+ self .assertEqual (config .room_prejoin_state .wildcard_types (), ["foo" ])
96
+ self .assertEqual (config .room_prejoin_state .concrete_types (), [])
105
97
106
98
config = self .read_config (
107
99
"""
@@ -112,7 +104,8 @@ def test_no_specific_state_key_overrides_specific_state_key(self) -> None:
112
104
- [foo, bar]
113
105
"""
114
106
)
115
- self .assertEqual (config .room_prejoin_state , {"foo" : StateKeyFilter .any ()})
107
+ self .assertEqual (config .room_prejoin_state .wildcard_types (), ["foo" ])
108
+ self .assertEqual (config .room_prejoin_state .concrete_types (), [])
116
109
117
110
def test_bad_event_type_entry_raises (self ) -> None :
118
111
with self .assertRaises (ConfigError ):
0 commit comments