diff --git a/rest_framework_simplejwt/backends.py b/rest_framework_simplejwt/backends.py index 49282b486..6bc2fafb0 100644 --- a/rest_framework_simplejwt/backends.py +++ b/rest_framework_simplejwt/backends.py @@ -1,10 +1,17 @@ import jwt from django.utils.translation import gettext_lazy as _ -from jwt import InvalidAlgorithmError, InvalidTokenError, PyJWKClient, algorithms +from jwt import InvalidAlgorithmError, InvalidTokenError, algorithms from .exceptions import TokenBackendError from .utils import format_lazy +try: + from jwt import PyJWKClient + + JWK_CLIENT_AVAILABLE = True +except ImportError: + JWK_CLIENT_AVAILABLE = False + ALLOWED_ALGORITHMS = { "HS256", "HS384", @@ -37,7 +44,10 @@ def __init__( self.audience = audience self.issuer = issuer - self.jwks_client = PyJWKClient(jwk_url) if jwk_url else None + if JWK_CLIENT_AVAILABLE: + self.jwks_client = PyJWKClient(jwk_url) if jwk_url else None + else: + self.jwks_client = None self.leeway = leeway def _validate_algorithm(self, algorithm): diff --git a/setup.py b/setup.py index 286921f00..5687b6604 100755 --- a/setup.py +++ b/setup.py @@ -54,7 +54,7 @@ install_requires=[ "django", "djangorestframework", - "pyjwt>=2,<3", + "pyjwt>=1.7.1,<3", ], python_requires=">=3.7", extras_require=extras_require, diff --git a/tests/test_backends.py b/tests/test_backends.py index 94a56829e..14bb6b9ad 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -3,10 +3,13 @@ from unittest.mock import patch import jwt +import pytest from django.test import TestCase -from jwt import PyJWS, algorithms +from jwt import PyJWS +from jwt import __version__ as jwt_version +from jwt import algorithms -from rest_framework_simplejwt.backends import TokenBackend +from rest_framework_simplejwt.backends import JWK_CLIENT_AVAILABLE, TokenBackend from rest_framework_simplejwt.exceptions import TokenBackendError from rest_framework_simplejwt.utils import aware_utcnow, datetime_to_epoch, make_utc from tests.keys import ( @@ -28,6 +31,8 @@ LEEWAY = 100 +IS_OLD_JWT = jwt_version == "1.7.1" + class TestTokenBackend(TestCase): def setUp(self): @@ -159,7 +164,7 @@ def test_decode_with_expiry(self): def test_decode_with_invalid_sig(self): self.payload["exp"] = aware_utcnow() - timedelta(seconds=1) for backend in self.backends: - with self.subTest("Test decode with invalid sig for f{backend.algorithm}"): + with self.subTest(f"Test decode with invalid sig for {backend.algorithm}"): payload = self.payload.copy() payload["exp"] = aware_utcnow() + timedelta(days=1) token_1 = jwt.encode( @@ -170,6 +175,10 @@ def test_decode_with_invalid_sig(self): payload, backend.signing_key, algorithm=backend.algorithm ) + if IS_OLD_JWT: + token_1 = token_1.decode("utf-8") + token_2 = token_2.decode("utf-8") + token_2_payload = token_2.rsplit(".", 1)[0] token_1_sig = token_1.rsplit(".", 1)[-1] invalid_token = token_2_payload + "." + token_1_sig @@ -189,8 +198,12 @@ def test_decode_with_invalid_sig_no_verify(self): token_2 = jwt.encode( payload, backend.signing_key, algorithm=backend.algorithm ) - # Payload copied - payload["exp"] = datetime_to_epoch(payload["exp"]) + if IS_OLD_JWT: + token_1 = token_1.decode("utf-8") + token_2 = token_2.decode("utf-8") + else: + # Payload copied + payload["exp"] = datetime_to_epoch(payload["exp"]) token_2_payload = token_2.rsplit(".", 1)[0] token_1_sig = token_1.rsplit(".", 1)[-1] @@ -210,9 +223,13 @@ def test_decode_success(self): token = jwt.encode( self.payload, backend.signing_key, algorithm=backend.algorithm ) - # Payload copied - payload = self.payload.copy() - payload["exp"] = datetime_to_epoch(self.payload["exp"]) + if IS_OLD_JWT: + token = token.decode("utf-8") + payload = self.payload + else: + # Payload copied + payload = self.payload.copy() + payload["exp"] = datetime_to_epoch(self.payload["exp"]) self.assertEqual(backend.decode(token), payload) @@ -223,11 +240,18 @@ def test_decode_aud_iss_success(self): self.payload["iss"] = ISSUER token = jwt.encode(self.payload, PRIVATE_KEY, algorithm="RS256") - # Payload copied - self.payload["exp"] = datetime_to_epoch(self.payload["exp"]) + if IS_OLD_JWT: + token = token.decode("utf-8") + else: + # Payload copied + self.payload["exp"] = datetime_to_epoch(self.payload["exp"]) self.assertEqual(self.aud_iss_token_backend.decode(token), self.payload) + @pytest.mark.skipif( + not JWK_CLIENT_AVAILABLE, + reason="PyJWT 1.7.1 doesn't have JWK client", + ) def test_decode_rsa_aud_iss_jwk_success(self): self.payload["exp"] = aware_utcnow() + timedelta(days=1) self.payload["foo"] = "baz" @@ -261,6 +285,8 @@ def test_decode_rsa_aud_iss_jwk_success(self): def test_decode_when_algorithm_not_available(self): token = jwt.encode(self.payload, PRIVATE_KEY, algorithm="RS256") + if IS_OLD_JWT: + token = token.decode("utf-8") pyjwt_without_rsa = PyJWS() pyjwt_without_rsa.unregister_algorithm("RS256") @@ -276,6 +302,8 @@ def _decode(jwt, key, algorithms, options, audience, issuer, leeway): def test_decode_when_token_algorithm_does_not_match(self): token = jwt.encode(self.payload, PRIVATE_KEY, algorithm="RS256") + if IS_OLD_JWT: + token = token.decode("utf-8") with self.assertRaisesRegex(TokenBackendError, "Invalid algorithm specified"): self.hmac_token_backend.decode(token) diff --git a/tox.ini b/tox.ini index bcd2cd76c..757a9c84b 100644 --- a/tox.ini +++ b/tox.ini @@ -1,8 +1,8 @@ [tox] envlist= - py{37,38,39}-dj22-drf310-tests - py{37,38,39,310}-dj{22,32}-drf{311,312,313}-tests - py{38,39,310}-dj{40,main}-drf313-tests + py{37,38,39}-dj22-drf310-pyjwt{171,2}-tests + py{37,38,39,310}-dj{22,32}-drf{311,312,313}-pyjwt{171,2}-tests + py{38,39,310}-dj{40,main}-drf313-pyjwt{171,2}-tests docs [gh-actions] @@ -25,7 +25,6 @@ DRF= 3.13: drf313 [testenv] -usedevelop=True commands = pytest {posargs:tests} --cov-append --cov-report=xml --cov=rest_framework_simplejwt extras= test @@ -40,6 +39,8 @@ deps= drf311: djangorestframework>=3.11,<3.12 drf312: djangorestframework>=3.12,<3.13 drf313: djangorestframework>=3.13,<3.14 + pyjwt171: pyjwt>=1.7.1,<1.8 + pyjwt2: pyjwt>=2,<3 djmain: https://github.com/django/django/archive/main.tar.gz allowlist_externals=make