Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 1189be4

Browse files
committed
Factor _AccountHandler proxy out to ModuleApi
We're going to need to use this from places that aren't password auth, so let's move it to a proper class.
1 parent b19d9e2 commit 1189be4

File tree

3 files changed

+83
-70
lines changed

3 files changed

+83
-70
lines changed

docs/password_auth_providers.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Password auth provider classes must provide the following methods:
2727
*class* ``SomeProvider``\(*config*, *account_handler*)
2828

2929
The constructor is passed the config object returned by ``parse_config``,
30-
and a ``synapse.handlers.auth._AccountHandler`` object which allows the
30+
and a ``synapse.module_api.ModuleApi`` object which allows the
3131
password provider to check if accounts exist and/or create new ones.
3232

3333
Optional methods

synapse/handlers/auth.py

Lines changed: 3 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16-
1716
from twisted.internet import defer
1817

1918
from ._base import BaseHandler
2019
from synapse.api.constants import LoginType
21-
from synapse.types import UserID
2220
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
21+
from synapse.module_api import ModuleApi
22+
from synapse.types import UserID
2323
from synapse.util.async import run_on_reactor
2424
from synapse.util.caches.expiringcache import ExpiringCache
2525

@@ -63,10 +63,7 @@ def __init__(self, hs):
6363
reset_expiry_on_get=True,
6464
)
6565

66-
account_handler = _AccountHandler(
67-
hs, check_user_exists=self.check_user_exists
68-
)
69-
66+
account_handler = ModuleApi(hs, self)
7067
self.password_providers = [
7168
module(config=config, account_handler=account_handler)
7269
for module, config in hs.config.password_providers
@@ -843,66 +840,3 @@ def _generate_base_macaroon(self, user_id):
843840
macaroon.add_first_party_caveat("gen = 1")
844841
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
845842
return macaroon
846-
847-
848-
class _AccountHandler(object):
849-
"""A proxy object that gets passed to password auth providers so they
850-
can register new users etc if necessary.
851-
"""
852-
def __init__(self, hs, check_user_exists):
853-
self.hs = hs
854-
855-
self._check_user_exists = check_user_exists
856-
self._store = hs.get_datastore()
857-
858-
def get_qualified_user_id(self, username):
859-
"""Qualify a user id, if necessary
860-
861-
Takes a user id provided by the user and adds the @ and :domain to
862-
qualify it, if necessary
863-
864-
Args:
865-
username (str): provided user id
866-
867-
Returns:
868-
str: qualified @user:id
869-
"""
870-
if username.startswith('@'):
871-
return username
872-
return UserID(username, self.hs.hostname).to_string()
873-
874-
def check_user_exists(self, user_id):
875-
"""Check if user exists.
876-
877-
Args:
878-
user_id (str): Complete @user:id
879-
880-
Returns:
881-
Deferred[str|None]: Canonical (case-corrected) user_id, or None
882-
if the user is not registered.
883-
"""
884-
return self._check_user_exists(user_id)
885-
886-
def register(self, localpart):
887-
"""Registers a new user with given localpart
888-
889-
Returns:
890-
Deferred: a 2-tuple of (user_id, access_token)
891-
"""
892-
reg = self.hs.get_handlers().registration_handler
893-
return reg.register(localpart=localpart)
894-
895-
def run_db_interaction(self, desc, func, *args, **kwargs):
896-
"""Run a function with a database connection
897-
898-
Args:
899-
desc (str): description for the transaction, for metrics etc
900-
func (func): function to be run. Passed a database cursor object
901-
as well as *args and **kwargs
902-
*args: positional args to be passed to func
903-
**kwargs: named args to be passed to func
904-
905-
Returns:
906-
Deferred[object]: result of func
907-
"""
908-
return self._store.runInteraction(desc, func, *args, **kwargs)

synapse/module_api/__init__.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2017 New Vector Ltd
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from synapse.types import UserID
17+
18+
19+
class ModuleApi(object):
20+
"""A proxy object that gets passed to password auth providers so they
21+
can register new users etc if necessary.
22+
"""
23+
def __init__(self, hs, auth_handler):
24+
self.hs = hs
25+
26+
self._store = hs.get_datastore()
27+
self._auth_handler = auth_handler
28+
29+
def get_qualified_user_id(self, username):
30+
"""Qualify a user id, if necessary
31+
32+
Takes a user id provided by the user and adds the @ and :domain to
33+
qualify it, if necessary
34+
35+
Args:
36+
username (str): provided user id
37+
38+
Returns:
39+
str: qualified @user:id
40+
"""
41+
if username.startswith('@'):
42+
return username
43+
return UserID(username, self.hs.hostname).to_string()
44+
45+
def check_user_exists(self, user_id):
46+
"""Check if user exists.
47+
48+
Args:
49+
user_id (str): Complete @user:id
50+
51+
Returns:
52+
Deferred[str|None]: Canonical (case-corrected) user_id, or None
53+
if the user is not registered.
54+
"""
55+
return self._auth_handler.check_user_exists(user_id)
56+
57+
def register(self, localpart):
58+
"""Registers a new user with given localpart
59+
60+
Returns:
61+
Deferred: a 2-tuple of (user_id, access_token)
62+
"""
63+
reg = self.hs.get_handlers().registration_handler
64+
return reg.register(localpart=localpart)
65+
66+
def run_db_interaction(self, desc, func, *args, **kwargs):
67+
"""Run a function with a database connection
68+
69+
Args:
70+
desc (str): description for the transaction, for metrics etc
71+
func (func): function to be run. Passed a database cursor object
72+
as well as *args and **kwargs
73+
*args: positional args to be passed to func
74+
**kwargs: named args to be passed to func
75+
76+
Returns:
77+
Deferred[object]: result of func
78+
"""
79+
return self._store.runInteraction(desc, func, *args, **kwargs)

0 commit comments

Comments
 (0)