Skip to content

Commit

Permalink
Setup initial PyJWT 1.7.1 support (#536)
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew-Chen-Wang authored Mar 1, 2022
1 parent 09d6599 commit 304819c
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 17 deletions.
14 changes: 12 additions & 2 deletions rest_framework_simplejwt/backends.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import jwt
from django.utils.translation import gettext_lazy as _
from jwt import InvalidAlgorithmError, InvalidTokenError, PyJWKClient, algorithms
from jwt import InvalidAlgorithmError, InvalidTokenError, algorithms

from .exceptions import TokenBackendError
from .utils import format_lazy

try:
from jwt import PyJWKClient

JWK_CLIENT_AVAILABLE = True
except ImportError:
JWK_CLIENT_AVAILABLE = False

ALLOWED_ALGORITHMS = {
"HS256",
"HS384",
Expand Down Expand Up @@ -37,7 +44,10 @@ def __init__(
self.audience = audience
self.issuer = issuer

self.jwks_client = PyJWKClient(jwk_url) if jwk_url else None
if JWK_CLIENT_AVAILABLE:
self.jwks_client = PyJWKClient(jwk_url) if jwk_url else None
else:
self.jwks_client = None
self.leeway = leeway

def _validate_algorithm(self, algorithm):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
install_requires=[
"django",
"djangorestframework",
"pyjwt>=2,<3",
"pyjwt>=1.7.1,<3",
],
python_requires=">=3.7",
extras_require=extras_require,
Expand Down
48 changes: 38 additions & 10 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
from unittest.mock import patch

import jwt
import pytest
from django.test import TestCase
from jwt import PyJWS, algorithms
from jwt import PyJWS
from jwt import __version__ as jwt_version
from jwt import algorithms

from rest_framework_simplejwt.backends import TokenBackend
from rest_framework_simplejwt.backends import JWK_CLIENT_AVAILABLE, TokenBackend
from rest_framework_simplejwt.exceptions import TokenBackendError
from rest_framework_simplejwt.utils import aware_utcnow, datetime_to_epoch, make_utc
from tests.keys import (
Expand All @@ -28,6 +31,8 @@

LEEWAY = 100

IS_OLD_JWT = jwt_version == "1.7.1"


class TestTokenBackend(TestCase):
def setUp(self):
Expand Down Expand Up @@ -159,7 +164,7 @@ def test_decode_with_expiry(self):
def test_decode_with_invalid_sig(self):
self.payload["exp"] = aware_utcnow() - timedelta(seconds=1)
for backend in self.backends:
with self.subTest("Test decode with invalid sig for f{backend.algorithm}"):
with self.subTest(f"Test decode with invalid sig for {backend.algorithm}"):
payload = self.payload.copy()
payload["exp"] = aware_utcnow() + timedelta(days=1)
token_1 = jwt.encode(
Expand All @@ -170,6 +175,10 @@ def test_decode_with_invalid_sig(self):
payload, backend.signing_key, algorithm=backend.algorithm
)

if IS_OLD_JWT:
token_1 = token_1.decode("utf-8")
token_2 = token_2.decode("utf-8")

token_2_payload = token_2.rsplit(".", 1)[0]
token_1_sig = token_1.rsplit(".", 1)[-1]
invalid_token = token_2_payload + "." + token_1_sig
Expand All @@ -189,8 +198,12 @@ def test_decode_with_invalid_sig_no_verify(self):
token_2 = jwt.encode(
payload, backend.signing_key, algorithm=backend.algorithm
)
# Payload copied
payload["exp"] = datetime_to_epoch(payload["exp"])
if IS_OLD_JWT:
token_1 = token_1.decode("utf-8")
token_2 = token_2.decode("utf-8")
else:
# Payload copied
payload["exp"] = datetime_to_epoch(payload["exp"])

token_2_payload = token_2.rsplit(".", 1)[0]
token_1_sig = token_1.rsplit(".", 1)[-1]
Expand All @@ -210,9 +223,13 @@ def test_decode_success(self):
token = jwt.encode(
self.payload, backend.signing_key, algorithm=backend.algorithm
)
# Payload copied
payload = self.payload.copy()
payload["exp"] = datetime_to_epoch(self.payload["exp"])
if IS_OLD_JWT:
token = token.decode("utf-8")
payload = self.payload
else:
# Payload copied
payload = self.payload.copy()
payload["exp"] = datetime_to_epoch(self.payload["exp"])

self.assertEqual(backend.decode(token), payload)

Expand All @@ -223,11 +240,18 @@ def test_decode_aud_iss_success(self):
self.payload["iss"] = ISSUER

token = jwt.encode(self.payload, PRIVATE_KEY, algorithm="RS256")
# Payload copied
self.payload["exp"] = datetime_to_epoch(self.payload["exp"])
if IS_OLD_JWT:
token = token.decode("utf-8")
else:
# Payload copied
self.payload["exp"] = datetime_to_epoch(self.payload["exp"])

self.assertEqual(self.aud_iss_token_backend.decode(token), self.payload)

@pytest.mark.skipif(
not JWK_CLIENT_AVAILABLE,
reason="PyJWT 1.7.1 doesn't have JWK client",
)
def test_decode_rsa_aud_iss_jwk_success(self):
self.payload["exp"] = aware_utcnow() + timedelta(days=1)
self.payload["foo"] = "baz"
Expand Down Expand Up @@ -261,6 +285,8 @@ def test_decode_rsa_aud_iss_jwk_success(self):

def test_decode_when_algorithm_not_available(self):
token = jwt.encode(self.payload, PRIVATE_KEY, algorithm="RS256")
if IS_OLD_JWT:
token = token.decode("utf-8")

pyjwt_without_rsa = PyJWS()
pyjwt_without_rsa.unregister_algorithm("RS256")
Expand All @@ -276,6 +302,8 @@ def _decode(jwt, key, algorithms, options, audience, issuer, leeway):

def test_decode_when_token_algorithm_does_not_match(self):
token = jwt.encode(self.payload, PRIVATE_KEY, algorithm="RS256")
if IS_OLD_JWT:
token = token.decode("utf-8")

with self.assertRaisesRegex(TokenBackendError, "Invalid algorithm specified"):
self.hmac_token_backend.decode(token)
Expand Down
9 changes: 5 additions & 4 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
[tox]
envlist=
py{37,38,39}-dj22-drf310-tests
py{37,38,39,310}-dj{22,32}-drf{311,312,313}-tests
py{38,39,310}-dj{40,main}-drf313-tests
py{37,38,39}-dj22-drf310-pyjwt{171,2}-tests
py{37,38,39,310}-dj{22,32}-drf{311,312,313}-pyjwt{171,2}-tests
py{38,39,310}-dj{40,main}-drf313-pyjwt{171,2}-tests
docs

[gh-actions]
Expand All @@ -25,7 +25,6 @@ DRF=
3.13: drf313

[testenv]
usedevelop=True
commands = pytest {posargs:tests} --cov-append --cov-report=xml --cov=rest_framework_simplejwt
extras=
test
Expand All @@ -40,6 +39,8 @@ deps=
drf311: djangorestframework>=3.11,<3.12
drf312: djangorestframework>=3.12,<3.13
drf313: djangorestframework>=3.13,<3.14
pyjwt171: pyjwt>=1.7.1,<1.8
pyjwt2: pyjwt>=2,<3
djmain: https://github.com/django/django/archive/main.tar.gz
allowlist_externals=make

Expand Down

0 comments on commit 304819c

Please sign in to comment.