Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 85a09f8

Browse files
authored
Fix module API's get_user_ip_and_agents function when run on workers (#11112)
1 parent 2b82ec4 commit 85a09f8

File tree

3 files changed

+91
-40
lines changed

3 files changed

+91
-40
lines changed

changelog.d/11112.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix a bug which caused the module API's `get_user_ip_and_agents` function to always fail on workers. `get_user_ip_and_agents` was introduced in 1.44.0 and did not function correctly on worker processes at the time.

synapse/module_api/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from synapse.logging.context import make_deferred_yieldable, run_in_background
4747
from synapse.metrics.background_process_metrics import run_as_background_process
4848
from synapse.rest.client.login import LoginResponse
49+
from synapse.storage import DataStore
4950
from synapse.storage.database import DatabasePool, LoggingTransaction
5051
from synapse.storage.databases.main.roommember import ProfileInfo
5152
from synapse.storage.state import StateFilter
@@ -61,6 +62,7 @@
6162
from synapse.util.caches.descriptors import cached
6263

6364
if TYPE_CHECKING:
65+
from synapse.app.generic_worker import GenericWorkerSlavedStore
6466
from synapse.server import HomeServer
6567

6668
"""
@@ -111,7 +113,9 @@ class ModuleApi:
111113
def __init__(self, hs: "HomeServer", auth_handler):
112114
self._hs = hs
113115

114-
self._store = hs.get_datastore()
116+
# TODO: Fix this type hint once the types for the data stores have been ironed
117+
# out.
118+
self._store: Union[DataStore, "GenericWorkerSlavedStore"] = hs.get_datastore()
115119
self._auth = hs.get_auth()
116120
self._auth_handler = auth_handler
117121
self._server_name = hs.hostname

synapse/storage/databases/main/client_ips.py

Lines changed: 85 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,58 @@ async def get_last_client_ip_by_device(
478478

479479
return {(d["user_id"], d["device_id"]): d for d in res}
480480

481+
async def get_user_ip_and_agents(
482+
self, user: UserID, since_ts: int = 0
483+
) -> List[LastConnectionInfo]:
484+
"""Fetch the IPs and user agents for a user since the given timestamp.
485+
486+
The result might be slightly out of date as client IPs are inserted in batches.
487+
488+
Args:
489+
user: The user for which to fetch IP addresses and user agents.
490+
since_ts: The timestamp after which to fetch IP addresses and user agents,
491+
in milliseconds.
492+
493+
Returns:
494+
A list of dictionaries, each containing:
495+
* `access_token`: The access token used.
496+
* `ip`: The IP address used.
497+
* `user_agent`: The last user agent seen for this access token and IP
498+
address combination.
499+
* `last_seen`: The timestamp at which this access token and IP address
500+
combination was last seen, in milliseconds.
501+
502+
Only the latest user agent for each access token and IP address combination
503+
is available.
504+
"""
505+
user_id = user.to_string()
506+
507+
def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]:
508+
txn.execute(
509+
"""
510+
SELECT access_token, ip, user_agent, last_seen FROM user_ips
511+
WHERE last_seen >= ? AND user_id = ?
512+
ORDER BY last_seen
513+
DESC
514+
""",
515+
(since_ts, user_id),
516+
)
517+
return cast(List[Tuple[str, str, str, int]], txn.fetchall())
518+
519+
rows = await self.db_pool.runInteraction(
520+
desc="get_user_ip_and_agents", func=get_recent
521+
)
522+
523+
return [
524+
{
525+
"access_token": access_token,
526+
"ip": ip,
527+
"user_agent": user_agent,
528+
"last_seen": last_seen,
529+
}
530+
for access_token, ip, user_agent, last_seen in rows
531+
]
532+
481533

482534
class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
483535
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
@@ -622,49 +674,43 @@ async def get_last_client_ip_by_device(
622674
async def get_user_ip_and_agents(
623675
self, user: UserID, since_ts: int = 0
624676
) -> List[LastConnectionInfo]:
677+
"""Fetch the IPs and user agents for a user since the given timestamp.
678+
679+
Args:
680+
user: The user for which to fetch IP addresses and user agents.
681+
since_ts: The timestamp after which to fetch IP addresses and user agents,
682+
in milliseconds.
683+
684+
Returns:
685+
A list of dictionaries, each containing:
686+
* `access_token`: The access token used.
687+
* `ip`: The IP address used.
688+
* `user_agent`: The last user agent seen for this access token and IP
689+
address combination.
690+
* `last_seen`: The timestamp at which this access token and IP address
691+
combination was last seen, in milliseconds.
692+
693+
Only the latest user agent for each access token and IP address combination
694+
is available.
625695
"""
626-
Fetch IP/User Agent connection since a given timestamp.
627-
"""
628-
user_id = user.to_string()
629-
results: Dict[Tuple[str, str], Tuple[str, int]] = {}
696+
results: Dict[Tuple[str, str], LastConnectionInfo] = {
697+
(connection["access_token"], connection["ip"]): connection
698+
for connection in await super().get_user_ip_and_agents(user, since_ts)
699+
}
630700

701+
# Overlay data that is pending insertion on top of the results from the
702+
# database.
703+
user_id = user.to_string()
631704
for key in self._batch_row_update:
632-
(
633-
uid,
634-
access_token,
635-
ip,
636-
) = key
705+
uid, access_token, ip = key
637706
if uid == user_id:
638707
user_agent, _, last_seen = self._batch_row_update[key]
639708
if last_seen >= since_ts:
640-
results[(access_token, ip)] = (user_agent, last_seen)
641-
642-
def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]:
643-
txn.execute(
644-
"""
645-
SELECT access_token, ip, user_agent, last_seen FROM user_ips
646-
WHERE last_seen >= ? AND user_id = ?
647-
ORDER BY last_seen
648-
DESC
649-
""",
650-
(since_ts, user_id),
651-
)
652-
return cast(List[Tuple[str, str, str, int]], txn.fetchall())
653-
654-
rows = await self.db_pool.runInteraction(
655-
desc="get_user_ip_and_agents", func=get_recent
656-
)
709+
results[(access_token, ip)] = {
710+
"access_token": access_token,
711+
"ip": ip,
712+
"user_agent": user_agent,
713+
"last_seen": last_seen,
714+
}
657715

658-
results.update(
659-
((access_token, ip), (user_agent, last_seen))
660-
for access_token, ip, user_agent, last_seen in rows
661-
)
662-
return [
663-
{
664-
"access_token": access_token,
665-
"ip": ip,
666-
"user_agent": user_agent,
667-
"last_seen": last_seen,
668-
}
669-
for (access_token, ip), (user_agent, last_seen) in results.items()
670-
]
716+
return list(results.values())

0 commit comments

Comments
 (0)