From 17ef9f8bee66b55bf580db8cbcdabdc8caa3c519 Mon Sep 17 00:00:00 2001 From: Jeremy Mayeres <1524722+jerr0328@users.noreply.github.com> Date: Mon, 15 Aug 2022 20:34:45 +0300 Subject: [PATCH] Fix uncaught exception with JWK (#600) * Fix uncaught exception with JWK * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Allow tests to run on older JWT versions Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- rest_framework_simplejwt/backends.py | 11 +++++---- tests/test_backends.py | 36 ++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/rest_framework_simplejwt/backends.py b/rest_framework_simplejwt/backends.py index 420b5d8cc..da41528c3 100644 --- a/rest_framework_simplejwt/backends.py +++ b/rest_framework_simplejwt/backends.py @@ -10,7 +10,7 @@ from .utils import format_lazy try: - from jwt import PyJWKClient + from jwt import PyJWKClient, PyJWKClientError JWK_CLIENT_AVAILABLE = True except ImportError: @@ -96,7 +96,10 @@ def get_verifying_key(self, token): return self.signing_key if self.jwks_client: - return self.jwks_client.get_signing_key_from_jwt(token).key + try: + return self.jwks_client.get_signing_key_from_jwt(token).key + except PyJWKClientError as ex: + raise TokenBackendError(_("Token is invalid or expired")) from ex return self.verifying_key @@ -145,5 +148,5 @@ def decode(self, token, verify=True): ) except InvalidAlgorithmError as ex: raise TokenBackendError(_("Invalid algorithm specified")) from ex - except InvalidTokenError: - raise TokenBackendError(_("Token is invalid or expired")) + except InvalidTokenError as ex: + raise TokenBackendError(_("Token is invalid or expired")) from ex diff --git a/tests/test_backends.py b/tests/test_backends.py index cd3a6fc04..0314ba671 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -292,6 +292,42 @@ def test_decode_rsa_aud_iss_jwk_success(self): self.assertEqual(jwk_token_backend.decode(token), self.payload) + @pytest.mark.skipif( + not JWK_CLIENT_AVAILABLE, + reason="PyJWT 1.7.1 doesn't have JWK client", + ) + def test_decode_jwk_missing_key_raises_tokenbackenderror(self): + self.payload["exp"] = aware_utcnow() + timedelta(days=1) + self.payload["foo"] = "baz" + self.payload["aud"] = AUDIENCE + self.payload["iss"] = ISSUER + + token = jwt.encode( + self.payload, + PRIVATE_KEY_2, + algorithm="RS256", + headers={"kid": "230498151c214b788dd97f22b85410a5"}, + ) + + mock_jwk_module = mock.MagicMock() + with patch("rest_framework_simplejwt.backends.PyJWKClient") as mock_jwk_module: + mock_jwk_client = mock.MagicMock() + + mock_jwk_module.return_value = mock_jwk_client + mock_jwk_client.get_signing_key_from_jwt.side_effect = jwt.PyJWKClientError( + "Unable to find a signing key that matches" + ) + + # Note the PRIV,PUB care is intentially the original pairing + jwk_token_backend = TokenBackend( + "RS256", PRIVATE_KEY, PUBLIC_KEY, AUDIENCE, ISSUER, JWK_URL + ) + + with self.assertRaisesRegex( + TokenBackendError, "Token is invalid or expired" + ): + jwk_token_backend.decode(token) + def test_decode_when_algorithm_not_available(self): token = jwt.encode(self.payload, PRIVATE_KEY, algorithm="RS256") if IS_OLD_JWT: