Skip to content

Commit 6dc23b3

Browse files
Validate sub and jti claims for the token (#1005)
* feat(jwt): Both JTI and sub are now being validated, test cases added, changelog updated * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 310962b commit 6dc23b3

File tree

5 files changed

+188
-1
lines changed

5 files changed

+188
-1
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,6 @@ target/
6363
.mypy_cache
6464
pip-wheel-metadata/
6565
.venv/
66+
67+
68+
.idea

CHANGELOG.rst

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ Changed
2727
jwt.encode({"payload":"abc"}, key=None, algorithm='none')
2828
```
2929

30+
- Added validation for 'sub' (subject) and 'jti' (JWT ID) claims in tokens
31+
3032
Fixed
3133
~~~~~
3234

jwt/api_jwt.py

+55-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
InvalidAudienceError,
1616
InvalidIssuedAtError,
1717
InvalidIssuerError,
18+
InvalidJTIError,
19+
InvalidSubjectError,
1820
MissingRequiredClaimError,
1921
)
2022
from .warnings import RemovedInPyjwt3Warning
@@ -39,6 +41,8 @@ def _get_default_options() -> dict[str, bool | list[str]]:
3941
"verify_iat": True,
4042
"verify_aud": True,
4143
"verify_iss": True,
44+
"verify_sub": True,
45+
"verify_jti": True,
4246
"require": [],
4347
}
4448

@@ -112,6 +116,7 @@ def decode_complete(
112116
# consider putting in options
113117
audience: str | Iterable[str] | None = None,
114118
issuer: str | Sequence[str] | None = None,
119+
subject: str | None = None,
115120
leeway: float | timedelta = 0,
116121
# kwargs
117122
**kwargs: Any,
@@ -145,6 +150,8 @@ def decode_complete(
145150
options.setdefault("verify_iat", False)
146151
options.setdefault("verify_aud", False)
147152
options.setdefault("verify_iss", False)
153+
options.setdefault("verify_sub", False)
154+
options.setdefault("verify_jti", False)
148155

149156
decoded = api_jws.decode_complete(
150157
jwt,
@@ -158,7 +165,12 @@ def decode_complete(
158165

159166
merged_options = {**self.options, **options}
160167
self._validate_claims(
161-
payload, merged_options, audience=audience, issuer=issuer, leeway=leeway
168+
payload,
169+
merged_options,
170+
audience=audience,
171+
issuer=issuer,
172+
leeway=leeway,
173+
subject=subject,
162174
)
163175

164176
decoded["payload"] = payload
@@ -193,6 +205,7 @@ def decode(
193205
# passthrough arguments to _validate_claims
194206
# consider putting in options
195207
audience: str | Iterable[str] | None = None,
208+
subject: str | None = None,
196209
issuer: str | Sequence[str] | None = None,
197210
leeway: float | timedelta = 0,
198211
# kwargs
@@ -214,6 +227,7 @@ def decode(
214227
verify=verify,
215228
detached_payload=detached_payload,
216229
audience=audience,
230+
subject=subject,
217231
issuer=issuer,
218232
leeway=leeway,
219233
)
@@ -225,6 +239,7 @@ def _validate_claims(
225239
options: dict[str, Any],
226240
audience=None,
227241
issuer=None,
242+
subject: str | None = None,
228243
leeway: float | timedelta = 0,
229244
) -> None:
230245
if isinstance(leeway, timedelta):
@@ -254,6 +269,12 @@ def _validate_claims(
254269
payload, audience, strict=options.get("strict_aud", False)
255270
)
256271

272+
if options["verify_sub"]:
273+
self._validate_sub(payload, subject)
274+
275+
if options["verify_jti"]:
276+
self._validate_jti(payload)
277+
257278
def _validate_required_claims(
258279
self,
259280
payload: dict[str, Any],
@@ -263,6 +284,39 @@ def _validate_required_claims(
263284
if payload.get(claim) is None:
264285
raise MissingRequiredClaimError(claim)
265286

287+
def _validate_sub(self, payload: dict[str, Any], subject=None) -> None:
288+
"""
289+
Checks whether "sub" if in the payload is valid ot not.
290+
This is an Optional claim
291+
292+
:param payload(dict): The payload which needs to be validated
293+
:param subject(str): The subject of the token
294+
"""
295+
296+
if "sub" not in payload:
297+
return
298+
299+
if not isinstance(payload["sub"], str):
300+
raise InvalidSubjectError("Subject must be a string")
301+
302+
if subject is not None:
303+
if payload.get("sub") != subject:
304+
raise InvalidSubjectError("Invalid subject")
305+
306+
def _validate_jti(self, payload: dict[str, Any]) -> None:
307+
"""
308+
Checks whether "jti" if in the payload is valid ot not
309+
This is an Optional claim
310+
311+
:param payload(dict): The payload which needs to be validated
312+
"""
313+
314+
if "jti" not in payload:
315+
return
316+
317+
if not isinstance(payload.get("jti"), str):
318+
raise InvalidJTIError("JWT ID must be a string")
319+
266320
def _validate_iat(
267321
self,
268322
payload: dict[str, Any],

jwt/exceptions.py

+8
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,11 @@ class PyJWKClientError(PyJWTError):
7272

7373
class PyJWKClientConnectionError(PyJWKClientError):
7474
pass
75+
76+
77+
class InvalidSubjectError(InvalidTokenError):
78+
pass
79+
80+
81+
class InvalidJTIError(InvalidTokenError):
82+
pass

tests/test_api_jwt.py

+120
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
InvalidAudienceError,
1515
InvalidIssuedAtError,
1616
InvalidIssuerError,
17+
InvalidJTIError,
18+
InvalidSubjectError,
1719
MissingRequiredClaimError,
1820
)
1921
from jwt.utils import base64url_decode
@@ -816,3 +818,121 @@ def test_decode_strict_ok(self, jwt, payload):
816818
options={"strict_aud": True},
817819
algorithms=["HS256"],
818820
)
821+
822+
# -------------------- Sub Claim Tests --------------------
823+
824+
def test_encode_decode_sub_claim(self, jwt):
825+
payload = {
826+
"sub": "user123",
827+
}
828+
secret = "your-256-bit-secret"
829+
token = jwt.encode(payload, secret, algorithm="HS256")
830+
decoded = jwt.decode(token, secret, algorithms=["HS256"])
831+
832+
assert decoded["sub"] == "user123"
833+
834+
def test_decode_without_and_not_required_sub_claim(self, jwt):
835+
payload = {}
836+
secret = "your-256-bit-secret"
837+
token = jwt.encode(payload, secret, algorithm="HS256")
838+
839+
decoded = jwt.decode(token, secret, algorithms=["HS256"])
840+
841+
assert "sub" not in decoded
842+
843+
def test_decode_missing_sub_but_required_claim(self, jwt):
844+
payload = {}
845+
secret = "your-256-bit-secret"
846+
token = jwt.encode(payload, secret, algorithm="HS256")
847+
848+
with pytest.raises(MissingRequiredClaimError):
849+
jwt.decode(
850+
token, secret, algorithms=["HS256"], options={"require": ["sub"]}
851+
)
852+
853+
def test_decode_invalid_int_sub_claim(self, jwt):
854+
payload = {
855+
"sub": 1224344,
856+
}
857+
secret = "your-256-bit-secret"
858+
token = jwt.encode(payload, secret, algorithm="HS256")
859+
860+
with pytest.raises(InvalidSubjectError):
861+
jwt.decode(token, secret, algorithms=["HS256"])
862+
863+
def test_decode_with_valid_sub_claim(self, jwt):
864+
payload = {
865+
"sub": "user123",
866+
}
867+
secret = "your-256-bit-secret"
868+
token = jwt.encode(payload, secret, algorithm="HS256")
869+
870+
decoded = jwt.decode(token, secret, algorithms=["HS256"], subject="user123")
871+
872+
assert decoded["sub"] == "user123"
873+
874+
def test_decode_with_invalid_sub_claim(self, jwt):
875+
payload = {
876+
"sub": "user123",
877+
}
878+
secret = "your-256-bit-secret"
879+
token = jwt.encode(payload, secret, algorithm="HS256")
880+
881+
with pytest.raises(InvalidSubjectError) as exc_info:
882+
jwt.decode(token, secret, algorithms=["HS256"], subject="user456")
883+
884+
assert "Invalid subject" in str(exc_info.value)
885+
886+
def test_decode_with_sub_claim_and_none_subject(self, jwt):
887+
payload = {
888+
"sub": "user789",
889+
}
890+
secret = "your-256-bit-secret"
891+
token = jwt.encode(payload, secret, algorithm="HS256")
892+
893+
decoded = jwt.decode(token, secret, algorithms=["HS256"], subject=None)
894+
assert decoded["sub"] == "user789"
895+
896+
# -------------------- JTI Claim Tests --------------------
897+
898+
def test_encode_decode_with_valid_jti_claim(self, jwt):
899+
payload = {
900+
"jti": "unique-id-456",
901+
}
902+
secret = "your-256-bit-secret"
903+
token = jwt.encode(payload, secret, algorithm="HS256")
904+
decoded = jwt.decode(token, secret, algorithms=["HS256"])
905+
906+
assert decoded["jti"] == "unique-id-456"
907+
908+
def test_decode_missing_jti_when_required_claim(self, jwt):
909+
payload = {"name": "Bob", "admin": False}
910+
secret = "your-256-bit-secret"
911+
token = jwt.encode(payload, secret, algorithm="HS256")
912+
913+
with pytest.raises(MissingRequiredClaimError) as exc_info:
914+
jwt.decode(
915+
token, secret, algorithms=["HS256"], options={"require": ["jti"]}
916+
)
917+
918+
assert "jti" in str(exc_info.value)
919+
920+
def test_decode_missing_jti_claim(self, jwt):
921+
payload = {}
922+
secret = "your-256-bit-secret"
923+
token = jwt.encode(payload, secret, algorithm="HS256")
924+
925+
decoded = jwt.decode(token, secret, algorithms=["HS256"])
926+
927+
assert decoded.get("jti") is None
928+
929+
def test_jti_claim_with_invalid_int_value(self, jwt):
930+
special_jti = 12223
931+
payload = {
932+
"jti": special_jti,
933+
}
934+
secret = "your-256-bit-secret"
935+
token = jwt.encode(payload, secret, algorithm="HS256")
936+
937+
with pytest.raises(InvalidJTIError):
938+
jwt.decode(token, secret, algorithms=["HS256"])

0 commit comments

Comments
 (0)