From 7a83f4349dbf940ac194511f827be965fbe959cd Mon Sep 17 00:00:00 2001 From: Reinhold Bertram Date: Fri, 11 Feb 2022 09:15:53 +0100 Subject: [PATCH] Send HTTP 400 response for invalid request (#1352) * Send HTTP 400 response for invalid request Given an invalid request, respond with an HTTP 400 error instead of closing the connection without a response. * changed signature of send_400_response to msg as str Co-authored-by: Mark Breedlove --- tests/protocols/test_http.py | 4 +- uvicorn/protocols/http/h11_impl.py | 49 +++++++++++++----------- uvicorn/protocols/http/httptools_impl.py | 48 ++++++++++++----------- 3 files changed, 54 insertions(+), 47 deletions(-) diff --git a/tests/protocols/test_http.py b/tests/protocols/test_http.py index 533b828df8..d5f343268b 100644 --- a/tests/protocols/test_http.py +++ b/tests/protocols/test_http.py @@ -742,5 +742,5 @@ def test_invalid_http_request(request_line, protocol_cls, caplog, event_loop): with get_connected_protocol(app, protocol_cls, event_loop) as protocol: protocol.data_received(request) - assert not protocol.transport.buffer - assert "Invalid HTTP request received." in caplog.messages + assert b"HTTP/1.1 400 Bad Request" in protocol.transport.buffer + assert b"Invalid HTTP request received." in protocol.transport.buffer diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index 78e5ffdc1c..6c1520f050 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -130,7 +130,7 @@ def handle_events(self): except h11.RemoteProtocolError as exc: msg = "Invalid HTTP request received." self.logger.warning(msg, exc_info=exc) - self.transport.close() + self.send_400_response(msg) return event_type = type(event) @@ -225,28 +225,7 @@ def handle_upgrade(self, event): if upgrade_value != b"websocket" or self.ws_protocol_class is None: msg = "Unsupported upgrade request." self.logger.warning(msg) - - from uvicorn.protocols.websockets.auto import AutoWebSocketsProtocol - - if AutoWebSocketsProtocol is None: - msg = "No supported WebSocket library detected. Please use 'pip install uvicorn[standard]', or install 'websockets' or 'wsproto' manually." # noqa: E501 - self.logger.warning(msg) - - reason = STATUS_PHRASES[400] - headers = [ - (b"content-type", b"text/plain; charset=utf-8"), - (b"connection", b"close"), - ] - event = h11.Response(status_code=400, headers=headers, reason=reason) - output = self.conn.send(event) - self.transport.write(output) - event = h11.Data(data=b"Unsupported upgrade request.") - output = self.conn.send(event) - self.transport.write(output) - event = h11.EndOfMessage() - output = self.conn.send(event) - self.transport.write(output) - self.transport.close() + self.send_400_response(msg) return if self.logger.level <= TRACE_LOG_LEVEL: @@ -265,6 +244,30 @@ def handle_upgrade(self, event): protocol.data_received(b"".join(output)) self.transport.set_protocol(protocol) + def send_400_response(self, msg: str): + + from uvicorn.protocols.websockets.auto import AutoWebSocketsProtocol + + if AutoWebSocketsProtocol is None: + msg = "No supported WebSocket library detected. Please use 'pip install uvicorn[standard]', or install 'websockets' or 'wsproto' manually." # noqa: E501 + self.logger.warning(msg) + + reason = STATUS_PHRASES[400] + headers = [ + (b"content-type", b"text/plain; charset=utf-8"), + (b"connection", b"close"), + ] + event = h11.Response(status_code=400, headers=headers, reason=reason) + output = self.conn.send(event) + self.transport.write(output) + event = h11.Data(data=msg.encode("ascii")) + output = self.conn.send(event) + self.transport.write(output) + event = h11.EndOfMessage() + output = self.conn.send(event) + self.transport.write(output) + self.transport.close() + def on_response_complete(self): self.server_state.total_requests += 1 diff --git a/uvicorn/protocols/http/httptools_impl.py b/uvicorn/protocols/http/httptools_impl.py index e6183bbcb7..78afec0de6 100644 --- a/uvicorn/protocols/http/httptools_impl.py +++ b/uvicorn/protocols/http/httptools_impl.py @@ -126,7 +126,8 @@ def data_received(self, data): except httptools.HttpParserError as exc: msg = "Invalid HTTP request received." self.logger.warning(msg, exc_info=exc) - self.transport.close() + self.send_400_response(msg) + return except httptools.HttpParserUpgrade: self.handle_upgrade() @@ -139,27 +140,7 @@ def handle_upgrade(self): if upgrade_value != b"websocket" or self.ws_protocol_class is None: msg = "Unsupported upgrade request." self.logger.warning(msg) - - from uvicorn.protocols.websockets.auto import AutoWebSocketsProtocol - - if AutoWebSocketsProtocol is None: - msg = "No supported WebSocket library detected. Please use 'pip install uvicorn[standard]', or install 'websockets' or 'wsproto' manually." # noqa: E501 - self.logger.warning(msg) - - content = [STATUS_LINE[400]] - for name, value in self.default_headers: - content.extend([name, b": ", value, b"\r\n"]) - content.extend( - [ - b"content-type: text/plain; charset=utf-8\r\n", - b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n", - b"connection: close\r\n", - b"\r\n", - msg.encode("ascii"), - ] - ) - self.transport.write(b"".join(content)) - self.transport.close() + self.send_400_response(msg) return if self.logger.level <= TRACE_LOG_LEVEL: @@ -179,6 +160,29 @@ def handle_upgrade(self): protocol.data_received(b"".join(output)) self.transport.set_protocol(protocol) + def send_400_response(self, msg: str): + + from uvicorn.protocols.websockets.auto import AutoWebSocketsProtocol + + if AutoWebSocketsProtocol is None: + msg = "No supported WebSocket library detected. Please use 'pip install uvicorn[standard]', or install 'websockets' or 'wsproto' manually." # noqa: E501 + self.logger.warning(msg) + + content = [STATUS_LINE[400]] + for name, value in self.default_headers: + content.extend([name, b": ", value, b"\r\n"]) + content.extend( + [ + b"content-type: text/plain; charset=utf-8\r\n", + b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n", + b"connection: close\r\n", + b"\r\n", + msg.encode("ascii"), + ] + ) + self.transport.write(b"".join(content)) + self.transport.close() + # Parser callbacks def on_url(self, url): method = self.parser.get_method()