diff --git a/jwt/api_jwk.py b/jwt/api_jwk.py index d35b9558..02f4679c 100644 --- a/jwt/api_jwk.py +++ b/jwt/api_jwk.py @@ -5,7 +5,13 @@ from typing import Any from .algorithms import get_default_algorithms, has_crypto, requires_cryptography -from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError, PyJWTError +from .exceptions import ( + InvalidKeyError, + MissingCryptographyError, + PyJWKError, + PyJWKSetError, + PyJWTError, +) from .types import JWKDict @@ -50,7 +56,9 @@ def __init__(self, jwk_data: JWKDict, algorithm: str | None = None) -> None: raise InvalidKeyError(f"Unsupported kty: {kty}") if not has_crypto and algorithm in requires_cryptography: - raise PyJWKError(f"{algorithm} requires 'cryptography' to be installed.") + raise MissingCryptographyError( + f"{algorithm} requires 'cryptography' to be installed." + ) self.algorithm_name = algorithm @@ -96,7 +104,9 @@ def __init__(self, keys: list[JWKDict]) -> None: for key in keys: try: self.keys.append(PyJWK(key)) - except PyJWTError: + except PyJWTError as error: + if isinstance(error, MissingCryptographyError): + raise error # skip unusable keys continue diff --git a/jwt/exceptions.py b/jwt/exceptions.py index 8ac6ecf7..0d985882 100644 --- a/jwt/exceptions.py +++ b/jwt/exceptions.py @@ -58,6 +58,10 @@ class PyJWKError(PyJWTError): pass +class MissingCryptographyError(PyJWKError): + pass + + class PyJWKSetError(PyJWTError): pass diff --git a/tests/test_api_jwk.py b/tests/test_api_jwk.py index 326e012a..55deca52 100644 --- a/tests/test_api_jwk.py +++ b/tests/test_api_jwk.py @@ -4,7 +4,12 @@ from jwt.algorithms import has_crypto from jwt.api_jwk import PyJWK, PyJWKSet -from jwt.exceptions import InvalidKeyError, PyJWKError, PyJWKSetError +from jwt.exceptions import ( + InvalidKeyError, + MissingCryptographyError, + PyJWKError, + PyJWKSetError, +) from .utils import crypto_required, key_path, no_crypto_required @@ -212,9 +217,14 @@ def test_missing_crypto_library_good_error_message(self): PyJWK({"kty": "dummy"}, algorithm="RS256") assert "cryptography" in str(exc.value) + @no_crypto_required + def test_missing_crypto_library_raises_missing_cryptography_error(self): + with pytest.raises(MissingCryptographyError): + PyJWK({"kty": "dummy"}, algorithm="RS256") + -@crypto_required class TestPyJWKSet: + @crypto_required def test_should_load_keys_from_jwk_data_dict(self): algo = RSAAlgorithm(RSAAlgorithm.SHA256) @@ -236,6 +246,7 @@ def test_should_load_keys_from_jwk_data_dict(self): assert jwk.key_id == "keyid-abc123" assert jwk.public_key_use == "sig" + @crypto_required def test_should_load_keys_from_jwk_data_json_string(self): algo = RSAAlgorithm(RSAAlgorithm.SHA256) @@ -257,6 +268,7 @@ def test_should_load_keys_from_jwk_data_json_string(self): assert jwk.key_id == "keyid-abc123" assert jwk.public_key_use == "sig" + @crypto_required def test_keyset_should_index_by_kid(self): algo = RSAAlgorithm(RSAAlgorithm.SHA256) @@ -279,6 +291,7 @@ def test_keyset_should_index_by_kid(self): with pytest.raises(KeyError): _ = jwk_set["this-kid-does-not-exist"] + @crypto_required def test_keyset_with_unknown_alg(self): # first keyset with unusable key and usable key with open(key_path("jwk_keyset_with_unknown_alg.json")) as keyfile: @@ -296,12 +309,19 @@ def test_keyset_with_unknown_alg(self): with pytest.raises(PyJWKSetError): _ = PyJWKSet.from_json(jwks_text) + @crypto_required def test_invalid_keys_list(self): with pytest.raises(PyJWKSetError) as err: PyJWKSet(keys="string") # type: ignore assert str(err.value) == "Invalid JWK Set value" + @crypto_required def test_empty_keys_list(self): with pytest.raises(PyJWKSetError) as err: PyJWKSet(keys=[]) assert str(err.value) == "The JWK Set did not contain any keys" + + @no_crypto_required + def test_missing_crypto_library_raises_when_required(self): + with pytest.raises(MissingCryptographyError): + PyJWKSet(keys=[{"kty": "RSA"}])