Skip to content

Commit 546e96c

Browse files
committed
Add "algorithm mismatch" error to improve jws
Upstream libraries that depend on `jws.verify()` break when the upstream keys contain a mixed set of algorithms. This is a nominal occurance for OIDC servers and should be properly handled.
1 parent 96474ec commit 546e96c

File tree

6 files changed

+27
-20
lines changed

6 files changed

+27
-20
lines changed

jose/backends/cryptography_backend.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from cryptography.x509 import load_pem_x509_certificate
1616

1717
from ..constants import ALGORITHMS
18-
from ..exceptions import JWEError, JWKError
18+
from ..exceptions import JWEError, JWKError, JWKAlgMismatchError
1919
from ..utils import base64_to_long, base64url_decode, base64url_encode, ensure_binary, long_to_base64
2020
from .base import Key
2121

@@ -52,7 +52,7 @@ class CryptographyECKey(Key):
5252

5353
def __init__(self, key, algorithm, cryptography_backend=default_backend):
5454
if algorithm not in ALGORITHMS.EC:
55-
raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
55+
raise JWKAlgMismatchError("%s is not a valid EC algorithm" % algorithm)
5656

5757
self.hash_alg = {
5858
ALGORITHMS.ES256: self.SHA256,
@@ -97,7 +97,7 @@ def __init__(self, key, algorithm, cryptography_backend=default_backend):
9797

9898
def _process_jwk(self, jwk_dict):
9999
if not jwk_dict.get("kty") == "EC":
100-
raise JWKError("Incorrect key type. Expected: 'EC', Received: %s" % jwk_dict.get("kty"))
100+
raise JWKAlgMismatchError("Incorrect key type. Expected: 'EC', Received: %s" % jwk_dict.get("kty"))
101101

102102
if not all(k in jwk_dict for k in ["x", "y", "crv"]):
103103
raise JWKError("Mandatory parameters are missing")
@@ -226,7 +226,7 @@ class CryptographyRSAKey(Key):
226226

227227
def __init__(self, key, algorithm, cryptography_backend=default_backend):
228228
if algorithm not in ALGORITHMS.RSA:
229-
raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
229+
raise JWKAlgMismatchError("%s is not a valid RSA algorithm" % algorithm)
230230

231231
self.hash_alg = {
232232
ALGORITHMS.RS256: self.SHA256,
@@ -273,7 +273,7 @@ def __init__(self, key, algorithm, cryptography_backend=default_backend):
273273

274274
def _process_jwk(self, jwk_dict):
275275
if not jwk_dict.get("kty") == "RSA":
276-
raise JWKError("Incorrect key type. Expected: 'RSA', Received: %s" % jwk_dict.get("kty"))
276+
raise JWKAlgMismatchError("Incorrect key type. Expected: 'RSA', Received: %s" % jwk_dict.get("kty"))
277277

278278
e = base64_to_long(jwk_dict.get("e", 256))
279279
n = base64_to_long(jwk_dict.get("n"))
@@ -441,9 +441,9 @@ class CryptographyAESKey(Key):
441441

442442
def __init__(self, key, algorithm):
443443
if algorithm not in ALGORITHMS.AES:
444-
raise JWKError("%s is not a valid AES algorithm" % algorithm)
444+
raise JWKAlgMismatchError("%s is not a valid AES algorithm" % algorithm)
445445
if algorithm not in ALGORITHMS.SUPPORTED.union(ALGORITHMS.AES_PSEUDO):
446-
raise JWKError("%s is not a supported algorithm" % algorithm)
446+
raise JWKAlgMismatchError("%s is not a supported algorithm" % algorithm)
447447

448448
self._algorithm = algorithm
449449
self._mode = self.MODES.get(self._algorithm)
@@ -538,7 +538,7 @@ class CryptographyHMACKey(Key):
538538

539539
def __init__(self, key, algorithm):
540540
if algorithm not in ALGORITHMS.HMAC:
541-
raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
541+
raise JWKAlgMismatchError("hash_alg: %s is not a valid hash algorithm" % algorithm)
542542
self._algorithm = algorithm
543543
self._hash_alg = self.ALG_MAP.get(algorithm)
544544

@@ -569,7 +569,7 @@ def __init__(self, key, algorithm):
569569

570570
def _process_jwk(self, jwk_dict):
571571
if not jwk_dict.get("kty") == "oct":
572-
raise JWKError("Incorrect key type. Expected: 'oct', Received: %s" % jwk_dict.get("kty"))
572+
raise JWKAlgMismatchError("Incorrect key type. Expected: 'oct', Received: %s" % jwk_dict.get("kty"))
573573

574574
k = jwk_dict.get("k")
575575
k = k.encode("utf-8")

jose/backends/ecdsa_backend.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from jose.backends.base import Key
66
from jose.constants import ALGORITHMS
7-
from jose.exceptions import JWKError
7+
from jose.exceptions import JWKError, JWKAlgMismatchError
88
from jose.utils import base64_to_long, long_to_base64
99

1010

@@ -35,7 +35,7 @@ class ECDSAECKey(Key):
3535

3636
def __init__(self, key, algorithm):
3737
if algorithm not in ALGORITHMS.EC:
38-
raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
38+
raise JWKAlgMismatchError("%s is not a valid EC algorithm" % algorithm)
3939

4040
self.hash_alg = {
4141
ALGORITHMS.ES256: self.SHA256,
@@ -75,7 +75,7 @@ def __init__(self, key, algorithm):
7575

7676
def _process_jwk(self, jwk_dict):
7777
if not jwk_dict.get("kty") == "EC":
78-
raise JWKError("Incorrect key type. Expected: 'EC', Received: %s" % jwk_dict.get("kty"))
78+
raise JWKAlgMismatchError("Incorrect key type. Expected: 'EC', Received: %s" % jwk_dict.get("kty"))
7979

8080
if not all(k in jwk_dict for k in ["x", "y", "crv"]):
8181
raise JWKError("Mandatory parameters are missing")

jose/backends/native.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from jose.backends.base import Key
66
from jose.constants import ALGORITHMS
7-
from jose.exceptions import JWKError
7+
from jose.exceptions import JWKError, JWKAlgMismatchError
88
from jose.utils import base64url_decode, base64url_encode
99

1010

@@ -22,7 +22,7 @@ class HMACKey(Key):
2222

2323
def __init__(self, key, algorithm):
2424
if algorithm not in ALGORITHMS.HMAC:
25-
raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
25+
raise JWKAlgMismatchError("hash_alg: %s is not a valid hash algorithm" % algorithm)
2626
self._algorithm = algorithm
2727
self._hash_alg = self.HASHES.get(algorithm)
2828

@@ -53,7 +53,7 @@ def __init__(self, key, algorithm):
5353

5454
def _process_jwk(self, jwk_dict):
5555
if not jwk_dict.get("kty") == "oct":
56-
raise JWKError("Incorrect key type. Expected: 'oct', Received: %s" % jwk_dict.get("kty"))
56+
raise JWKAlgMismatchError("Incorrect key type. Expected: 'oct', Received: %s" % jwk_dict.get("kty"))
5757

5858
k = jwk_dict.get("k")
5959
k = k.encode("utf-8")

jose/backends/rsa_backend.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414
from jose.backends.base import Key
1515
from jose.constants import ALGORITHMS
16-
from jose.exceptions import JWEError, JWKError
16+
from jose.exceptions import JWEError, JWKError, JWKAlgMismatchError
1717
from jose.utils import base64_to_long, long_to_base64
1818

1919
ALGORITHMS.SUPPORTED.remove(ALGORITHMS.RSA_OAEP) # RSA OAEP not supported
@@ -124,7 +124,7 @@ class RSAKey(Key):
124124

125125
def __init__(self, key, algorithm):
126126
if algorithm not in ALGORITHMS.RSA:
127-
raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm)
127+
raise JWKAlgMismatchError("%s is not a valid RSA algorithm" % algorithm)
128128

129129
if algorithm in ALGORITHMS.RSA_KW and algorithm != ALGORITHMS.RSA1_5:
130130
raise JWKError("alg: %s is not supported by the RSA backend" % algorithm)
@@ -174,7 +174,7 @@ def __init__(self, key, algorithm):
174174

175175
def _process_jwk(self, jwk_dict):
176176
if not jwk_dict.get("kty") == "RSA":
177-
raise JWKError("Incorrect key type. Expected: 'RSA', Received: %s" % jwk_dict.get("kty"))
177+
raise JWKAlgMismatchError("Incorrect key type. Expected: 'RSA', Received: %s" % jwk_dict.get("kty"))
178178

179179
e = base64_to_long(jwk_dict.get("e"))
180180
n = base64_to_long(jwk_dict.get("n"))

jose/exceptions.py

+4
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ class ExpiredSignatureError(JWTError):
2929
class JWKError(JOSEError):
3030
pass
3131

32+
class JWKAlgMismatchError(JWKError):
33+
'''JWK Key type doesn't support the given algorithm.'''
34+
pass
35+
3236

3337
class JWEError(JOSEError):
3438
"""Base error for all JWE errors"""

jose/jws.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from jose import jwk
66
from jose.backends.base import Key
77
from jose.constants import ALGORITHMS
8-
from jose.exceptions import JWSError, JWSSignatureError
8+
from jose.exceptions import JWSError, JWSSignatureError, JWKAlgMismatchError
99
from jose.utils import base64url_decode, base64url_encode
1010

1111

@@ -205,7 +205,10 @@ def _load(jwt):
205205
def _sig_matches_keys(keys, signing_input, signature, alg):
206206
for key in keys:
207207
if not isinstance(key, Key):
208-
key = jwk.construct(key, alg)
208+
try:
209+
key = jwk.construct(key, alg)
210+
except JWKAlgMismatchError:
211+
continue
209212
try:
210213
if key.verify(signing_input, signature):
211214
return True

0 commit comments

Comments
 (0)