diff --git a/setup.cfg b/setup.cfg index 641424992..03c5a1a7c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,6 +37,7 @@ files = uvicorn/protocols/http/__init__.py, uvicorn/protocols/websockets/__init__.py, uvicorn/protocols/http/h11_impl.py, + uvicorn/protocols/http/httptools_impl.py, tests/middleware/test_wsgi.py, tests/middleware/test_proxy_headers.py, tests/test_config.py, diff --git a/uvicorn/protocols/http/httptools_impl.py b/uvicorn/protocols/http/httptools_impl.py index 9af6f9b9e..e932b74b0 100644 --- a/uvicorn/protocols/http/httptools_impl.py +++ b/uvicorn/protocols/http/httptools_impl.py @@ -2,12 +2,26 @@ import http import logging import re +import sys import urllib +from asyncio.events import TimerHandle from collections import deque +from typing import Callable, Deque, List, Optional, Tuple, Union, cast import httptools +from asgiref.typing import ( + ASGI3Application, + ASGIReceiveEvent, + ASGISendEvent, + HTTPDisconnectEvent, + HTTPRequestEvent, + HTTPResponseBodyEvent, + HTTPResponseStartEvent, + HTTPScope, +) from uvicorn._logging import TRACE_LOG_LEVEL +from uvicorn.config import Config from uvicorn.protocols.http.flow_control import ( CLOSE_HEADER, HIGH_WATER_LIMIT, @@ -21,12 +35,18 @@ get_remote_addr, is_ssl, ) +from uvicorn.server import ServerState + +if sys.version_info < (3, 8): # pragma: py-gte-38 + from typing_extensions import Literal +else: # pragma: py-lt-38 + from typing import Literal HEADER_RE = re.compile(b'[\x00-\x1F\x7F()<>@,;:[]={} \t\\"]') HEADER_VALUE_RE = re.compile(b"[\x00-\x1F\x7F]") -def _get_status_line(status_code): +def _get_status_line(status_code: int) -> bytes: try: phrase = http.HTTPStatus(status_code).phrase.encode() except ValueError: @@ -40,7 +60,12 @@ def _get_status_line(status_code): class HttpToolsProtocol(asyncio.Protocol): - def __init__(self, config, server_state, _loop=None): + def __init__( + self, + config: Config, + server_state: ServerState, + _loop: Optional[asyncio.AbstractEventLoop] = None, + ) -> None: if not config.loaded: config.load() @@ -56,7 +81,7 @@ def __init__(self, config, server_state, _loop=None): self.limit_concurrency = config.limit_concurrency # Timeouts - self.timeout_keep_alive_task = None + self.timeout_keep_alive_task: Optional[TimerHandle] = None self.timeout_keep_alive = config.timeout_keep_alive # Global state @@ -66,21 +91,23 @@ def __init__(self, config, server_state, _loop=None): self.default_headers = server_state.default_headers # Per-connection state - self.transport = None - self.flow = None - self.server = None - self.client = None - self.scheme = None - self.pipeline = deque() + self.transport: asyncio.Transport = None # type: ignore[assignment] + self.flow: FlowControl = None # type: ignore[assignment] + self.server: Optional[Tuple[str, int]] = None + self.client: Optional[Tuple[str, int]] = None + self.scheme: Optional[Literal["http", "https"]] = None + self.pipeline: Deque[Tuple[RequestResponseCycle, ASGI3Application]] = deque() # Per-request state - self.scope = None - self.headers = None + self.scope: HTTPScope = None # type: ignore[assignment] + self.headers: List[Tuple[bytes, bytes]] = None # type: ignore[assignment] self.expect_100_continue = False - self.cycle = None + self.cycle: RequestResponseCycle = None # type: ignore[assignment] # Protocol interface - def connection_made(self, transport): + def connection_made( # type: ignore[override] + self, transport: asyncio.Transport + ) -> None: self.connections.add(self) self.transport = transport @@ -90,14 +117,14 @@ def connection_made(self, transport): self.scheme = "https" if is_ssl(transport) else "http" if self.logger.level <= TRACE_LOG_LEVEL: - prefix = "%s:%d - " % tuple(self.client) if self.client else "" + prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix) - def connection_lost(self, exc): + def connection_lost(self, exc: Optional[Exception]) -> None: self.connections.discard(self) if self.logger.level <= TRACE_LOG_LEVEL: - prefix = "%s:%d - " % tuple(self.client) if self.client else "" + prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection lost", prefix) if self.cycle and not self.cycle.response_complete: @@ -109,15 +136,15 @@ def connection_lost(self, exc): if exc is None: self.transport.close() - def eof_received(self): + def eof_received(self) -> None: pass - def _unset_keepalive_if_required(self): + def _unset_keepalive_if_required(self) -> None: if self.timeout_keep_alive_task is not None: self.timeout_keep_alive_task.cancel() self.timeout_keep_alive_task = None - def data_received(self, data): + def data_received(self, data: bytes) -> None: self._unset_keepalive_if_required() try: @@ -130,7 +157,7 @@ def data_received(self, data): except httptools.HttpParserUpgrade: self.handle_upgrade() - def handle_upgrade(self): + def handle_upgrade(self) -> None: upgrade_value = None for name, value in self.headers: if name == b"upgrade": @@ -148,7 +175,7 @@ def handle_upgrade(self): return if self.logger.level <= TRACE_LOG_LEVEL: - prefix = "%s:%d - " % tuple(self.client) if self.client else "" + prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix) self.connections.discard(self) @@ -157,14 +184,14 @@ def handle_upgrade(self): for name, value in self.scope["headers"]: output += [name, b": ", value, b"\r\n"] output.append(b"\r\n") - protocol = self.ws_protocol_class( + protocol = self.ws_protocol_class( # type: ignore[call-arg] config=self.config, server_state=self.server_state ) protocol.connection_made(self.transport) protocol.data_received(b"".join(output)) self.transport.set_protocol(protocol) - def send_400_response(self, msg: str): + def send_400_response(self, msg: str) -> None: content = [STATUS_LINE[400]] for name, value in self.default_headers: @@ -181,11 +208,11 @@ def send_400_response(self, msg: str): self.transport.write(b"".join(content)) self.transport.close() - def on_message_begin(self): + def on_message_begin(self) -> None: self.url = b"" self.expect_100_continue = False self.headers = [] - self.scope = { + self.scope = { # type: ignore[typeddict-item] "type": "http", "asgi": {"version": self.config.asgi_version, "spec_version": "2.3"}, "http_version": "1.1", @@ -197,16 +224,16 @@ def on_message_begin(self): } # Parser callbacks - def on_url(self, url): + def on_url(self, url: bytes) -> None: self.url += url - def on_header(self, name: bytes, value: bytes): + def on_header(self, name: bytes, value: bytes) -> None: name = name.lower() if name == b"expect" and value.lower() == b"100-continue": self.expect_100_continue = True self.headers.append((name, value)) - def on_headers_complete(self): + def on_headers_complete(self) -> None: http_version = self.parser.get_http_version() method = self.parser.get_method() self.scope["method"] = method.decode("ascii") @@ -258,7 +285,7 @@ def on_headers_complete(self): self.flow.pause_reading() self.pipeline.appendleft((self.cycle, app)) - def on_body(self, body: bytes): + def on_body(self, body: bytes) -> None: if self.parser.should_upgrade() or self.cycle.response_complete: return self.cycle.body += body @@ -266,13 +293,13 @@ def on_body(self, body: bytes): self.flow.pause_reading() self.cycle.message_event.set() - def on_message_complete(self): + def on_message_complete(self) -> None: if self.parser.should_upgrade() or self.cycle.response_complete: return self.cycle.more_body = False self.cycle.message_event.set() - def on_response_complete(self): + def on_response_complete(self) -> None: # Callback for pipelined HTTP requests to be started. self.server_state.total_requests += 1 @@ -296,7 +323,7 @@ def on_response_complete(self): task.add_done_callback(self.tasks.discard) self.tasks.add(task) - def shutdown(self): + def shutdown(self) -> None: """ Called by the server to commence a graceful shutdown. """ @@ -305,19 +332,19 @@ def shutdown(self): else: self.cycle.keep_alive = False - def pause_writing(self): + def pause_writing(self) -> None: """ Called by the transport when the write buffer exceeds the high water mark. """ self.flow.pause_writing() - def resume_writing(self): + def resume_writing(self) -> None: """ Called by the transport when the write buffer drops below the low water mark. """ self.flow.resume_writing() - def timeout_keep_alive_handler(self): + def timeout_keep_alive_handler(self) -> None: """ Called on a keep-alive connection if no new data is received after a short delay. @@ -329,17 +356,17 @@ def timeout_keep_alive_handler(self): class RequestResponseCycle: def __init__( self, - scope, - transport, - flow, - logger, - access_logger, - access_log, - default_headers, - message_event, - expect_100_continue, - keep_alive, - on_response, + scope: HTTPScope, + transport: asyncio.Transport, + flow: FlowControl, + logger: logging.Logger, + access_logger: logging.Logger, + access_log: bool, + default_headers: List[Tuple[bytes, bytes]], + message_event: asyncio.Event, + expect_100_continue: bool, + keep_alive: bool, + on_response: Callable[..., None], ): self.scope = scope self.transport = transport @@ -363,11 +390,11 @@ def __init__( # Response state self.response_started = False self.response_complete = False - self.chunked_encoding = None + self.chunked_encoding: Optional[bool] = None self.expected_content_length = 0 # ASGI exception wrapper - async def run_asgi(self, app): + async def run_asgi(self, app: ASGI3Application) -> None: try: result = await app(self.scope, self.receive, self.send) except BaseException as exc: @@ -391,25 +418,27 @@ async def run_asgi(self, app): self.logger.error(msg) self.transport.close() finally: - self.on_response = None - - async def send_500_response(self): - await self.send( - { - "type": "http.response.start", - "status": 500, - "headers": [ - (b"content-type", b"text/plain; charset=utf-8"), - (b"connection", b"close"), - ], - } - ) - await self.send( - {"type": "http.response.body", "body": b"Internal Server Error"} - ) + self.on_response = lambda: None + + async def send_500_response(self) -> None: + response_start_event: HTTPResponseStartEvent = { + "type": "http.response.start", + "status": 500, + "headers": [ + (b"content-type", b"text/plain; charset=utf-8"), + (b"connection", b"close"), + ], + } + await self.send(response_start_event) + response_body_event: HTTPResponseBodyEvent = { + "type": "http.response.body", + "body": b"Internal Server Error", + "more_body": False, + } + await self.send(response_body_event) # ASGI interface - async def send(self, message): + async def send(self, message: ASGISendEvent) -> None: message_type = message["type"] if self.flow.write_paused and not self.disconnected: @@ -423,6 +452,7 @@ async def send(self, message): if message_type != "http.response.start": msg = "Expected ASGI message 'http.response.start', but got '%s'." raise RuntimeError(msg % message_type) + message = cast(HTTPResponseStartEvent, message) self.response_started = True self.waiting_for_100_continue = False @@ -481,7 +511,7 @@ async def send(self, message): msg = "Expected ASGI message 'http.response.body', but got '%s'." raise RuntimeError(msg % message_type) - body = message.get("body", b"") + body = cast(bytes, message.get("body", b"")) more_body = message.get("more_body", False) # Write response body @@ -518,7 +548,7 @@ async def send(self, message): msg = "Unexpected ASGI message '%s' sent, after response already completed." raise RuntimeError(msg % message_type) - async def receive(self): + async def receive(self) -> ASGIReceiveEvent: if self.waiting_for_100_continue and not self.transport.is_closing(): self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n") self.waiting_for_100_continue = False @@ -528,6 +558,7 @@ async def receive(self): await self.message_event.wait() self.message_event.clear() + message: Union[HTTPDisconnectEvent, HTTPRequestEvent] if self.disconnected or self.response_complete: message = {"type": "http.disconnect"} else: