Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async implementation #8617

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion rest_framework/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@ def distinct(queryset, base):

# async_to_sync is required for async view support
if django.VERSION >= (4, 1):
from asgiref.sync import async_to_sync
from asgiref.sync import async_to_sync, sync_to_async
else:
async_to_sync = None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about

def async_to_sync(*args, **kwargs):
    raise NotImplementedError("DRF async only supports Django >= 4.1")

Copy link

@dongyuzheng dongyuzheng Sep 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A type checker would not be happy about async_to_sync = None.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

async_to_sync is available in Django since version 3.1. In this case I'm using it just to run the test. If pytest-asyncio is added then it can be removed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would personally rather we take a testing dependency on pytest-asyncio rather than implement a test-only compat helper.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it's best to defer the decision to add an additional dependency to project leads. Especially if the new dependency can be easily circumvented, such as in this case.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an a user of drf I would not happy to have any async tests related dependency in my project

sync_to_async = None


# coreschema is optional
Expand Down
20 changes: 15 additions & 5 deletions rest_framework/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from django.views.decorators.csrf import csrf_exempt
from django.views.generic import View

from rest_framework.compat import sync_to_async
from rest_framework import exceptions, status
from rest_framework.request import Request
from rest_framework.response import Response
Expand Down Expand Up @@ -524,7 +525,7 @@ async def async_dispatch(self, request, *args, **kwargs):
self.headers = self.default_response_headers # deprecate?

try:
self.initial(request, *args, **kwargs)
sync_to_async(self.initial)(request, *args, **kwargs)
em1208 marked this conversation as resolved.
Show resolved Hide resolved

# Get the appropriate handler method
if request.method.lower() in self.http_method_names:
Expand Down Expand Up @@ -555,7 +556,16 @@ def options(self, request, *args, **kwargs):
"""
Handler method for HTTP 'OPTIONS' request.
"""
if self.metadata_class is None:
return self.http_method_not_allowed(request, *args, **kwargs)
data = self.metadata_class().determine_metadata(request, self)
return Response(data, status=status.HTTP_200_OK)
def func():
if self.metadata_class is None:
return self.http_method_not_allowed(request, *args, **kwargs)
data = self.metadata_class().determine_metadata(request, self)
return Response(data, status=status.HTTP_200_OK)

if hasattr(self, 'view_is_async') and self.view_is_async:
em1208 marked this conversation as resolved.
Show resolved Hide resolved
async def handler():
return func()
Copy link

@Archmonger Archmonger Sep 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should use sync_to_async on func as an optimization. Unless we believe self.metadata_class().determine_metadata(request, self) does nothing beyond reading in-memory variables (haven't checked).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the code and under some conditions self.metadata_class().determine_metadata(request, self) can make a database query, so I wrapped it with sync_to_async.

else:
def handler():
return func()
return handler()
57 changes: 57 additions & 0 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import django
import pytest
from django.test import TestCase
from django.contrib.auth.models import User

from rest_framework import status
from rest_framework.compat import async_to_sync
Expand Down Expand Up @@ -101,6 +102,15 @@ def test_get_succeeds(self):
assert response.status_code == status.HTTP_200_OK
assert response.data == {'method': 'GET'}

def test_logged_in_get_succeeds(self):
user = User.objects.create_user('user', '[email protected]', 'password')
request = factory.get('/')
del user.is_active
em1208 marked this conversation as resolved.
Show resolved Hide resolved
request.user = user
em1208 marked this conversation as resolved.
Show resolved Hide resolved
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'method': 'GET'}

def test_post_succeeds(self):
request = factory.post('/', {'test': 'foo'})
response = self.view(request)
Expand All @@ -111,6 +121,11 @@ def test_post_succeeds(self):
assert response.status_code == status.HTTP_200_OK
assert response.data == expected

def test_options_succeeds(self):
request = factory.options('/')
response = self.view(request)
assert response.status_code == status.HTTP_200_OK

def test_400_parse_error(self):
request = factory.post('/', 'f00bar', content_type='application/json')
response = self.view(request)
Expand All @@ -131,6 +146,15 @@ def test_get_succeeds(self):
assert response.status_code == status.HTTP_200_OK
assert response.data == {'method': 'GET'}

def test_logged_in_get_succeeds(self):
user = User.objects.create_user('user', '[email protected]', 'password')
request = factory.get('/')
del user.is_active
request.user = user
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'method': 'GET'}

def test_post_succeeds(self):
request = factory.post('/', {'test': 'foo'})
response = self.view(request)
Expand All @@ -141,6 +165,11 @@ def test_post_succeeds(self):
assert response.status_code == status.HTTP_200_OK
assert response.data == expected

def test_options_succeeds(self):
request = factory.options('/')
response = self.view(request)
assert response.status_code == status.HTTP_200_OK

def test_400_parse_error(self):
request = factory.post('/', 'f00bar', content_type='application/json')
response = self.view(request)
Expand All @@ -165,6 +194,15 @@ def test_get_succeeds(self):
assert response.status_code == status.HTTP_200_OK
assert response.data == {'method': 'GET'}

def test_logged_in_get_succeeds(self):
user = User.objects.create_user('user', '[email protected]', 'password')
request = factory.get('/')
del user.is_active
request.user = user
response = async_to_sync(self.view)(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'method': 'GET'}

def test_post_succeeds(self):
request = factory.post('/', {'test': 'foo'})
response = async_to_sync(self.view)(request)
Expand All @@ -175,6 +213,11 @@ def test_post_succeeds(self):
assert response.status_code == status.HTTP_200_OK
assert response.data == expected

def test_options_succeeds(self):
request = factory.options('/')
response = async_to_sync(self.view)(request)
assert response.status_code == status.HTTP_200_OK

def test_400_parse_error(self):
request = factory.post('/', 'f00bar', content_type='application/json')
response = async_to_sync(self.view)(request)
Expand All @@ -199,6 +242,15 @@ def test_get_succeeds(self):
assert response.status_code == status.HTTP_200_OK
assert response.data == {'method': 'GET'}

def test_logged_in_get_succeeds(self):
user = User.objects.create_user('user', '[email protected]', 'password')
request = factory.get('/')
del user.is_active
request.user = user
response = async_to_sync(self.view)(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'method': 'GET'}

def test_post_succeeds(self):
request = factory.post('/', {'test': 'foo'})
response = async_to_sync(self.view)(request)
Expand All @@ -209,6 +261,11 @@ def test_post_succeeds(self):
assert response.status_code == status.HTTP_200_OK
assert response.data == expected

def test_options_succeeds(self):
request = factory.options('/')
response = async_to_sync(self.view)(request)
assert response.status_code == status.HTTP_200_OK

def test_400_parse_error(self):
request = factory.post('/', 'f00bar', content_type='application/json')
response = async_to_sync(self.view)(request)
Expand Down