Skip to content

Commit

Permalink
Fix user-passes-test decorator for permission_required and others (#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
KundaPanda authored Oct 5, 2021
1 parent d516104 commit b318bbb
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 25 deletions.
44 changes: 21 additions & 23 deletions strawberry_django_jwt/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,12 @@
]


def login_required(target):
def with_info(target):
def signature_add_fn(self, info: Info, *args, **kwargs):
# Only called when no info should be passed, no need to check
return target(self, *args, **kwargs)

# get_result is used by strawberry-graphql-django model mutations
get_result = getattr(target, "get_result", None)

if get_result is not None and callable(get_result):
target.get_result = login_required(target.get_result)
return target

# Create a fake target function with info argument

target_inspection = inspect.signature(target)
target_clean = target
if "info" not in target_inspection.parameters.keys():
Expand All @@ -67,29 +59,34 @@ def signature_add_fn(self, info: Info, *args, **kwargs):
# Copy annotations as well
signature_add_fn.__annotations__ = target.__annotations__
target_clean = signature_add_fn
wrapped = user_passes_test(lambda u: u.is_authenticated)(target_clean)
return wrapped
return target_clean


def context(f):
def decorator(func):
def wrapper(*args, **kwargs):
info = kwargs.get("info")
ctx = get_context(info)
return func(ctx, *args, **kwargs)

return wrapper
def context(func):
def wrapper(*args, **kwargs):
info = kwargs.get("info")
ctx = get_context(info)
return func(ctx, *args, **kwargs)

return decorator
return wrapper


def user_passes_test(test_func, exc=exceptions.PermissionDenied):
def decorator(f):
@wraps(f)
@context(f)
# get_result is used by strawberry-graphql-django model mutations
get_result = getattr(f, "get_result", None)

if get_result is not None and callable(get_result):
f.get_result = decorator(f.get_result)
return f

f_with_info = with_info(f)

@wraps(f_with_info)
@context
def wrapper(context, *args, **kwargs):
if context and test_func(context.user):
return dispose_extra_kwargs(f)(*args, **kwargs)
return dispose_extra_kwargs(f_with_info)(*args, **kwargs)
raise exc

return wrapper
Expand All @@ -99,6 +96,7 @@ def wrapper(context, *args, **kwargs):

staff_member_required = user_passes_test(lambda u: u.is_staff)
superuser_required = user_passes_test(lambda u: u.is_superuser)
login_required = user_passes_test(lambda u: u.is_authenticated)


def login_field(fn=None):
Expand Down
34 changes: 33 additions & 1 deletion tests/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from django.contrib.auth.models import AnonymousUser
from strawberry.types import Info

from strawberry_django_jwt.decorators import dispose_extra_kwargs
from strawberry_django_jwt.decorators import dispose_extra_kwargs, permission_required
from strawberry_django_jwt.decorators import login_required
from strawberry_django_jwt.mixins import JSONWebTokenMixin
from strawberry_django_jwt.model_object_types import UserType
Expand Down Expand Up @@ -218,6 +218,38 @@ class Query(JSONWebTokenMixin):
self.assertEqual(data["testModel"], [])
self.assertIsNone(response.errors)

def test_permission_required(self):
@strawberry.type
class Query(JSONWebTokenMixin):
@strawberry.field
@permission_required("tests.run_tests")
def test(self) -> str:
return "TEST"

@strawberry.field
@permission_required("tests.run_tests")
def test_info(self, info: Info) -> str:
return "TEST"

self.client.schema(query=Query, mutation=self.Mutation)

query = """
query Test {
test
testInfo
}
"""

headers = {
jwt_settings.JWT_AUTH_HEADER_NAME: f"{jwt_settings.JWT_AUTH_HEADER_PREFIX} {self.token}",
}

response = self.client.execute(query, **headers)
data = response.data

self.assertEqual(data["test"], "TEST")
self.assertIsNone(response.errors)


if django.VERSION[:2] >= (3, 1):
from .testcases import AsyncSchemaTestCase
Expand Down
11 changes: 10 additions & 1 deletion tests/testcases.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import django
import strawberry
from django.contrib.auth import get_user_model
from django.contrib.auth.models import Permission, User
from django.contrib.contenttypes.models import ContentType
from django.test import RequestFactory
from django.test import testcases
from graphql.execution.execute import GraphQLResolveInfo
Expand All @@ -15,14 +17,21 @@
from strawberry_django_jwt.testcases import JSONWebTokenTestCase
from strawberry_django_jwt.utils import jwt_encode
from strawberry_django_jwt.utils import jwt_payload
from tests.models import MyTestModel


class UserTestCase(testcases.TestCase):
def setUp(self):
self.user = get_user_model().objects.create_user(
self.user: User = get_user_model().objects.create_user(
username="test",
password="dolphins",
)
self.test_permission = Permission.objects.create(
codename="run_tests",
name="Can run tests",
content_type=ContentType.objects.get_for_model(MyTestModel),
)
self.user.user_permissions.add(self.test_permission)


class TestCase(UserTestCase):
Expand Down

0 comments on commit b318bbb

Please sign in to comment.