diff --git a/docs/api-guide/views.md b/docs/api-guide/views.md index b293de75ab2..ccf87dc64df 100644 --- a/docs/api-guide/views.md +++ b/docs/api-guide/views.md @@ -217,6 +217,22 @@ You may pass `None` in order to exclude the view from schema generation. def view(request): return Response({"message": "Will not appear in schema!"}) +# Async Views + +When using Django 4.1 and above, REST framework allows you to work with async class and function based views. + +For class based views, all handler methods must be async, otherwise Django will raise an exception. For function based views, the function itself must be async. + +For example: + + class AsyncView(APIView): + async def get(self, request): + return Response({"message": "This is an async class based view."}) + + + @api_view(['GET']) + async def async_view(request): + return Response({"message": "This is an async function based view."}) [cite]: https://reinout.vanrees.org/weblog/2011/08/24/class-based-views-usage.html [cite2]: http://www.boredomandlaziness.org/2012/05/djangos-cbvs-are-not-mistake-but.html diff --git a/rest_framework/compat.py b/rest_framework/compat.py index ac5cbc572a8..87cd0b8f041 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -41,6 +41,17 @@ def distinct(queryset, base): uritemplate = None +# async_to_sync is required for async view support +if django.VERSION >= (4, 1): + from asgiref.sync import async_to_sync, iscoroutinefunction, sync_to_async +else: + async_to_sync = None + sync_to_async = None + + def iscoroutinefunction(func): + return False + + # coreschema is optional try: import coreschema diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 3b572c09ef8..5a378833b1b 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -10,6 +10,7 @@ from django.forms.utils import pretty_name +from rest_framework.compat import iscoroutinefunction from rest_framework.views import APIView @@ -46,8 +47,12 @@ def decorator(func): allowed_methods = set(http_method_names) | {'options'} WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods] - def handler(self, *args, **kwargs): - return func(*args, **kwargs) + if iscoroutinefunction(func): + async def handler(self, *args, **kwargs): + return await func(*args, **kwargs) + else: + def handler(self, *args, **kwargs): + return func(*args, **kwargs) for method in http_method_names: setattr(WrappedAPIView, method.lower(), handler) diff --git a/rest_framework/views.py b/rest_framework/views.py index 4c30029fdc5..a3689f8660d 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -1,6 +1,8 @@ """ Provides an APIView class that is the base of all views in REST framework. """ +import asyncio + from django.conf import settings from django.core.exceptions import PermissionDenied from django.db import connections, models @@ -12,6 +14,9 @@ from django.views.generic import View from rest_framework import exceptions, status +from rest_framework.compat import ( + async_to_sync, iscoroutinefunction, sync_to_async +) from rest_framework.request import Request from rest_framework.response import Response from rest_framework.schemas import DefaultSchema @@ -328,13 +333,52 @@ def check_permissions(self, request): Check if the request should be permitted. Raises an appropriate exception if the request is not permitted. """ + async_permissions, sync_permissions = [], [] for permission in self.get_permissions(): - if not permission.has_permission(request, self): - self.permission_denied( - request, - message=getattr(permission, 'message', None), - code=getattr(permission, 'code', None) - ) + if iscoroutinefunction(permission.has_permission): + async_permissions.append(permission) + else: + sync_permissions.append(permission) + + async def check_async(): + results = await asyncio.gather( + *(permission.has_permission(request, self) for permission in + async_permissions), return_exceptions=True + ) + + for idx in range(len(async_permissions)): + if isinstance(results[idx], Exception): + raise results[idx] + elif not results[idx]: + self.permission_denied( + request, + message=getattr(async_permissions[idx], "message", None), + code=getattr(async_permissions[idx], "code", None), + ) + + def check_sync(): + for permission in sync_permissions: + if not permission.has_permission(request, self): + self.permission_denied( + request, + message=getattr(permission, 'message', None), + code=getattr(permission, 'code', None) + ) + + if getattr(self, 'view_is_async', False): + + async def func(): + if async_permissions: + await check_async() + if sync_permissions: + await sync_to_async(check_sync)() + + return func() + else: + if sync_permissions: + check_sync() + if async_permissions: + async_to_sync(check_async) def check_object_permissions(self, request, obj): """ @@ -354,21 +398,65 @@ def check_throttles(self, request): Check if request should be throttled. Raises an appropriate exception if the request is throttled. """ - throttle_durations = [] + async_throttle_durations, sync_throttle_durations = [], [] for throttle in self.get_throttles(): - if not throttle.allow_request(request, self): - throttle_durations.append(throttle.wait()) + if iscoroutinefunction(throttle.allow_request): + async_throttle_durations.append(throttle) + else: + sync_throttle_durations.append(throttle) + + async def async_throttles(): + for throttle in async_throttle_durations: + if not await throttle.allow_request(request, self): + yield throttle.wait() + + def sync_throttles(): + for throttle in sync_throttle_durations: + if not throttle.allow_request(request, self): + yield throttle.wait() + + if getattr(self, 'view_is_async', False): - if throttle_durations: - # Filter out `None` values which may happen in case of config / rate - # changes, see #1438 - durations = [ - duration for duration in throttle_durations - if duration is not None - ] + async def func(): + throttle_durations = [] - duration = max(durations, default=None) - self.throttled(request, duration) + if async_throttle_durations: + throttle_durations.extend(duration async for duration in async_throttles()) + + if sync_throttle_durations: + throttle_durations.extend(duration for duration in sync_throttles()) + + if throttle_durations: + # Filter out `None` values which may happen in case of config / rate + # changes, see #1438 + durations = [ + duration for duration in throttle_durations + if duration is not None + ] + + duration = max(durations, default=None) + self.throttled(request, duration) + + return func() + else: + throttle_durations = [] + + if sync_throttle_durations: + throttle_durations.extend(sync_throttles()) + + if async_throttle_durations: + throttle_durations.extend(async_to_sync(async_throttles)()) + + if throttle_durations: + # Filter out `None` values which may happen in case of config / rate + # changes, see #1438 + durations = [ + duration for duration in throttle_durations + if duration is not None + ] + + duration = max(durations, default=None) + self.throttled(request, duration) def determine_version(self, request, *args, **kwargs): """ @@ -410,10 +498,20 @@ def initial(self, request, *args, **kwargs): version, scheme = self.determine_version(request, *args, **kwargs) request.version, request.versioning_scheme = version, scheme - # Ensure that the incoming request is permitted - self.perform_authentication(request) - self.check_permissions(request) - self.check_throttles(request) + if getattr(self, 'view_is_async', False): + + async def func(): + # Ensure that the incoming request is permitted + await sync_to_async(self.perform_authentication)(request) + await self.check_permissions(request) + await self.check_throttles(request) + + return func() + else: + # Ensure that the incoming request is permitted + self.perform_authentication(request) + self.check_permissions(request) + self.check_throttles(request) def finalize_response(self, request, response, *args, **kwargs): """ @@ -469,7 +567,15 @@ def handle_exception(self, exc): self.raise_uncaught_exception(exc) response.exception = True - return response + + if getattr(self, 'view_is_async', False): + + async def func(): + return response + + return func() + else: + return response def raise_uncaught_exception(self, exc): if settings.DEBUG: @@ -493,23 +599,49 @@ def dispatch(self, request, *args, **kwargs): self.request = request self.headers = self.default_response_headers # deprecate? - try: - self.initial(request, *args, **kwargs) + if getattr(self, 'view_is_async', False): - # Get the appropriate handler method - if request.method.lower() in self.http_method_names: - handler = getattr(self, request.method.lower(), - self.http_method_not_allowed) - else: - handler = self.http_method_not_allowed + async def func(): + + try: + await self.initial(request, *args, **kwargs) + + # Get the appropriate handler method + if request.method.lower() in self.http_method_names: + handler = getattr(self, request.method.lower(), + self.http_method_not_allowed) + else: + handler = self.http_method_not_allowed + + response = await handler(request, *args, **kwargs) - response = handler(request, *args, **kwargs) + except Exception as exc: + response = await self.handle_exception(exc) - except Exception as exc: - response = self.handle_exception(exc) + return self.finalize_response(request, response, *args, **kwargs) - self.response = self.finalize_response(request, response, *args, **kwargs) - return self.response + self.response = func() + + return self.response + else: + try: + self.initial(request, *args, **kwargs) + + # Get the appropriate handler method + if request.method.lower() in self.http_method_names: + handler = getattr(self, request.method.lower(), + self.http_method_not_allowed) + else: + handler = self.http_method_not_allowed + + response = handler(request, *args, **kwargs) + + except Exception as exc: + response = self.handle_exception(exc) + + self.response = self.finalize_response(request, response, *args, **kwargs) + + return self.response def options(self, request, *args, **kwargs): """ @@ -518,4 +650,12 @@ def options(self, request, *args, **kwargs): if self.metadata_class is None: return self.http_method_not_allowed(request, *args, **kwargs) data = self.metadata_class().determine_metadata(request, self) - return Response(data, status=status.HTTP_200_OK) + + if getattr(self, 'view_is_async', False): + + async def func(): + return Response(data, status=status.HTTP_200_OK) + + return func() + else: + return Response(data, status=status.HTTP_200_OK) diff --git a/tests/test_throttling.py b/tests/test_throttling.py index d5a61232d92..cba16f6fba9 100644 --- a/tests/test_throttling.py +++ b/tests/test_throttling.py @@ -1,7 +1,7 @@ """ Tests for the throttling implementations in the permissions module. """ - +import django import pytest from django.contrib.auth.models import User from django.core.cache import cache @@ -9,6 +9,7 @@ from django.http import HttpRequest from django.test import TestCase +from rest_framework.compat import async_to_sync from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import api_settings @@ -43,6 +44,14 @@ def allow_request(self, request, view): return False +class NonTimeAsyncThrottle(BaseThrottle): + def allow_request(self, request, view): + if not hasattr(self.__class__, 'called'): + self.__class__.called = True + return True + return False + + class MockView_DoubleThrottling(APIView): throttle_classes = (User3SecRateThrottle, User6MinRateThrottle,) @@ -50,6 +59,13 @@ def get(self, request): return Response('foo') +class MockAsyncView_DoubleThrottling(APIView): + throttle_classes = (User3SecRateThrottle, User6MinRateThrottle,) + + async def get(self, request): + return Response('foo') + + class MockView(APIView): throttle_classes = (User3SecRateThrottle,) @@ -57,6 +73,13 @@ def get(self, request): return Response('foo') +class MockAsyncView(APIView): + throttle_classes = (User3SecRateThrottle,) + + async def get(self, request): + return Response('foo') + + class MockView_MinuteThrottling(APIView): throttle_classes = (User3MinRateThrottle,) @@ -64,6 +87,13 @@ def get(self, request): return Response('foo') +class MockAsyncView_MinuteThrottling(APIView): + throttle_classes = (User3MinRateThrottle,) + + async def get(self, request): + return Response('foo') + + class MockView_NonTimeThrottling(APIView): throttle_classes = (NonTimeThrottle,) @@ -71,6 +101,13 @@ def get(self, request): return Response('foo') +class MockAsyncView_NonTimeThrottling(APIView): + throttle_classes = (NonTimeAsyncThrottle,) + + async def get(self, request): + return Response('foo') + + class ThrottlingTests(TestCase): def setUp(self): """ @@ -252,6 +289,191 @@ def test_non_time_throttle(self): self.assertFalse('Retry-After' in response) +@pytest.mark.skipif( + django.VERSION < (4, 1), + reason="Async view support requires Django 4.1 or higher", +) +class AsyncThrottlingTests(TestCase): + def setUp(self): + """ + Reset the cache so that no throttles will be active + """ + cache.clear() + self.factory = APIRequestFactory() + + def test_requests_are_throttled(self): + """ + Ensure request rate is limited + """ + request = self.factory.get('/') + for dummy in range(4): + response = async_to_sync(MockAsyncView.as_view())(request) + assert response.status_code == 429 + + def set_throttle_timer(self, view, value): + """ + Explicitly set the timer, overriding time.time() + """ + for cls in view.throttle_classes: + cls.timer = lambda self: value + + def test_request_throttling_expires(self): + """ + Ensure request rate is limited for a limited duration only + """ + self.set_throttle_timer(MockAsyncView, 0) + + request = self.factory.get('/') + for dummy in range(4): + response = async_to_sync(MockAsyncView.as_view())(request) + assert response.status_code == 429 + + # Advance the timer by one second + self.set_throttle_timer(MockAsyncView, 1) + + response = async_to_sync(MockAsyncView.as_view())(request) + assert response.status_code == 200 + + async def ensure_is_throttled(self, view, expect): + request = self.factory.get('/') + request.user = await User.objects.acreate(username='a') + for dummy in range(3): + await view.as_view()(request) + request.user = await User.objects.acreate(username='b') + response = await view.as_view()(request) + assert response.status_code == expect + + def test_request_throttling_is_per_user(self): + """ + Ensure request rate is only limited per user, not globally for + PerUserThrottles + """ + async_to_sync(self.ensure_is_throttled)(MockAsyncView, 200) + + def test_request_throttling_multiple_throttles(self): + """ + Ensure all throttle classes see each request even when the request is + already being throttled + """ + self.set_throttle_timer(MockAsyncView_DoubleThrottling, 0) + request = self.factory.get('/') + for dummy in range(4): + response = async_to_sync(MockAsyncView_DoubleThrottling.as_view())(request) + assert response.status_code == 429 + assert int(response['retry-after']) == 1 + + # At this point our client made 4 requests (one was throttled) in a + # second. If we advance the timer by one additional second, the client + # should be allowed to make 2 more before being throttled by the 2nd + # throttle class, which has a limit of 6 per minute. + self.set_throttle_timer(MockAsyncView_DoubleThrottling, 1) + for dummy in range(2): + response = async_to_sync(MockAsyncView_DoubleThrottling.as_view())(request) + assert response.status_code == 200 + + response = async_to_sync(MockAsyncView_DoubleThrottling.as_view())(request) + assert response.status_code == 429 + assert int(response['retry-after']) == 59 + + # Just to make sure check again after two more seconds. + self.set_throttle_timer(MockAsyncView_DoubleThrottling, 2) + response = async_to_sync(MockAsyncView_DoubleThrottling.as_view())(request) + assert response.status_code == 429 + assert int(response['retry-after']) == 58 + + def test_throttle_rate_change_negative(self): + self.set_throttle_timer(MockAsyncView_DoubleThrottling, 0) + request = self.factory.get('/') + for dummy in range(24): + response = async_to_sync(MockAsyncView_DoubleThrottling.as_view())(request) + assert response.status_code == 429 + assert int(response['retry-after']) == 60 + + previous_rate = User3SecRateThrottle.rate + try: + User3SecRateThrottle.rate = '1/sec' + + for dummy in range(24): + response = async_to_sync(MockAsyncView_DoubleThrottling.as_view())(request) + + assert response.status_code == 429 + assert int(response['retry-after']) == 60 + finally: + # reset + User3SecRateThrottle.rate = previous_rate + + async def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers): + """ + Ensure the response returns an Retry-After field with status and next attributes + set properly. + """ + request = self.factory.get('/') + for timer, expect in expected_headers: + self.set_throttle_timer(view, timer) + response = await view.as_view()(request) + if expect is not None: + assert response['Retry-After'] == expect + else: + assert not'Retry-After' in response + + def test_seconds_fields(self): + """ + Ensure for second based throttles. + """ + async_to_sync(self.ensure_response_header_contains_proper_throttle_field)( + MockAsyncView, ( + (0, None), + (0, None), + (0, None), + (0, '1') + ) + ) + + def test_minutes_fields(self): + """ + Ensure for minute based throttles. + """ + async_to_sync(self.ensure_response_header_contains_proper_throttle_field)( + MockAsyncView_MinuteThrottling, ( + (0, None), + (0, None), + (0, None), + (0, '60') + ) + ) + + def test_next_rate_remains_constant_if_followed(self): + """ + If a client follows the recommended next request rate, + the throttling rate should stay constant. + """ + async_to_sync(self.ensure_response_header_contains_proper_throttle_field)( + MockAsyncView_MinuteThrottling, ( + (0, None), + (20, None), + (40, None), + (60, None), + (80, None) + ) + ) + + def test_non_time_throttle(self): + """ + Ensure for second based throttles. + """ + request = self.factory.get('/') + + self.assertFalse(hasattr(MockAsyncView_NonTimeThrottling.throttle_classes[0], 'called')) + + response = async_to_sync(MockAsyncView_NonTimeThrottling.as_view())(request) + self.assertFalse('Retry-After' in response) + + self.assertTrue(MockAsyncView_NonTimeThrottling.throttle_classes[0].called) + + response = async_to_sync(MockAsyncView_NonTimeThrottling.as_view())(request) + self.assertFalse('Retry-After' in response) + + class ScopedRateThrottleTests(TestCase): """ Tests for ScopedRateThrottle. diff --git a/tests/test_views.py b/tests/test_views.py index 2648c9fb38d..c1cc5a04e68 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -1,8 +1,12 @@ import copy +import django +import pytest +from django.contrib.auth.models import User from django.test import TestCase from rest_framework import status +from rest_framework.compat import async_to_sync from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.settings import APISettings, api_settings @@ -22,16 +26,36 @@ def post(self, request, *args, **kwargs): return Response({'method': 'POST', 'data': request.data}) +class BasicAsyncView(APIView): + async def get(self, request, *args, **kwargs): + return Response({'method': 'GET'}) + + async def post(self, request, *args, **kwargs): + return Response({'method': 'POST', 'data': request.data}) + + @api_view(['GET', 'POST', 'PUT', 'PATCH']) def basic_view(request): if request.method == 'GET': - return {'method': 'GET'} + return Response({'method': 'GET'}) elif request.method == 'POST': - return {'method': 'POST', 'data': request.data} + return Response({'method': 'POST', 'data': request.data}) elif request.method == 'PUT': - return {'method': 'PUT', 'data': request.data} + return Response({'method': 'PUT', 'data': request.data}) elif request.method == 'PATCH': - return {'method': 'PATCH', 'data': request.data} + return Response({'method': 'PATCH', 'data': request.data}) + + +@api_view(['GET', 'POST', 'PUT', 'PATCH']) +async def basic_async_view(request): + if request.method == 'GET': + return Response({'method': 'GET'}) + elif request.method == 'POST': + return Response({'method': 'POST', 'data': request.data}) + elif request.method == 'PUT': + return Response({'method': 'PUT', 'data': request.data}) + elif request.method == 'PATCH': + return Response({'method': 'PATCH', 'data': request.data}) class ErrorView(APIView): @@ -72,6 +96,36 @@ class ClassBasedViewIntegrationTests(TestCase): def setUp(self): self.view = BasicView.as_view() + def test_get_succeeds(self): + request = factory.get('/') + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_logged_in_get_succeeds(self): + user = User.objects.create_user('user', 'user@example.com', 'password') + request = factory.get('/') + del user.is_active + request.user = user + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_post_succeeds(self): + request = factory.post('/', {'test': 'foo'}) + response = self.view(request) + expected = { + 'method': 'POST', + 'data': {'test': ['foo']} + } + assert response.status_code == status.HTTP_200_OK + assert response.data == expected + + def test_options_succeeds(self): + request = factory.options('/') + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + def test_400_parse_error(self): request = factory.post('/', 'f00bar', content_type='application/json') response = self.view(request) @@ -82,10 +136,88 @@ def test_400_parse_error(self): assert sanitise_json_error(response.data) == expected +@pytest.mark.skipif( + django.VERSION < (4, 1), + reason="Async view support requires Django 4.1 or higher", +) +class ClassBasedAsyncViewIntegrationTests(TestCase): + def setUp(self): + self.view = BasicAsyncView.as_view() + + def test_get_succeeds(self): + request = factory.get('/') + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_logged_in_get_succeeds(self): + user = User.objects.create_user('user', 'user@example.com', 'password') + request = factory.get('/') + del user.is_active + request.user = user + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_post_succeeds(self): + request = factory.post('/', {'test': 'foo'}) + response = async_to_sync(self.view)(request) + expected = { + 'method': 'POST', + 'data': {'test': ['foo']} + } + assert response.status_code == status.HTTP_200_OK + assert response.data == expected + + def test_options_succeeds(self): + request = factory.options('/') + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + + def test_400_parse_error(self): + request = factory.post('/', 'f00bar', content_type='application/json') + response = async_to_sync(self.view)(request) + expected = { + 'detail': JSON_ERROR + } + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert sanitise_json_error(response.data) == expected + + class FunctionBasedViewIntegrationTests(TestCase): def setUp(self): self.view = basic_view + def test_get_succeeds(self): + request = factory.get('/') + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_logged_in_get_succeeds(self): + user = User.objects.create_user('user', 'user@example.com', 'password') + request = factory.get('/') + del user.is_active + request.user = user + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_post_succeeds(self): + request = factory.post('/', {'test': 'foo'}) + response = self.view(request) + expected = { + 'method': 'POST', + 'data': {'test': ['foo']} + } + assert response.status_code == status.HTTP_200_OK + assert response.data == expected + + def test_options_succeeds(self): + request = factory.options('/') + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + def test_400_parse_error(self): request = factory.post('/', 'f00bar', content_type='application/json') response = self.view(request) @@ -96,6 +228,54 @@ def test_400_parse_error(self): assert sanitise_json_error(response.data) == expected +@pytest.mark.skipif( + django.VERSION < (4, 1), + reason="Async view support requires Django 4.1 or higher", +) +class FunctionBasedAsyncViewIntegrationTests(TestCase): + def setUp(self): + self.view = basic_async_view + + def test_get_succeeds(self): + request = factory.get('/') + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_logged_in_get_succeeds(self): + user = User.objects.create_user('user', 'user@example.com', 'password') + request = factory.get('/') + del user.is_active + request.user = user + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_post_succeeds(self): + request = factory.post('/', {'test': 'foo'}) + response = async_to_sync(self.view)(request) + expected = { + 'method': 'POST', + 'data': {'test': ['foo']} + } + assert response.status_code == status.HTTP_200_OK + assert response.data == expected + + def test_options_succeeds(self): + request = factory.options('/') + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + + def test_400_parse_error(self): + request = factory.post('/', 'f00bar', content_type='application/json') + response = async_to_sync(self.view)(request) + expected = { + 'detail': JSON_ERROR + } + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert sanitise_json_error(response.data) == expected + + class TestCustomExceptionHandler(TestCase): def setUp(self): self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER