Skip to content

Commit

Permalink
Async view implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
jameshilliard committed May 10, 2023
1 parent 99e8b40 commit 1dfc074
Show file tree
Hide file tree
Showing 5 changed files with 395 additions and 43 deletions.
16 changes: 16 additions & 0 deletions docs/api-guide/views.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,22 @@ You may pass `None` in order to exclude the view from schema generation.
def view(request):
return Response({"message": "Will not appear in schema!"})

# Async Views

When using Django 4.1 and above, REST framework allows you to work with async class and function based views.

For class based views, all handler methods must be async, otherwise Django will raise an exception. For function based views, the function itself must be async.

For example:

class AsyncView(APIView):
async def get(self, request):
return Response({"message": "This is an async class based view."})


@api_view(['GET'])
async def async_view(request):
return Response({"message": "This is an async function based view."})

[cite]: https://reinout.vanrees.org/weblog/2011/08/24/class-based-views-usage.html
[cite2]: http://www.boredomandlaziness.org/2012/05/djangos-cbvs-are-not-mistake-but.html
Expand Down
11 changes: 11 additions & 0 deletions rest_framework/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ def distinct(queryset, base):
uritemplate = None


# async_to_sync is required for async view support
if django.VERSION >= (4, 1):
from asgiref.sync import async_to_sync, iscoroutinefunction, sync_to_async
else:
async_to_sync = None
sync_to_async = None

def iscoroutinefunction(func):
return False


# coreschema is optional
try:
import coreschema
Expand Down
9 changes: 7 additions & 2 deletions rest_framework/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from django.forms.utils import pretty_name

from rest_framework.compat import iscoroutinefunction
from rest_framework.views import APIView


Expand Down Expand Up @@ -46,8 +47,12 @@ def decorator(func):
allowed_methods = set(http_method_names) | {'options'}
WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods]

def handler(self, *args, **kwargs):
return func(*args, **kwargs)
if iscoroutinefunction(func):
async def handler(self, *args, **kwargs):
return await func(*args, **kwargs)
else:
def handler(self, *args, **kwargs):
return func(*args, **kwargs)

for method in http_method_names:
setattr(WrappedAPIView, method.lower(), handler)
Expand Down
214 changes: 177 additions & 37 deletions rest_framework/views.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
Provides an APIView class that is the base of all views in REST framework.
"""
import asyncio

from django.conf import settings
from django.core.exceptions import PermissionDenied
from django.db import connections, models
Expand All @@ -12,6 +14,9 @@
from django.views.generic import View

from rest_framework import exceptions, status
from rest_framework.compat import (
async_to_sync, iscoroutinefunction, sync_to_async
)
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.schemas import DefaultSchema
Expand Down Expand Up @@ -328,13 +333,52 @@ def check_permissions(self, request):
Check if the request should be permitted.
Raises an appropriate exception if the request is not permitted.
"""
async_permissions, sync_permissions = [], []
for permission in self.get_permissions():
if not permission.has_permission(request, self):
self.permission_denied(
request,
message=getattr(permission, 'message', None),
code=getattr(permission, 'code', None)
)
if iscoroutinefunction(permission.has_permission):
async_permissions.append(permission)
else:
sync_permissions.append(permission)

async def check_async():
results = await asyncio.gather(
*(permission.has_permission(request, self) for permission in
async_permissions), return_exceptions=True
)

for idx in range(len(async_permissions)):
if isinstance(results[idx], Exception):
raise results[idx]
elif not results[idx]:
self.permission_denied(
request,
message=getattr(async_permissions[idx], "message", None),
code=getattr(async_permissions[idx], "code", None),
)

def check_sync():
for permission in sync_permissions:
if not permission.has_permission(request, self):
self.permission_denied(
request,
message=getattr(permission, 'message', None),
code=getattr(permission, 'code', None)
)

if getattr(self, 'view_is_async', False):

async def func():
if async_permissions:
await check_async()
if sync_permissions:
await sync_to_async(check_sync)()

return func()
else:
if sync_permissions:
check_sync()
if async_permissions:
async_to_sync(check_async)

def check_object_permissions(self, request, obj):
"""
Expand All @@ -354,21 +398,65 @@ def check_throttles(self, request):
Check if request should be throttled.
Raises an appropriate exception if the request is throttled.
"""
throttle_durations = []
async_throttle_durations, sync_throttle_durations = [], []
for throttle in self.get_throttles():
if not throttle.allow_request(request, self):
throttle_durations.append(throttle.wait())
if iscoroutinefunction(throttle.allow_request):
async_throttle_durations.append(throttle)
else:
sync_throttle_durations.append(throttle)

async def async_throttles():
for throttle in async_throttle_durations:
if not await throttle.allow_request(request, self):
yield throttle.wait()

def sync_throttles():
for throttle in sync_throttle_durations:
if not throttle.allow_request(request, self):
yield throttle.wait()

if getattr(self, 'view_is_async', False):

if throttle_durations:
# Filter out `None` values which may happen in case of config / rate
# changes, see #1438
durations = [
duration for duration in throttle_durations
if duration is not None
]
async def func():
throttle_durations = []

duration = max(durations, default=None)
self.throttled(request, duration)
if async_throttle_durations:
throttle_durations.extend(duration async for duration in async_throttles())

if sync_throttle_durations:
throttle_durations.extend(duration async for duration in await sync_to_async(sync_throttles)())

if throttle_durations:
# Filter out `None` values which may happen in case of config / rate
# changes, see #1438
durations = [
duration for duration in throttle_durations
if duration is not None
]

duration = max(durations, default=None)
self.throttled(request, duration)

return func()
else:
throttle_durations = []

if sync_throttle_durations:
throttle_durations.extend(sync_throttles())

if async_throttle_durations:
throttle_durations.extend(async_to_sync(async_throttles)())

if throttle_durations:
# Filter out `None` values which may happen in case of config / rate
# changes, see #1438
durations = [
duration for duration in throttle_durations
if duration is not None
]

duration = max(durations, default=None)
self.throttled(request, duration)

def determine_version(self, request, *args, **kwargs):
"""
Expand Down Expand Up @@ -410,10 +498,20 @@ def initial(self, request, *args, **kwargs):
version, scheme = self.determine_version(request, *args, **kwargs)
request.version, request.versioning_scheme = version, scheme

# Ensure that the incoming request is permitted
self.perform_authentication(request)
self.check_permissions(request)
self.check_throttles(request)
if getattr(self, 'view_is_async', False):

async def func():
# Ensure that the incoming request is permitted
await sync_to_async(self.perform_authentication)(request)
await self.check_permissions(request)
await self.check_throttles(request)

return func()
else:
# Ensure that the incoming request is permitted
self.perform_authentication(request)
self.check_permissions(request)
self.check_throttles(request)

def finalize_response(self, request, response, *args, **kwargs):
"""
Expand Down Expand Up @@ -469,7 +567,15 @@ def handle_exception(self, exc):
self.raise_uncaught_exception(exc)

response.exception = True
return response

if getattr(self, 'view_is_async', False):

async def func():
return response

return func()
else:
return response

def raise_uncaught_exception(self, exc):
if settings.DEBUG:
Expand All @@ -493,23 +599,49 @@ def dispatch(self, request, *args, **kwargs):
self.request = request
self.headers = self.default_response_headers # deprecate?

try:
self.initial(request, *args, **kwargs)
if getattr(self, 'view_is_async', False):

# Get the appropriate handler method
if request.method.lower() in self.http_method_names:
handler = getattr(self, request.method.lower(),
self.http_method_not_allowed)
else:
handler = self.http_method_not_allowed
async def func():

try:
await self.initial(request, *args, **kwargs)

# Get the appropriate handler method
if request.method.lower() in self.http_method_names:
handler = getattr(self, request.method.lower(),
self.http_method_not_allowed)
else:
handler = self.http_method_not_allowed

response = await handler(request, *args, **kwargs)

response = handler(request, *args, **kwargs)
except Exception as exc:
response = await self.handle_exception(exc)

except Exception as exc:
response = self.handle_exception(exc)
return self.finalize_response(request, response, *args, **kwargs)

self.response = self.finalize_response(request, response, *args, **kwargs)
return self.response
self.response = func()

return self.response
else:
try:
self.initial(request, *args, **kwargs)

# Get the appropriate handler method
if request.method.lower() in self.http_method_names:
handler = getattr(self, request.method.lower(),
self.http_method_not_allowed)
else:
handler = self.http_method_not_allowed

response = handler(request, *args, **kwargs)

except Exception as exc:
response = self.handle_exception(exc)

self.response = self.finalize_response(request, response, *args, **kwargs)

return self.response

def options(self, request, *args, **kwargs):
"""
Expand All @@ -518,4 +650,12 @@ def options(self, request, *args, **kwargs):
if self.metadata_class is None:
return self.http_method_not_allowed(request, *args, **kwargs)
data = self.metadata_class().determine_metadata(request, self)
return Response(data, status=status.HTTP_200_OK)

if getattr(self, 'view_is_async', False):

async def func():
return Response(data, status=status.HTTP_200_OK)

return func()
else:
return Response(data, status=status.HTTP_200_OK)
Loading

0 comments on commit 1dfc074

Please sign in to comment.