Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Add some type hints to datastore #12717

Merged
merged 16 commits into from
May 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12717.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add some type hints to datastore.
2 changes: 0 additions & 2 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ exclude = (?x)
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/event_federation.py
|synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/roommember.py
|synapse/storage/schema/

|tests/api/test_auth.py
Expand Down
24 changes: 17 additions & 7 deletions synapse/federation/sender/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@
import abc
import logging
from collections import OrderedDict
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set, Tuple
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Hashable,
Iterable,
List,
Optional,
Set,
Tuple,
)

import attr
from prometheus_client import Counter
Expand Down Expand Up @@ -409,7 +419,7 @@ async def handle_event(event: EventBase) -> None:
)
return

destinations: Optional[Set[str]] = None
destinations: Optional[Collection[str]] = None
if not event.prev_event_ids():
# If there are no prev event IDs then the state is empty
# and so no remote servers in the room
Expand Down Expand Up @@ -444,7 +454,7 @@ async def handle_event(event: EventBase) -> None:
)
return

destinations = {
sharded_destinations = {
d
for d in destinations
if self._federation_shard_config.should_handle(
Expand All @@ -456,12 +466,12 @@ async def handle_event(event: EventBase) -> None:
# If we are sending the event on behalf of another server
# then it already has the event and there is no reason to
# send the event to it.
destinations.discard(send_on_behalf_of)
sharded_destinations.discard(send_on_behalf_of)

logger.debug("Sending %s to %r", event, destinations)
logger.debug("Sending %s to %r", event, sharded_destinations)

if destinations:
await self._send_pdu(event, destinations)
if sharded_destinations:
await self._send_pdu(event, sharded_destinations)

now = self.clock.time_msec()
ts = await self.store.get_received_ts(event.event_id)
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,10 +410,10 @@ async def current_sync_for_user(
set_tag(SynapseTags.SYNC_RESULT, bool(sync_result))
return sync_result

async def push_rules_for_user(self, user: UserID) -> JsonDict:
async def push_rules_for_user(self, user: UserID) -> Dict[str, Dict[str, list]]:
user_id = user.to_string()
rules = await self.store.get_push_rules_for_user(user_id)
rules = format_push_rules_for_user(user, rules)
rules_raw = await self.store.get_push_rules_for_user(user_id)
rules = format_push_rules_for_user(user, rules_raw)
return rules

async def ephemeral_by_room(
Expand Down
4 changes: 2 additions & 2 deletions synapse/rest/client/push_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ async def on_GET(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDic
# we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference
rules = await self.store.get_push_rules_for_user(user_id)
rules_raw = await self.store.get_push_rules_for_user(user_id)

rules = format_push_rules_for_user(requester.user, rules)
rules = format_push_rules_for_user(requester.user, rules_raw)

path_parts = path.split("/")[1:]

Expand Down
4 changes: 2 additions & 2 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,13 @@ async def get_current_users_in_room(
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
return await self.store.get_joined_users_from_state(room_id, entry)

async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
async def get_current_hosts_in_room(self, room_id: str) -> FrozenSet[str]:
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
return await self.get_hosts_in_room_at_events(room_id, event_ids)

async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str]
) -> Set[str]:
) -> FrozenSet[str]:
"""Get the hosts that were in a room at the given event ids

Args:
Expand Down
8 changes: 1 addition & 7 deletions synapse/storage/databases/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,7 @@
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import (
IdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache

Expand Down Expand Up @@ -155,8 +151,6 @@ def __init__(
],
)

self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
self._group_updates_id_gen = StreamIdGenerator(
db_conn, "local_group_updates", "stream_id"
)
Expand Down
56 changes: 28 additions & 28 deletions synapse/storage/databases/main/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,19 @@
import calendar
import logging
import time
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, Dict, List, Tuple, cast

from synapse.metrics import GaugeBucketCollector
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
from synapse.storage.types import Cursor

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -73,7 +76,7 @@ def __init__(

@wrap_as_background_process("read_forward_extremities")
async def _read_forward_extremities(self) -> None:
def fetch(txn):
def fetch(txn: LoggingTransaction) -> List[Tuple[int, int]]:
txn.execute(
"""
SELECT t1.c, t2.c
Expand All @@ -86,7 +89,7 @@ def fetch(txn):
) t2 ON t1.room_id = t2.room_id
"""
)
return txn.fetchall()
return cast(List[Tuple[int, int]], txn.fetchall())

res = await self.db_pool.runInteraction("read_forward_extremities", fetch)

Expand All @@ -104,20 +107,20 @@ async def count_daily_e2ee_messages(self) -> int:
call to this function, it will return None.
"""

def _count_messages(txn):
def _count_messages(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(*) FROM events
WHERE type = 'm.room.encrypted'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count

return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)

async def count_daily_sent_e2ee_messages(self) -> int:
def _count_messages(txn):
def _count_messages(txn: LoggingTransaction) -> int:
# This is good enough as if you have silly characters in your own
# hostname then that's your own fault.
like_clause = "%:" + self.hs.hostname
Expand All @@ -130,22 +133,22 @@ def _count_messages(txn):
"""

txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count

return await self.db_pool.runInteraction(
"count_daily_sent_e2ee_messages", _count_messages
)

async def count_daily_active_e2ee_rooms(self) -> int:
def _count(txn):
def _count(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(DISTINCT room_id) FROM events
WHERE type = 'm.room.encrypted'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count

return await self.db_pool.runInteraction(
Expand All @@ -160,20 +163,20 @@ async def count_daily_messages(self) -> int:
call to this function, it will return None.
"""

def _count_messages(txn):
def _count_messages(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(*) FROM events
WHERE type = 'm.room.message'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count

return await self.db_pool.runInteraction("count_messages", _count_messages)

async def count_daily_sent_messages(self) -> int:
def _count_messages(txn):
def _count_messages(txn: LoggingTransaction) -> int:
# This is good enough as if you have silly characters in your own
# hostname then that's your own fault.
like_clause = "%:" + self.hs.hostname
Expand All @@ -186,22 +189,22 @@ def _count_messages(txn):
"""

txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count

return await self.db_pool.runInteraction(
"count_daily_sent_messages", _count_messages
)

async def count_daily_active_rooms(self) -> int:
def _count(txn):
def _count(txn: LoggingTransaction) -> int:
sql = """
SELECT COUNT(DISTINCT room_id) FROM events
WHERE type = 'm.room.message'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
return count

return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
Expand All @@ -227,7 +230,7 @@ async def count_monthly_users(self) -> int:
"count_monthly_users", self._count_users, thirty_days_ago
)

def _count_users(self, txn: Cursor, time_from: int) -> int:
def _count_users(self, txn: LoggingTransaction, time_from: int) -> int:
"""
Returns number of users seen in the past time_from period
"""
Expand All @@ -242,7 +245,7 @@ def _count_users(self, txn: Cursor, time_from: int) -> int:
# Mypy knows that fetchone() might return None if there are no rows.
# We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
# returns exactly one row.
(count,) = txn.fetchone() # type: ignore[misc]
(count,) = cast(Tuple[int], txn.fetchone())
return count

async def count_r30_users(self) -> Dict[str, int]:
Expand All @@ -256,7 +259,7 @@ async def count_r30_users(self) -> Dict[str, int]:
A mapping of counts globally as well as broken out by platform.
"""

def _count_r30_users(txn):
def _count_r30_users(txn: LoggingTransaction) -> Dict[str, int]:
thirty_days_in_secs = 86400 * 30
now = int(self._clock.time())
thirty_days_ago_in_secs = now - thirty_days_in_secs
Expand Down Expand Up @@ -321,7 +324,7 @@ def _count_r30_users(txn):

txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))

(count,) = txn.fetchone()
(count,) = cast(Tuple[int], txn.fetchone())
results["all"] = count

return results
Expand All @@ -348,7 +351,7 @@ async def count_r30v2_users(self) -> Dict[str, int]:
- "web" (any web application -- it's not possible to distinguish Element Web here)
"""

def _count_r30v2_users(txn):
def _count_r30v2_users(txn: LoggingTransaction) -> Dict[str, int]:
thirty_days_in_secs = 86400 * 30
now = int(self._clock.time())
sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs
Expand Down Expand Up @@ -445,11 +448,8 @@ def _count_r30v2_users(txn):
thirty_days_in_secs * 1000,
),
)
row = txn.fetchone()
if row is None:
results["all"] = 0
else:
results["all"] = row[0]
(count,) = cast(Tuple[int], txn.fetchone())
results["all"] = count

return results

Expand All @@ -471,7 +471,7 @@ async def generate_user_daily_visits(self) -> None:
Generates daily visit data for use in cohort/ retention analysis
"""

def _generate_user_daily_visits(txn):
def _generate_user_daily_visits(txn: LoggingTransaction) -> None:
logger.info("Calling _generate_user_daily_visits")
today_start = self._get_start_of_day()
a_day_in_milliseconds = 24 * 60 * 60 * 1000
Expand Down
Loading