Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async view implementation #8978

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ jobs:
runs-on: ubuntu-20.04

strategy:
fail-fast: false
matrix:
python-version:
- '3.6'
Expand Down
16 changes: 16 additions & 0 deletions docs/api-guide/views.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,22 @@ You may pass `None` in order to exclude the view from schema generation.
def view(request):
return Response({"message": "Will not appear in schema!"})

# Async Views

When using Django 4.1 and above, REST framework allows you to work with async class and function based views.

For class based views, all handler methods must be async, otherwise Django will raise an exception. For function based views, the function itself must be async.

For example:

class AsyncView(APIView):
async def get(self, request):
return Response({"message": "This is an async class based view."})


@api_view(['GET'])
async def async_view(request):
return Response({"message": "This is an async function based view."})

[cite]: https://reinout.vanrees.org/weblog/2011/08/24/class-based-views-usage.html
[cite2]: http://www.boredomandlaziness.org/2012/05/djangos-cbvs-are-not-mistake-but.html
Expand Down
11 changes: 11 additions & 0 deletions rest_framework/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ def distinct(queryset, base):
uritemplate = None


# async_to_sync is required for async view support
if django.VERSION >= (4, 1):
from asgiref.sync import async_to_sync, iscoroutinefunction, sync_to_async
else:
async_to_sync = None
sync_to_async = None

def iscoroutinefunction(func):
return False


# coreschema is optional
try:
import coreschema
Expand Down
9 changes: 7 additions & 2 deletions rest_framework/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from django.forms.utils import pretty_name

from rest_framework.compat import iscoroutinefunction
from rest_framework.views import APIView


Expand Down Expand Up @@ -46,8 +47,12 @@ def decorator(func):
allowed_methods = set(http_method_names) | {'options'}
WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods]

def handler(self, *args, **kwargs):
return func(*args, **kwargs)
if iscoroutinefunction(func):
async def handler(self, *args, **kwargs):
return await func(*args, **kwargs)
else:
def handler(self, *args, **kwargs):
return func(*args, **kwargs)

for method in http_method_names:
setattr(WrappedAPIView, method.lower(), handler)
Expand Down
15 changes: 14 additions & 1 deletion rest_framework/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from django.test.client import Client as DjangoClient
from django.test.client import ClientHandler
from django.test.client import RequestFactory as DjangoRequestFactory

if django.VERSION >= (4, 1):
from django.test.client import AsyncRequestFactory as DjangoAsyncRequestFactory

from django.utils.encoding import force_bytes
from django.utils.http import urlencode

Expand Down Expand Up @@ -136,7 +140,7 @@ def CoreAPIClient(*args, **kwargs):
raise ImproperlyConfigured('coreapi must be installed in order to use CoreAPIClient.')


class APIRequestFactory(DjangoRequestFactory):
class APIRequestFactoryMixin:
renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT

Expand Down Expand Up @@ -240,6 +244,15 @@ def request(self, **kwargs):
return request


class APIRequestFactory(APIRequestFactoryMixin, DjangoRequestFactory):
pass


if django.VERSION >= (4, 1):
class APIAsyncRequestFactory(APIRequestFactoryMixin, DjangoAsyncRequestFactory):
pass


class ForceAuthClientHandler(ClientHandler):
"""
A patched version of ClientHandler that can enforce authentication
Expand Down
105 changes: 79 additions & 26 deletions rest_framework/throttling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from django.core.cache import cache as default_cache
from django.core.exceptions import ImproperlyConfigured

from rest_framework.compat import (
async_to_sync, iscoroutinefunction, sync_to_async
)
from rest_framework.settings import api_settings


Expand Down Expand Up @@ -64,6 +67,8 @@ class SimpleRateThrottle(BaseThrottle):
cache_format = 'throttle_%(scope)s_%(ident)s'
scope = None
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
sync_capable = True
async_capable = True

def __init__(self):
if not getattr(self, 'rate', None):
Expand Down Expand Up @@ -113,23 +118,52 @@ def allow_request(self, request, view):
On success calls `throttle_success`.
On failure calls `throttle_failure`.
"""
if self.rate is None:
return True

self.key = self.get_cache_key(request, view)
if self.key is None:
return True

self.history = self.cache.get(self.key, [])
self.now = self.timer()

# Drop any requests from the history which have now passed the
# throttle duration
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
if getattr(view, 'view_is_async', False):

async def func():
if self.rate is None:
return True

self.key = self.get_cache_key(request, view)
if self.key is None:
return True

self.history = self.cache.get(self.key, [])
if iscoroutinefunction(self.timer):
self.now = await self.timer()
else:
self.now = await sync_to_async(self.timer)()

# Drop any requests from the history which have now passed the
# throttle duration
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()

return func()
else:
if self.rate is None:
return True

self.key = self.get_cache_key(request, view)
if self.key is None:
return True

self.history = self.cache.get(self.key, [])
if iscoroutinefunction(self.timer):
self.now = async_to_sync(self.timer)()
else:
self.now = self.timer()

# Drop any requests from the history which have now passed the
# throttle duration
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()

def throttle_success(self):
"""
Expand Down Expand Up @@ -210,6 +244,8 @@ class ScopedRateThrottle(SimpleRateThrottle):
user id of the request, and the scope of the view being accessed.
"""
scope_attr = 'throttle_scope'
sync_capable = True
async_capable = True

def __init__(self):
# Override the usual SimpleRateThrottle, because we can't determine
Expand All @@ -220,17 +256,34 @@ def allow_request(self, request, view):
# We can only determine the scope once we're called by the view.
self.scope = getattr(view, self.scope_attr, None)

# If a view does not have a `throttle_scope` always allow the request
if not self.scope:
return True
if getattr(view, 'view_is_async', False):

# Determine the allowed request rate as we normally would during
# the `__init__` call.
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
async def func(allow_request):
# If a view does not have a `throttle_scope` always allow the request
if not self.scope:
return True

# Determine the allowed request rate as we normally would during
# the `__init__` call.
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)

# We can now proceed as normal.
return await allow_request(request, view)

return func(super().allow_request)
else:
# If a view does not have a `throttle_scope` always allow the request
if not self.scope:
return True

# Determine the allowed request rate as we normally would during
# the `__init__` call.
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)

# We can now proceed as normal.
return super().allow_request(request, view)
# We can now proceed as normal.
return super().allow_request(request, view)

def get_cache_key(self, request, view):
"""
Expand Down
Loading