Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setup initial PyJWT 1.7.1 support #536

Merged
merged 3 commits into from
Mar 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions rest_framework_simplejwt/backends.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import jwt
from django.utils.translation import gettext_lazy as _
from jwt import InvalidAlgorithmError, InvalidTokenError, PyJWKClient, algorithms
from jwt import InvalidAlgorithmError, InvalidTokenError, algorithms

from .exceptions import TokenBackendError
from .utils import format_lazy

try:
from jwt import PyJWKClient

JWK_CLIENT_AVAILABLE = True
except ImportError:
JWK_CLIENT_AVAILABLE = False

ALLOWED_ALGORITHMS = {
"HS256",
"HS384",
Expand Down Expand Up @@ -37,7 +44,10 @@ def __init__(
self.audience = audience
self.issuer = issuer

self.jwks_client = PyJWKClient(jwk_url) if jwk_url else None
if JWK_CLIENT_AVAILABLE:
self.jwks_client = PyJWKClient(jwk_url) if jwk_url else None
else:
self.jwks_client = None
self.leeway = leeway

def _validate_algorithm(self, algorithm):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
install_requires=[
"django",
"djangorestframework",
"pyjwt>=2,<3",
"pyjwt>=1.7.1,<3",
],
python_requires=">=3.7",
extras_require=extras_require,
Expand Down
48 changes: 38 additions & 10 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
from unittest.mock import patch

import jwt
import pytest
from django.test import TestCase
from jwt import PyJWS, algorithms
from jwt import PyJWS
from jwt import __version__ as jwt_version
from jwt import algorithms

from rest_framework_simplejwt.backends import TokenBackend
from rest_framework_simplejwt.backends import JWK_CLIENT_AVAILABLE, TokenBackend
from rest_framework_simplejwt.exceptions import TokenBackendError
from rest_framework_simplejwt.utils import aware_utcnow, datetime_to_epoch, make_utc
from tests.keys import (
Expand All @@ -28,6 +31,8 @@

LEEWAY = 100

IS_OLD_JWT = jwt_version == "1.7.1"


class TestTokenBackend(TestCase):
def setUp(self):
Expand Down Expand Up @@ -159,7 +164,7 @@ def test_decode_with_expiry(self):
def test_decode_with_invalid_sig(self):
self.payload["exp"] = aware_utcnow() - timedelta(seconds=1)
for backend in self.backends:
with self.subTest("Test decode with invalid sig for f{backend.algorithm}"):
with self.subTest(f"Test decode with invalid sig for {backend.algorithm}"):
payload = self.payload.copy()
payload["exp"] = aware_utcnow() + timedelta(days=1)
token_1 = jwt.encode(
Expand All @@ -170,6 +175,10 @@ def test_decode_with_invalid_sig(self):
payload, backend.signing_key, algorithm=backend.algorithm
)

if IS_OLD_JWT:
token_1 = token_1.decode("utf-8")
token_2 = token_2.decode("utf-8")

token_2_payload = token_2.rsplit(".", 1)[0]
token_1_sig = token_1.rsplit(".", 1)[-1]
invalid_token = token_2_payload + "." + token_1_sig
Expand All @@ -189,8 +198,12 @@ def test_decode_with_invalid_sig_no_verify(self):
token_2 = jwt.encode(
payload, backend.signing_key, algorithm=backend.algorithm
)
# Payload copied
payload["exp"] = datetime_to_epoch(payload["exp"])
if IS_OLD_JWT:
token_1 = token_1.decode("utf-8")
token_2 = token_2.decode("utf-8")
else:
# Payload copied
payload["exp"] = datetime_to_epoch(payload["exp"])

token_2_payload = token_2.rsplit(".", 1)[0]
token_1_sig = token_1.rsplit(".", 1)[-1]
Expand All @@ -210,9 +223,13 @@ def test_decode_success(self):
token = jwt.encode(
self.payload, backend.signing_key, algorithm=backend.algorithm
)
# Payload copied
payload = self.payload.copy()
payload["exp"] = datetime_to_epoch(self.payload["exp"])
if IS_OLD_JWT:
token = token.decode("utf-8")
payload = self.payload
else:
# Payload copied
payload = self.payload.copy()
payload["exp"] = datetime_to_epoch(self.payload["exp"])

self.assertEqual(backend.decode(token), payload)

Expand All @@ -223,11 +240,18 @@ def test_decode_aud_iss_success(self):
self.payload["iss"] = ISSUER

token = jwt.encode(self.payload, PRIVATE_KEY, algorithm="RS256")
# Payload copied
self.payload["exp"] = datetime_to_epoch(self.payload["exp"])
if IS_OLD_JWT:
token = token.decode("utf-8")
else:
# Payload copied
self.payload["exp"] = datetime_to_epoch(self.payload["exp"])

self.assertEqual(self.aud_iss_token_backend.decode(token), self.payload)

@pytest.mark.skipif(
not JWK_CLIENT_AVAILABLE,
reason="PyJWT 1.7.1 doesn't have JWK client",
)
def test_decode_rsa_aud_iss_jwk_success(self):
self.payload["exp"] = aware_utcnow() + timedelta(days=1)
self.payload["foo"] = "baz"
Expand Down Expand Up @@ -261,6 +285,8 @@ def test_decode_rsa_aud_iss_jwk_success(self):

def test_decode_when_algorithm_not_available(self):
token = jwt.encode(self.payload, PRIVATE_KEY, algorithm="RS256")
if IS_OLD_JWT:
token = token.decode("utf-8")

pyjwt_without_rsa = PyJWS()
pyjwt_without_rsa.unregister_algorithm("RS256")
Expand All @@ -276,6 +302,8 @@ def _decode(jwt, key, algorithms, options, audience, issuer, leeway):

def test_decode_when_token_algorithm_does_not_match(self):
token = jwt.encode(self.payload, PRIVATE_KEY, algorithm="RS256")
if IS_OLD_JWT:
token = token.decode("utf-8")

with self.assertRaisesRegex(TokenBackendError, "Invalid algorithm specified"):
self.hmac_token_backend.decode(token)
Expand Down
9 changes: 5 additions & 4 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
[tox]
envlist=
py{37,38,39}-dj22-drf310-tests
py{37,38,39,310}-dj{22,32}-drf{311,312,313}-tests
py{38,39,310}-dj{40,main}-drf313-tests
py{37,38,39}-dj22-drf310-pyjwt{171,2}-tests
py{37,38,39,310}-dj{22,32}-drf{311,312,313}-pyjwt{171,2}-tests
py{38,39,310}-dj{40,main}-drf313-pyjwt{171,2}-tests
docs

[gh-actions]
Expand All @@ -25,7 +25,6 @@ DRF=
3.13: drf313

[testenv]
usedevelop=True
commands = pytest {posargs:tests} --cov-append --cov-report=xml --cov=rest_framework_simplejwt
extras=
test
Expand All @@ -40,6 +39,8 @@ deps=
drf311: djangorestframework>=3.11,<3.12
drf312: djangorestframework>=3.12,<3.13
drf313: djangorestframework>=3.13,<3.14
pyjwt171: pyjwt>=1.7.1,<1.8
pyjwt2: pyjwt>=2,<3
djmain: https://github.com/django/django/archive/main.tar.gz
allowlist_externals=make

Expand Down