|
20 | 20 |
|
21 | 21 | from synapse.api.constants import LoginType
|
22 | 22 | from synapse.api.errors import Codes, LoginError, SynapseError
|
| 23 | +from synapse.types import UserID |
23 | 24 | from synapse.util import json_decoder
|
24 | 25 |
|
25 | 26 | if TYPE_CHECKING:
|
@@ -314,12 +315,94 @@ async def check_auth(self, authdict: dict, clientip: str) -> Any:
|
314 | 315 | )
|
315 | 316 |
|
316 | 317 |
|
| 318 | +class JwtAuthChecker(UserInteractiveAuthChecker): |
| 319 | + AUTH_TYPE = LoginType.JWT |
| 320 | + |
| 321 | + def __init__(self, hs: "HomeServer"): |
| 322 | + super().__init__(hs) |
| 323 | + self.hs = hs |
| 324 | + |
| 325 | + def is_enabled(self) -> bool: |
| 326 | + return bool(self.hs.config.jwt.jwt_enabled) |
| 327 | + |
| 328 | + async def check_auth(self, authdict: dict, clientip: str) -> Any: |
| 329 | + token = authdict.get("token", None) |
| 330 | + if token is None: |
| 331 | + raise LoginError( |
| 332 | + 403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN |
| 333 | + ) |
| 334 | + |
| 335 | + from authlib.jose import JsonWebToken, JWTClaims |
| 336 | + from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError |
| 337 | + |
| 338 | + jwt = JsonWebToken([self.hs.config.jwt.jwt_algorithm]) |
| 339 | + claim_options = {} |
| 340 | + if self.hs.config.jwt.jwt_issuer is not None: |
| 341 | + claim_options["iss"] = { |
| 342 | + "value": self.hs.config.jwt.jwt_issuer, |
| 343 | + "essential": True, |
| 344 | + } |
| 345 | + if self.hs.config.jwt.jwt_audiences is not None: |
| 346 | + claim_options["aud"] = { |
| 347 | + "values": self.hs.config.jwt.jwt_audiences, |
| 348 | + "essential": True, |
| 349 | + } |
| 350 | + |
| 351 | + try: |
| 352 | + claims = jwt.decode( |
| 353 | + token, |
| 354 | + key=self.hs.config.jwt.jwt_secret, |
| 355 | + claims_cls=JWTClaims, |
| 356 | + claims_options=claim_options, |
| 357 | + ) |
| 358 | + except BadSignatureError: |
| 359 | + # We handle this case separately to provide a better error message |
| 360 | + raise LoginError( |
| 361 | + 403, |
| 362 | + "JWT validation failed: Signature verification failed", |
| 363 | + errcode=Codes.FORBIDDEN, |
| 364 | + ) |
| 365 | + except JoseError as e: |
| 366 | + # A JWT error occurred, return some info back to the client. |
| 367 | + raise LoginError( |
| 368 | + 403, |
| 369 | + "JWT validation failed: %s" % (str(e),), |
| 370 | + errcode=Codes.FORBIDDEN, |
| 371 | + ) |
| 372 | + |
| 373 | + try: |
| 374 | + claims.validate(leeway=120) # allows 2 min of clock skew |
| 375 | + |
| 376 | + # Enforce the old behavior which is rolled out in productive |
| 377 | + # servers: if the JWT contains an 'aud' claim but none is |
| 378 | + # configured, the login attempt will fail |
| 379 | + if claims.get("aud") is not None: |
| 380 | + if ( |
| 381 | + self.hs.config.jwt.jwt_audiences is None |
| 382 | + or len(self.hs.config.jwt.jwt_audiences) == 0 |
| 383 | + ): |
| 384 | + raise InvalidClaimError("aud") |
| 385 | + except JoseError as e: |
| 386 | + raise LoginError( |
| 387 | + 403, |
| 388 | + "JWT validation failed: %s" % (str(e),), |
| 389 | + errcode=Codes.FORBIDDEN, |
| 390 | + ) |
| 391 | + |
| 392 | + user = claims.get(self.hs.config.jwt.jwt_subject_claim, None) |
| 393 | + if user is None: |
| 394 | + raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN) |
| 395 | + |
| 396 | + return UserID(user, self.hs.hostname).to_string() |
| 397 | + |
| 398 | + |
317 | 399 | INTERACTIVE_AUTH_CHECKERS: Sequence[Type[UserInteractiveAuthChecker]] = [
|
318 | 400 | DummyAuthChecker,
|
319 | 401 | TermsAuthChecker,
|
320 | 402 | RecaptchaAuthChecker,
|
321 | 403 | EmailIdentityAuthChecker,
|
322 | 404 | MsisdnAuthChecker,
|
323 | 405 | RegistrationTokenAuthChecker,
|
| 406 | + JwtAuthChecker, |
324 | 407 | ]
|
325 | 408 | """A list of UserInteractiveAuthChecker classes"""
|
0 commit comments