Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Save the OIDC session ID (sid) with the device on login #11482

Merged
merged 11 commits into from
Dec 6, 2021
1 change: 1 addition & 0 deletions changelog.d/11482.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Save the OpenID Connect session ID on login.
34 changes: 31 additions & 3 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import bcrypt
import pymacaroons
import unpaddedbase64
from pymacaroons.exceptions import MacaroonVerificationFailedException

from twisted.web.server import Request

Expand Down Expand Up @@ -182,8 +183,11 @@ class LoginTokenAttributes:

user_id = attr.ib(type=str)

# the SSO Identity Provider that the user authenticated with, to get this token
auth_provider_id = attr.ib(type=str)
"""The SSO Identity Provider that the user authenticated with, to get this token."""

auth_provider_session_id = attr.ib(type=Optional[str])
"""The session ID advertised by the SSO Identity Provider."""


class AuthHandler:
Expand Down Expand Up @@ -1650,6 +1654,7 @@ async def complete_sso_login(
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
auth_provider_session_id: Optional[str] = None,
) -> None:
"""Having figured out a mxid for this user, complete the HTTP request

Expand All @@ -1665,6 +1670,7 @@ async def complete_sso_login(
during successful login. Must be JSON serializable.
new_user: True if we should use wording appropriate to a user who has just
registered.
auth_provider_session_id: The session ID got during login from the SSO IdP.
"""
# If the account has been deactivated, do not proceed with the login
# flow.
Expand All @@ -1685,6 +1691,7 @@ async def complete_sso_login(
extra_attributes,
new_user=new_user,
user_profile_data=profile,
auth_provider_session_id=auth_provider_session_id,
)

def _complete_sso_login(
Expand All @@ -1696,6 +1703,7 @@ def _complete_sso_login(
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
user_profile_data: Optional[ProfileInfo] = None,
auth_provider_session_id: Optional[str] = None,
) -> None:
"""
The synchronous portion of complete_sso_login.
Expand All @@ -1717,7 +1725,9 @@ def _complete_sso_login(

# Create a login token
login_token = self.macaroon_gen.generate_short_term_login_token(
registered_user_id, auth_provider_id=auth_provider_id
registered_user_id,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)

# Append the login token to the original redirect URL (i.e. with its query
Expand Down Expand Up @@ -1822,6 +1832,7 @@ def generate_short_term_login_token(
self,
user_id: str,
auth_provider_id: str,
auth_provider_session_id: Optional[str] = None,
duration_in_ms: int = (2 * 60 * 1000),
) -> str:
macaroon = self._generate_base_macaroon(user_id)
Expand All @@ -1830,6 +1841,10 @@ def generate_short_term_login_token(
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
if auth_provider_session_id is not None:
macaroon.add_first_party_caveat(
"auth_provider_session_id = %s" % (auth_provider_session_id,)
)
return macaroon.serialize()

def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
Expand All @@ -1851,15 +1866,28 @@ def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
user_id = get_value_from_macaroon(macaroon, "user_id")
auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")

auth_provider_session_id: Optional[str] = None
try:
auth_provider_session_id = get_value_from_macaroon(
macaroon, "auth_provider_session_id"
)
except MacaroonVerificationFailedException:
pass

v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = login")
v.satisfy_general(lambda c: c.startswith("user_id = "))
v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = "))
satisfy_expiry(v, self.hs.get_clock().time_msec)
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)

return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
return LoginTokenAttributes(
user_id=user_id,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)

def generate_delete_pusher_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id)
Expand Down
8 changes: 8 additions & 0 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ async def check_device_registered(
user_id: str,
device_id: Optional[str],
initial_device_display_name: Optional[str] = None,
auth_provider_id: Optional[str] = None,
auth_provider_session_id: Optional[str] = None,
) -> str:
"""
If the given device has not been registered, register it with the
Expand All @@ -312,6 +314,8 @@ async def check_device_registered(
user_id: @user:id
device_id: device id supplied by client
initial_device_display_name: device display name from client
auth_provider_id: The SSO IdP the user used, if any.
auth_provider_session_id: The session ID (sid) got from a OIDC login.
Returns:
device id (generated if none was supplied)
"""
Expand All @@ -323,6 +327,8 @@ async def check_device_registered(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)
if new_device:
await self.notify_device_update(user_id, [device_id])
Expand All @@ -337,6 +343,8 @@ async def check_device_registered(
user_id=user_id,
device_id=new_device_id,
initial_device_display_name=initial_device_display_name,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)
if new_device:
await self.notify_device_update(user_id, [new_device_id])
Expand Down
58 changes: 35 additions & 23 deletions synapse/handlers/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from authlib.jose import JsonWebToken, jwt
from authlib.oauth2.auth import ClientAuth
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
from authlib.oidc.core import CodeIDToken, UserInfo
from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
from jinja2 import Environment, Template
from pymacaroons.exceptions import (
Expand Down Expand Up @@ -117,7 +117,8 @@ async def load_metadata(self) -> None:
for idp_id, p in self._providers.items():
try:
await p.load_metadata()
await p.load_jwks()
if not p._uses_userinfo:
await p.load_jwks()
except Exception as e:
raise Exception(
"Error while initialising OIDC provider %r" % (idp_id,)
Expand Down Expand Up @@ -498,10 +499,6 @@ async def load_jwks(self, force: bool = False) -> JWKS:
return await self._jwks.get()

async def _load_jwks(self) -> JWKS:
if self._uses_userinfo:
# We're not using jwt signing, return an empty jwk set
return {"keys": []}

metadata = await self.load_metadata()

# Load the JWKS using the `jwks_uri` metadata.
Expand Down Expand Up @@ -663,7 +660,7 @@ async def _fetch_userinfo(self, token: Token) -> UserInfo:

return UserInfo(resp)

async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
"""Return an instance of UserInfo from token's ``id_token``.

Args:
Expand All @@ -673,7 +670,7 @@ async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
request. This value should match the one inside the token.

Returns:
An object representing the user.
The decoded claims in the ID token.
"""
metadata = await self.load_metadata()
claims_params = {
Expand All @@ -684,9 +681,6 @@ async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
# If we got an `access_token`, there should be an `at_hash` claim
# in the `id_token` that we can check against.
claims_params["access_token"] = token["access_token"]
claims_cls = CodeIDToken
else:
claims_cls = ImplicitIDToken
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking again at what CodeIDToken and ImplicitIDToken verify in authlib, it really did not make sense to have that logic here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meaning that we should always be using CodeIDToken? Will this have any sort of user visible change? I think the result of this is given to mapping providers?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only things it changes is the verification done on claims in the ID token. CodeIDToken is for the flow we're using, ImplicitIDToken is for the implicit flow (where you get the access token/ID token directly in the callback params).


alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
jwt = JsonWebToken(alg_values)
Expand All @@ -703,7 +697,7 @@ async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
claims = jwt.decode(
id_token,
key=jwk_set,
claims_cls=claims_cls,
claims_cls=CodeIDToken,
claims_options=claim_options,
claims_params=claims_params,
)
Expand All @@ -713,15 +707,16 @@ async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
claims = jwt.decode(
id_token,
key=jwk_set,
claims_cls=claims_cls,
claims_cls=CodeIDToken,
claims_options=claim_options,
claims_params=claims_params,
)

logger.debug("Decoded id_token JWT %r; validating", claims)

claims.validate(leeway=120) # allows 2 min of clock skew
return UserInfo(claims)

return claims

async def handle_redirect_request(
self,
Expand Down Expand Up @@ -837,22 +832,37 @@ async def handle_oidc_callback(

logger.debug("Successfully obtained OAuth2 token data: %r", token)

# Now that we have a token, get the userinfo, either by decoding the
# `id_token` or by fetching the `userinfo_endpoint`.
# If there is an id_token, it should be validated, regardless if the
# userinfo endpoint is used or not.
if token.get("id_token") is not None:
try:
id_token = await self._parse_id_token(token, nonce=session_data.nonce)
sid = id_token.get("sid")
except Exception as e:
logger.exception("Invalid id_token")
self._sso_handler.render_error(request, "invalid_token", str(e))
return
else:
id_token = None
sid = None

# Now that we have a token, get the userinfo either from the `id_token`
# claims or by fetching the `userinfo_endpoint`.
if self._uses_userinfo:
try:
userinfo = await self._fetch_userinfo(token)
except Exception as e:
logger.exception("Could not fetch userinfo")
self._sso_handler.render_error(request, "fetch_error", str(e))
return
elif id_token is not None:
userinfo = UserInfo(id_token)
else:
try:
userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
except Exception as e:
logger.exception("Invalid id_token")
self._sso_handler.render_error(request, "invalid_token", str(e))
return
logger.error("Missing id_token in token response")
self._sso_handler.render_error(
request, "invalid_token", "Missing id_token in token response"
)
return

# first check if we're doing a UIA
if session_data.ui_auth_session_id:
Expand Down Expand Up @@ -884,7 +894,7 @@ async def handle_oidc_callback(
# Call the mapper to register/login the user
try:
await self._complete_oidc_login(
userinfo, token, request, session_data.client_redirect_url
userinfo, token, request, session_data.client_redirect_url, sid
)
except MappingException as e:
logger.exception("Could not map user")
Expand All @@ -896,6 +906,7 @@ async def _complete_oidc_login(
token: Token,
request: SynapseRequest,
client_redirect_url: str,
sid: Optional[str],
) -> None:
"""Given a UserInfo response, complete the login flow

Expand Down Expand Up @@ -1008,6 +1019,7 @@ async def grandfather_existing_users() -> Optional[str]:
oidc_response_to_user_attributes,
grandfather_existing_users,
extra_attributes,
auth_provider_session_id=sid,
)

def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
Expand Down
15 changes: 12 additions & 3 deletions synapse/handlers/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,7 @@ async def register_device(
is_appservice_ghost: bool = False,
auth_provider_id: Optional[str] = None,
should_issue_refresh_token: bool = False,
auth_provider_session_id: Optional[str] = None,
) -> Tuple[str, str, Optional[int], Optional[str]]:
"""Register a device for a user and generate an access token.

Expand All @@ -752,9 +753,9 @@ async def register_device(
device_id: The device ID to check, or None to generate a new one.
initial_display_name: An optional display name for the device.
is_guest: Whether this is a guest account
auth_provider_id: The SSO IdP the user used, if any (just used for the
prometheus metrics).
auth_provider_id: The SSO IdP the user used, if any.
should_issue_refresh_token: Whether it should also issue a refresh token
auth_provider_session_id: The session ID got during login from the SSO IdP.
Returns:
Tuple of device ID, access token, access token expiration time and refresh token
"""
Expand All @@ -765,6 +766,8 @@ async def register_device(
is_guest=is_guest,
is_appservice_ghost=is_appservice_ghost,
should_issue_refresh_token=should_issue_refresh_token,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)

login_counter.labels(
Expand All @@ -787,6 +790,8 @@ async def register_device_inner(
is_guest: bool = False,
is_appservice_ghost: bool = False,
should_issue_refresh_token: bool = False,
auth_provider_id: Optional[str] = None,
auth_provider_session_id: Optional[str] = None,
) -> LoginDict:
"""Helper for register_device

Expand All @@ -806,7 +811,11 @@ class and RegisterDeviceReplicationServlet.
refresh_token_id = None

registered_device_id = await self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
user_id,
device_id,
initial_display_name,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)
if is_guest:
assert access_token_expiry is None
Expand Down
4 changes: 4 additions & 0 deletions synapse/handlers/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ async def complete_sso_login_request(
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
extra_login_attributes: Optional[JsonDict] = None,
auth_provider_session_id: Optional[str] = None,
) -> None:
"""
Given an SSO ID, retrieve the user ID for it and possibly register the user.
Expand Down Expand Up @@ -415,6 +416,8 @@ async def complete_sso_login_request(
extra_login_attributes: An optional dictionary of extra
attributes to be provided to the client in the login response.

auth_provider_session_id: An optional session ID from the OIDC login

Raises:
MappingException if there was a problem mapping the response to a user.
RedirectException: if the mapping provider needs to redirect the user
Expand Down Expand Up @@ -490,6 +493,7 @@ async def complete_sso_login_request(
client_redirect_url,
extra_login_attributes,
new_user=new_user,
auth_provider_session_id=auth_provider_session_id,
)

async def _call_attribute_mapper(
Expand Down
2 changes: 2 additions & 0 deletions synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,7 @@ def generate_short_term_login_token(
user_id: str,
duration_in_ms: int = (2 * 60 * 1000),
auth_provider_id: str = "",
auth_provider_session_id: Optional[str] = None,
) -> str:
"""Generate a login token suitable for m.login.token authentication

Expand All @@ -643,6 +644,7 @@ def generate_short_term_login_token(
return self._hs.get_macaroon_generator().generate_short_term_login_token(
user_id,
auth_provider_id,
auth_provider_session_id,
duration_in_ms,
)

Expand Down
Loading