Skip to content

Commit

Permalink
Async view implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
jameshilliard committed Jun 15, 2023
1 parent 71f87a5 commit d269dbb
Show file tree
Hide file tree
Showing 10 changed files with 1,030 additions and 71 deletions.
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ jobs:
runs-on: ubuntu-20.04

strategy:
fail-fast: false
matrix:
python-version:
- '3.6'
Expand Down
16 changes: 16 additions & 0 deletions docs/api-guide/views.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,22 @@ You may pass `None` in order to exclude the view from schema generation.
def view(request):
return Response({"message": "Will not appear in schema!"})

# Async Views

When using Django 4.1 and above, REST framework allows you to work with async class and function based views.

For class based views, all handler methods must be async, otherwise Django will raise an exception. For function based views, the function itself must be async.

For example:

class AsyncView(APIView):
async def get(self, request):
return Response({"message": "This is an async class based view."})


@api_view(['GET'])
async def async_view(request):
return Response({"message": "This is an async function based view."})

[cite]: https://reinout.vanrees.org/weblog/2011/08/24/class-based-views-usage.html
[cite2]: http://www.boredomandlaziness.org/2012/05/djangos-cbvs-are-not-mistake-but.html
Expand Down
11 changes: 11 additions & 0 deletions rest_framework/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ def distinct(queryset, base):
uritemplate = None


# async_to_sync is required for async view support
if django.VERSION >= (4, 1):
from asgiref.sync import async_to_sync, iscoroutinefunction, sync_to_async
else:
async_to_sync = None
sync_to_async = None

def iscoroutinefunction(func):
return False


# coreschema is optional
try:
import coreschema
Expand Down
9 changes: 7 additions & 2 deletions rest_framework/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from django.forms.utils import pretty_name

from rest_framework.compat import iscoroutinefunction
from rest_framework.views import APIView


Expand Down Expand Up @@ -46,8 +47,12 @@ def decorator(func):
allowed_methods = set(http_method_names) | {'options'}
WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods]

def handler(self, *args, **kwargs):
return func(*args, **kwargs)
if iscoroutinefunction(func):
async def handler(self, *args, **kwargs):
return await func(*args, **kwargs)
else:
def handler(self, *args, **kwargs):
return func(*args, **kwargs)

for method in http_method_names:
setattr(WrappedAPIView, method.lower(), handler)
Expand Down
1 change: 0 additions & 1 deletion rest_framework/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import contextlib
import warnings

from base64 import b64decode, b64encode
from collections import namedtuple
from urllib import parse
Expand Down
109 changes: 109 additions & 0 deletions rest_framework/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from django.test.client import Client as DjangoClient
from django.test.client import ClientHandler
from django.test.client import RequestFactory as DjangoRequestFactory

if django.VERSION >= (4, 1):
from django.test.client import AsyncRequestFactory as DjangoAsyncRequestFactory

from django.utils.encoding import force_bytes
from django.utils.http import urlencode

Expand Down Expand Up @@ -240,6 +244,111 @@ def request(self, **kwargs):
return request


if django.VERSION >= (4, 1):
class APIAsyncRequestFactory(DjangoAsyncRequestFactory):
renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT

def __init__(self, enforce_csrf_checks=False, **defaults):
self.enforce_csrf_checks = enforce_csrf_checks
self.renderer_classes = {}
for cls in self.renderer_classes_list:
self.renderer_classes[cls.format] = cls
super().__init__(**defaults)

def _encode_data(self, data, format=None, content_type=None):
"""
Encode the data returning a two tuple of (bytes, content_type)
"""

if data is None:
return ('', content_type)

assert format is None or content_type is None, (
'You may not set both `format` and `content_type`.'
)

if content_type:
# Content type specified explicitly, treat data as a raw bytestring
ret = force_bytes(data, settings.DEFAULT_CHARSET)

else:
format = format or self.default_format

assert format in self.renderer_classes, (
"Invalid format '{}'. Available formats are {}. "
"Set TEST_REQUEST_RENDERER_CLASSES to enable "
"extra request formats.".format(
format,
', '.join(["'" + fmt + "'" for fmt in self.renderer_classes])
)
)

# Use format and render the data into a bytestring
renderer = self.renderer_classes[format]()
ret = renderer.render(data)

# Determine the content-type header from the renderer
content_type = renderer.media_type
if renderer.charset:
content_type = "{}; charset={}".format(
content_type, renderer.charset
)

# Coerce text to bytes if required.
if isinstance(ret, str):
ret = ret.encode(renderer.charset)

return ret, content_type

def get(self, path, data=None, **extra):
r = {
'QUERY_STRING': urlencode(data or {}, doseq=True),
}
if not data and '?' in path:
# Fix to support old behavior where you have the arguments in the
# url. See #1461.
query_string = force_bytes(path.split('?')[1])
query_string = query_string.decode('iso-8859-1')
r['QUERY_STRING'] = query_string
r.update(extra)
return self.generic('GET', path, **r)

def post(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('POST', path, data, content_type, **extra)

def put(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('PUT', path, data, content_type, **extra)

def patch(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('PATCH', path, data, content_type, **extra)

def delete(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('DELETE', path, data, content_type, **extra)

def options(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('OPTIONS', path, data, content_type, **extra)

def generic(self, method, path, data='',
content_type='application/octet-stream', secure=False, **extra):
# Include the CONTENT_TYPE, regardless of whether or not data is empty.
if content_type is not None:
extra['CONTENT_TYPE'] = str(content_type)

return super().generic(
method, path, data, content_type, secure, **extra)

def request(self, **kwargs):
request = super().request(**kwargs)
request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
return request


class ForceAuthClientHandler(ClientHandler):
"""
A patched version of ClientHandler that can enforce authentication
Expand Down
105 changes: 79 additions & 26 deletions rest_framework/throttling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from django.core.cache import cache as default_cache
from django.core.exceptions import ImproperlyConfigured

from rest_framework.compat import (
async_to_sync, iscoroutinefunction, sync_to_async
)
from rest_framework.settings import api_settings


Expand Down Expand Up @@ -64,6 +67,8 @@ class SimpleRateThrottle(BaseThrottle):
cache_format = 'throttle_%(scope)s_%(ident)s'
scope = None
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
sync_capable = True
async_capable = True

def __init__(self):
if not getattr(self, 'rate', None):
Expand Down Expand Up @@ -113,23 +118,52 @@ def allow_request(self, request, view):
On success calls `throttle_success`.
On failure calls `throttle_failure`.
"""
if self.rate is None:
return True

self.key = self.get_cache_key(request, view)
if self.key is None:
return True

self.history = self.cache.get(self.key, [])
self.now = self.timer()

# Drop any requests from the history which have now passed the
# throttle duration
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
if getattr(view, 'view_is_async', False):

async def func():
if self.rate is None:
return True

self.key = self.get_cache_key(request, view)
if self.key is None:
return True

self.history = self.cache.get(self.key, [])
if iscoroutinefunction(self.timer):
self.now = await self.timer()
else:
self.now = await sync_to_async(self.timer)()

# Drop any requests from the history which have now passed the
# throttle duration
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()

return func()
else:
if self.rate is None:
return True

self.key = self.get_cache_key(request, view)
if self.key is None:
return True

self.history = self.cache.get(self.key, [])
if iscoroutinefunction(self.timer):
self.now = async_to_sync(self.timer)()
else:
self.now = self.timer()

# Drop any requests from the history which have now passed the
# throttle duration
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()

def throttle_success(self):
"""
Expand Down Expand Up @@ -210,6 +244,8 @@ class ScopedRateThrottle(SimpleRateThrottle):
user id of the request, and the scope of the view being accessed.
"""
scope_attr = 'throttle_scope'
sync_capable = True
async_capable = True

def __init__(self):
# Override the usual SimpleRateThrottle, because we can't determine
Expand All @@ -220,17 +256,34 @@ def allow_request(self, request, view):
# We can only determine the scope once we're called by the view.
self.scope = getattr(view, self.scope_attr, None)

# If a view does not have a `throttle_scope` always allow the request
if not self.scope:
return True
if getattr(view, 'view_is_async', False):

# Determine the allowed request rate as we normally would during
# the `__init__` call.
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
async def func(allow_request):
# If a view does not have a `throttle_scope` always allow the request
if not self.scope:
return True

# Determine the allowed request rate as we normally would during
# the `__init__` call.
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)

# We can now proceed as normal.
return await allow_request(request, view)

return func(super().allow_request)
else:
# If a view does not have a `throttle_scope` always allow the request
if not self.scope:
return True

# Determine the allowed request rate as we normally would during
# the `__init__` call.
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)

# We can now proceed as normal.
return super().allow_request(request, view)
# We can now proceed as normal.
return super().allow_request(request, view)

def get_cache_key(self, request, view):
"""
Expand Down
Loading

0 comments on commit d269dbb

Please sign in to comment.