From 1dfc074342aeca3b0554364491220d0f14b9b2c0 Mon Sep 17 00:00:00 2001 From: James Hilliard Date: Wed, 10 May 2023 08:51:42 -0600 Subject: [PATCH] Async view implementation --- docs/api-guide/views.md | 16 +++ rest_framework/compat.py | 11 ++ rest_framework/decorators.py | 9 +- rest_framework/views.py | 214 +++++++++++++++++++++++++++++------ tests/test_views.py | 188 +++++++++++++++++++++++++++++- 5 files changed, 395 insertions(+), 43 deletions(-) diff --git a/docs/api-guide/views.md b/docs/api-guide/views.md index b293de75ab2..ccf87dc64df 100644 --- a/docs/api-guide/views.md +++ b/docs/api-guide/views.md @@ -217,6 +217,22 @@ You may pass `None` in order to exclude the view from schema generation. def view(request): return Response({"message": "Will not appear in schema!"}) +# Async Views + +When using Django 4.1 and above, REST framework allows you to work with async class and function based views. + +For class based views, all handler methods must be async, otherwise Django will raise an exception. For function based views, the function itself must be async. + +For example: + + class AsyncView(APIView): + async def get(self, request): + return Response({"message": "This is an async class based view."}) + + + @api_view(['GET']) + async def async_view(request): + return Response({"message": "This is an async function based view."}) [cite]: https://reinout.vanrees.org/weblog/2011/08/24/class-based-views-usage.html [cite2]: http://www.boredomandlaziness.org/2012/05/djangos-cbvs-are-not-mistake-but.html diff --git a/rest_framework/compat.py b/rest_framework/compat.py index ac5cbc572a8..87cd0b8f041 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -41,6 +41,17 @@ def distinct(queryset, base): uritemplate = None +# async_to_sync is required for async view support +if django.VERSION >= (4, 1): + from asgiref.sync import async_to_sync, iscoroutinefunction, sync_to_async +else: + async_to_sync = None + sync_to_async = None + + def iscoroutinefunction(func): + return False + + # coreschema is optional try: import coreschema diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 3b572c09ef8..5a378833b1b 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -10,6 +10,7 @@ from django.forms.utils import pretty_name +from rest_framework.compat import iscoroutinefunction from rest_framework.views import APIView @@ -46,8 +47,12 @@ def decorator(func): allowed_methods = set(http_method_names) | {'options'} WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods] - def handler(self, *args, **kwargs): - return func(*args, **kwargs) + if iscoroutinefunction(func): + async def handler(self, *args, **kwargs): + return await func(*args, **kwargs) + else: + def handler(self, *args, **kwargs): + return func(*args, **kwargs) for method in http_method_names: setattr(WrappedAPIView, method.lower(), handler) diff --git a/rest_framework/views.py b/rest_framework/views.py index 4c30029fdc5..e37a05fb025 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -1,6 +1,8 @@ """ Provides an APIView class that is the base of all views in REST framework. """ +import asyncio + from django.conf import settings from django.core.exceptions import PermissionDenied from django.db import connections, models @@ -12,6 +14,9 @@ from django.views.generic import View from rest_framework import exceptions, status +from rest_framework.compat import ( + async_to_sync, iscoroutinefunction, sync_to_async +) from rest_framework.request import Request from rest_framework.response import Response from rest_framework.schemas import DefaultSchema @@ -328,13 +333,52 @@ def check_permissions(self, request): Check if the request should be permitted. Raises an appropriate exception if the request is not permitted. """ + async_permissions, sync_permissions = [], [] for permission in self.get_permissions(): - if not permission.has_permission(request, self): - self.permission_denied( - request, - message=getattr(permission, 'message', None), - code=getattr(permission, 'code', None) - ) + if iscoroutinefunction(permission.has_permission): + async_permissions.append(permission) + else: + sync_permissions.append(permission) + + async def check_async(): + results = await asyncio.gather( + *(permission.has_permission(request, self) for permission in + async_permissions), return_exceptions=True + ) + + for idx in range(len(async_permissions)): + if isinstance(results[idx], Exception): + raise results[idx] + elif not results[idx]: + self.permission_denied( + request, + message=getattr(async_permissions[idx], "message", None), + code=getattr(async_permissions[idx], "code", None), + ) + + def check_sync(): + for permission in sync_permissions: + if not permission.has_permission(request, self): + self.permission_denied( + request, + message=getattr(permission, 'message', None), + code=getattr(permission, 'code', None) + ) + + if getattr(self, 'view_is_async', False): + + async def func(): + if async_permissions: + await check_async() + if sync_permissions: + await sync_to_async(check_sync)() + + return func() + else: + if sync_permissions: + check_sync() + if async_permissions: + async_to_sync(check_async) def check_object_permissions(self, request, obj): """ @@ -354,21 +398,65 @@ def check_throttles(self, request): Check if request should be throttled. Raises an appropriate exception if the request is throttled. """ - throttle_durations = [] + async_throttle_durations, sync_throttle_durations = [], [] for throttle in self.get_throttles(): - if not throttle.allow_request(request, self): - throttle_durations.append(throttle.wait()) + if iscoroutinefunction(throttle.allow_request): + async_throttle_durations.append(throttle) + else: + sync_throttle_durations.append(throttle) + + async def async_throttles(): + for throttle in async_throttle_durations: + if not await throttle.allow_request(request, self): + yield throttle.wait() + + def sync_throttles(): + for throttle in sync_throttle_durations: + if not throttle.allow_request(request, self): + yield throttle.wait() + + if getattr(self, 'view_is_async', False): - if throttle_durations: - # Filter out `None` values which may happen in case of config / rate - # changes, see #1438 - durations = [ - duration for duration in throttle_durations - if duration is not None - ] + async def func(): + throttle_durations = [] - duration = max(durations, default=None) - self.throttled(request, duration) + if async_throttle_durations: + throttle_durations.extend(duration async for duration in async_throttles()) + + if sync_throttle_durations: + throttle_durations.extend(duration async for duration in await sync_to_async(sync_throttles)()) + + if throttle_durations: + # Filter out `None` values which may happen in case of config / rate + # changes, see #1438 + durations = [ + duration for duration in throttle_durations + if duration is not None + ] + + duration = max(durations, default=None) + self.throttled(request, duration) + + return func() + else: + throttle_durations = [] + + if sync_throttle_durations: + throttle_durations.extend(sync_throttles()) + + if async_throttle_durations: + throttle_durations.extend(async_to_sync(async_throttles)()) + + if throttle_durations: + # Filter out `None` values which may happen in case of config / rate + # changes, see #1438 + durations = [ + duration for duration in throttle_durations + if duration is not None + ] + + duration = max(durations, default=None) + self.throttled(request, duration) def determine_version(self, request, *args, **kwargs): """ @@ -410,10 +498,20 @@ def initial(self, request, *args, **kwargs): version, scheme = self.determine_version(request, *args, **kwargs) request.version, request.versioning_scheme = version, scheme - # Ensure that the incoming request is permitted - self.perform_authentication(request) - self.check_permissions(request) - self.check_throttles(request) + if getattr(self, 'view_is_async', False): + + async def func(): + # Ensure that the incoming request is permitted + await sync_to_async(self.perform_authentication)(request) + await self.check_permissions(request) + await self.check_throttles(request) + + return func() + else: + # Ensure that the incoming request is permitted + self.perform_authentication(request) + self.check_permissions(request) + self.check_throttles(request) def finalize_response(self, request, response, *args, **kwargs): """ @@ -469,7 +567,15 @@ def handle_exception(self, exc): self.raise_uncaught_exception(exc) response.exception = True - return response + + if getattr(self, 'view_is_async', False): + + async def func(): + return response + + return func() + else: + return response def raise_uncaught_exception(self, exc): if settings.DEBUG: @@ -493,23 +599,49 @@ def dispatch(self, request, *args, **kwargs): self.request = request self.headers = self.default_response_headers # deprecate? - try: - self.initial(request, *args, **kwargs) + if getattr(self, 'view_is_async', False): - # Get the appropriate handler method - if request.method.lower() in self.http_method_names: - handler = getattr(self, request.method.lower(), - self.http_method_not_allowed) - else: - handler = self.http_method_not_allowed + async def func(): + + try: + await self.initial(request, *args, **kwargs) + + # Get the appropriate handler method + if request.method.lower() in self.http_method_names: + handler = getattr(self, request.method.lower(), + self.http_method_not_allowed) + else: + handler = self.http_method_not_allowed + + response = await handler(request, *args, **kwargs) - response = handler(request, *args, **kwargs) + except Exception as exc: + response = await self.handle_exception(exc) - except Exception as exc: - response = self.handle_exception(exc) + return self.finalize_response(request, response, *args, **kwargs) - self.response = self.finalize_response(request, response, *args, **kwargs) - return self.response + self.response = func() + + return self.response + else: + try: + self.initial(request, *args, **kwargs) + + # Get the appropriate handler method + if request.method.lower() in self.http_method_names: + handler = getattr(self, request.method.lower(), + self.http_method_not_allowed) + else: + handler = self.http_method_not_allowed + + response = handler(request, *args, **kwargs) + + except Exception as exc: + response = self.handle_exception(exc) + + self.response = self.finalize_response(request, response, *args, **kwargs) + + return self.response def options(self, request, *args, **kwargs): """ @@ -518,4 +650,12 @@ def options(self, request, *args, **kwargs): if self.metadata_class is None: return self.http_method_not_allowed(request, *args, **kwargs) data = self.metadata_class().determine_metadata(request, self) - return Response(data, status=status.HTTP_200_OK) + + if getattr(self, 'view_is_async', False): + + async def func(): + return Response(data, status=status.HTTP_200_OK) + + return func() + else: + return Response(data, status=status.HTTP_200_OK) diff --git a/tests/test_views.py b/tests/test_views.py index 2648c9fb38d..e1c07f7a2e3 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -1,8 +1,12 @@ import copy +import django +import pytest +from django.contrib.auth.models import User from django.test import TestCase from rest_framework import status +from rest_framework.compat import async_to_sync from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.settings import APISettings, api_settings @@ -22,16 +26,36 @@ def post(self, request, *args, **kwargs): return Response({'method': 'POST', 'data': request.data}) +class BasicAsyncView(APIView): + async def get(self, request, *args, **kwargs): + return Response({'method': 'GET'}) + + async def post(self, request, *args, **kwargs): + return Response({'method': 'POST', 'data': request.data}) + + @api_view(['GET', 'POST', 'PUT', 'PATCH']) def basic_view(request): if request.method == 'GET': - return {'method': 'GET'} + return Response({'method': 'GET'}) elif request.method == 'POST': - return {'method': 'POST', 'data': request.data} + return Response({'method': 'POST', 'data': request.data}) elif request.method == 'PUT': - return {'method': 'PUT', 'data': request.data} + return Response({'method': 'PUT', 'data': request.data}) elif request.method == 'PATCH': - return {'method': 'PATCH', 'data': request.data} + return Response({'method': 'PATCH', 'data': request.data}) + + +@api_view(['GET', 'POST', 'PUT', 'PATCH']) +async def basic_async_view(request): + if request.method == 'GET': + return Response({'method': 'GET'}) + elif request.method == 'POST': + return Response({'method': 'POST', 'data': request.data}) + elif request.method == 'PUT': + return Response({'method': 'PUT', 'data': request.data}) + elif request.method == 'PATCH': + return Response({'method': 'PATCH', 'data': request.data}) class ErrorView(APIView): @@ -72,6 +96,36 @@ class ClassBasedViewIntegrationTests(TestCase): def setUp(self): self.view = BasicView.as_view() + def test_get_succeeds(self): + request = factory.get('/') + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_logged_in_get_succeeds(self): + user = User.objects.create_user('user', 'user@example.com', 'password') + request = factory.get('/') + del user.is_active + request.user = user + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_post_succeeds(self): + request = factory.post('/', {'test': 'foo'}) + response = self.view(request) + expected = { + 'method': 'POST', + 'data': {'test': ['foo']} + } + assert response.status_code == status.HTTP_200_OK + assert response.data == expected + + def test_options_succeeds(self): + request = factory.options('/') + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + def test_400_parse_error(self): request = factory.post('/', 'f00bar', content_type='application/json') response = self.view(request) @@ -86,6 +140,36 @@ class FunctionBasedViewIntegrationTests(TestCase): def setUp(self): self.view = basic_view + def test_get_succeeds(self): + request = factory.get('/') + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_logged_in_get_succeeds(self): + user = User.objects.create_user('user', 'user@example.com', 'password') + request = factory.get('/') + del user.is_active + request.user = user + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_post_succeeds(self): + request = factory.post('/', {'test': 'foo'}) + response = self.view(request) + expected = { + 'method': 'POST', + 'data': {'test': ['foo']} + } + assert response.status_code == status.HTTP_200_OK + assert response.data == expected + + def test_options_succeeds(self): + request = factory.options('/') + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + def test_400_parse_error(self): request = factory.post('/', 'f00bar', content_type='application/json') response = self.view(request) @@ -96,6 +180,102 @@ def test_400_parse_error(self): assert sanitise_json_error(response.data) == expected +@pytest.mark.skipif( + django.VERSION < (4, 1), + reason="Async view support requires Django 4.1 or higher", +) +class ClassBasedAsyncViewIntegrationTests(TestCase): + def setUp(self): + self.view = BasicAsyncView.as_view() + + def test_get_succeeds(self): + request = factory.get('/') + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_logged_in_get_succeeds(self): + user = User.objects.create_user('user', 'user@example.com', 'password') + request = factory.get('/') + del user.is_active + request.user = user + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_post_succeeds(self): + request = factory.post('/', {'test': 'foo'}) + response = async_to_sync(self.view)(request) + expected = { + 'method': 'POST', + 'data': {'test': ['foo']} + } + assert response.status_code == status.HTTP_200_OK + assert response.data == expected + + def test_options_succeeds(self): + request = factory.options('/') + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + + def test_400_parse_error(self): + request = factory.post('/', 'f00bar', content_type='application/json') + response = async_to_sync(self.view)(request) + expected = { + 'detail': JSON_ERROR + } + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert sanitise_json_error(response.data) == expected + + +@pytest.mark.skipif( + django.VERSION < (4, 1), + reason="Async view support requires Django 4.1 or higher", +) +class FunctionBasedAsyncViewIntegrationTests(TestCase): + def setUp(self): + self.view = basic_async_view + + def test_get_succeeds(self): + request = factory.get('/') + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_logged_in_get_succeeds(self): + user = User.objects.create_user('user', 'user@example.com', 'password') + request = factory.get('/') + del user.is_active + request.user = user + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_post_succeeds(self): + request = factory.post('/', {'test': 'foo'}) + response = async_to_sync(self.view)(request) + expected = { + 'method': 'POST', + 'data': {'test': ['foo']} + } + assert response.status_code == status.HTTP_200_OK + assert response.data == expected + + def test_options_succeeds(self): + request = factory.options('/') + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + + def test_400_parse_error(self): + request = factory.post('/', 'f00bar', content_type='application/json') + response = async_to_sync(self.view)(request) + expected = { + 'detail': JSON_ERROR + } + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert sanitise_json_error(response.data) == expected + + class TestCustomExceptionHandler(TestCase): def setUp(self): self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER