Skip to content

Commit

Permalink
Fix uncaught exception with JWK (#600)
Browse files Browse the repository at this point in the history
* Fix uncaught exception with JWK

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Allow tests to run on older JWT versions

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
jerr0328 and pre-commit-ci[bot] authored Aug 15, 2022
1 parent 9b6f1de commit 17ef9f8
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
11 changes: 7 additions & 4 deletions rest_framework_simplejwt/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .utils import format_lazy

try:
from jwt import PyJWKClient
from jwt import PyJWKClient, PyJWKClientError

JWK_CLIENT_AVAILABLE = True
except ImportError:
Expand Down Expand Up @@ -96,7 +96,10 @@ def get_verifying_key(self, token):
return self.signing_key

if self.jwks_client:
return self.jwks_client.get_signing_key_from_jwt(token).key
try:
return self.jwks_client.get_signing_key_from_jwt(token).key
except PyJWKClientError as ex:
raise TokenBackendError(_("Token is invalid or expired")) from ex

return self.verifying_key

Expand Down Expand Up @@ -145,5 +148,5 @@ def decode(self, token, verify=True):
)
except InvalidAlgorithmError as ex:
raise TokenBackendError(_("Invalid algorithm specified")) from ex
except InvalidTokenError:
raise TokenBackendError(_("Token is invalid or expired"))
except InvalidTokenError as ex:
raise TokenBackendError(_("Token is invalid or expired")) from ex
36 changes: 36 additions & 0 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,42 @@ def test_decode_rsa_aud_iss_jwk_success(self):

self.assertEqual(jwk_token_backend.decode(token), self.payload)

@pytest.mark.skipif(
not JWK_CLIENT_AVAILABLE,
reason="PyJWT 1.7.1 doesn't have JWK client",
)
def test_decode_jwk_missing_key_raises_tokenbackenderror(self):
self.payload["exp"] = aware_utcnow() + timedelta(days=1)
self.payload["foo"] = "baz"
self.payload["aud"] = AUDIENCE
self.payload["iss"] = ISSUER

token = jwt.encode(
self.payload,
PRIVATE_KEY_2,
algorithm="RS256",
headers={"kid": "230498151c214b788dd97f22b85410a5"},
)

mock_jwk_module = mock.MagicMock()
with patch("rest_framework_simplejwt.backends.PyJWKClient") as mock_jwk_module:
mock_jwk_client = mock.MagicMock()

mock_jwk_module.return_value = mock_jwk_client
mock_jwk_client.get_signing_key_from_jwt.side_effect = jwt.PyJWKClientError(
"Unable to find a signing key that matches"
)

# Note the PRIV,PUB care is intentially the original pairing
jwk_token_backend = TokenBackend(
"RS256", PRIVATE_KEY, PUBLIC_KEY, AUDIENCE, ISSUER, JWK_URL
)

with self.assertRaisesRegex(
TokenBackendError, "Token is invalid or expired"
):
jwk_token_backend.decode(token)

def test_decode_when_algorithm_not_available(self):
token = jwt.encode(self.payload, PRIVATE_KEY, algorithm="RS256")
if IS_OLD_JWT:
Expand Down

0 comments on commit 17ef9f8

Please sign in to comment.