diff --git a/setup.cfg b/setup.cfg index 546caefb0..29afbf6da 100644 --- a/setup.cfg +++ b/setup.cfg @@ -41,7 +41,8 @@ files = uvicorn/middleware/__init__.py, uvicorn/protocols/__init__.py, uvicorn/protocols/http/__init__.py, - uvicorn/protocols/websockets/__init__.py + uvicorn/protocols/websockets/__init__.py, + uvicorn/protocols/http/httptools_impl.py [mypy-tests.*] diff --git a/uvicorn/_types.py b/uvicorn/_types.py index 0a547a50d..ddb701056 100644 --- a/uvicorn/_types.py +++ b/uvicorn/_types.py @@ -1,14 +1,49 @@ +import asyncio +import sys import types -import typing +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterable, + MutableMapping, + Optional, + Tuple, + Type, + Union, +) + +if sys.version_info >= (3, 8): + from typing import Protocol +else: + from typing_extensions import Protocol + +if TYPE_CHECKING: + from uvicorn.config import Config + from uvicorn.server_state import ServerState # WSGI -Environ = typing.MutableMapping[str, typing.Any] -ExcInfo = typing.Tuple[ - typing.Type[BaseException], BaseException, typing.Optional[types.TracebackType] -] -StartResponse = typing.Callable[ - [str, typing.Iterable[typing.Tuple[str, str]], typing.Optional[ExcInfo]], None -] -WSGIApp = typing.Callable[ - [Environ, StartResponse], typing.Union[typing.Iterable[bytes], BaseException] -] +Environ = MutableMapping[str, Any] +ExcInfo = Tuple[Type[BaseException], BaseException, Optional[types.TracebackType]] +StartResponse = Callable[[str, Iterable[Tuple[str, str]], Optional[ExcInfo]], None] +WSGIApp = Callable[[Environ, StartResponse], Union[Iterable[bytes], BaseException]] + + +class WebProtocol(Protocol): + def __init__( + self, + config: "Config", + server_state: "ServerState", + on_connection_lost: Optional[Callable[..., Any]], + _loop: Optional[asyncio.AbstractEventLoop] = None, + ) -> None: + ... + + def connection_made(self, transport) -> None: + ... + + def data_received(self, data: bytes) -> None: + ... + + def shutdown(self) -> None: + ... diff --git a/uvicorn/config.py b/uvicorn/config.py index 98430e867..6d36d0722 100644 --- a/uvicorn/config.py +++ b/uvicorn/config.py @@ -9,6 +9,7 @@ import sys from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union +from uvicorn._types import WebProtocol from uvicorn.logging import TRACE_LOG_LEVEL if sys.version_info < (3, 8): @@ -134,6 +135,8 @@ def create_ssl_context( class Config: + ws_protocol_class: Type[WebProtocol] + def __init__( self, app: Union[ASGIApplication, Callable, str], @@ -143,7 +146,7 @@ def __init__( fd: Optional[int] = None, loop: LoopSetupType = "auto", http: Union[Type[asyncio.Protocol], HTTPProtocolType] = "auto", - ws: Union[Type[asyncio.Protocol], WSProtocolType] = "auto", + ws: Union[Type[WebProtocol], WSProtocolType] = "auto", ws_max_size: int = 16 * 1024 * 1024, ws_ping_interval: int = 20, ws_ping_timeout: int = 20, @@ -337,7 +340,7 @@ def load(self) -> None: if isinstance(self.ws, str): ws_protocol_class = import_from_string(WS_PROTOCOLS[self.ws]) - self.ws_protocol_class: Optional[Type[asyncio.Protocol]] = ws_protocol_class + self.ws_protocol_class = ws_protocol_class else: self.ws_protocol_class = self.ws @@ -374,7 +377,7 @@ def load(self) -> None: if self.interface == "wsgi": self.loaded_app = WSGIMiddleware(self.loaded_app) - self.ws_protocol_class = None + self.ws_protocol_class = None # type: ignore[assignment] elif self.interface == "asgi2": self.loaded_app = ASGI2Middleware(self.loaded_app) diff --git a/uvicorn/protocols/http/httptools_impl.py b/uvicorn/protocols/http/httptools_impl.py index 557bd4f2a..e577652ef 100644 --- a/uvicorn/protocols/http/httptools_impl.py +++ b/uvicorn/protocols/http/httptools_impl.py @@ -3,10 +3,13 @@ import logging import re import urllib -from typing import Callable +from asyncio.events import TimerHandle +from typing import Any, ByteString, Callable, Optional, Tuple, cast import httptools +from asgiref.typing import ASGI3Application, ASGIReceiveEvent, ASGISendEvent, HTTPScope +from uvicorn.config import Config from uvicorn.logging import TRACE_LOG_LEVEL from uvicorn.protocols.http.flow_control import ( CLOSE_HEADER, @@ -21,12 +24,13 @@ get_remote_addr, is_ssl, ) +from uvicorn.server import ServerState HEADER_RE = re.compile(b'[\x00-\x1F\x7F()<>@,;:[]={} \t\\"]') HEADER_VALUE_RE = re.compile(b"[\x00-\x1F\x7F]") -def _get_status_line(status_code): +def _get_status_line(status_code: int) -> ByteString: try: phrase = http.HTTPStatus(status_code).phrase.encode() except ValueError: @@ -41,8 +45,12 @@ def _get_status_line(status_code): class HttpToolsProtocol(asyncio.Protocol): def __init__( - self, config, server_state, on_connection_lost: Callable = None, _loop=None - ): + self, + config: Config, + server_state: ServerState, + on_connection_lost: Callable[..., None] = None, + _loop: Optional[asyncio.BaseEventLoop] = None, + ) -> None: if not config.loaded: config.load() @@ -59,7 +67,7 @@ def __init__( self.limit_concurrency = config.limit_concurrency # Timeouts - self.timeout_keep_alive_task = None + self.timeout_keep_alive_task: Optional[TimerHandle] = None self.timeout_keep_alive = config.timeout_keep_alive # Global state @@ -69,22 +77,22 @@ def __init__( self.default_headers = server_state.default_headers # Per-connection state - self.transport = None - self.flow = None - self.server = None - self.client = None - self.scheme = None - self.pipeline = [] + self.transport: asyncio.Transport + self.flow: FlowControl + self.server: Optional[Tuple[str, int]] = None + self.client: Optional[Tuple[str, int]] = None + self.scheme: str + self.pipeline: list = [] # Per-request state - self.url = None - self.scope = None - self.headers = None + self.url: bytes + self.scope: HTTPScope + self.headers: list = [] self.expect_100_continue = False - self.cycle = None + self.cycle: RequestResponseCycle = None # type: ignore # Protocol interface - def connection_made(self, transport): + def connection_made(self, transport: asyncio.Transport) -> None: # type: ignore self.connections.add(self) self.transport = transport @@ -93,15 +101,15 @@ def connection_made(self, transport): self.client = get_remote_addr(transport) self.scheme = "https" if is_ssl(transport) else "http" - if self.logger.level <= TRACE_LOG_LEVEL: - prefix = "%s:%d - " % tuple(self.client) if self.client else "" + if self.logger.level <= TRACE_LOG_LEVEL and self.client is not None: + prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix) - def connection_lost(self, exc): + def connection_lost(self, exc: Any) -> None: self.connections.discard(self) - if self.logger.level <= TRACE_LOG_LEVEL: - prefix = "%s:%d - " % tuple(self.client) if self.client else "" + if self.logger.level <= TRACE_LOG_LEVEL and self.client is not None: + prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection lost", prefix) if self.cycle and not self.cycle.response_complete: @@ -116,15 +124,15 @@ def connection_lost(self, exc): if self.on_connection_lost is not None: self.on_connection_lost() - def eof_received(self): + def eof_received(self) -> None: pass - def _unset_keepalive_if_required(self): + def _unset_keepalive_if_required(self) -> None: if self.timeout_keep_alive_task is not None: self.timeout_keep_alive_task.cancel() self.timeout_keep_alive_task = None - def data_received(self, data): + def data_received(self, data: bytes) -> None: self._unset_keepalive_if_required() try: @@ -136,7 +144,7 @@ def data_received(self, data): except httptools.HttpParserUpgrade: self.handle_upgrade() - def handle_upgrade(self): + def handle_upgrade(self) -> None: upgrade_value = None for name, value in self.headers: if name == b"upgrade": @@ -168,8 +176,8 @@ def handle_upgrade(self): self.transport.close() return - if self.logger.level <= TRACE_LOG_LEVEL: - prefix = "%s:%d - " % tuple(self.client) if self.client else "" + if self.logger.level <= TRACE_LOG_LEVEL and self.client is not None: + prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix) self.connections.discard(self) @@ -185,10 +193,10 @@ def handle_upgrade(self): ) protocol.connection_made(self.transport) protocol.data_received(b"".join(output)) - self.transport.set_protocol(protocol) + self.transport.set_protocol(protocol) # type: ignore[arg-type] # Parser callbacks - def on_url(self, url): + def on_url(self, url: bytes) -> None: method = self.parser.get_method() parsed_url = httptools.parse_url(url) raw_path = parsed_url.path @@ -211,15 +219,16 @@ def on_url(self, url): "raw_path": raw_path, "query_string": parsed_url.query if parsed_url.query else b"", "headers": self.headers, + "extensions": {}, } - def on_header(self, name: bytes, value: bytes): + def on_header(self, name: bytes, value: bytes) -> None: name = name.lower() if name == b"expect" and value.lower() == b"100-continue": self.expect_100_continue = True self.headers.append((name, value)) - def on_headers_complete(self): + def on_headers_complete(self) -> None: http_version = self.parser.get_http_version() if http_version != "1.1": self.scope["http_version"] = http_version @@ -261,7 +270,7 @@ def on_headers_complete(self): self.flow.pause_reading() self.pipeline.insert(0, (self.cycle, app)) - def on_body(self, body: bytes): + def on_body(self, body: bytes) -> None: if self.parser.should_upgrade() or self.cycle.response_complete: return self.cycle.body += body @@ -269,13 +278,13 @@ def on_body(self, body: bytes): self.flow.pause_reading() self.cycle.message_event.set() - def on_message_complete(self): + def on_message_complete(self) -> None: if self.parser.should_upgrade() or self.cycle.response_complete: return self.cycle.more_body = False self.cycle.message_event.set() - def on_response_complete(self): + def on_response_complete(self) -> None: # Callback for pipelined HTTP requests to be started. self.server_state.total_requests += 1 @@ -299,7 +308,7 @@ def on_response_complete(self): task.add_done_callback(self.tasks.discard) self.tasks.add(task) - def shutdown(self): + def shutdown(self) -> None: """ Called by the server to commence a graceful shutdown. """ @@ -308,19 +317,19 @@ def shutdown(self): else: self.cycle.keep_alive = False - def pause_writing(self): + def pause_writing(self) -> None: """ Called by the transport when the write buffer exceeds the high water mark. """ self.flow.pause_writing() - def resume_writing(self): + def resume_writing(self) -> None: """ Called by the transport when the write buffer drops below the low water mark. """ self.flow.resume_writing() - def timeout_keep_alive_handler(self): + def timeout_keep_alive_handler(self) -> None: """ Called on a keep-alive connection if no new data is received after a short delay. @@ -332,17 +341,17 @@ def timeout_keep_alive_handler(self): class RequestResponseCycle: def __init__( self, - scope, - transport, - flow, - logger, - access_logger, - access_log, - default_headers, - message_event, - expect_100_continue, - keep_alive, - on_response, + scope: HTTPScope, + transport: asyncio.Transport, + flow: FlowControl, + logger: logging.Logger, + access_logger: logging.Logger, + access_log: bool, + default_headers: list, + message_event: asyncio.Event, + expect_100_continue: bool, + keep_alive: bool, + on_response: Callable, ): self.scope = scope self.transport = transport @@ -366,11 +375,11 @@ def __init__( # Response state self.response_started = False self.response_complete = False - self.chunked_encoding = None + self.chunked_encoding: Optional[bool] = None self.expected_content_length = 0 # ASGI exception wrapper - async def run_asgi(self, app): + async def run_asgi(self, app: ASGI3Application) -> None: try: result = await app(self.scope, self.receive, self.send) except BaseException as exc: @@ -394,11 +403,11 @@ async def run_asgi(self, app): self.logger.error(msg) self.transport.close() finally: - self.on_response = None + self.on_response = lambda _: None - async def send_500_response(self): + async def send_500_response(self) -> None: await self.send( - { + { # type: ignore "type": "http.response.start", "status": 500, "headers": [ @@ -408,11 +417,14 @@ async def send_500_response(self): } ) await self.send( - {"type": "http.response.body", "body": b"Internal Server Error"} + { # type: ignore + "type": "http.response.body", + "body": b"Internal Server Error", + } ) # ASGI interface - async def send(self, message): + async def send(self, message: ASGISendEvent) -> None: message_type = message["type"] if self.flow.write_paused and not self.disconnected: @@ -430,8 +442,9 @@ async def send(self, message): self.response_started = True self.waiting_for_100_continue = False - status_code = message["status"] - headers = self.default_headers + list(message.get("headers", [])) + status_code = message["status"] # type: ignore + headers = list(cast(list, message.get("headers", []))) + headers += self.default_headers if CLOSE_HEADER in self.scope["headers"] and CLOSE_HEADER not in headers: headers = headers + [CLOSE_HEADER] @@ -484,8 +497,8 @@ async def send(self, message): msg = "Expected ASGI message 'http.response.body', but got '%s'." raise RuntimeError(msg % message_type) - body = message.get("body", b"") - more_body = message.get("more_body", False) + body: bytes = message.get("body", b"") # type: ignore + more_body = message.get("more_body", False) # type: ignore # Write response body if self.scope["method"] == "HEAD": @@ -521,7 +534,7 @@ async def send(self, message): msg = "Unexpected ASGI message '%s' sent, after response already completed." raise RuntimeError(msg % message_type) - async def receive(self): + async def receive(self) -> ASGIReceiveEvent: if self.waiting_for_100_continue and not self.transport.is_closing(): self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n") self.waiting_for_100_continue = False @@ -536,9 +549,9 @@ async def receive(self): else: message = { "type": "http.request", - "body": self.body, - "more_body": self.more_body, + "body": self.body, # type: ignore + "more_body": self.more_body, # type: ignore } self.body = b"" - return message + return message # type: ignore diff --git a/uvicorn/server.py b/uvicorn/server.py index e8de6512d..cfbbda1f4 100644 --- a/uvicorn/server.py +++ b/uvicorn/server.py @@ -9,16 +9,13 @@ import time from email.utils import formatdate from types import FrameType -from typing import Any, List, Optional, Set, Tuple, Union +from typing import Any, List, Optional import click from uvicorn._handlers.http import handle_http from uvicorn.config import Config -from uvicorn.protocols.http.h11_impl import H11Protocol -from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol -from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol -from uvicorn.protocols.websockets.wsproto_impl import WSProtocol +from uvicorn.server_state import ServerState if sys.platform != "win32": from asyncio import start_unix_server as _start_unix_server @@ -35,20 +32,6 @@ async def _start_unix_server(*args: Any, **kwargs: Any) -> Any: logger = logging.getLogger("uvicorn.error") -Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol] - - -class ServerState: - """ - Shared servers state that is available between all protocol instances. - """ - - def __init__(self) -> None: - self.total_requests = 0 - self.connections: Set[Protocols] = set() - self.tasks: Set[asyncio.Task] = set() - self.default_headers: List[Tuple[bytes, bytes]] = [] - class Server: def __init__(self, config: Config) -> None: @@ -306,7 +289,6 @@ def install_signal_handlers(self) -> None: signal.signal(sig, self.handle_exit) def handle_exit(self, sig: signal.Signals, frame: FrameType) -> None: - if self.should_exit: self.force_exit = True else: diff --git a/uvicorn/server_state.py b/uvicorn/server_state.py new file mode 100644 index 000000000..20d45eabe --- /dev/null +++ b/uvicorn/server_state.py @@ -0,0 +1,16 @@ +import asyncio +from typing import List, Set, Tuple + +from uvicorn._types import WebProtocol + + +class ServerState: + """ + Shared servers state that is available between all protocol instances. + """ + + def __init__(self) -> None: + self.total_requests = 0 + self.connections: Set[WebProtocol] = set() + self.tasks: Set[asyncio.Task] = set() + self.default_headers: List[Tuple[bytes, bytes]] = []