Skip to content

Commit

Permalink
Merge branch 'master' into test/requests
Browse files Browse the repository at this point in the history
  • Loading branch information
Orenoid authored Sep 1, 2024
2 parents adeb680 + 72c2334 commit 7c5047c
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 37 deletions.
12 changes: 12 additions & 0 deletions docs/release-notes.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
## 0.38.3

September 1, 2024

#### Added

* Support for Python 3.13 [#2662](https://github.com/encode/starlette/pull/2662).

#### Fixed

* Don't poll for disconnects in `BaseHTTPMiddleware` via `StreamingResponse` [#2620](https://github.com/encode/starlette/pull/2620).

## 0.38.2

July 27, 2024
Expand Down
2 changes: 1 addition & 1 deletion starlette/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.38.2"
__version__ = "0.38.3"
35 changes: 24 additions & 11 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from starlette._utils import collapse_excgroups
from starlette.background import BackgroundTask
from starlette.requests import ClientDisconnect, Request
from starlette.responses import ContentStream, Response, StreamingResponse
from starlette.responses import AsyncContentStream, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
Expand Down Expand Up @@ -56,6 +55,7 @@ async def wrapped_receive(self) -> Message:
# at this point a disconnect is all that we should be receiving
# if we get something else, things went wrong somewhere
raise RuntimeError(f"Unexpected message received: {msg['type']}")
self._wrapped_rcv_disconnected = True
return msg

# wrapped_rcv state 3: not yet consumed
Expand Down Expand Up @@ -198,20 +198,33 @@ async def dispatch(
raise NotImplementedError() # pragma: no cover


class _StreamingResponse(StreamingResponse):
class _StreamingResponse(Response):
def __init__(
self,
content: ContentStream,
content: AsyncContentStream,
status_code: int = 200,
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
info: typing.Mapping[str, typing.Any] | None = None,
) -> None:
self._info = info
super().__init__(content, status_code, headers, media_type, background)
self.info = info
self.body_iterator = content
self.status_code = status_code
self.media_type = media_type
self.init_headers(headers)

async def stream_response(self, send: Send) -> None:
if self._info:
await send({"type": "http.response.debug", "info": self._info})
return await super().stream_response(send)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.info is not None:
await send({"type": "http.response.debug", "info": self.info})
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)

async for chunk in self.body_iterator:
await send({"type": "http.response.body", "body": chunk, "more_body": True})

await send({"type": "http.response.body", "body": b"", "more_body": False})
169 changes: 144 additions & 25 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Generator,
)

Expand All @@ -16,7 +17,7 @@
from starlette.background import BackgroundTask
from starlette.middleware import Middleware, _MiddlewareClass
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.requests import ClientDisconnect, Request
from starlette.responses import PlainTextResponse, Response, StreamingResponse
from starlette.routing import Route, WebSocketRoute
from starlette.testclient import TestClient
Expand Down Expand Up @@ -260,7 +261,6 @@ async def homepage(request: Request) -> PlainTextResponse:
@pytest.mark.anyio
async def test_run_background_tasks_even_if_client_disconnects() -> None:
# test for https://github.com/encode/starlette/issues/1438
request_body_sent = False
response_complete = anyio.Event()
background_task_run = anyio.Event()

Expand Down Expand Up @@ -293,13 +293,7 @@ async def passthrough(
}

async def receive() -> Message:
nonlocal request_body_sent
if not request_body_sent:
request_body_sent = True
return {"type": "http.request", "body": b"", "more_body": False}
# We simulate a client that disconnects immediately after receiving the response
await response_complete.wait()
return {"type": "http.disconnect"}
raise NotImplementedError("Should not be called!") # pragma: no cover

async def send(message: Message) -> None:
if message["type"] == "http.response.body":
Expand All @@ -313,7 +307,6 @@ async def send(message: Message) -> None:

@pytest.mark.anyio
async def test_do_not_block_on_background_tasks() -> None:
request_body_sent = False
response_complete = anyio.Event()
events: list[str | Message] = []

Expand Down Expand Up @@ -345,12 +338,7 @@ async def passthrough(
}

async def receive() -> Message:
nonlocal request_body_sent
if not request_body_sent:
request_body_sent = True
return {"type": "http.request", "body": b"", "more_body": False}
await response_complete.wait()
return {"type": "http.disconnect"}
raise NotImplementedError("Should not be called!") # pragma: no cover

async def send(message: Message) -> None:
if message["type"] == "http.response.body":
Expand Down Expand Up @@ -379,7 +367,6 @@ async def send(message: Message) -> None:
@pytest.mark.anyio
async def test_run_context_manager_exit_even_if_client_disconnects() -> None:
# test for https://github.com/encode/starlette/issues/1678#issuecomment-1172916042
request_body_sent = False
response_complete = anyio.Event()
context_manager_exited = anyio.Event()

Expand Down Expand Up @@ -424,13 +411,7 @@ async def passthrough(
}

async def receive() -> Message:
nonlocal request_body_sent
if not request_body_sent:
request_body_sent = True
return {"type": "http.request", "body": b"", "more_body": False}
# We simulate a client that disconnects immediately after receiving the response
await response_complete.wait()
return {"type": "http.disconnect"}
raise NotImplementedError("Should not be called!") # pragma: no cover

async def send(message: Message) -> None:
if message["type"] == "http.response.body":
Expand Down Expand Up @@ -778,7 +759,9 @@ async def rcv() -> AsyncGenerator[Message, None]:
yield {"type": "http.request", "body": b"1", "more_body": True}
yield {"type": "http.request", "body": b"2", "more_body": True}
yield {"type": "http.request", "body": b"3"}
await anyio.sleep(float("inf"))
raise AssertionError( # pragma: no cover
"Should not be called, no need to poll for disconnect"
)

sent: list[Message] = []

Expand Down Expand Up @@ -1033,3 +1016,139 @@ async def endpoint(request: Request) -> Response:
resp.raise_for_status()

assert bodies == [b"Hello, World!-foo"]


@pytest.mark.anyio
async def test_multiple_middlewares_stacked_client_disconnected() -> None:
class MyMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp, version: int, events: list[str]) -> None:
self.version = version
self.events = events
super().__init__(app)

async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
self.events.append(f"{self.version}:STARTED")
res = await call_next(request)
self.events.append(f"{self.version}:COMPLETED")
return res

async def sleepy(request: Request) -> Response:
try:
await request.body()
except ClientDisconnect:
pass
else: # pragma: no cover
raise AssertionError("Should have raised ClientDisconnect")
return Response(b"")

events: list[str] = []

app = Starlette(
routes=[Route("/", sleepy)],
middleware=[
Middleware(MyMiddleware, version=_ + 1, events=events) for _ in range(10)
],
)

scope = {
"type": "http",
"version": "3",
"method": "GET",
"path": "/",
}

async def receive() -> AsyncIterator[Message]:
yield {"type": "http.disconnect"}

sent: list[Message] = []

async def send(message: Message) -> None:
sent.append(message)

await app(scope, receive().__anext__, send)

assert events == [
"1:STARTED",
"2:STARTED",
"3:STARTED",
"4:STARTED",
"5:STARTED",
"6:STARTED",
"7:STARTED",
"8:STARTED",
"9:STARTED",
"10:STARTED",
"10:COMPLETED",
"9:COMPLETED",
"8:COMPLETED",
"7:COMPLETED",
"6:COMPLETED",
"5:COMPLETED",
"4:COMPLETED",
"3:COMPLETED",
"2:COMPLETED",
"1:COMPLETED",
]

assert sent == [
{
"type": "http.response.start",
"status": 200,
"headers": [(b"content-length", b"0")],
},
{"type": "http.response.body", "body": b"", "more_body": False},
]


@pytest.mark.anyio
@pytest.mark.parametrize("send_body", [True, False])
async def test_poll_for_disconnect_repeated(send_body: bool) -> None:
async def app_poll_disconnect(scope: Scope, receive: Receive, send: Send) -> None:
for _ in range(2):
msg = await receive()
while msg["type"] == "http.request":
msg = await receive()
assert msg["type"] == "http.disconnect"
await Response(b"good!")(scope, receive, send)

class MyMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
return await call_next(request)

app = MyMiddleware(app_poll_disconnect)

scope = {
"type": "http",
"version": "3",
"method": "GET",
"path": "/",
}

async def receive() -> AsyncIterator[Message]:
# the key here is that we only ever send 1 htt.disconnect message
if send_body:
yield {"type": "http.request", "body": b"hello", "more_body": True}
yield {"type": "http.request", "body": b"", "more_body": False}
yield {"type": "http.disconnect"}
raise AssertionError("Should not be called, would hang") # pragma: no cover

sent: list[Message] = []

async def send(message: Message) -> None:
sent.append(message)

await app(scope, receive().__anext__, send)

assert sent == [
{
"type": "http.response.start",
"status": 200,
"headers": [(b"content-length", b"5")],
},
{"type": "http.response.body", "body": b"good!", "more_body": True},
{"type": "http.response.body", "body": b"", "more_body": False},
]

0 comments on commit 7c5047c

Please sign in to comment.