From 549320946dfdaac6ccb6c0d3150c46db92637187 Mon Sep 17 00:00:00 2001 From: Dingning <49514587+HappyDingning@users.noreply.github.com> Date: Tue, 12 Dec 2023 22:09:18 +0800 Subject: [PATCH] Fixed #35030 -- Made django.contrib.auth decorators to work with async functions. --- django/contrib/auth/decorators.py | 80 ++++++++--- docs/releases/5.1.txt | 5 + docs/topics/auth/default.txt | 13 ++ tests/auth_tests/test_decorators.py | 209 ++++++++++++++++++++++++++++ 4 files changed, 288 insertions(+), 19 deletions(-) diff --git a/django/contrib/auth/decorators.py b/django/contrib/auth/decorators.py index cfcc4a2d3a10..b220cc2bd39e 100644 --- a/django/contrib/auth/decorators.py +++ b/django/contrib/auth/decorators.py @@ -1,6 +1,9 @@ +import asyncio from functools import wraps from urllib.parse import urlparse +from asgiref.sync import async_to_sync, sync_to_async + from django.conf import settings from django.contrib.auth import REDIRECT_FIELD_NAME from django.core.exceptions import PermissionDenied @@ -17,10 +20,7 @@ def user_passes_test( """ def decorator(view_func): - @wraps(view_func) - def _wrapper_view(request, *args, **kwargs): - if test_func(request.user): - return view_func(request, *args, **kwargs) + def _redirect_to_login(request): path = request.build_absolute_uri() resolved_login_url = resolve_url(login_url or settings.LOGIN_URL) # If the login url is the same scheme and net location then just @@ -35,7 +35,32 @@ def _wrapper_view(request, *args, **kwargs): return redirect_to_login(path, resolved_login_url, redirect_field_name) - return _wrapper_view + if asyncio.iscoroutinefunction(view_func): + + async def _view_wrapper(request, *args, **kwargs): + auser = await request.auser() + if asyncio.iscoroutinefunction(test_func): + test_pass = await test_func(auser) + else: + test_pass = await sync_to_async(test_func)(auser) + + if test_pass: + return await view_func(request, *args, **kwargs) + return _redirect_to_login(request) + + else: + + def _view_wrapper(request, *args, **kwargs): + if asyncio.iscoroutinefunction(test_func): + test_pass = async_to_sync(test_func)(request.user) + else: + test_pass = test_func(request.user) + + if test_pass: + return view_func(request, *args, **kwargs) + return _redirect_to_login(request) + + return wraps(view_func)(_view_wrapper) return decorator @@ -64,19 +89,36 @@ def permission_required(perm, login_url=None, raise_exception=False): If the raise_exception parameter is given the PermissionDenied exception is raised. """ + if isinstance(perm, str): + perms = (perm,) + else: + perms = perm + + def decorator(view_func): + if asyncio.iscoroutinefunction(view_func): + + async def check_perms(user): + # First check if the user has the permission (even anon users). + if await sync_to_async(user.has_perms)(perms): + return True + # In case the 403 handler should be called raise the exception. + if raise_exception: + raise PermissionDenied + # As the last resort, show the login form. + return False - def check_perms(user): - if isinstance(perm, str): - perms = (perm,) else: - perms = perm - # First check if the user has the permission (even anon users) - if user.has_perms(perms): - return True - # In case the 403 handler should be called raise the exception - if raise_exception: - raise PermissionDenied - # As the last resort, show the login form - return False - - return user_passes_test(check_perms, login_url=login_url) + + def check_perms(user): + # First check if the user has the permission (even anon users). + if user.has_perms(perms): + return True + # In case the 403 handler should be called raise the exception. + if raise_exception: + raise PermissionDenied + # As the last resort, show the login form. + return False + + return user_passes_test(check_perms, login_url=login_url)(view_func) + + return decorator diff --git a/docs/releases/5.1.txt b/docs/releases/5.1.txt index 4eab41394605..e7a99e8d7b80 100644 --- a/docs/releases/5.1.txt +++ b/docs/releases/5.1.txt @@ -52,6 +52,11 @@ Minor features form save. This is now available in the admin when visiting the user creation and password change pages. +* :func:`~.django.contrib.auth.decorators.login_required`, + :func:`~.django.contrib.auth.decorators.permission_required`, and + :func:`~.django.contrib.auth.decorators.user_passes_test` decorators now + support wrapping asynchronous view functions. + :mod:`django.contrib.contenttypes` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/topics/auth/default.txt b/docs/topics/auth/default.txt index 2b57f62f1372..795a1bdacc68 100644 --- a/docs/topics/auth/default.txt +++ b/docs/topics/auth/default.txt @@ -617,6 +617,10 @@ The ``login_required`` decorator :func:`django.contrib.admin.views.decorators.staff_member_required` decorator a useful alternative to ``login_required()``. +.. versionchanged:: 5.1 + + Support for wrapping asynchronous view functions was added. + .. currentmodule:: django.contrib.auth.mixins The ``LoginRequiredMixin`` mixin @@ -714,6 +718,11 @@ email in the desired domain and if not, redirects to the login page:: @user_passes_test(email_check, login_url="/login/") def my_view(request): ... + .. versionchanged:: 5.1 + + Support for wrapping asynchronous view functions and using asynchronous + test callables was added. + .. currentmodule:: django.contrib.auth.mixins .. class:: UserPassesTestMixin @@ -818,6 +827,10 @@ The ``permission_required`` decorator ``redirect_authenticated_user=True`` and the logged-in user doesn't have all of the required permissions. +.. versionchanged:: 5.1 + + Support for wrapping asynchronous view functions was added. + .. currentmodule:: django.contrib.auth.mixins The ``PermissionRequiredMixin`` mixin diff --git a/tests/auth_tests/test_decorators.py b/tests/auth_tests/test_decorators.py index 6cc92302d639..48fa915c5cb4 100644 --- a/tests/auth_tests/test_decorators.py +++ b/tests/auth_tests/test_decorators.py @@ -1,3 +1,7 @@ +from asyncio import iscoroutinefunction + +from asgiref.sync import sync_to_async + from django.conf import settings from django.contrib.auth import models from django.contrib.auth.decorators import ( @@ -19,6 +23,22 @@ class LoginRequiredTestCase(AuthViewsTestCase): Tests the login_required decorators """ + factory = RequestFactory() + + def test_wrapped_sync_function_is_not_coroutine_function(self): + def sync_view(request): + return HttpResponse() + + wrapped_view = login_required(sync_view) + self.assertIs(iscoroutinefunction(wrapped_view), False) + + def test_wrapped_async_function_is_coroutine_function(self): + async def async_view(request): + return HttpResponse() + + wrapped_view = login_required(async_view) + self.assertIs(iscoroutinefunction(wrapped_view), True) + def test_callable(self): """ login_required is assignable to callable objects. @@ -63,6 +83,35 @@ def test_login_required_next_url(self): view_url="/login_required_login_url/", login_url="/somewhere/" ) + async def test_login_required_async_view(self, login_url=None): + async def async_view(request): + return HttpResponse() + + async def auser_anonymous(): + return models.AnonymousUser() + + async def auser(): + return self.u1 + + if login_url is None: + async_view = login_required(async_view) + login_url = settings.LOGIN_URL + else: + async_view = login_required(async_view, login_url=login_url) + + request = self.factory.get("/rand") + request.auser = auser_anonymous + response = await async_view(request) + self.assertEqual(response.status_code, 302) + self.assertIn(login_url, response.url) + + request.auser = auser + response = await async_view(request) + self.assertEqual(response.status_code, 200) + + async def test_login_required_next_url_async_view(self): + await self.test_login_required_async_view(login_url="/somewhere/") + class PermissionsRequiredDecoratorTest(TestCase): """ @@ -80,6 +129,24 @@ def setUpTestData(cls): ) cls.user.user_permissions.add(*perms) + @classmethod + async def auser(cls): + return cls.user + + def test_wrapped_sync_function_is_not_coroutine_function(self): + def sync_view(request): + return HttpResponse() + + wrapped_view = permission_required([])(sync_view) + self.assertIs(iscoroutinefunction(wrapped_view), False) + + def test_wrapped_async_function_is_coroutine_function(self): + async def async_view(request): + return HttpResponse() + + wrapped_view = permission_required([])(async_view) + self.assertIs(iscoroutinefunction(wrapped_view), True) + def test_many_permissions_pass(self): @permission_required( ["auth_tests.add_customuser", "auth_tests.change_customuser"] @@ -147,6 +214,73 @@ def a_view(request): with self.assertRaises(PermissionDenied): a_view(request) + async def test_many_permissions_pass_async_view(self): + @permission_required( + ["auth_tests.add_customuser", "auth_tests.change_customuser"] + ) + async def async_view(request): + return HttpResponse() + + request = self.factory.get("/rand") + request.auser = self.auser + response = await async_view(request) + self.assertEqual(response.status_code, 200) + + async def test_many_permissions_in_set_pass_async_view(self): + @permission_required( + {"auth_tests.add_customuser", "auth_tests.change_customuser"} + ) + async def async_view(request): + return HttpResponse() + + request = self.factory.get("/rand") + request.auser = self.auser + response = await async_view(request) + self.assertEqual(response.status_code, 200) + + async def test_single_permission_pass_async_view(self): + @permission_required("auth_tests.add_customuser") + async def async_view(request): + return HttpResponse() + + request = self.factory.get("/rand") + request.auser = self.auser + response = await async_view(request) + self.assertEqual(response.status_code, 200) + + async def test_permissioned_denied_redirect_async_view(self): + @permission_required( + [ + "auth_tests.add_customuser", + "auth_tests.change_customuser", + "nonexistent-permission", + ] + ) + async def async_view(request): + return HttpResponse() + + request = self.factory.get("/rand") + request.auser = self.auser + response = await async_view(request) + self.assertEqual(response.status_code, 302) + + async def test_permissioned_denied_exception_raised_async_view(self): + @permission_required( + [ + "auth_tests.add_customuser", + "auth_tests.change_customuser", + "nonexistent-permission", + ], + raise_exception=True, + ) + async def async_view(request): + return HttpResponse() + + request = self.factory.get("/rand") + request.auser = self.auser + with self.assertRaises(PermissionDenied): + await async_view(request) + class UserPassesTestDecoratorTest(TestCase): factory = RequestFactory() @@ -162,6 +296,28 @@ def setUpTestData(cls): ) cls.user_pass.user_permissions.add(*perms) + @classmethod + async def auser_pass(cls): + return cls.user_pass + + @classmethod + async def auser_deny(cls): + return cls.user_deny + + def test_wrapped_sync_function_is_not_coroutine_function(self): + def sync_view(request): + return HttpResponse() + + wrapped_view = user_passes_test(lambda user: True)(sync_view) + self.assertIs(iscoroutinefunction(wrapped_view), False) + + def test_wrapped_async_function_is_coroutine_function(self): + async def async_view(request): + return HttpResponse() + + wrapped_view = user_passes_test(lambda user: True)(async_view) + self.assertIs(iscoroutinefunction(wrapped_view), True) + def test_decorator(self): def sync_test_func(user): return bool( @@ -180,3 +336,56 @@ def sync_view(request): request.user = self.user_deny response = sync_view(request) self.assertEqual(response.status_code, 302) + + def test_decorator_async_test_func(self): + async def async_test_func(user): + return await sync_to_async(user.has_perms)(["auth_tests.add_customuser"]) + + @user_passes_test(async_test_func) + def sync_view(request): + return HttpResponse() + + request = self.factory.get("/rand") + request.user = self.user_pass + response = sync_view(request) + self.assertEqual(response.status_code, 200) + + request.user = self.user_deny + response = sync_view(request) + self.assertEqual(response.status_code, 302) + + async def test_decorator_async_view(self): + def sync_test_func(user): + return bool( + models.Group.objects.filter(name__istartswith=user.username).exists() + ) + + @user_passes_test(sync_test_func) + async def async_view(request): + return HttpResponse() + + request = self.factory.get("/rand") + request.auser = self.auser_pass + response = await async_view(request) + self.assertEqual(response.status_code, 200) + + request.auser = self.auser_deny + response = await async_view(request) + self.assertEqual(response.status_code, 302) + + async def test_decorator_async_view_async_test_func(self): + async def async_test_func(user): + return await sync_to_async(user.has_perms)(["auth_tests.add_customuser"]) + + @user_passes_test(async_test_func) + async def async_view(request): + return HttpResponse() + + request = self.factory.get("/rand") + request.auser = self.auser_pass + response = await async_view(request) + self.assertEqual(response.status_code, 200) + + request.auser = self.auser_deny + response = await async_view(request) + self.assertEqual(response.status_code, 302)