Skip to content

Commit 057bc6f

Browse files
committed
jwt.decode: Require algorithms keyword argument
On decode, require algorithms to be specified to avoid algorithm confusion when verify_signature is True. This is similar to what pyJWT is doing in https://github.com/jpadilla/pyjwt/blob/master/jwt/api_jwt.py#L146-L149 See #346
1 parent 34bd82c commit 057bc6f

File tree

2 files changed

+54
-42
lines changed

2 files changed

+54
-42
lines changed

jose/jwt.py

+8
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,14 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None
141141

142142
verify_signature = defaults.get("verify_signature", True)
143143

144+
# Forbid the usage of the jwt.decode without alogrightms parameter
145+
# See https://github.com/mpdavis/python-jose/issues/346 for more
146+
# information CVE-2024-33663
147+
if verify_signature and algorithms is None:
148+
raise JWTError("It is required that you pass in a value for "
149+
'the "algorithms" argument when calling '
150+
"decode().")
151+
144152
try:
145153
payload = jws.verify(token, key, algorithms, verify=verify_signature)
146154
except JWSError as e:

tests/test_jwt.py

+46-42
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66

77
from jose import jws, jwt
8+
from jose.constants import ALGORITHMS
89
from jose.exceptions import JWTError, JWKError
910

1011

@@ -56,7 +57,7 @@ def test_no_alg(self, claims, key):
5657
],
5758
)
5859
def test_numeric_key(self, key, token):
59-
token_info = jwt.decode(token, key)
60+
token_info = jwt.decode(token, key, algorithms=ALGORITHMS.SUPPORTED)
6061
assert token_info == {"name": "test"}
6162

6263
def test_invalid_claims_json(self):
@@ -108,7 +109,7 @@ def test_no_alg_default_headers(self, claims, key, headers):
108109

109110
def test_non_default_headers(self, claims, key, headers):
110111
encoded = jwt.encode(claims, key, headers=headers)
111-
decoded = jwt.decode(encoded, key)
112+
decoded = jwt.decode(encoded, key, algorithms=ALGORITHMS.HS256)
112113
assert claims == decoded
113114
all_headers = jwt.get_unverified_headers(encoded)
114115
for k, v in headers.items():
@@ -159,7 +160,7 @@ def test_encode(self, claims, key):
159160
def test_decode(self, claims, key):
160161
token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9" ".eyJhIjoiYiJ9" ".jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8"
161162

162-
decoded = jwt.decode(token, key)
163+
decoded = jwt.decode(token, key, algorithms=ALGORITHMS.SUPPORTED)
163164

164165
assert decoded == claims
165166

@@ -190,31 +191,31 @@ def test_leeway_is_timedelta(self, claims, key):
190191
options = {"leeway": leeway}
191192

192193
token = jwt.encode(claims, key)
193-
jwt.decode(token, key, options=options)
194+
jwt.decode(token, key, options=options, algorithms=ALGORITHMS.HS256)
194195

195196
def test_iat_not_int(self, key):
196197
claims = {"iat": "test"}
197198

198199
token = jwt.encode(claims, key)
199200

200201
with pytest.raises(JWTError):
201-
jwt.decode(token, key)
202+
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
202203

203204
def test_nbf_not_int(self, key):
204205
claims = {"nbf": "test"}
205206

206207
token = jwt.encode(claims, key)
207208

208209
with pytest.raises(JWTError):
209-
jwt.decode(token, key)
210+
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
210211

211212
def test_nbf_datetime(self, key):
212213
nbf = datetime.utcnow() - timedelta(seconds=5)
213214

214215
claims = {"nbf": nbf}
215216

216217
token = jwt.encode(claims, key)
217-
jwt.decode(token, key)
218+
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
218219

219220
def test_nbf_with_leeway(self, key):
220221
nbf = datetime.utcnow() + timedelta(seconds=5)
@@ -226,7 +227,7 @@ def test_nbf_with_leeway(self, key):
226227
options = {"leeway": 10}
227228

228229
token = jwt.encode(claims, key)
229-
jwt.decode(token, key, options=options)
230+
jwt.decode(token, key, options=options, algorithms=ALGORITHMS.HS256)
230231

231232
def test_nbf_in_future(self, key):
232233
nbf = datetime.utcnow() + timedelta(seconds=5)
@@ -236,7 +237,7 @@ def test_nbf_in_future(self, key):
236237
token = jwt.encode(claims, key)
237238

238239
with pytest.raises(JWTError):
239-
jwt.decode(token, key)
240+
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
240241

241242
def test_nbf_skip(self, key):
242243
nbf = datetime.utcnow() + timedelta(seconds=5)
@@ -246,27 +247,27 @@ def test_nbf_skip(self, key):
246247
token = jwt.encode(claims, key)
247248

248249
with pytest.raises(JWTError):
249-
jwt.decode(token, key)
250+
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
250251

251252
options = {"verify_nbf": False}
252253

253-
jwt.decode(token, key, options=options)
254+
jwt.decode(token, key, options=options, algorithms=ALGORITHMS.HS256)
254255

255256
def test_exp_not_int(self, key):
256257
claims = {"exp": "test"}
257258

258259
token = jwt.encode(claims, key)
259260

260261
with pytest.raises(JWTError):
261-
jwt.decode(token, key)
262+
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
262263

263264
def test_exp_datetime(self, key):
264265
exp = datetime.utcnow() + timedelta(seconds=5)
265266

266267
claims = {"exp": exp}
267268

268269
token = jwt.encode(claims, key)
269-
jwt.decode(token, key)
270+
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
270271

271272
def test_exp_with_leeway(self, key):
272273
exp = datetime.utcnow() - timedelta(seconds=5)
@@ -278,7 +279,7 @@ def test_exp_with_leeway(self, key):
278279
options = {"leeway": 10}
279280

280281
token = jwt.encode(claims, key)
281-
jwt.decode(token, key, options=options)
282+
jwt.decode(token, key, options=options, algorithms=ALGORITHMS.HS256)
282283

283284
def test_exp_in_past(self, key):
284285
exp = datetime.utcnow() - timedelta(seconds=5)
@@ -288,7 +289,7 @@ def test_exp_in_past(self, key):
288289
token = jwt.encode(claims, key)
289290

290291
with pytest.raises(JWTError):
291-
jwt.decode(token, key)
292+
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
292293

293294
def test_exp_skip(self, key):
294295
exp = datetime.utcnow() - timedelta(seconds=5)
@@ -298,35 +299,35 @@ def test_exp_skip(self, key):
298299
token = jwt.encode(claims, key)
299300

300301
with pytest.raises(JWTError):
301-
jwt.decode(token, key)
302+
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
302303

303304
options = {"verify_exp": False}
304305

305-
jwt.decode(token, key, options=options)
306+
jwt.decode(token, key, options=options, algorithms=ALGORITHMS.HS256)
306307

307308
def test_aud_string(self, key):
308309
aud = "audience"
309310

310311
claims = {"aud": aud}
311312

312313
token = jwt.encode(claims, key)
313-
jwt.decode(token, key, audience=aud)
314+
jwt.decode(token, key, audience=aud, algorithms=ALGORITHMS.HS256)
314315

315316
def test_aud_list(self, key):
316317
aud = "audience"
317318

318319
claims = {"aud": [aud]}
319320

320321
token = jwt.encode(claims, key)
321-
jwt.decode(token, key, audience=aud)
322+
jwt.decode(token, key, audience=aud, algorithms=ALGORITHMS.HS256)
322323

323324
def test_aud_list_multiple(self, key):
324325
aud = "audience"
325326

326327
claims = {"aud": [aud, "another"]}
327328

328329
token = jwt.encode(claims, key)
329-
jwt.decode(token, key, audience=aud)
330+
jwt.decode(token, key, audience=aud, algorithms=ALGORITHMS.HS256)
330331

331332
def test_aud_list_is_strings(self, key):
332333
aud = "audience"
@@ -335,7 +336,7 @@ def test_aud_list_is_strings(self, key):
335336

336337
token = jwt.encode(claims, key)
337338
with pytest.raises(JWTError):
338-
jwt.decode(token, key, audience=aud)
339+
jwt.decode(token, key, audience=aud, algorithms=ALGORITHMS.HS256)
339340

340341
def test_aud_case_sensitive(self, key):
341342
aud = "audience"
@@ -344,13 +345,13 @@ def test_aud_case_sensitive(self, key):
344345

345346
token = jwt.encode(claims, key)
346347
with pytest.raises(JWTError):
347-
jwt.decode(token, key, audience="AUDIENCE")
348+
jwt.decode(token, key, audience="AUDIENCE", algorithms=ALGORITHMS.HS256)
348349

349350
def test_aud_empty_claim(self, claims, key):
350351
aud = "audience"
351352

352353
token = jwt.encode(claims, key)
353-
jwt.decode(token, key, audience=aud)
354+
jwt.decode(token, key, audience=aud, algorithms=ALGORITHMS.HS256)
354355

355356
def test_aud_not_string_or_list(self, key):
356357
aud = 1
@@ -359,7 +360,7 @@ def test_aud_not_string_or_list(self, key):
359360

360361
token = jwt.encode(claims, key)
361362
with pytest.raises(JWTError):
362-
jwt.decode(token, key)
363+
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
363364

364365
def test_aud_given_number(self, key):
365366
aud = "audience"
@@ -368,31 +369,31 @@ def test_aud_given_number(self, key):
368369

369370
token = jwt.encode(claims, key)
370371
with pytest.raises(JWTError):
371-
jwt.decode(token, key, audience=1)
372+
jwt.decode(token, key, audience=1, algorithms=ALGORITHMS.HS256)
372373

373374
def test_iss_string(self, key):
374375
iss = "issuer"
375376

376377
claims = {"iss": iss}
377378

378379
token = jwt.encode(claims, key)
379-
jwt.decode(token, key, issuer=iss)
380+
jwt.decode(token, key, issuer=iss, algorithms=ALGORITHMS.HS256)
380381

381382
def test_iss_list(self, key):
382383
iss = "issuer"
383384

384385
claims = {"iss": iss}
385386

386387
token = jwt.encode(claims, key)
387-
jwt.decode(token, key, issuer=["https://issuer", "issuer"])
388+
jwt.decode(token, key, issuer=["https://issuer", "issuer"], algorithms=ALGORITHMS.HS256)
388389

389390
def test_iss_tuple(self, key):
390391
iss = "issuer"
391392

392393
claims = {"iss": iss}
393394

394395
token = jwt.encode(claims, key)
395-
jwt.decode(token, key, issuer=("https://issuer", "issuer"))
396+
jwt.decode(token, key, issuer=("https://issuer", "issuer"), algorithms=ALGORITHMS.HS256)
396397

397398
def test_iss_invalid(self, key):
398399
iss = "issuer"
@@ -401,15 +402,15 @@ def test_iss_invalid(self, key):
401402

402403
token = jwt.encode(claims, key)
403404
with pytest.raises(JWTError):
404-
jwt.decode(token, key, issuer="another")
405+
jwt.decode(token, key, issuer="another", algorithms=ALGORITHMS.HS256)
405406

406407
def test_sub_string(self, key):
407408
sub = "subject"
408409

409410
claims = {"sub": sub}
410411

411412
token = jwt.encode(claims, key)
412-
jwt.decode(token, key)
413+
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
413414

414415
def test_sub_invalid(self, key):
415416
sub = 1
@@ -418,15 +419,15 @@ def test_sub_invalid(self, key):
418419

419420
token = jwt.encode(claims, key)
420421
with pytest.raises(JWTError):
421-
jwt.decode(token, key)
422+
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
422423

423424
def test_sub_correct(self, key):
424425
sub = "subject"
425426

426427
claims = {"sub": sub}
427428

428429
token = jwt.encode(claims, key)
429-
jwt.decode(token, key, subject=sub)
430+
jwt.decode(token, key, subject=sub, algorithms=ALGORITHMS.HS256)
430431

431432
def test_sub_incorrect(self, key):
432433
sub = "subject"
@@ -435,15 +436,15 @@ def test_sub_incorrect(self, key):
435436

436437
token = jwt.encode(claims, key)
437438
with pytest.raises(JWTError):
438-
jwt.decode(token, key, subject="another")
439+
jwt.decode(token, key, subject="another", algorithms=ALGORITHMS.HS256)
439440

440441
def test_jti_string(self, key):
441442
jti = "JWT ID"
442443

443444
claims = {"jti": jti}
444445

445446
token = jwt.encode(claims, key)
446-
jwt.decode(token, key)
447+
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
447448

448449
def test_jti_invalid(self, key):
449450
jti = 1
@@ -452,33 +453,33 @@ def test_jti_invalid(self, key):
452453

453454
token = jwt.encode(claims, key)
454455
with pytest.raises(JWTError):
455-
jwt.decode(token, key)
456+
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
456457

457458
def test_at_hash(self, claims, key):
458459
access_token = "<ACCESS_TOKEN>"
459460
token = jwt.encode(claims, key, access_token=access_token)
460-
payload = jwt.decode(token, key, access_token=access_token)
461+
payload = jwt.decode(token, key, access_token=access_token, algorithms=ALGORITHMS.HS256)
461462
assert "at_hash" in payload
462463

463464
def test_at_hash_invalid(self, claims, key):
464465
token = jwt.encode(claims, key, access_token="<ACCESS_TOKEN>")
465466
with pytest.raises(JWTError):
466-
jwt.decode(token, key, access_token="<OTHER_TOKEN>")
467+
jwt.decode(token, key, access_token="<OTHER_TOKEN>", algorithms=ALGORITHMS.HS256)
467468

468469
def test_at_hash_missing_access_token(self, claims, key):
469470
token = jwt.encode(claims, key, access_token="<ACCESS_TOKEN>")
470471
with pytest.raises(JWTError):
471-
jwt.decode(token, key)
472+
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)
472473

473474
def test_at_hash_missing_claim(self, claims, key):
474475
token = jwt.encode(claims, key)
475-
payload = jwt.decode(token, key, access_token="<ACCESS_TOKEN>")
476+
payload = jwt.decode(token, key, access_token="<ACCESS_TOKEN>", algorithms=ALGORITHMS.HS256)
476477
assert "at_hash" not in payload
477478

478479
def test_at_hash_unable_to_calculate(self, claims, key):
479480
token = jwt.encode(claims, key, access_token="<ACCESS_TOKEN>")
480481
with pytest.raises(JWTError):
481-
jwt.decode(token, key, access_token="\xe2")
482+
jwt.decode(token, key, access_token="\xe2", algorithms=ALGORITHMS.HS256)
482483

483484
def test_bad_claims(self):
484485
bad_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.iOJ5SiNfaNO_pa2J4Umtb3b3zmk5C18-mhTCVNsjnck"
@@ -516,12 +517,12 @@ def test_require(self, claims, key, claim, value):
516517

517518
token = jwt.encode(claims, key)
518519
with pytest.raises(JWTError):
519-
jwt.decode(token, key, options=options, audience=str(value))
520+
jwt.decode(token, key, options=options, audience=str(value), algorithms=ALGORITHMS.HS256)
520521

521522
new_claims = dict(claims)
522523
new_claims[claim] = value
523524
token = jwt.encode(new_claims, key)
524-
jwt.decode(token, key, options=options, audience=str(value))
525+
jwt.decode(token, key, options=options, audience=str(value), algorithms=ALGORITHMS.HS256)
525526

526527
def test_CVE_2024_33663(self):
527528
"""Test based on https://github.com/mpdavis/python-jose/issues/346"""
@@ -554,4 +555,7 @@ def test_CVE_2024_33663(self):
554555
# algorithm field is left unspecified
555556
# but the library will happily still verify without warning, trusting the user-controlled alg field of the token header
556557
with pytest.raises(JWKError):
558+
data = jwt.decode(evil_token, PUBKEY, algorithms=ALGORITHMS.HS256)
559+
560+
with pytest.raises(JWTError, match='.*required.*"algorithms".*'):
557561
data = jwt.decode(evil_token, PUBKEY)

0 commit comments

Comments
 (0)