Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Fix import cycle #11965

Merged
merged 2 commits into from
Feb 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11965.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix an import cycle in `synapse.event_auth`.
54 changes: 31 additions & 23 deletions synapse/event_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import logging
import typing
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union

from canonicaljson import encode_canonical_json
Expand All @@ -34,15 +35,18 @@
EventFormatVersions,
RoomVersion,
)
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
from synapse.types import StateMap, UserID, get_domain_from_id

if typing.TYPE_CHECKING:
# conditional imports to avoid import cycle
from synapse.events import EventBase
from synapse.events.builder import EventBuilder

logger = logging.getLogger(__name__)


def validate_event_for_room_version(
room_version_obj: RoomVersion, event: EventBase
room_version_obj: RoomVersion, event: "EventBase"
) -> None:
"""Ensure that the event complies with the limits, and has the right signatures

Expand Down Expand Up @@ -113,7 +117,9 @@ def validate_event_for_room_version(


def check_auth_rules_for_event(
room_version_obj: RoomVersion, event: EventBase, auth_events: Iterable[EventBase]
room_version_obj: RoomVersion,
event: "EventBase",
auth_events: Iterable["EventBase"],
) -> None:
"""Check that an event complies with the auth rules

Expand Down Expand Up @@ -256,7 +262,7 @@ def check_auth_rules_for_event(
logger.debug("Allowing! %s", event)


def _check_size_limits(event: EventBase) -> None:
def _check_size_limits(event: "EventBase") -> None:
if len(event.user_id) > 255:
raise EventSizeError("'user_id' too large")
if len(event.room_id) > 255:
Expand All @@ -271,7 +277,7 @@ def _check_size_limits(event: EventBase) -> None:
raise EventSizeError("event too large")


def _can_federate(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
def _can_federate(event: "EventBase", auth_events: StateMap["EventBase"]) -> bool:
creation_event = auth_events.get((EventTypes.Create, ""))
# There should always be a creation event, but if not don't federate.
if not creation_event:
Expand All @@ -281,7 +287,7 @@ def _can_federate(event: EventBase, auth_events: StateMap[EventBase]) -> bool:


def _is_membership_change_allowed(
room_version: RoomVersion, event: EventBase, auth_events: StateMap[EventBase]
room_version: RoomVersion, event: "EventBase", auth_events: StateMap["EventBase"]
) -> None:
"""
Confirms that the event which changes membership is an allowed change.
Expand Down Expand Up @@ -471,23 +477,25 @@ def _is_membership_change_allowed(


def _check_event_sender_in_room(
event: EventBase, auth_events: StateMap[EventBase]
event: "EventBase", auth_events: StateMap["EventBase"]
) -> None:
key = (EventTypes.Member, event.user_id)
member_event = auth_events.get(key)

_check_joined_room(member_event, event.user_id, event.room_id)


def _check_joined_room(member: Optional[EventBase], user_id: str, room_id: str) -> None:
def _check_joined_room(
member: Optional["EventBase"], user_id: str, room_id: str
) -> None:
if not member or member.membership != Membership.JOIN:
raise AuthError(
403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member))
)


def get_send_level(
etype: str, state_key: Optional[str], power_levels_event: Optional[EventBase]
etype: str, state_key: Optional[str], power_levels_event: Optional["EventBase"]
) -> int:
"""Get the power level required to send an event of a given type

Expand Down Expand Up @@ -523,7 +531,7 @@ def get_send_level(
return int(send_level)


def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
def _can_send_event(event: "EventBase", auth_events: StateMap["EventBase"]) -> bool:
power_levels_event = get_power_level_event(auth_events)

send_level = get_send_level(event.type, event.get("state_key"), power_levels_event)
Expand All @@ -547,8 +555,8 @@ def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:

def check_redaction(
room_version_obj: RoomVersion,
event: EventBase,
auth_events: StateMap[EventBase],
event: "EventBase",
auth_events: StateMap["EventBase"],
) -> bool:
"""Check whether the event sender is allowed to redact the target event.

Expand Down Expand Up @@ -585,8 +593,8 @@ def check_redaction(

def check_historical(
room_version_obj: RoomVersion,
event: EventBase,
auth_events: StateMap[EventBase],
event: "EventBase",
auth_events: StateMap["EventBase"],
) -> None:
"""Check whether the event sender is allowed to send historical related
events like "insertion", "batch", and "marker".
Expand Down Expand Up @@ -616,8 +624,8 @@ def check_historical(

def _check_power_levels(
room_version_obj: RoomVersion,
event: EventBase,
auth_events: StateMap[EventBase],
event: "EventBase",
auth_events: StateMap["EventBase"],
) -> None:
user_list = event.content.get("users", {})
# Validate users
Expand Down Expand Up @@ -710,11 +718,11 @@ def _check_power_levels(
)


def get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]:
def get_power_level_event(auth_events: StateMap["EventBase"]) -> Optional["EventBase"]:
return auth_events.get((EventTypes.PowerLevels, ""))


def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int:
def get_user_power_level(user_id: str, auth_events: StateMap["EventBase"]) -> int:
"""Get a user's power level

Args:
Expand Down Expand Up @@ -750,7 +758,7 @@ def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int:
return 0


def get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int:
def get_named_level(auth_events: StateMap["EventBase"], name: str, default: int) -> int:
power_level_event = get_power_level_event(auth_events)

if not power_level_event:
Expand All @@ -764,7 +772,7 @@ def get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -


def _verify_third_party_invite(
event: EventBase, auth_events: StateMap[EventBase]
event: "EventBase", auth_events: StateMap["EventBase"]
) -> bool:
"""
Validates that the invite event is authorized by a previous third-party invite.
Expand Down Expand Up @@ -829,7 +837,7 @@ def _verify_third_party_invite(
return False


def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]:
def get_public_keys(invite_event: "EventBase") -> List[Dict[str, Any]]:
public_keys = []
if "public_key" in invite_event.content:
o = {"public_key": invite_event.content["public_key"]}
Expand All @@ -841,7 +849,7 @@ def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]:


def auth_types_for_event(
room_version: RoomVersion, event: Union[EventBase, EventBuilder]
room_version: RoomVersion, event: Union["EventBase", "EventBuilder"]
) -> Set[Tuple[str, str]]:
"""Given an event, return a list of (EventType, StateKey) that may be
needed to auth the event. The returned list may be a superset of what
Expand Down