Skip to content

Commit

Permalink
Raise ClientDisconnected on send() when client disconnected (#2220)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex authored Feb 12, 2024
1 parent bd552df commit 1e5f1be
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 80 deletions.
96 changes: 40 additions & 56 deletions tests/protocols/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from uvicorn.lifespan.on import LifespanOn
from uvicorn.main import ServerState
from uvicorn.protocols.http.h11_impl import H11Protocol
from uvicorn.protocols.utils import ClientDisconnected

try:
from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol
Expand Down Expand Up @@ -369,9 +370,7 @@ async def test_close(http_protocol_cls: HTTPProtocol):


@pytest.mark.anyio
async def test_chunked_encoding(
http_protocol_cls: HTTPProtocol,
):
async def test_chunked_encoding(http_protocol_cls: HTTPProtocol):
app = Response(
b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"}
)
Expand All @@ -385,9 +384,7 @@ async def test_chunked_encoding(


@pytest.mark.anyio
async def test_chunked_encoding_empty_body(
http_protocol_cls: HTTPProtocol,
):
async def test_chunked_encoding_empty_body(http_protocol_cls: HTTPProtocol):
app = Response(
b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"}
)
Expand Down Expand Up @@ -416,9 +413,7 @@ async def test_chunked_encoding_head_request(


@pytest.mark.anyio
async def test_pipelined_requests(
http_protocol_cls: HTTPProtocol,
):
async def test_pipelined_requests(http_protocol_cls: HTTPProtocol):
app = Response("Hello, world", media_type="text/plain")

protocol = get_connected_protocol(app, http_protocol_cls)
Expand All @@ -440,9 +435,7 @@ async def test_pipelined_requests(


@pytest.mark.anyio
async def test_undersized_request(
http_protocol_cls: HTTPProtocol,
):
async def test_undersized_request(http_protocol_cls: HTTPProtocol):
app = Response(b"xxx", headers={"content-length": "10"})

protocol = get_connected_protocol(app, http_protocol_cls)
Expand All @@ -452,9 +445,7 @@ async def test_undersized_request(


@pytest.mark.anyio
async def test_oversized_request(
http_protocol_cls: HTTPProtocol,
):
async def test_oversized_request(http_protocol_cls: HTTPProtocol):
app = Response(b"xxx" * 20, headers={"content-length": "10"})

protocol = get_connected_protocol(app, http_protocol_cls)
Expand All @@ -464,9 +455,7 @@ async def test_oversized_request(


@pytest.mark.anyio
async def test_large_post_request(
http_protocol_cls: HTTPProtocol,
):
async def test_large_post_request(http_protocol_cls: HTTPProtocol):
app = Response("Hello, world", media_type="text/plain")

protocol = get_connected_protocol(app, http_protocol_cls)
Expand All @@ -486,9 +475,7 @@ async def test_invalid_http(http_protocol_cls: HTTPProtocol):


@pytest.mark.anyio
async def test_app_exception(
http_protocol_cls: HTTPProtocol,
):
async def test_app_exception(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
raise Exception()

Expand All @@ -500,9 +487,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_exception_during_response(
http_protocol_cls: HTTPProtocol,
):
async def test_exception_during_response(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
await send({"type": "http.response.body", "body": b"1", "more_body": True})
Expand All @@ -516,9 +501,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_no_response_returned(
http_protocol_cls: HTTPProtocol,
):
async def test_no_response_returned(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
...

Expand All @@ -530,9 +513,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_partial_response_returned(
http_protocol_cls: HTTPProtocol,
):
async def test_partial_response_returned(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})

Expand All @@ -544,9 +525,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_duplicate_start_message(
http_protocol_cls: HTTPProtocol,
):
async def test_duplicate_start_message(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
await send({"type": "http.response.start", "status": 200})
Expand All @@ -559,9 +538,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_missing_start_message(
http_protocol_cls: HTTPProtocol,
):
async def test_missing_start_message(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.body", "body": b""})

Expand All @@ -573,9 +550,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_message_after_body_complete(
http_protocol_cls: HTTPProtocol,
):
async def test_message_after_body_complete(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
await send({"type": "http.response.body", "body": b""})
Expand All @@ -589,9 +564,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_value_returned(
http_protocol_cls: HTTPProtocol,
):
async def test_value_returned(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
await send({"type": "http.response.body", "body": b""})
Expand All @@ -605,9 +578,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_early_disconnect(
http_protocol_cls: HTTPProtocol,
):
async def test_early_disconnect(http_protocol_cls: HTTPProtocol):
got_disconnect_event = False

async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
Expand All @@ -629,9 +600,26 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_early_response(
http_protocol_cls: HTTPProtocol,
):
async def test_disconnect_on_send(http_protocol_cls: HTTPProtocol) -> None:
got_disconnected = False

async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
try:
await send({"type": "http.response.start", "status": 200})
except ClientDisconnected:
nonlocal got_disconnected
got_disconnected = True

protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.eof_received()
protocol.connection_lost(None)
await protocol.loop.run_one()
assert got_disconnected


@pytest.mark.anyio
async def test_early_response(http_protocol_cls: HTTPProtocol):
app = Response("Hello, world", media_type="text/plain")

protocol = get_connected_protocol(app, http_protocol_cls)
Expand All @@ -643,9 +631,7 @@ async def test_early_response(


@pytest.mark.anyio
async def test_read_after_response(
http_protocol_cls: HTTPProtocol,
):
async def test_read_after_response(http_protocol_cls: HTTPProtocol):
message_after_response = None

async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
Expand All @@ -663,9 +649,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_http10_request(
http_protocol_cls: HTTPProtocol,
):
async def test_http10_request(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "http"
content = "Version: %s" % scope["http_version"]
Expand Down Expand Up @@ -876,8 +860,8 @@ async def asgi(receive: ASGIReceiveCallable, send: ASGISendCallable):
@pytest.mark.parametrize(
"asgi2or3_app, expected_scopes",
[
(asgi3app, {"version": "3.0", "spec_version": "2.3"}),
(asgi2app, {"version": "2.0", "spec_version": "2.3"}),
(asgi3app, {"version": "3.0", "spec_version": "2.4"}),
(asgi2app, {"version": "2.0", "spec_version": "2.4"}),
],
)
async def test_scopes(
Expand Down
17 changes: 10 additions & 7 deletions uvicorn/protocols/http/h11_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
service_unavailable,
)
from uvicorn.protocols.utils import (
ClientDisconnected,
get_client_addr,
get_local_addr,
get_path_with_query_string,
Expand Down Expand Up @@ -205,7 +206,7 @@ def handle_events(self) -> None:
"type": "http",
"asgi": {
"version": self.config.asgi_version,
"spec_version": "2.3",
"spec_version": "2.4",
},
"http_version": event.http_version.decode("ascii"),
"server": self.server,
Expand Down Expand Up @@ -412,6 +413,8 @@ async def run_asgi(self, app: "ASGI3Application") -> None:
result = await app( # type: ignore[func-returns-value]
self.scope, self.receive, self.send
)
except ClientDisconnected:
pass
except BaseException as exc:
msg = "Exception in ASGI application\n"
self.logger.error(msg, exc_info=exc)
Expand All @@ -436,7 +439,7 @@ async def run_asgi(self, app: "ASGI3Application") -> None:
self.on_response = lambda: None

async def send_500_response(self) -> None:
response_start_event: "HTTPResponseStartEvent" = {
response_start_event: HTTPResponseStartEvent = {
"type": "http.response.start",
"status": 500,
"headers": [
Expand All @@ -445,22 +448,22 @@ async def send_500_response(self) -> None:
],
}
await self.send(response_start_event)
response_body_event: "HTTPResponseBodyEvent" = {
response_body_event: HTTPResponseBodyEvent = {
"type": "http.response.body",
"body": b"Internal Server Error",
"more_body": False,
}
await self.send(response_body_event)

# ASGI interface
async def send(self, message: "ASGISendEvent") -> None:
async def send(self, message: ASGISendEvent) -> None:
message_type = message["type"]

if self.flow.write_paused and not self.disconnected:
await self.flow.drain()

if self.disconnected:
return
raise ClientDisconnected

if not self.response_started:
# Sending response status line and headers
Expand Down Expand Up @@ -527,7 +530,7 @@ async def send(self, message: "ASGISendEvent") -> None:
self.transport.close()
self.on_response()

async def receive(self) -> "ASGIReceiveEvent":
async def receive(self) -> ASGIReceiveEvent:
if self.waiting_for_100_continue and not self.transport.is_closing():
headers: list[tuple[str, str]] = []
event = h11.InformationalResponse(
Expand All @@ -545,7 +548,7 @@ async def receive(self) -> "ASGIReceiveEvent":
if self.disconnected or self.response_complete:
return {"type": "http.disconnect"}

message: "HTTPRequestEvent" = {
message: HTTPRequestEvent = {
"type": "http.request",
"body": self.body,
"more_body": self.more_body,
Expand Down
Loading

0 comments on commit 1e5f1be

Please sign in to comment.