Skip to content

Commit f4e912f

Browse files
committed
Make JWT require to know what to expect
This is needed to address CVE-2022-3102. Thanks to Tom tervoort from Secura for finding and reporting this issue. Also test that "unepxected" token types are not validated Signed-off-by: Simo Sorce <[email protected]>
1 parent 5a13cfc commit f4e912f

File tree

3 files changed

+101
-11
lines changed

3 files changed

+101
-11
lines changed

docs/source/jwt.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Now decrypt and verify::
4242
>>> k = {"k": "Wal4ZHCBsml0Al_Y8faoNTKsXCkw8eefKXYFuwTBOpA", "kty": "oct"}
4343
>>> key = jwk.JWK(**k)
4444
>>> e = 'eyJhbGciOiJBMjU2S1ciLCJlbmMiOiJBMjU2Q0JDLUhTNTEyIn0.ST5RmjqDLj696xo7YFTFuKUhcd3naCrm6yMjBM3cqWiFD6U8j2JIsbclsF7ryNg8Ktmt1kQJRKavV6DaTl1T840tP3sIs1qz.wSxVhZH5GyzbJnPBAUMdzQ.6uiVYwrRBzAm7Uge9rEUjExPWGbgerF177A7tMuQurJAqBhgk3_5vee5DRH84kHSapFOxcEuDdMBEQLI7V2E0F57-d01TFStHzwtgtSmeZRQ6JSIL5XlgJouwHfSxn9Z_TGl5xxq4TksORHED1vnRA.5jPyPWanJVqlOohApEbHmxi3JHp1MXbmvQe2_dVd8FI'
45-
>>> ET = jwt.JWT(key=key, jwt=e)
45+
>>> ET = jwt.JWT(key=key, jwt=e, expected_type="JWE")
4646
>>> ST = jwt.JWT(key=key, jwt=ET.claims)
4747
>>> ST.claims
4848
'{"info":"I\'m a signed token"}'

jwcrypto/jwt.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from jwcrypto.common import JWException, JWKeyNotFound
1010
from jwcrypto.common import json_decode, json_encode
1111
from jwcrypto.jwe import JWE
12+
from jwcrypto.jwe import default_allowed_algs as jwe_algs
1213
from jwcrypto.jws import JWS
1314

1415

@@ -153,7 +154,8 @@ class JWT:
153154
"""
154155

155156
def __init__(self, header=None, claims=None, jwt=None, key=None,
156-
algs=None, default_claims=None, check_claims=None):
157+
algs=None, default_claims=None, check_claims=None,
158+
expected_type=None):
157159
"""Creates a JWT object.
158160
159161
:param header: A dict or a JSON string with the JWT Header data.
@@ -169,6 +171,12 @@ def __init__(self, header=None, claims=None, jwt=None, key=None,
169171
:param check_claims: An optional dict of claims that must be
170172
present in the token, if the value is not None the claim must
171173
match exactly.
174+
:param expected_type: An optional string that defines what kind
175+
of token to expect when validating a deserialized token.
176+
Supported values: "JWS" or "JWE"
177+
If left to None the code will try to detect what the expected
178+
type is based on other parameters like 'algs' and will default
179+
to JWS if no hints are found. It has no effect on token creation.
172180
173181
Note: either the header,claims or jwt,key parameters should be
174182
provided as a deserialization operation (which occurs if the jwt
@@ -190,6 +198,7 @@ def __init__(self, header=None, claims=None, jwt=None, key=None,
190198
self._leeway = 60 # 1 minute clock skew allowed
191199
self._validity = 600 # 10 minutes validity (up to 11 with leeway)
192200
self.deserializelog = None
201+
self._expected_type = expected_type
193202

194203
if header:
195204
self.header = header
@@ -276,6 +285,33 @@ def validity(self):
276285
def validity(self, v):
277286
self._validity = int(v)
278287

288+
@property
289+
def expected_type(self):
290+
if self._expected_type is not None:
291+
return self._expected_type
292+
293+
# If no expected type is set we default to accept only JWSs,
294+
# however to improve backwards compatibility we try some
295+
# heuristic to see if there has been strong indication of
296+
# what the expected token type is.
297+
if self._expected_type is None and self._algs:
298+
if set(self._algs).issubset(jwe_algs + ['RSA1_5']):
299+
self._expected_type = "JWE"
300+
if self._expected_type is None and self._header:
301+
if "enc" in json_decode(self._header):
302+
self._expected_type = "JWE"
303+
if self._expected_type is None:
304+
self._expected_type = "JWS"
305+
306+
return self._expected_type
307+
308+
@expected_type.setter
309+
def expected_type(self, v):
310+
if v in ["JWS", "JWE"]:
311+
self._expected_type = v
312+
else:
313+
raise ValueError("Invalid value, must be 'JWS' or 'JWE'")
314+
279315
def _add_optional_claim(self, name, claims):
280316
if name in claims:
281317
return
@@ -472,6 +508,7 @@ def make_signed_token(self, key):
472508
t.allowed_algs = self._algs
473509
t.add_signature(key, protected=self.header)
474510
self.token = t
511+
self._expected_type = "JWS"
475512

476513
def make_encrypted_token(self, key):
477514
"""Encrypts the payload.
@@ -488,6 +525,7 @@ def make_encrypted_token(self, key):
488525
t.allowed_algs = self._algs
489526
t.add_recipient(key)
490527
self.token = t
528+
self._expected_type = "JWE"
491529

492530
def validate(self, key):
493531
"""Validate a JWT token that was deserialized w/o providing a key
@@ -500,13 +538,23 @@ def validate(self, key):
500538
if self.token is None:
501539
raise ValueError("Token empty")
502540

541+
et = self.expected_type
542+
validate_fn = None
543+
544+
if isinstance(self.token, JWS):
545+
if et != "JWS":
546+
raise TypeError("Expected {}, got JWS".format(et))
547+
validate_fn = self.token.verify
548+
elif isinstance(self.token, JWE):
549+
if et != "JWE":
550+
print("algs: {}".format(self._algs))
551+
raise TypeError("Expected {}, got JWE".format(et))
552+
validate_fn = self.token.decrypt
553+
else:
554+
raise ValueError("Token format unrecognized")
555+
503556
try:
504-
if isinstance(self.token, JWS):
505-
self.token.verify(key)
506-
elif isinstance(self.token, JWE):
507-
self.token.decrypt(key)
508-
else:
509-
raise ValueError("Token format unrecognized")
557+
validate_fn(key)
510558
self.deserializelog.append("Success")
511559
except Exception as e: # pylint: disable=broad-except
512560
if isinstance(self.token, JWS):
@@ -520,7 +568,10 @@ def validate(self, key):
520568
raise
521569

522570
self.header = self.token.jose_header
523-
self.claims = self.token.payload.decode('utf-8')
571+
payload = self.token.payload
572+
if isinstance(payload, bytes):
573+
payload = payload.decode('utf-8')
574+
self.claims = payload
524575
self._check_provided_claims()
525576

526577
def deserialize(self, jwt, key=None):

jwcrypto/tests.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from jwcrypto.common import json_decode, json_encode
2222

2323
jwe_algs_and_rsa1_5 = jwe.default_allowed_algs + ['RSA1_5']
24-
jws_algs_and_rsa1_5 = jws.default_allowed_algs + ['RSA1_5']
2524

2625
# RFC 7517 - A.1
2726
PublicKeys = {"keys": [
@@ -1531,9 +1530,11 @@ def test_A2(self):
15311530
tinner = jwt.JWT(jwt=touter.claims, key=sigkey, check_claims=False)
15321531
self.assertEqual(A1_claims, json_decode(tinner.claims))
15331532

1533+
# Test Exception throwing when token is encrypted with
1534+
# algorithms not in the allowed set
15341535
with self.assertRaises(jwe.InvalidJWEData):
15351536
jwt.JWT(jwt=A2_token, key=E_A2_ex['key'],
1536-
algs=jws_algs_and_rsa1_5)
1537+
algs=['A192KW', 'A192CBC-HS384', 'RSA1_5'])
15371538

15381539
def test_decrypt_keyset(self):
15391540
key = jwk.JWK(kid='testkey', **E_A2_key)
@@ -1738,6 +1739,43 @@ def test_Issue_277(self):
17381739
jwt=sertok, check_claims={"aud": ["nomatch",
17391740
"failmatch"]})
17401741

1742+
def test_unexpected(self):
1743+
key = jwk.JWK(generate='oct', size=256)
1744+
claims = {"testclaim": "test"}
1745+
token = jwt.JWT(header={"alg": "HS256"}, claims=claims)
1746+
token.make_signed_token(key)
1747+
sertok = token.serialize()
1748+
1749+
token.validate(key)
1750+
token.expected_type = "JWS"
1751+
token.validate(key)
1752+
token.expected_type = "JWE"
1753+
with self.assertRaises(TypeError):
1754+
token.validate(key)
1755+
1756+
jwt.JWT(jwt=sertok, key=key)
1757+
jwt.JWT(jwt=sertok, key=key, expected_type='JWS')
1758+
with self.assertRaises(TypeError):
1759+
jwt.JWT(jwt=sertok, key=key, expected_type='JWE')
1760+
1761+
token = jwt.JWT(header={"alg": "A256KW", "enc": "A256GCM"},
1762+
claims=claims)
1763+
token.make_encrypted_token(key)
1764+
enctok = token.serialize()
1765+
1766+
token.validate(key)
1767+
token.expected_type = "JWE"
1768+
token.validate(key)
1769+
token.expected_type = "JWS"
1770+
with self.assertRaises(TypeError):
1771+
token.validate(key)
1772+
1773+
jwt.JWT(jwt=enctok, key=key, expected_type='JWE')
1774+
with self.assertRaises(TypeError):
1775+
jwt.JWT(jwt=enctok, key=key)
1776+
with self.assertRaises(TypeError):
1777+
jwt.JWT(jwt=enctok, key=key, expected_type='JWS')
1778+
17411779

17421780
class ConformanceTests(unittest.TestCase):
17431781

@@ -2107,6 +2145,7 @@ def test_jwt_equality(self):
21072145

21082146
ect = jwt.JWT.from_jose_token(ea.serialize())
21092147
self.assertNotEqual(ea, ect)
2148+
ect.expected_type = "JWE"
21102149
ect.validate(key)
21112150
self.assertEqual(ea, ect)
21122151

0 commit comments

Comments
 (0)