Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix BackgroundTasks with BaseHTTPMiddleware #2688

Merged
merged 4 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await send({"type": "http.response.body", "body": chunk, "more_body": True})

await send({"type": "http.response.body", "body": b"", "more_body": False})

if self.background:
await self.background()
Copy link

@dmitry-mli dmitry-mli Sep 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For full backward compatibility add it to constructor also:

class _StreamingResponse(Response):
    def __init__(
        self,
        content: AsyncContentStream,
        status_code: int = 200,
        headers: typing.Mapping[str, str] | None = None,
        media_type: str | None = None,
        background: BackgroundTask | None = None,                <<<
        info: typing.Mapping[str, typing.Any] | None = None,
    ) -> None:
        self.info = info
        self.body_iterator = content
        self.status_code = status_code
        self.media_type = media_type
        self.background = background                             <<<
        self.init_headers(headers)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because some middleware may be referring to that field. Or below will work also

self.background = None

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No users should be calling the constructor unless they do something like type(response)(...) which is not something we need to support.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, this aligns with the class being marked as private. Then please include

self.background = None

For full backward compatibility, any response object injected into BaseHTTPMiddleware based middleware needs to have the background field (default None)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I see your point now with the edit. Good point. Added.

@Kludex this is another reason to move the background task logic out of responses and a good reason to rework our Response inheritance so that Response and StreamingResponse both inherit from a BaseResponse that has the API and initialization of common bits.

41 changes: 33 additions & 8 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,16 +1006,28 @@ async def endpoint(request: Request) -> Response:

@pytest.mark.anyio
async def test_multiple_middlewares_stacked_client_disconnected() -> None:
"""
Tests for:
- https://github.com/encode/starlette/issues/2516
- https://github.com/encode/starlette/pull/2687
"""
ordered_events: list[str] = []
unordered_events: list[str] = []

class MyMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp, version: int, events: list[str]) -> None:
def __init__(self, app: ASGIApp, version: int) -> None:
self.version = version
self.events = events
super().__init__(app)

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
self.events.append(f"{self.version}:STARTED")
ordered_events.append(f"{self.version}:STARTED")
res = await call_next(request)
self.events.append(f"{self.version}:COMPLETED")
ordered_events.append(f"{self.version}:COMPLETED")

def background() -> None:
unordered_events.append(f"{self.version}:BACKGROUND")

res.background = BackgroundTask(background)
return res

async def sleepy(request: Request) -> Response:
Expand All @@ -1027,11 +1039,9 @@ async def sleepy(request: Request) -> Response:
raise AssertionError("Should have raised ClientDisconnect")
return Response(b"")

events: list[str] = []

app = Starlette(
routes=[Route("/", sleepy)],
middleware=[Middleware(MyMiddleware, version=_ + 1, events=events) for _ in range(10)],
middleware=[Middleware(MyMiddleware, version=_ + 1) for _ in range(10)],
)

scope = {
Expand All @@ -1051,7 +1061,7 @@ async def send(message: Message) -> None:

await app(scope, receive().__anext__, send)

assert events == [
assert ordered_events == [
"1:STARTED",
"2:STARTED",
"3:STARTED",
Expand All @@ -1074,6 +1084,21 @@ async def send(message: Message) -> None:
"1:COMPLETED",
]

assert sorted(unordered_events) == sorted(
[
"1:BACKGROUND",
"2:BACKGROUND",
"3:BACKGROUND",
"4:BACKGROUND",
"5:BACKGROUND",
"6:BACKGROUND",
"7:BACKGROUND",
"8:BACKGROUND",
"9:BACKGROUND",
"10:BACKGROUND",
]
)

assert sent == [
{
"type": "http.response.start",
Expand Down
Loading