Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix uncaught exception with JWK #600

Merged
merged 3 commits into from
Aug 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions rest_framework_simplejwt/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .utils import format_lazy

try:
from jwt import PyJWKClient
from jwt import PyJWKClient, PyJWKClientError

JWK_CLIENT_AVAILABLE = True
except ImportError:
Expand Down Expand Up @@ -96,7 +96,10 @@ def get_verifying_key(self, token):
return self.signing_key

if self.jwks_client:
return self.jwks_client.get_signing_key_from_jwt(token).key
try:
return self.jwks_client.get_signing_key_from_jwt(token).key
except PyJWKClientError as ex:
raise TokenBackendError(_("Token is invalid or expired")) from ex

return self.verifying_key

Expand Down Expand Up @@ -145,5 +148,5 @@ def decode(self, token, verify=True):
)
except InvalidAlgorithmError as ex:
raise TokenBackendError(_("Invalid algorithm specified")) from ex
except InvalidTokenError:
raise TokenBackendError(_("Token is invalid or expired"))
except InvalidTokenError as ex:
raise TokenBackendError(_("Token is invalid or expired")) from ex
36 changes: 36 additions & 0 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,42 @@ def test_decode_rsa_aud_iss_jwk_success(self):

self.assertEqual(jwk_token_backend.decode(token), self.payload)

@pytest.mark.skipif(
not JWK_CLIENT_AVAILABLE,
reason="PyJWT 1.7.1 doesn't have JWK client",
)
def test_decode_jwk_missing_key_raises_tokenbackenderror(self):
self.payload["exp"] = aware_utcnow() + timedelta(days=1)
self.payload["foo"] = "baz"
self.payload["aud"] = AUDIENCE
self.payload["iss"] = ISSUER

token = jwt.encode(
self.payload,
PRIVATE_KEY_2,
algorithm="RS256",
headers={"kid": "230498151c214b788dd97f22b85410a5"},
)

mock_jwk_module = mock.MagicMock()
with patch("rest_framework_simplejwt.backends.PyJWKClient") as mock_jwk_module:
mock_jwk_client = mock.MagicMock()

mock_jwk_module.return_value = mock_jwk_client
mock_jwk_client.get_signing_key_from_jwt.side_effect = jwt.PyJWKClientError(
"Unable to find a signing key that matches"
)

# Note the PRIV,PUB care is intentially the original pairing
jwk_token_backend = TokenBackend(
"RS256", PRIVATE_KEY, PUBLIC_KEY, AUDIENCE, ISSUER, JWK_URL
)

with self.assertRaisesRegex(
TokenBackendError, "Token is invalid or expired"
):
jwk_token_backend.decode(token)

def test_decode_when_algorithm_not_available(self):
token = jwt.encode(self.payload, PRIVATE_KEY, algorithm="RS256")
if IS_OLD_JWT:
Expand Down