From 830b266ebfeea3d2f62adddfd52f05f3c4fee81c Mon Sep 17 00:00:00 2001 From: James Hilliard Date: Wed, 10 May 2023 08:51:42 -0600 Subject: [PATCH] Async view implementation --- docs/api-guide/views.md | 16 + requirements/requirements-documentation.txt | 2 +- rest_framework/compat.py | 11 + rest_framework/decorators.py | 9 +- rest_framework/pagination.py | 1 - rest_framework/test.py | 109 +++++ rest_framework/throttling.py | 105 +++-- rest_framework/views.py | 228 +++++++++-- tests/test_throttling.py | 433 +++++++++++++++++++- tests/test_views.py | 188 ++++++++- 10 files changed, 1030 insertions(+), 72 deletions(-) diff --git a/docs/api-guide/views.md b/docs/api-guide/views.md index b293de75ab2..ccf87dc64df 100644 --- a/docs/api-guide/views.md +++ b/docs/api-guide/views.md @@ -217,6 +217,22 @@ You may pass `None` in order to exclude the view from schema generation. def view(request): return Response({"message": "Will not appear in schema!"}) +# Async Views + +When using Django 4.1 and above, REST framework allows you to work with async class and function based views. + +For class based views, all handler methods must be async, otherwise Django will raise an exception. For function based views, the function itself must be async. + +For example: + + class AsyncView(APIView): + async def get(self, request): + return Response({"message": "This is an async class based view."}) + + + @api_view(['GET']) + async def async_view(request): + return Response({"message": "This is an async function based view."}) [cite]: https://reinout.vanrees.org/weblog/2011/08/24/class-based-views-usage.html [cite2]: http://www.boredomandlaziness.org/2012/05/djangos-cbvs-are-not-mistake-but.html diff --git a/requirements/requirements-documentation.txt b/requirements/requirements-documentation.txt index cf2dc26e884..962c3ac3df1 100644 --- a/requirements/requirements-documentation.txt +++ b/requirements/requirements-documentation.txt @@ -1,3 +1,3 @@ # MkDocs to build our documentation. -mkdocs>=1.1.2,<1.2 +mkdocs>=1.4.3,<1.5 jinja2>=2.10,<3.1.0 # contextfilter has been renamed diff --git a/rest_framework/compat.py b/rest_framework/compat.py index ac5cbc572a8..87cd0b8f041 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -41,6 +41,17 @@ def distinct(queryset, base): uritemplate = None +# async_to_sync is required for async view support +if django.VERSION >= (4, 1): + from asgiref.sync import async_to_sync, iscoroutinefunction, sync_to_async +else: + async_to_sync = None + sync_to_async = None + + def iscoroutinefunction(func): + return False + + # coreschema is optional try: import coreschema diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 3b572c09ef8..5a378833b1b 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -10,6 +10,7 @@ from django.forms.utils import pretty_name +from rest_framework.compat import iscoroutinefunction from rest_framework.views import APIView @@ -46,8 +47,12 @@ def decorator(func): allowed_methods = set(http_method_names) | {'options'} WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods] - def handler(self, *args, **kwargs): - return func(*args, **kwargs) + if iscoroutinefunction(func): + async def handler(self, *args, **kwargs): + return await func(*args, **kwargs) + else: + def handler(self, *args, **kwargs): + return func(*args, **kwargs) for method in http_method_names: setattr(WrappedAPIView, method.lower(), handler) diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index ce87785472a..7303890b03f 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -5,7 +5,6 @@ import contextlib import warnings - from base64 import b64decode, b64encode from collections import namedtuple from urllib import parse diff --git a/rest_framework/test.py b/rest_framework/test.py index 04409f9621d..3b350b43967 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -11,6 +11,10 @@ from django.test.client import Client as DjangoClient from django.test.client import ClientHandler from django.test.client import RequestFactory as DjangoRequestFactory + +if django.VERSION >= (4, 1): + from django.test.client import AsyncRequestFactory as DjangoAsyncRequestFactory + from django.utils.encoding import force_bytes from django.utils.http import urlencode @@ -240,6 +244,111 @@ def request(self, **kwargs): return request +if django.VERSION >= (4, 1): + class APIAsyncRequestFactory(DjangoAsyncRequestFactory): + renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES + default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT + + def __init__(self, enforce_csrf_checks=False, **defaults): + self.enforce_csrf_checks = enforce_csrf_checks + self.renderer_classes = {} + for cls in self.renderer_classes_list: + self.renderer_classes[cls.format] = cls + super().__init__(**defaults) + + def _encode_data(self, data, format=None, content_type=None): + """ + Encode the data returning a two tuple of (bytes, content_type) + """ + + if data is None: + return ('', content_type) + + assert format is None or content_type is None, ( + 'You may not set both `format` and `content_type`.' + ) + + if content_type: + # Content type specified explicitly, treat data as a raw bytestring + ret = force_bytes(data, settings.DEFAULT_CHARSET) + + else: + format = format or self.default_format + + assert format in self.renderer_classes, ( + "Invalid format '{}'. Available formats are {}. " + "Set TEST_REQUEST_RENDERER_CLASSES to enable " + "extra request formats.".format( + format, + ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes]) + ) + ) + + # Use format and render the data into a bytestring + renderer = self.renderer_classes[format]() + ret = renderer.render(data) + + # Determine the content-type header from the renderer + content_type = renderer.media_type + if renderer.charset: + content_type = "{}; charset={}".format( + content_type, renderer.charset + ) + + # Coerce text to bytes if required. + if isinstance(ret, str): + ret = ret.encode(renderer.charset) + + return ret, content_type + + def get(self, path, data=None, **extra): + r = { + 'QUERY_STRING': urlencode(data or {}, doseq=True), + } + if not data and '?' in path: + # Fix to support old behavior where you have the arguments in the + # url. See #1461. + query_string = force_bytes(path.split('?')[1]) + query_string = query_string.decode('iso-8859-1') + r['QUERY_STRING'] = query_string + r.update(extra) + return self.generic('GET', path, **r) + + def post(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('POST', path, data, content_type, **extra) + + def put(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('PUT', path, data, content_type, **extra) + + def patch(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('PATCH', path, data, content_type, **extra) + + def delete(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('DELETE', path, data, content_type, **extra) + + def options(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('OPTIONS', path, data, content_type, **extra) + + def generic(self, method, path, data='', + content_type='application/octet-stream', secure=False, **extra): + # Include the CONTENT_TYPE, regardless of whether or not data is empty. + if content_type is not None: + extra['CONTENT_TYPE'] = str(content_type) + + return super().generic( + method, path, data, content_type, secure, **extra) + + def request(self, **kwargs): + request = super().request(**kwargs) + request._dont_enforce_csrf_checks = not self.enforce_csrf_checks + return request + + class ForceAuthClientHandler(ClientHandler): """ A patched version of ClientHandler that can enforce authentication diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index c0d6cf42fe0..c7990221b36 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -6,6 +6,9 @@ from django.core.cache import cache as default_cache from django.core.exceptions import ImproperlyConfigured +from rest_framework.compat import ( + async_to_sync, iscoroutinefunction, sync_to_async +) from rest_framework.settings import api_settings @@ -64,6 +67,8 @@ class SimpleRateThrottle(BaseThrottle): cache_format = 'throttle_%(scope)s_%(ident)s' scope = None THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES + sync_capable = True + async_capable = True def __init__(self): if not getattr(self, 'rate', None): @@ -113,23 +118,52 @@ def allow_request(self, request, view): On success calls `throttle_success`. On failure calls `throttle_failure`. """ - if self.rate is None: - return True - - self.key = self.get_cache_key(request, view) - if self.key is None: - return True - - self.history = self.cache.get(self.key, []) - self.now = self.timer() - - # Drop any requests from the history which have now passed the - # throttle duration - while self.history and self.history[-1] <= self.now - self.duration: - self.history.pop() - if len(self.history) >= self.num_requests: - return self.throttle_failure() - return self.throttle_success() + if getattr(view, 'view_is_async', False): + + async def func(): + if self.rate is None: + return True + + self.key = self.get_cache_key(request, view) + if self.key is None: + return True + + self.history = self.cache.get(self.key, []) + if iscoroutinefunction(self.timer): + self.now = await self.timer() + else: + self.now = await sync_to_async(self.timer)() + + # Drop any requests from the history which have now passed the + # throttle duration + while self.history and self.history[-1] <= self.now - self.duration: + self.history.pop() + if len(self.history) >= self.num_requests: + return self.throttle_failure() + return self.throttle_success() + + return func() + else: + if self.rate is None: + return True + + self.key = self.get_cache_key(request, view) + if self.key is None: + return True + + self.history = self.cache.get(self.key, []) + if iscoroutinefunction(self.timer): + self.now = async_to_sync(self.timer)() + else: + self.now = self.timer() + + # Drop any requests from the history which have now passed the + # throttle duration + while self.history and self.history[-1] <= self.now - self.duration: + self.history.pop() + if len(self.history) >= self.num_requests: + return self.throttle_failure() + return self.throttle_success() def throttle_success(self): """ @@ -210,6 +244,8 @@ class ScopedRateThrottle(SimpleRateThrottle): user id of the request, and the scope of the view being accessed. """ scope_attr = 'throttle_scope' + sync_capable = True + async_capable = True def __init__(self): # Override the usual SimpleRateThrottle, because we can't determine @@ -220,17 +256,34 @@ def allow_request(self, request, view): # We can only determine the scope once we're called by the view. self.scope = getattr(view, self.scope_attr, None) - # If a view does not have a `throttle_scope` always allow the request - if not self.scope: - return True + if getattr(view, 'view_is_async', False): - # Determine the allowed request rate as we normally would during - # the `__init__` call. - self.rate = self.get_rate() - self.num_requests, self.duration = self.parse_rate(self.rate) + async def func(allow_request): + # If a view does not have a `throttle_scope` always allow the request + if not self.scope: + return True + + # Determine the allowed request rate as we normally would during + # the `__init__` call. + self.rate = self.get_rate() + self.num_requests, self.duration = self.parse_rate(self.rate) + + # We can now proceed as normal. + return await allow_request(request, view) + + return func(super().allow_request) + else: + # If a view does not have a `throttle_scope` always allow the request + if not self.scope: + return True + + # Determine the allowed request rate as we normally would during + # the `__init__` call. + self.rate = self.get_rate() + self.num_requests, self.duration = self.parse_rate(self.rate) - # We can now proceed as normal. - return super().allow_request(request, view) + # We can now proceed as normal. + return super().allow_request(request, view) def get_cache_key(self, request, view): """ diff --git a/rest_framework/views.py b/rest_framework/views.py index 4c30029fdc5..a3a105c267d 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -1,6 +1,8 @@ """ Provides an APIView class that is the base of all views in REST framework. """ +import asyncio + from django.conf import settings from django.core.exceptions import PermissionDenied from django.db import connections, models @@ -12,6 +14,9 @@ from django.views.generic import View from rest_framework import exceptions, status +from rest_framework.compat import ( + async_to_sync, iscoroutinefunction, sync_to_async +) from rest_framework.request import Request from rest_framework.response import Response from rest_framework.schemas import DefaultSchema @@ -328,13 +333,52 @@ def check_permissions(self, request): Check if the request should be permitted. Raises an appropriate exception if the request is not permitted. """ + async_permissions, sync_permissions = [], [] for permission in self.get_permissions(): - if not permission.has_permission(request, self): - self.permission_denied( - request, - message=getattr(permission, 'message', None), - code=getattr(permission, 'code', None) - ) + if iscoroutinefunction(permission.has_permission): + async_permissions.append(permission) + else: + sync_permissions.append(permission) + + async def check_async(): + results = await asyncio.gather( + *(permission.has_permission(request, self) for permission in + async_permissions), return_exceptions=True + ) + + for idx in range(len(async_permissions)): + if isinstance(results[idx], Exception): + raise results[idx] + elif not results[idx]: + self.permission_denied( + request, + message=getattr(async_permissions[idx], "message", None), + code=getattr(async_permissions[idx], "code", None), + ) + + def check_sync(): + for permission in sync_permissions: + if not permission.has_permission(request, self): + self.permission_denied( + request, + message=getattr(permission, 'message', None), + code=getattr(permission, 'code', None) + ) + + if getattr(self, 'view_is_async', False): + + async def func(): + if async_permissions: + await check_async() + if sync_permissions: + await sync_to_async(check_sync)() + + return func() + else: + if sync_permissions: + check_sync() + if async_permissions: + async_to_sync(check_async) def check_object_permissions(self, request, obj): """ @@ -354,21 +398,79 @@ def check_throttles(self, request): Check if request should be throttled. Raises an appropriate exception if the request is throttled. """ - throttle_durations = [] + async_throttle_durations, sync_throttle_durations = [], [] + view_is_async = getattr(self, 'view_is_async', False) for throttle in self.get_throttles(): - if not throttle.allow_request(request, self): - throttle_durations.append(throttle.wait()) + throttle_can_sync = getattr(throttle, "sync_capable", True) + throttle_can_async = getattr(throttle, "async_capable", False) + if not throttle_can_sync and not throttle_can_async: + raise RuntimeError( + "Throttle %s must have at least one of " + "sync_capable/async_capable set to True." % throttle.__class__.__name__ + ) + elif not view_is_async and throttle_can_sync: + throttle_is_async = False + elif iscoroutinefunction(throttle.allow_request): + throttle_is_async = True + else: + throttle_is_async = throttle_can_async + if throttle_is_async: + async_throttle_durations.append(throttle) + else: + sync_throttle_durations.append(throttle) + + async def async_throttles(): + for throttle in async_throttle_durations: + if not await throttle.allow_request(request, self): + yield throttle.wait() + + def sync_throttles(): + for throttle in sync_throttle_durations: + if not throttle.allow_request(request, self): + yield throttle.wait() + + if view_is_async: + + async def func(): + throttle_durations = [] + + if async_throttle_durations: + throttle_durations.extend([duration async for duration in async_throttles()]) + + if sync_throttle_durations: + throttle_durations.extend(duration for duration in sync_throttles()) + + if throttle_durations: + # Filter out `None` values which may happen in case of config / rate + # changes, see #1438 + durations = [ + duration for duration in throttle_durations + if duration is not None + ] + + duration = max(durations, default=None) + self.throttled(request, duration) + + return func() + else: + throttle_durations = [] + + if sync_throttle_durations: + throttle_durations.extend(sync_throttles()) + + if async_throttle_durations: + throttle_durations.extend(async_to_sync(async_throttles)()) - if throttle_durations: - # Filter out `None` values which may happen in case of config / rate - # changes, see #1438 - durations = [ - duration for duration in throttle_durations - if duration is not None - ] + if throttle_durations: + # Filter out `None` values which may happen in case of config / rate + # changes, see #1438 + durations = [ + duration for duration in throttle_durations + if duration is not None + ] - duration = max(durations, default=None) - self.throttled(request, duration) + duration = max(durations, default=None) + self.throttled(request, duration) def determine_version(self, request, *args, **kwargs): """ @@ -410,10 +512,20 @@ def initial(self, request, *args, **kwargs): version, scheme = self.determine_version(request, *args, **kwargs) request.version, request.versioning_scheme = version, scheme - # Ensure that the incoming request is permitted - self.perform_authentication(request) - self.check_permissions(request) - self.check_throttles(request) + if getattr(self, 'view_is_async', False): + + async def func(): + # Ensure that the incoming request is permitted + await sync_to_async(self.perform_authentication)(request) + await self.check_permissions(request) + await self.check_throttles(request) + + return func() + else: + # Ensure that the incoming request is permitted + self.perform_authentication(request) + self.check_permissions(request) + self.check_throttles(request) def finalize_response(self, request, response, *args, **kwargs): """ @@ -469,7 +581,15 @@ def handle_exception(self, exc): self.raise_uncaught_exception(exc) response.exception = True - return response + + if getattr(self, 'view_is_async', False): + + async def func(): + return response + + return func() + else: + return response def raise_uncaught_exception(self, exc): if settings.DEBUG: @@ -493,23 +613,49 @@ def dispatch(self, request, *args, **kwargs): self.request = request self.headers = self.default_response_headers # deprecate? - try: - self.initial(request, *args, **kwargs) + if getattr(self, 'view_is_async', False): - # Get the appropriate handler method - if request.method.lower() in self.http_method_names: - handler = getattr(self, request.method.lower(), - self.http_method_not_allowed) - else: - handler = self.http_method_not_allowed + async def func(): + + try: + await self.initial(request, *args, **kwargs) + + # Get the appropriate handler method + if request.method.lower() in self.http_method_names: + handler = getattr(self, request.method.lower(), + self.http_method_not_allowed) + else: + handler = self.http_method_not_allowed - response = handler(request, *args, **kwargs) + response = await handler(request, *args, **kwargs) - except Exception as exc: - response = self.handle_exception(exc) + except Exception as exc: + response = await self.handle_exception(exc) - self.response = self.finalize_response(request, response, *args, **kwargs) - return self.response + return self.finalize_response(request, response, *args, **kwargs) + + self.response = func() + + return self.response + else: + try: + self.initial(request, *args, **kwargs) + + # Get the appropriate handler method + if request.method.lower() in self.http_method_names: + handler = getattr(self, request.method.lower(), + self.http_method_not_allowed) + else: + handler = self.http_method_not_allowed + + response = handler(request, *args, **kwargs) + + except Exception as exc: + response = self.handle_exception(exc) + + self.response = self.finalize_response(request, response, *args, **kwargs) + + return self.response def options(self, request, *args, **kwargs): """ @@ -518,4 +664,12 @@ def options(self, request, *args, **kwargs): if self.metadata_class is None: return self.http_method_not_allowed(request, *args, **kwargs) data = self.metadata_class().determine_metadata(request, self) - return Response(data, status=status.HTTP_200_OK) + + if getattr(self, 'view_is_async', False): + + async def func(): + return Response(data, status=status.HTTP_200_OK) + + return func() + else: + return Response(data, status=status.HTTP_200_OK) diff --git a/tests/test_throttling.py b/tests/test_throttling.py index d5a61232d92..07cccc77a56 100644 --- a/tests/test_throttling.py +++ b/tests/test_throttling.py @@ -1,7 +1,7 @@ """ Tests for the throttling implementations in the permissions module. """ - +import django import pytest from django.contrib.auth.models import User from django.core.cache import cache @@ -9,10 +9,15 @@ from django.http import HttpRequest from django.test import TestCase +from rest_framework.compat import async_to_sync from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import api_settings from rest_framework.test import APIRequestFactory, force_authenticate + +if django.VERSION >= (4, 1): + from rest_framework.test import APIAsyncRequestFactory + from rest_framework.throttling import ( AnonRateThrottle, BaseThrottle, ScopedRateThrottle, SimpleRateThrottle, UserRateThrottle @@ -43,6 +48,14 @@ def allow_request(self, request, view): return False +class NonTimeAsyncThrottle(BaseThrottle): + def allow_request(self, request, view): + if not hasattr(self.__class__, 'called'): + self.__class__.called = True + return True + return False + + class MockView_DoubleThrottling(APIView): throttle_classes = (User3SecRateThrottle, User6MinRateThrottle,) @@ -50,6 +63,13 @@ def get(self, request): return Response('foo') +class MockAsyncView_DoubleThrottling(APIView): + throttle_classes = (User3SecRateThrottle, User6MinRateThrottle,) + + async def get(self, request): + return Response('foo') + + class MockView(APIView): throttle_classes = (User3SecRateThrottle,) @@ -57,6 +77,13 @@ def get(self, request): return Response('foo') +class MockAsyncView(APIView): + throttle_classes = (User3SecRateThrottle,) + + async def get(self, request): + return Response('foo') + + class MockView_MinuteThrottling(APIView): throttle_classes = (User3MinRateThrottle,) @@ -64,6 +91,13 @@ def get(self, request): return Response('foo') +class MockAsyncView_MinuteThrottling(APIView): + throttle_classes = (User3MinRateThrottle,) + + async def get(self, request): + return Response('foo') + + class MockView_NonTimeThrottling(APIView): throttle_classes = (NonTimeThrottle,) @@ -71,6 +105,13 @@ def get(self, request): return Response('foo') +class MockAsyncView_NonTimeThrottling(APIView): + throttle_classes = (NonTimeAsyncThrottle,) + + async def get(self, request): + return Response('foo') + + class ThrottlingTests(TestCase): def setUp(self): """ @@ -252,12 +293,198 @@ def test_non_time_throttle(self): self.assertFalse('Retry-After' in response) +@pytest.mark.skipif( + django.VERSION < (4, 1), + reason="Async view support requires Django 4.1 or higher", +) +class AsyncThrottlingTests(TestCase): + def setUp(self): + """ + Reset the cache so that no throttles will be active + """ + cache.clear() + self.factory = APIAsyncRequestFactory() + + def test_requests_are_throttled(self): + """ + Ensure request rate is limited + """ + request = self.factory.get('/') + for dummy in range(4): + response = async_to_sync(MockAsyncView.as_view())(request) + assert response.status_code == 429 + + def set_throttle_timer(self, view, value): + """ + Explicitly set the timer, overriding time.time() + """ + for cls in view.throttle_classes: + cls.timer = lambda self: value + + def test_request_throttling_expires(self): + """ + Ensure request rate is limited for a limited duration only + """ + self.set_throttle_timer(MockAsyncView, 0) + + request = self.factory.get('/') + for dummy in range(4): + response = async_to_sync(MockAsyncView.as_view())(request) + assert response.status_code == 429 + + # Advance the timer by one second + self.set_throttle_timer(MockAsyncView, 1) + + response = async_to_sync(MockAsyncView.as_view())(request) + assert response.status_code == 200 + + async def ensure_is_throttled(self, view, expect): + request = self.factory.get('/') + request.user = await User.objects.acreate(username='a') + for dummy in range(3): + await view.as_view()(request) + request.user = await User.objects.acreate(username='b') + response = await view.as_view()(request) + assert response.status_code == expect + + def test_request_throttling_is_per_user(self): + """ + Ensure request rate is only limited per user, not globally for + PerUserThrottles + """ + async_to_sync(self.ensure_is_throttled)(MockAsyncView, 200) + + def test_request_throttling_multiple_throttles(self): + """ + Ensure all throttle classes see each request even when the request is + already being throttled + """ + self.set_throttle_timer(MockAsyncView_DoubleThrottling, 0) + request = self.factory.get('/') + for dummy in range(4): + response = async_to_sync(MockAsyncView_DoubleThrottling.as_view())(request) + assert response.status_code == 429 + assert int(response['retry-after']) == 1 + + # At this point our client made 4 requests (one was throttled) in a + # second. If we advance the timer by one additional second, the client + # should be allowed to make 2 more before being throttled by the 2nd + # throttle class, which has a limit of 6 per minute. + self.set_throttle_timer(MockAsyncView_DoubleThrottling, 1) + for dummy in range(2): + response = async_to_sync(MockAsyncView_DoubleThrottling.as_view())(request) + assert response.status_code == 200 + + response = async_to_sync(MockAsyncView_DoubleThrottling.as_view())(request) + assert response.status_code == 429 + assert int(response['retry-after']) == 59 + + # Just to make sure check again after two more seconds. + self.set_throttle_timer(MockAsyncView_DoubleThrottling, 2) + response = async_to_sync(MockAsyncView_DoubleThrottling.as_view())(request) + assert response.status_code == 429 + assert int(response['retry-after']) == 58 + + def test_throttle_rate_change_negative(self): + self.set_throttle_timer(MockAsyncView_DoubleThrottling, 0) + request = self.factory.get('/') + for dummy in range(24): + response = async_to_sync(MockAsyncView_DoubleThrottling.as_view())(request) + assert response.status_code == 429 + assert int(response['retry-after']) == 60 + + previous_rate = User3SecRateThrottle.rate + try: + User3SecRateThrottle.rate = '1/sec' + + for dummy in range(24): + response = async_to_sync(MockAsyncView_DoubleThrottling.as_view())(request) + + assert response.status_code == 429 + assert int(response['retry-after']) == 60 + finally: + # reset + User3SecRateThrottle.rate = previous_rate + + async def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers): + """ + Ensure the response returns an Retry-After field with status and next attributes + set properly. + """ + request = self.factory.get('/') + for timer, expect in expected_headers: + self.set_throttle_timer(view, timer) + response = await view.as_view()(request) + if expect is not None: + assert response['Retry-After'] == expect + else: + assert not'Retry-After' in response + + def test_seconds_fields(self): + """ + Ensure for second based throttles. + """ + async_to_sync(self.ensure_response_header_contains_proper_throttle_field)( + MockAsyncView, ( + (0, None), + (0, None), + (0, None), + (0, '1') + ) + ) + + def test_minutes_fields(self): + """ + Ensure for minute based throttles. + """ + async_to_sync(self.ensure_response_header_contains_proper_throttle_field)( + MockAsyncView_MinuteThrottling, ( + (0, None), + (0, None), + (0, None), + (0, '60') + ) + ) + + def test_next_rate_remains_constant_if_followed(self): + """ + If a client follows the recommended next request rate, + the throttling rate should stay constant. + """ + async_to_sync(self.ensure_response_header_contains_proper_throttle_field)( + MockAsyncView_MinuteThrottling, ( + (0, None), + (20, None), + (40, None), + (60, None), + (80, None) + ) + ) + + def test_non_time_throttle(self): + """ + Ensure for second based throttles. + """ + request = self.factory.get('/') + + self.assertFalse(hasattr(MockAsyncView_NonTimeThrottling.throttle_classes[0], 'called')) + + response = async_to_sync(MockAsyncView_NonTimeThrottling.as_view())(request) + self.assertFalse('Retry-After' in response) + + self.assertTrue(MockAsyncView_NonTimeThrottling.throttle_classes[0].called) + + response = async_to_sync(MockAsyncView_NonTimeThrottling.as_view())(request) + self.assertFalse('Retry-After' in response) + + class ScopedRateThrottleTests(TestCase): """ Tests for ScopedRateThrottle. """ def setUp(self): + cache.clear() self.throttle = ScopedRateThrottle() class XYScopedRateThrottle(ScopedRateThrottle): @@ -372,6 +599,131 @@ class DummyView: assert cache_key == 'throttle_user_%s' % user.pk +@pytest.mark.skipif( + django.VERSION < (4, 1), + reason="Async view support requires Django 4.1 or higher", +) +class AsyncScopedRateThrottleTests(TestCase): + """ + Tests for ScopedRateThrottle. + """ + + def setUp(self): + cache.clear() + self.throttle = ScopedRateThrottle() + + class XYScopedRateThrottle(ScopedRateThrottle): + TIMER_SECONDS = 0 + THROTTLE_RATES = {'x': '3/min', 'y': '1/min'} + + async def timer(self): + return self.TIMER_SECONDS + + class XView(APIView): + throttle_classes = (XYScopedRateThrottle,) + throttle_scope = 'x' + + async def get(self, request): + return Response('x') + + class YView(APIView): + throttle_classes = (XYScopedRateThrottle,) + throttle_scope = 'y' + + async def get(self, request): + return Response('y') + + class UnscopedView(APIView): + throttle_classes = (XYScopedRateThrottle,) + + async def get(self, request): + return Response('y') + + self.throttle_class = XYScopedRateThrottle + self.factory = APIAsyncRequestFactory() + self.x_view = XView.as_view() + self.y_view = YView.as_view() + self.unscoped_view = UnscopedView.as_view() + + def increment_timer(self, seconds=1): + self.throttle_class.TIMER_SECONDS += seconds + + def test_scoped_rate_throttle(self): + request = self.factory.get('/') + + # Should be able to hit x view 3 times per minute. + response = async_to_sync(self.x_view)(request) + assert response.status_code == 200 + + self.increment_timer() + response = async_to_sync(self.x_view)(request) + assert response.status_code == 200 + + self.increment_timer() + response = async_to_sync(self.x_view)(request) + assert response.status_code == 200 + self.increment_timer() + response = async_to_sync(self.x_view)(request) + assert response.status_code == 429 + + # Should be able to hit y view 1 time per minute. + self.increment_timer() + response = async_to_sync(self.y_view)(request) + assert response.status_code == 200 + + self.increment_timer() + response = async_to_sync(self.y_view)(request) + assert response.status_code == 429 + + # Ensure throttles properly reset by advancing the rest of the minute + self.increment_timer(55) + + # Should still be able to hit x view 3 times per minute. + response = async_to_sync(self.x_view)(request) + assert response.status_code == 200 + + self.increment_timer() + response = async_to_sync(self.x_view)(request) + assert response.status_code == 200 + + self.increment_timer() + response = async_to_sync(self.x_view)(request) + assert response.status_code == 200 + + self.increment_timer() + response = async_to_sync(self.x_view)(request) + assert response.status_code == 429 + + # Should still be able to hit y view 1 time per minute. + self.increment_timer() + response = async_to_sync(self.y_view)(request) + assert response.status_code == 200 + + self.increment_timer() + response = async_to_sync(self.y_view)(request) + assert response.status_code == 429 + + def test_unscoped_view_not_throttled(self): + request = self.factory.get('/') + + for idx in range(10): + self.increment_timer() + response = async_to_sync(self.unscoped_view)(request) + assert response.status_code == 200 + + def test_get_cache_key_returns_correct_key_if_user_is_authenticated(self): + class DummyView: + throttle_scope = 'user' + + request = Request(HttpRequest()) + user = User.objects.create(username='test') + force_authenticate(request, user) + request.user = user + self.throttle.allow_request(request, DummyView()) + cache_key = self.throttle.get_cache_key(request, view=DummyView()) + assert cache_key == 'throttle_user_%s' % user.pk + + class XffTestingBase(TestCase): def setUp(self): @@ -400,6 +752,34 @@ def config_proxy(self, num_proxies): setattr(api_settings, 'NUM_PROXIES', num_proxies) +class AsyncXffTestingBase(TestCase): + def setUp(self): + + class Throttle(ScopedRateThrottle): + THROTTLE_RATES = {'test_limit': '1/day'} + TIMER_SECONDS = 0 + + async def timer(self): + return self.TIMER_SECONDS + + class View(APIView): + throttle_classes = (Throttle,) + throttle_scope = 'test_limit' + + async def get(self, request): + return Response('test_limit') + + cache.clear() + self.throttle = Throttle() + self.view = View.as_view() + self.request = APIAsyncRequestFactory().get('/some_uri') + self.request.META['REMOTE_ADDR'] = '3.3.3.3' + self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 2.2.2.2' + + def config_proxy(self, num_proxies): + setattr(api_settings, 'NUM_PROXIES', num_proxies) + + class IdWithXffBasicTests(XffTestingBase): def test_accepts_request_under_limit(self): self.config_proxy(0) @@ -411,6 +791,21 @@ def test_denies_request_over_limit(self): assert self.view(self.request).status_code == 429 +@pytest.mark.skipif( + django.VERSION < (4, 1), + reason="Async view support requires Django 4.1 or higher", +) +class AsyncIdWithXffBasicTests(AsyncXffTestingBase): + def test_accepts_request_under_limit(self): + self.config_proxy(0) + assert async_to_sync(self.view)(self.request).status_code == 200 + + def test_denies_request_over_limit(self): + self.config_proxy(0) + async_to_sync(self.view)(self.request) + assert async_to_sync(self.view)(self.request).status_code == 429 + + class XffSpoofingTests(XffTestingBase): def test_xff_spoofing_doesnt_change_machine_id_with_one_app_proxy(self): self.config_proxy(1) @@ -425,6 +820,24 @@ def test_xff_spoofing_doesnt_change_machine_id_with_two_app_proxies(self): assert self.view(self.request).status_code == 429 +@pytest.mark.skipif( + django.VERSION < (4, 1), + reason="Async view support requires Django 4.1 or higher", +) +class AsyncXffSpoofingTests(AsyncXffTestingBase): + def test_xff_spoofing_doesnt_change_machine_id_with_one_app_proxy(self): + self.config_proxy(1) + async_to_sync(self.view)(self.request) + self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 5.5.5.5, 2.2.2.2' + assert async_to_sync(self.view)(self.request).status_code == 429 + + def test_xff_spoofing_doesnt_change_machine_id_with_two_app_proxies(self): + self.config_proxy(2) + async_to_sync(self.view)(self.request) + self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 1.1.1.1, 2.2.2.2' + assert async_to_sync(self.view)(self.request).status_code == 429 + + class XffUniqueMachinesTest(XffTestingBase): def test_unique_clients_are_counted_independently_with_one_proxy(self): self.config_proxy(1) @@ -439,6 +852,24 @@ def test_unique_clients_are_counted_independently_with_two_proxies(self): assert self.view(self.request).status_code == 200 +@pytest.mark.skipif( + django.VERSION < (4, 1), + reason="Async view support requires Django 4.1 or higher", +) +class AsyncXffUniqueMachinesTest(AsyncXffTestingBase): + def test_unique_clients_are_counted_independently_with_one_proxy(self): + self.config_proxy(1) + async_to_sync(self.view)(self.request) + self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 7.7.7.7' + assert async_to_sync(self.view)(self.request).status_code == 200 + + def test_unique_clients_are_counted_independently_with_two_proxies(self): + self.config_proxy(2) + async_to_sync(self.view)(self.request) + self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 7.7.7.7, 2.2.2.2' + assert async_to_sync(self.view)(self.request).status_code == 200 + + class BaseThrottleTests(TestCase): def test_allow_request_raises_not_implemented_error(self): diff --git a/tests/test_views.py b/tests/test_views.py index 2648c9fb38d..c1cc5a04e68 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -1,8 +1,12 @@ import copy +import django +import pytest +from django.contrib.auth.models import User from django.test import TestCase from rest_framework import status +from rest_framework.compat import async_to_sync from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.settings import APISettings, api_settings @@ -22,16 +26,36 @@ def post(self, request, *args, **kwargs): return Response({'method': 'POST', 'data': request.data}) +class BasicAsyncView(APIView): + async def get(self, request, *args, **kwargs): + return Response({'method': 'GET'}) + + async def post(self, request, *args, **kwargs): + return Response({'method': 'POST', 'data': request.data}) + + @api_view(['GET', 'POST', 'PUT', 'PATCH']) def basic_view(request): if request.method == 'GET': - return {'method': 'GET'} + return Response({'method': 'GET'}) elif request.method == 'POST': - return {'method': 'POST', 'data': request.data} + return Response({'method': 'POST', 'data': request.data}) elif request.method == 'PUT': - return {'method': 'PUT', 'data': request.data} + return Response({'method': 'PUT', 'data': request.data}) elif request.method == 'PATCH': - return {'method': 'PATCH', 'data': request.data} + return Response({'method': 'PATCH', 'data': request.data}) + + +@api_view(['GET', 'POST', 'PUT', 'PATCH']) +async def basic_async_view(request): + if request.method == 'GET': + return Response({'method': 'GET'}) + elif request.method == 'POST': + return Response({'method': 'POST', 'data': request.data}) + elif request.method == 'PUT': + return Response({'method': 'PUT', 'data': request.data}) + elif request.method == 'PATCH': + return Response({'method': 'PATCH', 'data': request.data}) class ErrorView(APIView): @@ -72,6 +96,36 @@ class ClassBasedViewIntegrationTests(TestCase): def setUp(self): self.view = BasicView.as_view() + def test_get_succeeds(self): + request = factory.get('/') + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_logged_in_get_succeeds(self): + user = User.objects.create_user('user', 'user@example.com', 'password') + request = factory.get('/') + del user.is_active + request.user = user + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_post_succeeds(self): + request = factory.post('/', {'test': 'foo'}) + response = self.view(request) + expected = { + 'method': 'POST', + 'data': {'test': ['foo']} + } + assert response.status_code == status.HTTP_200_OK + assert response.data == expected + + def test_options_succeeds(self): + request = factory.options('/') + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + def test_400_parse_error(self): request = factory.post('/', 'f00bar', content_type='application/json') response = self.view(request) @@ -82,10 +136,88 @@ def test_400_parse_error(self): assert sanitise_json_error(response.data) == expected +@pytest.mark.skipif( + django.VERSION < (4, 1), + reason="Async view support requires Django 4.1 or higher", +) +class ClassBasedAsyncViewIntegrationTests(TestCase): + def setUp(self): + self.view = BasicAsyncView.as_view() + + def test_get_succeeds(self): + request = factory.get('/') + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_logged_in_get_succeeds(self): + user = User.objects.create_user('user', 'user@example.com', 'password') + request = factory.get('/') + del user.is_active + request.user = user + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_post_succeeds(self): + request = factory.post('/', {'test': 'foo'}) + response = async_to_sync(self.view)(request) + expected = { + 'method': 'POST', + 'data': {'test': ['foo']} + } + assert response.status_code == status.HTTP_200_OK + assert response.data == expected + + def test_options_succeeds(self): + request = factory.options('/') + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + + def test_400_parse_error(self): + request = factory.post('/', 'f00bar', content_type='application/json') + response = async_to_sync(self.view)(request) + expected = { + 'detail': JSON_ERROR + } + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert sanitise_json_error(response.data) == expected + + class FunctionBasedViewIntegrationTests(TestCase): def setUp(self): self.view = basic_view + def test_get_succeeds(self): + request = factory.get('/') + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_logged_in_get_succeeds(self): + user = User.objects.create_user('user', 'user@example.com', 'password') + request = factory.get('/') + del user.is_active + request.user = user + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_post_succeeds(self): + request = factory.post('/', {'test': 'foo'}) + response = self.view(request) + expected = { + 'method': 'POST', + 'data': {'test': ['foo']} + } + assert response.status_code == status.HTTP_200_OK + assert response.data == expected + + def test_options_succeeds(self): + request = factory.options('/') + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + def test_400_parse_error(self): request = factory.post('/', 'f00bar', content_type='application/json') response = self.view(request) @@ -96,6 +228,54 @@ def test_400_parse_error(self): assert sanitise_json_error(response.data) == expected +@pytest.mark.skipif( + django.VERSION < (4, 1), + reason="Async view support requires Django 4.1 or higher", +) +class FunctionBasedAsyncViewIntegrationTests(TestCase): + def setUp(self): + self.view = basic_async_view + + def test_get_succeeds(self): + request = factory.get('/') + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_logged_in_get_succeeds(self): + user = User.objects.create_user('user', 'user@example.com', 'password') + request = factory.get('/') + del user.is_active + request.user = user + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_post_succeeds(self): + request = factory.post('/', {'test': 'foo'}) + response = async_to_sync(self.view)(request) + expected = { + 'method': 'POST', + 'data': {'test': ['foo']} + } + assert response.status_code == status.HTTP_200_OK + assert response.data == expected + + def test_options_succeeds(self): + request = factory.options('/') + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + + def test_400_parse_error(self): + request = factory.post('/', 'f00bar', content_type='application/json') + response = async_to_sync(self.view)(request) + expected = { + 'detail': JSON_ERROR + } + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert sanitise_json_error(response.data) == expected + + class TestCustomExceptionHandler(TestCase): def setUp(self): self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER