From 11393ab1316f973c5fbb534305750740d909b4e4 Mon Sep 17 00:00:00 2001 From: James Thorniley Date: Thu, 4 Jan 2024 13:14:30 +0000 Subject: [PATCH] Fixed #35059 -- Ensured that ASGIHandler always sends the request_finished signal. Prior to this work, when async tasks that process the request are cancelled due to receiving an early "http.disconnect" ASGI message, the request_finished signal was not being sent, potentially leading to resource leaks (such as database connections). This branch ensures that the request_finished signal is sent even in the case of early termination of the response. Regression in 64cea1e48f285ea2162c669208d95188b32bbc82. Co-authored-by: Natalia <124304+nessita@users.noreply.github.com> Co-authored-by: Carlton Gibson --- django/core/handlers/asgi.py | 18 ++++- docs/releases/5.0.2.txt | 4 ++ tests/asgi/tests.py | 134 ++++++++++++++++++++++++++++++++++- 3 files changed, 153 insertions(+), 3 deletions(-) diff --git a/django/core/handlers/asgi.py b/django/core/handlers/asgi.py index 7b0086fb765a..3af080599ab1 100644 --- a/django/core/handlers/asgi.py +++ b/django/core/handlers/asgi.py @@ -186,11 +186,18 @@ async def handle(self, scope, receive, send): if request is None: body_file.close() await self.send_response(error_response, send) + await sync_to_async(error_response.close)() return async def process_request(request, send): response = await self.run_get_response(request) - await self.send_response(response, send) + try: + await self.send_response(response, send) + except asyncio.CancelledError: + # Client disconnected during send_response (ignore exception). + pass + + return response # Try to catch a disconnect while getting response. tasks = [ @@ -221,6 +228,14 @@ async def process_request(request, send): except asyncio.CancelledError: # Task re-raised the CancelledError as expected. pass + + try: + response = tasks[1].result() + except asyncio.CancelledError: + await signals.request_finished.asend(sender=self.__class__) + else: + await sync_to_async(response.close)() + body_file.close() async def listen_for_disconnect(self, receive): @@ -346,7 +361,6 @@ async def send_response(self, response, send): "more_body": not last, } ) - await sync_to_async(response.close, thread_sensitive=True)() @classmethod def chunk_bytes(cls, data): diff --git a/docs/releases/5.0.2.txt b/docs/releases/5.0.2.txt index 83f1af7b4f00..64ffcb88bdae 100644 --- a/docs/releases/5.0.2.txt +++ b/docs/releases/5.0.2.txt @@ -28,3 +28,7 @@ Bugfixes * Fixed a regression in Django 5.0 that caused a crash of the ``dumpdata`` management command when a base queryset used ``prefetch_related()`` (:ticket:`35159`). + +* Fixed a regression in Django 5.0 that caused the ``request_finished`` signal to + sometimes not be fired when running Django through an ASGI server, resulting + in potential resource leaks (:ticket:`35059`). diff --git a/tests/asgi/tests.py b/tests/asgi/tests.py index 0fbb586f854d..963f45f798d4 100644 --- a/tests/asgi/tests.py +++ b/tests/asgi/tests.py @@ -1,12 +1,15 @@ import asyncio import sys import threading +import time from pathlib import Path +from asgiref.sync import sync_to_async from asgiref.testing import ApplicationCommunicator from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler from django.core.asgi import get_asgi_application +from django.core.exceptions import RequestDataTooBig from django.core.handlers.asgi import ASGIHandler, ASGIRequest from django.core.signals import request_finished, request_started from django.db import close_old_connections @@ -20,6 +23,7 @@ ) from django.urls import path from django.utils.http import http_date +from django.views.decorators.csrf import csrf_exempt from .urls import sync_waiter, test_filename @@ -205,6 +209,96 @@ async def test_post_body(self): self.assertEqual(response_body["type"], "http.response.body") self.assertEqual(response_body["body"], b"Echo!") + async def test_create_request_error(self): + # Track request_finished signal. + signal_handler = SignalHandler() + request_finished.connect(signal_handler) + self.addCleanup(request_finished.disconnect, signal_handler) + + # Request class that always fails creation with RequestDataTooBig. + class TestASGIRequest(ASGIRequest): + + def __init__(self, scope, body_file): + super().__init__(scope, body_file) + raise RequestDataTooBig() + + # Handler to use the custom request class. + class TestASGIHandler(ASGIHandler): + request_class = TestASGIRequest + + application = TestASGIHandler() + scope = self.async_request_factory._base_scope(path="/not-important/") + communicator = ApplicationCommunicator(application, scope) + + # Initiate request. + await communicator.send_input({"type": "http.request"}) + # Give response.close() time to finish. + await communicator.wait() + + self.assertEqual(len(signal_handler.calls), 1) + self.assertNotEqual( + signal_handler.calls[0]["thread"], threading.current_thread() + ) + + async def test_cancel_post_request_with_sync_processing(self): + """ + The request.body object should be available and readable in view + code, even if the ASGIHandler cancels processing part way through. + """ + loop = asyncio.get_event_loop() + # Events to monitor the view processing from the parent test code. + view_started_event = asyncio.Event() + view_finished_event = asyncio.Event() + # Record received request body or exceptions raised in the test view + outcome = [] + + # This view will run in a new thread because it is wrapped in + # sync_to_async. The view consumes the POST body data after a short + # delay. The test will cancel the request using http.disconnect during + # the delay, but because this is a sync view the code runs to + # completion. There should be no exceptions raised inside the view + # code. + @csrf_exempt + @sync_to_async + def post_view(request): + try: + loop.call_soon_threadsafe(view_started_event.set) + time.sleep(0.1) + # Do something to read request.body after pause + outcome.append({"request_body": request.body}) + return HttpResponse("ok") + except Exception as e: + outcome.append({"exception": e}) + finally: + loop.call_soon_threadsafe(view_finished_event.set) + + # Request class to use the view. + class TestASGIRequest(ASGIRequest): + urlconf = (path("post/", post_view),) + + # Handler to use request class. + class TestASGIHandler(ASGIHandler): + request_class = TestASGIRequest + + application = TestASGIHandler() + scope = self.async_request_factory._base_scope( + method="POST", + path="/post/", + ) + communicator = ApplicationCommunicator(application, scope) + + await communicator.send_input({"type": "http.request", "body": b"Body data!"}) + + # Wait until the view code has started, then send http.disconnect. + await view_started_event.wait() + await communicator.send_input({"type": "http.disconnect"}) + # Wait until view code has finished. + await view_finished_event.wait() + with self.assertRaises(asyncio.TimeoutError): + await communicator.receive_output() + + self.assertEqual(outcome, [{"request_body": b"Body data!"}]) + async def test_untouched_request_body_gets_closed(self): application = get_asgi_application() scope = self.async_request_factory._base_scope(method="POST", path="/post/") @@ -345,7 +439,9 @@ async def test_request_lifecycle_signals_dispatched_with_thread_sensitive(self): # AsyncToSync should have executed the signals in the same thread. self.assertEqual(len(signal_handler.calls), 2) request_started_call, request_finished_call = signal_handler.calls - self.assertEqual(request_started_call["thread"], request_finished_call["thread"]) + self.assertEqual( + request_started_call["thread"], request_finished_call["thread"] + ) async def test_concurrent_async_uses_multiple_thread_pools(self): sync_waiter.active_threads.clear() @@ -381,6 +477,10 @@ async def test_concurrent_async_uses_multiple_thread_pools(self): async def test_asyncio_cancel_error(self): # Flag to check if the view was cancelled. view_did_cancel = False + # Track request_finished signal. + signal_handler = SignalHandler() + request_finished.connect(signal_handler) + self.addCleanup(request_finished.disconnect, signal_handler) # A view that will listen for the cancelled error. async def view(request): @@ -415,6 +515,13 @@ class TestASGIHandler(ASGIHandler): # Give response.close() time to finish. await communicator.wait() self.assertIs(view_did_cancel, False) + # Exactly one call to request_finished handler. + self.assertEqual(len(signal_handler.calls), 1) + handler_call = signal_handler.calls.pop() + # It was NOT on the async thread. + self.assertNotEqual(handler_call["thread"], threading.current_thread()) + # The signal sender is the handler class. + self.assertEqual(handler_call["kwargs"], {"sender": TestASGIHandler}) # Request cycle with a disconnect before the view can respond. application = TestASGIHandler() @@ -430,11 +537,22 @@ class TestASGIHandler(ASGIHandler): await communicator.receive_output() await communicator.wait() self.assertIs(view_did_cancel, True) + # Exactly one call to request_finished handler. + self.assertEqual(len(signal_handler.calls), 1) + handler_call = signal_handler.calls.pop() + # It was NOT on the async thread. + self.assertNotEqual(handler_call["thread"], threading.current_thread()) + # The signal sender is the handler class. + self.assertEqual(handler_call["kwargs"], {"sender": TestASGIHandler}) async def test_asyncio_streaming_cancel_error(self): # Similar to test_asyncio_cancel_error(), but during a streaming # response. view_did_cancel = False + # Track request_finished signals. + signal_handler = SignalHandler() + request_finished.connect(signal_handler) + self.addCleanup(request_finished.disconnect, signal_handler) async def streaming_response(): nonlocal view_did_cancel @@ -469,6 +587,13 @@ class TestASGIHandler(ASGIHandler): self.assertEqual(response_body["body"], b"Hello World!") await communicator.wait() self.assertIs(view_did_cancel, False) + # Exactly one call to request_finished handler. + self.assertEqual(len(signal_handler.calls), 1) + handler_call = signal_handler.calls.pop() + # It was NOT on the async thread. + self.assertNotEqual(handler_call["thread"], threading.current_thread()) + # The signal sender is the handler class. + self.assertEqual(handler_call["kwargs"], {"sender": TestASGIHandler}) # Request cycle with a disconnect. application = TestASGIHandler() @@ -487,6 +612,13 @@ class TestASGIHandler(ASGIHandler): await communicator.receive_output() await communicator.wait() self.assertIs(view_did_cancel, True) + # Exactly one call to request_finished handler. + self.assertEqual(len(signal_handler.calls), 1) + handler_call = signal_handler.calls.pop() + # It was NOT on the async thread. + self.assertNotEqual(handler_call["thread"], threading.current_thread()) + # The signal sender is the handler class. + self.assertEqual(handler_call["kwargs"], {"sender": TestASGIHandler}) async def test_streaming(self): scope = self.async_request_factory._base_scope(