From a53d4d02ffe5ed5a33139fa214d0f36ae8844e8c Mon Sep 17 00:00:00 2001 From: Aaron Chong Date: Mon, 27 Jun 2022 22:37:00 +0800 Subject: [PATCH 1/5] Test for https://github.com/encode/starlette/issues/1527 --- tests/middleware/test_base.py | 45 +++++++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 976d77b86..3b1570434 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1,13 +1,16 @@ import contextvars +from contextlib import AsyncExitStack +from typing import AsyncGenerator, Awaitable, Callable import pytest from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import PlainTextResponse, StreamingResponse +from starlette.requests import Request +from starlette.responses import PlainTextResponse, Response, StreamingResponse from starlette.routing import Route, WebSocketRoute -from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.types import ASGIApp, Message, Receive, Scope, Send class CustomMiddleware(BaseHTTPMiddleware): @@ -206,3 +209,41 @@ async def homepage(request): client = test_client_factory(app) response = client.get("/") assert response.status_code == 200, response.content + + +@pytest.mark.anyio +async def test_client_disconnects_before_response_is_sent() -> None: + app: ASGIApp + + async def homepage(request: Request): + # await anyio.sleep(5) + return PlainTextResponse("hi!") + + async def dispatch( + request: Request, call_next: Callable[[Request], Awaitable[Response]] + ) -> Response: + return await call_next(request) + + app = BaseHTTPMiddleware(Route("/", homepage), dispatch=dispatch) + app = BaseHTTPMiddleware(app, dispatch=dispatch) + + async def recv_gen() -> AsyncGenerator[Message, None]: + yield {"type": "http.request"} + yield {"type": "http.disconnect"} + yield {"type": "http.disconnect"} + + async def send_gen() -> AsyncGenerator[None, Message]: + msg = yield + assert msg["type"] == "http.response.start" + msg = yield + raise AssertionError("Should not be called") + + scope = {"type": "http", "method": "GET", "path": "/"} + + async with AsyncExitStack() as stack: + recv = recv_gen() + stack.push_async_callback(recv.aclose) + send = send_gen() + stack.push_async_callback(send.aclose) + await send.__anext__() + await app(scope, recv.__aiter__().__anext__, send.asend) From 5b604fec402e8e2b3627d3e7d21c26ac28bf17ca Mon Sep 17 00:00:00 2001 From: Aaron Chong Date: Tue, 28 Jun 2022 02:04:08 +0800 Subject: [PATCH 2/5] Shield send "http.response.start" from cancellation --- starlette/middleware/base.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 49a5e3e2d..1dca5e4e0 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -4,7 +4,7 @@ from starlette.requests import Request from starlette.responses import Response, StreamingResponse -from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.types import ASGIApp, Message, Receive, Scope, Send RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] DispatchFunction = typing.Callable[ @@ -28,12 +28,22 @@ async def call_next(request: Request) -> Response: app_exc: typing.Optional[Exception] = None send_stream, recv_stream = anyio.create_memory_object_stream() + async def send(msg: Message) -> None: + # Shield send "http.response.start" from cancellation. + # Otherwise, `await recv_stream.receive()` will raise `anyio.EndOfStream` if request is disconnected, + # due to `task_group.cancel_scope.cancel()` in `StreamingResponse.__call__..wrap` + # and cancellation check in `await checkpoint()` of `MemoryObjectSendStream.send`, + # and then `RuntimeError: No response returned.` will be raised below. + shield = msg["type"] == "http.response.start" + with anyio.CancelScope(shield=shield): + await send_stream.send(msg) + async def coro() -> None: nonlocal app_exc async with send_stream: try: - await self.app(scope, request.receive, send_stream.send) + await self.app(scope, request.receive, send) except Exception as exc: app_exc = exc From d0165cae151882df656d12bc5a1509f296c0560e Mon Sep 17 00:00:00 2001 From: Aaron Chong Date: Tue, 28 Jun 2022 02:37:01 +0800 Subject: [PATCH 3/5] Fix E501 line too long --- starlette/middleware/base.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 1dca5e4e0..6ed9c20c7 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -30,9 +30,11 @@ async def call_next(request: Request) -> Response: async def send(msg: Message) -> None: # Shield send "http.response.start" from cancellation. - # Otherwise, `await recv_stream.receive()` will raise `anyio.EndOfStream` if request is disconnected, - # due to `task_group.cancel_scope.cancel()` in `StreamingResponse.__call__..wrap` - # and cancellation check in `await checkpoint()` of `MemoryObjectSendStream.send`, + # Otherwise, `await recv_stream.receive()` will raise + # `anyio.EndOfStream` if request is disconnected, + # due to `task_group.cancel_scope.cancel()` in + # `StreamingResponse.__call__..wrap` and cancellation check + # during `await checkpoint()` in `MemoryObjectSendStream.send`, # and then `RuntimeError: No response returned.` will be raised below. shield = msg["type"] == "http.response.start" with anyio.CancelScope(shield=shield): From 47db5e27c64a9bd5e6a30fea3bc6e56cb6b039bc Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 27 Jun 2022 11:39:56 -0700 Subject: [PATCH 4/5] add pragma --- tests/middleware/test_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 3b1570434..80087686d 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -236,7 +236,7 @@ async def send_gen() -> AsyncGenerator[None, Message]: msg = yield assert msg["type"] == "http.response.start" msg = yield - raise AssertionError("Should not be called") + raise AssertionError("Should not be called") # pragma: no cover scope = {"type": "http", "method": "GET", "path": "/"} From 68c5e9345c2276ed793bdb83c6723c9388b81404 Mon Sep 17 00:00:00 2001 From: Aaron Chong Date: Tue, 28 Jun 2022 02:44:28 +0800 Subject: [PATCH 5/5] Reword comments Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> --- starlette/middleware/base.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 6ed9c20c7..721a2eb58 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -31,11 +31,13 @@ async def call_next(request: Request) -> Response: async def send(msg: Message) -> None: # Shield send "http.response.start" from cancellation. # Otherwise, `await recv_stream.receive()` will raise - # `anyio.EndOfStream` if request is disconnected, + # `anyio.EndOfStream` if the connection is disconnected, # due to `task_group.cancel_scope.cancel()` in - # `StreamingResponse.__call__..wrap` and cancellation check - # during `await checkpoint()` in `MemoryObjectSendStream.send`, - # and then `RuntimeError: No response returned.` will be raised below. + # `StreamingResponse.__call__..wrap` + # and cancellation check during `await checkpoint()` in + # `MemoryObjectSendStream.send`. + # This would trigger the check we have in this middleware resulting in + # `RuntimeError: No response returned.` being raised below. shield = msg["type"] == "http.response.start" with anyio.CancelScope(shield=shield): await send_stream.send(msg)