11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- from typing import TYPE_CHECKING , Optional
14
+ from typing import TYPE_CHECKING , Collection , Optional
15
15
16
16
from synapse .api .constants import EventTypes , JoinRules , Membership
17
17
from synapse .api .errors import AuthError
@@ -59,32 +59,76 @@ async def check_restricted_join_rules(
59
59
):
60
60
return
61
61
62
+ # This is not a room with a restricted join rule, so we don't need to do the
63
+ # restricted room specific checks.
64
+ #
65
+ # Note: We'll be applying the standard join rule checks later, which will
66
+ # catch the cases of e.g. trying to join private rooms without an invite.
67
+ if not await self .has_restricted_join_rules (state_ids , room_version ):
68
+ return
69
+
70
+ # Get the spaces which allow access to this room and check if the user is
71
+ # in any of them.
72
+ allowed_spaces = await self .get_spaces_that_allow_join (state_ids )
73
+ if not await self .is_user_in_rooms (allowed_spaces , user_id ):
74
+ raise AuthError (
75
+ 403 ,
76
+ "You do not belong to any of the required spaces to join this room." ,
77
+ )
78
+
79
+ async def has_restricted_join_rules (
80
+ self , state_ids : StateMap [str ], room_version : RoomVersion
81
+ ) -> bool :
82
+ """
83
+ Return if the room has the proper join rules set for access via spaces.
84
+
85
+ Args:
86
+ state_ids: The state of the room as it currently is.
87
+ room_version: The room version of the room to query.
88
+
89
+ Returns:
90
+ True if the proper room version and join rules are set for restricted access.
91
+ """
62
92
# This only applies to room versions which support the new join rule.
63
93
if not room_version .msc3083_join_rules :
64
- return
94
+ return False
65
95
66
96
# If there's no join rule, then it defaults to invite (so this doesn't apply).
67
97
join_rules_event_id = state_ids .get ((EventTypes .JoinRules , "" ), None )
68
98
if not join_rules_event_id :
69
- return
99
+ return False
100
+
101
+ # If the join rule is not restricted, this doesn't apply.
102
+ join_rules_event = await self ._store .get_event (join_rules_event_id )
103
+ return join_rules_event .content .get ("join_rule" ) == JoinRules .MSC3083_RESTRICTED
104
+
105
+ async def get_spaces_that_allow_join (
106
+ self , state_ids : StateMap [str ]
107
+ ) -> Collection [str ]:
108
+ """
109
+ Generate a list of spaces which allow access to a room.
110
+
111
+ Args:
112
+ state_ids: The state of the room as it currently is.
113
+
114
+ Returns:
115
+ A collection of spaces which provide membership to the room.
116
+ """
117
+ # If there's no join rule, then it defaults to invite (so this doesn't apply).
118
+ join_rules_event_id = state_ids .get ((EventTypes .JoinRules , "" ), None )
119
+ if not join_rules_event_id :
120
+ return ()
70
121
71
122
# If the join rule is not restricted, this doesn't apply.
72
123
join_rules_event = await self ._store .get_event (join_rules_event_id )
73
- if join_rules_event .content .get ("join_rule" ) != JoinRules .MSC3083_RESTRICTED :
74
- return
75
124
76
125
# If allowed is of the wrong form, then only allow invited users.
77
126
allowed_spaces = join_rules_event .content .get ("allow" , [])
78
127
if not isinstance (allowed_spaces , list ):
79
- allowed_spaces = ()
80
-
81
- # Get the list of joined rooms and see if there's an overlap.
82
- if allowed_spaces :
83
- joined_rooms = await self ._store .get_rooms_for_user (user_id )
84
- else :
85
- joined_rooms = ()
128
+ return ()
86
129
87
130
# Pull out the other room IDs, invalid data gets filtered.
131
+ result = []
88
132
for space in allowed_spaces :
89
133
if not isinstance (space , dict ):
90
134
continue
@@ -93,13 +137,31 @@ async def check_restricted_join_rules(
93
137
if not isinstance (space_id , str ):
94
138
continue
95
139
96
- # The user was joined to one of the spaces specified, they can join
97
- # this room!
98
- if space_id in joined_rooms :
99
- return
140
+ result .append (space_id )
141
+
142
+ return result
143
+
144
+ async def is_user_in_rooms (self , room_ids : Collection [str ], user_id : str ) -> bool :
145
+ """
146
+ Check whether a user is a member of any of the provided rooms.
147
+
148
+ Args:
149
+ room_ids: The rooms to check for membership.
150
+ user_id: The user to check.
151
+
152
+ Returns:
153
+ True if the user is in any of the rooms, false otherwise.
154
+ """
155
+ if not room_ids :
156
+ return False
157
+
158
+ # Get the list of joined rooms and see if there's an overlap.
159
+ joined_rooms = await self ._store .get_rooms_for_user (user_id )
160
+
161
+ # Check each room and see if the user is in it.
162
+ for room_id in room_ids :
163
+ if room_id in joined_rooms :
164
+ return True
100
165
101
- # The user was not in any of the required spaces.
102
- raise AuthError (
103
- 403 ,
104
- "You do not belong to any of the required spaces to join this room." ,
105
- )
166
+ # The user was not in any of the rooms.
167
+ return False
0 commit comments