Skip to content

Commit

Permalink
Fixed #35059 -- Ensured that ASGIHandler always sends the request_fin…
Browse files Browse the repository at this point in the history
…ished signal.

Prior to this work, when async tasks that process the request are cancelled due
to receiving an early "http.disconnect" ASGI message, the request_finished
signal was not being sent, potentially leading to resource leaks (such as
database connections).

This branch ensures that the request_finished signal is sent even in the case
of early termination of the response.

Regression in 64cea1e.

Co-authored-by: Natalia <[email protected]>
Co-authored-by: Carlton Gibson <[email protected]>
  • Loading branch information
3 people committed Jan 31, 2024
1 parent a43d75e commit 11393ab
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 3 deletions.
18 changes: 16 additions & 2 deletions django/core/handlers/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,18 @@ async def handle(self, scope, receive, send):
if request is None:
body_file.close()
await self.send_response(error_response, send)
await sync_to_async(error_response.close)()
return

async def process_request(request, send):
response = await self.run_get_response(request)
await self.send_response(response, send)
try:
await self.send_response(response, send)
except asyncio.CancelledError:
# Client disconnected during send_response (ignore exception).
pass

return response

# Try to catch a disconnect while getting response.
tasks = [
Expand Down Expand Up @@ -221,6 +228,14 @@ async def process_request(request, send):
except asyncio.CancelledError:
# Task re-raised the CancelledError as expected.
pass

try:
response = tasks[1].result()
except asyncio.CancelledError:
await signals.request_finished.asend(sender=self.__class__)
else:
await sync_to_async(response.close)()

body_file.close()

async def listen_for_disconnect(self, receive):
Expand Down Expand Up @@ -346,7 +361,6 @@ async def send_response(self, response, send):
"more_body": not last,
}
)
await sync_to_async(response.close, thread_sensitive=True)()

@classmethod
def chunk_bytes(cls, data):
Expand Down
4 changes: 4 additions & 0 deletions docs/releases/5.0.2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,7 @@ Bugfixes
* Fixed a regression in Django 5.0 that caused a crash of the ``dumpdata``
management command when a base queryset used ``prefetch_related()``
(:ticket:`35159`).

* Fixed a regression in Django 5.0 that caused the ``request_finished`` signal to
sometimes not be fired when running Django through an ASGI server, resulting
in potential resource leaks (:ticket:`35059`).
134 changes: 133 additions & 1 deletion tests/asgi/tests.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import asyncio
import sys
import threading
import time
from pathlib import Path

from asgiref.sync import sync_to_async
from asgiref.testing import ApplicationCommunicator

from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler
from django.core.asgi import get_asgi_application
from django.core.exceptions import RequestDataTooBig
from django.core.handlers.asgi import ASGIHandler, ASGIRequest
from django.core.signals import request_finished, request_started
from django.db import close_old_connections
Expand All @@ -20,6 +23,7 @@
)
from django.urls import path
from django.utils.http import http_date
from django.views.decorators.csrf import csrf_exempt

from .urls import sync_waiter, test_filename

Expand Down Expand Up @@ -205,6 +209,96 @@ async def test_post_body(self):
self.assertEqual(response_body["type"], "http.response.body")
self.assertEqual(response_body["body"], b"Echo!")

async def test_create_request_error(self):
# Track request_finished signal.
signal_handler = SignalHandler()
request_finished.connect(signal_handler)
self.addCleanup(request_finished.disconnect, signal_handler)

# Request class that always fails creation with RequestDataTooBig.
class TestASGIRequest(ASGIRequest):

def __init__(self, scope, body_file):
super().__init__(scope, body_file)
raise RequestDataTooBig()

# Handler to use the custom request class.
class TestASGIHandler(ASGIHandler):
request_class = TestASGIRequest

application = TestASGIHandler()
scope = self.async_request_factory._base_scope(path="/not-important/")
communicator = ApplicationCommunicator(application, scope)

# Initiate request.
await communicator.send_input({"type": "http.request"})
# Give response.close() time to finish.
await communicator.wait()

self.assertEqual(len(signal_handler.calls), 1)
self.assertNotEqual(
signal_handler.calls[0]["thread"], threading.current_thread()
)

async def test_cancel_post_request_with_sync_processing(self):
"""
The request.body object should be available and readable in view
code, even if the ASGIHandler cancels processing part way through.
"""
loop = asyncio.get_event_loop()
# Events to monitor the view processing from the parent test code.
view_started_event = asyncio.Event()
view_finished_event = asyncio.Event()
# Record received request body or exceptions raised in the test view
outcome = []

# This view will run in a new thread because it is wrapped in
# sync_to_async. The view consumes the POST body data after a short
# delay. The test will cancel the request using http.disconnect during
# the delay, but because this is a sync view the code runs to
# completion. There should be no exceptions raised inside the view
# code.
@csrf_exempt
@sync_to_async
def post_view(request):
try:
loop.call_soon_threadsafe(view_started_event.set)
time.sleep(0.1)
# Do something to read request.body after pause
outcome.append({"request_body": request.body})
return HttpResponse("ok")
except Exception as e:
outcome.append({"exception": e})
finally:
loop.call_soon_threadsafe(view_finished_event.set)

# Request class to use the view.
class TestASGIRequest(ASGIRequest):
urlconf = (path("post/", post_view),)

# Handler to use request class.
class TestASGIHandler(ASGIHandler):
request_class = TestASGIRequest

application = TestASGIHandler()
scope = self.async_request_factory._base_scope(
method="POST",
path="/post/",
)
communicator = ApplicationCommunicator(application, scope)

await communicator.send_input({"type": "http.request", "body": b"Body data!"})

# Wait until the view code has started, then send http.disconnect.
await view_started_event.wait()
await communicator.send_input({"type": "http.disconnect"})
# Wait until view code has finished.
await view_finished_event.wait()
with self.assertRaises(asyncio.TimeoutError):
await communicator.receive_output()

self.assertEqual(outcome, [{"request_body": b"Body data!"}])

async def test_untouched_request_body_gets_closed(self):
application = get_asgi_application()
scope = self.async_request_factory._base_scope(method="POST", path="/post/")
Expand Down Expand Up @@ -345,7 +439,9 @@ async def test_request_lifecycle_signals_dispatched_with_thread_sensitive(self):
# AsyncToSync should have executed the signals in the same thread.
self.assertEqual(len(signal_handler.calls), 2)
request_started_call, request_finished_call = signal_handler.calls
self.assertEqual(request_started_call["thread"], request_finished_call["thread"])
self.assertEqual(
request_started_call["thread"], request_finished_call["thread"]
)

async def test_concurrent_async_uses_multiple_thread_pools(self):
sync_waiter.active_threads.clear()
Expand Down Expand Up @@ -381,6 +477,10 @@ async def test_concurrent_async_uses_multiple_thread_pools(self):
async def test_asyncio_cancel_error(self):
# Flag to check if the view was cancelled.
view_did_cancel = False
# Track request_finished signal.
signal_handler = SignalHandler()
request_finished.connect(signal_handler)
self.addCleanup(request_finished.disconnect, signal_handler)

# A view that will listen for the cancelled error.
async def view(request):
Expand Down Expand Up @@ -415,6 +515,13 @@ class TestASGIHandler(ASGIHandler):
# Give response.close() time to finish.
await communicator.wait()
self.assertIs(view_did_cancel, False)
# Exactly one call to request_finished handler.
self.assertEqual(len(signal_handler.calls), 1)
handler_call = signal_handler.calls.pop()
# It was NOT on the async thread.
self.assertNotEqual(handler_call["thread"], threading.current_thread())
# The signal sender is the handler class.
self.assertEqual(handler_call["kwargs"], {"sender": TestASGIHandler})

# Request cycle with a disconnect before the view can respond.
application = TestASGIHandler()
Expand All @@ -430,11 +537,22 @@ class TestASGIHandler(ASGIHandler):
await communicator.receive_output()
await communicator.wait()
self.assertIs(view_did_cancel, True)
# Exactly one call to request_finished handler.
self.assertEqual(len(signal_handler.calls), 1)
handler_call = signal_handler.calls.pop()
# It was NOT on the async thread.
self.assertNotEqual(handler_call["thread"], threading.current_thread())
# The signal sender is the handler class.
self.assertEqual(handler_call["kwargs"], {"sender": TestASGIHandler})

async def test_asyncio_streaming_cancel_error(self):
# Similar to test_asyncio_cancel_error(), but during a streaming
# response.
view_did_cancel = False
# Track request_finished signals.
signal_handler = SignalHandler()
request_finished.connect(signal_handler)
self.addCleanup(request_finished.disconnect, signal_handler)

async def streaming_response():
nonlocal view_did_cancel
Expand Down Expand Up @@ -469,6 +587,13 @@ class TestASGIHandler(ASGIHandler):
self.assertEqual(response_body["body"], b"Hello World!")
await communicator.wait()
self.assertIs(view_did_cancel, False)
# Exactly one call to request_finished handler.
self.assertEqual(len(signal_handler.calls), 1)
handler_call = signal_handler.calls.pop()
# It was NOT on the async thread.
self.assertNotEqual(handler_call["thread"], threading.current_thread())
# The signal sender is the handler class.
self.assertEqual(handler_call["kwargs"], {"sender": TestASGIHandler})

# Request cycle with a disconnect.
application = TestASGIHandler()
Expand All @@ -487,6 +612,13 @@ class TestASGIHandler(ASGIHandler):
await communicator.receive_output()
await communicator.wait()
self.assertIs(view_did_cancel, True)
# Exactly one call to request_finished handler.
self.assertEqual(len(signal_handler.calls), 1)
handler_call = signal_handler.calls.pop()
# It was NOT on the async thread.
self.assertNotEqual(handler_call["thread"], threading.current_thread())
# The signal sender is the handler class.
self.assertEqual(handler_call["kwargs"], {"sender": TestASGIHandler})

async def test_streaming(self):
scope = self.async_request_factory._base_scope(
Expand Down

0 comments on commit 11393ab

Please sign in to comment.