Skip to content

Commit

Permalink
Annotate httptools_impl.py (#1484)
Browse files Browse the repository at this point in the history
* Annotate `httptools_impl.py`

* Update uvicorn/protocols/http/httptools_impl.py
  • Loading branch information
Kludex authored May 13, 2022
1 parent 208ef4e commit da9beb7
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 67 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ files =
uvicorn/protocols/http/__init__.py,
uvicorn/protocols/websockets/__init__.py,
uvicorn/protocols/http/h11_impl.py,
uvicorn/protocols/http/httptools_impl.py,
tests/middleware/test_wsgi.py,
tests/middleware/test_proxy_headers.py,
tests/test_config.py,
Expand Down
165 changes: 98 additions & 67 deletions uvicorn/protocols/http/httptools_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,26 @@
import http
import logging
import re
import sys
import urllib
from asyncio.events import TimerHandle
from collections import deque
from typing import Callable, Deque, List, Optional, Tuple, Union, cast

import httptools
from asgiref.typing import (
ASGI3Application,
ASGIReceiveEvent,
ASGISendEvent,
HTTPDisconnectEvent,
HTTPRequestEvent,
HTTPResponseBodyEvent,
HTTPResponseStartEvent,
HTTPScope,
)

from uvicorn._logging import TRACE_LOG_LEVEL
from uvicorn.config import Config
from uvicorn.protocols.http.flow_control import (
CLOSE_HEADER,
HIGH_WATER_LIMIT,
Expand All @@ -21,12 +35,18 @@
get_remote_addr,
is_ssl,
)
from uvicorn.server import ServerState

if sys.version_info < (3, 8): # pragma: py-gte-38
from typing_extensions import Literal
else: # pragma: py-lt-38
from typing import Literal

HEADER_RE = re.compile(b'[\x00-\x1F\x7F()<>@,;:[]={} \t\\"]')
HEADER_VALUE_RE = re.compile(b"[\x00-\x1F\x7F]")


def _get_status_line(status_code):
def _get_status_line(status_code: int) -> bytes:
try:
phrase = http.HTTPStatus(status_code).phrase.encode()
except ValueError:
Expand All @@ -40,7 +60,12 @@ def _get_status_line(status_code):


class HttpToolsProtocol(asyncio.Protocol):
def __init__(self, config, server_state, _loop=None):
def __init__(
self,
config: Config,
server_state: ServerState,
_loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
if not config.loaded:
config.load()

Expand All @@ -56,7 +81,7 @@ def __init__(self, config, server_state, _loop=None):
self.limit_concurrency = config.limit_concurrency

# Timeouts
self.timeout_keep_alive_task = None
self.timeout_keep_alive_task: Optional[TimerHandle] = None
self.timeout_keep_alive = config.timeout_keep_alive

# Global state
Expand All @@ -66,21 +91,23 @@ def __init__(self, config, server_state, _loop=None):
self.default_headers = server_state.default_headers

# Per-connection state
self.transport = None
self.flow = None
self.server = None
self.client = None
self.scheme = None
self.pipeline = deque()
self.transport: asyncio.Transport = None # type: ignore[assignment]
self.flow: FlowControl = None # type: ignore[assignment]
self.server: Optional[Tuple[str, int]] = None
self.client: Optional[Tuple[str, int]] = None
self.scheme: Optional[Literal["http", "https"]] = None
self.pipeline: Deque[Tuple[RequestResponseCycle, ASGI3Application]] = deque()

# Per-request state
self.scope = None
self.headers = None
self.scope: HTTPScope = None # type: ignore[assignment]
self.headers: List[Tuple[bytes, bytes]] = None # type: ignore[assignment]
self.expect_100_continue = False
self.cycle = None
self.cycle: RequestResponseCycle = None # type: ignore[assignment]

# Protocol interface
def connection_made(self, transport):
def connection_made( # type: ignore[override]
self, transport: asyncio.Transport
) -> None:
self.connections.add(self)

self.transport = transport
Expand All @@ -90,14 +117,14 @@ def connection_made(self, transport):
self.scheme = "https" if is_ssl(transport) else "http"

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix)

def connection_lost(self, exc):
def connection_lost(self, exc: Optional[Exception]) -> None:
self.connections.discard(self)

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection lost", prefix)

if self.cycle and not self.cycle.response_complete:
Expand All @@ -109,15 +136,15 @@ def connection_lost(self, exc):
if exc is None:
self.transport.close()

def eof_received(self):
def eof_received(self) -> None:
pass

def _unset_keepalive_if_required(self):
def _unset_keepalive_if_required(self) -> None:
if self.timeout_keep_alive_task is not None:
self.timeout_keep_alive_task.cancel()
self.timeout_keep_alive_task = None

def data_received(self, data):
def data_received(self, data: bytes) -> None:
self._unset_keepalive_if_required()

try:
Expand All @@ -130,7 +157,7 @@ def data_received(self, data):
except httptools.HttpParserUpgrade:
self.handle_upgrade()

def handle_upgrade(self):
def handle_upgrade(self) -> None:
upgrade_value = None
for name, value in self.headers:
if name == b"upgrade":
Expand All @@ -148,7 +175,7 @@ def handle_upgrade(self):
return

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix)

self.connections.discard(self)
Expand All @@ -157,14 +184,14 @@ def handle_upgrade(self):
for name, value in self.scope["headers"]:
output += [name, b": ", value, b"\r\n"]
output.append(b"\r\n")
protocol = self.ws_protocol_class(
protocol = self.ws_protocol_class( # type: ignore[call-arg]
config=self.config, server_state=self.server_state
)
protocol.connection_made(self.transport)
protocol.data_received(b"".join(output))
self.transport.set_protocol(protocol)

def send_400_response(self, msg: str):
def send_400_response(self, msg: str) -> None:

content = [STATUS_LINE[400]]
for name, value in self.default_headers:
Expand All @@ -181,11 +208,11 @@ def send_400_response(self, msg: str):
self.transport.write(b"".join(content))
self.transport.close()

def on_message_begin(self):
def on_message_begin(self) -> None:
self.url = b""
self.expect_100_continue = False
self.headers = []
self.scope = {
self.scope = { # type: ignore[typeddict-item]
"type": "http",
"asgi": {"version": self.config.asgi_version, "spec_version": "2.3"},
"http_version": "1.1",
Expand All @@ -197,16 +224,16 @@ def on_message_begin(self):
}

# Parser callbacks
def on_url(self, url):
def on_url(self, url: bytes) -> None:
self.url += url

def on_header(self, name: bytes, value: bytes):
def on_header(self, name: bytes, value: bytes) -> None:
name = name.lower()
if name == b"expect" and value.lower() == b"100-continue":
self.expect_100_continue = True
self.headers.append((name, value))

def on_headers_complete(self):
def on_headers_complete(self) -> None:
http_version = self.parser.get_http_version()
method = self.parser.get_method()
self.scope["method"] = method.decode("ascii")
Expand Down Expand Up @@ -258,21 +285,21 @@ def on_headers_complete(self):
self.flow.pause_reading()
self.pipeline.appendleft((self.cycle, app))

def on_body(self, body: bytes):
def on_body(self, body: bytes) -> None:
if self.parser.should_upgrade() or self.cycle.response_complete:
return
self.cycle.body += body
if len(self.cycle.body) > HIGH_WATER_LIMIT:
self.flow.pause_reading()
self.cycle.message_event.set()

def on_message_complete(self):
def on_message_complete(self) -> None:
if self.parser.should_upgrade() or self.cycle.response_complete:
return
self.cycle.more_body = False
self.cycle.message_event.set()

def on_response_complete(self):
def on_response_complete(self) -> None:
# Callback for pipelined HTTP requests to be started.
self.server_state.total_requests += 1

Expand All @@ -296,7 +323,7 @@ def on_response_complete(self):
task.add_done_callback(self.tasks.discard)
self.tasks.add(task)

def shutdown(self):
def shutdown(self) -> None:
"""
Called by the server to commence a graceful shutdown.
"""
Expand All @@ -305,19 +332,19 @@ def shutdown(self):
else:
self.cycle.keep_alive = False

def pause_writing(self):
def pause_writing(self) -> None:
"""
Called by the transport when the write buffer exceeds the high water mark.
"""
self.flow.pause_writing()

def resume_writing(self):
def resume_writing(self) -> None:
"""
Called by the transport when the write buffer drops below the low water mark.
"""
self.flow.resume_writing()

def timeout_keep_alive_handler(self):
def timeout_keep_alive_handler(self) -> None:
"""
Called on a keep-alive connection if no new data is received after a short
delay.
Expand All @@ -329,17 +356,17 @@ def timeout_keep_alive_handler(self):
class RequestResponseCycle:
def __init__(
self,
scope,
transport,
flow,
logger,
access_logger,
access_log,
default_headers,
message_event,
expect_100_continue,
keep_alive,
on_response,
scope: HTTPScope,
transport: asyncio.Transport,
flow: FlowControl,
logger: logging.Logger,
access_logger: logging.Logger,
access_log: bool,
default_headers: List[Tuple[bytes, bytes]],
message_event: asyncio.Event,
expect_100_continue: bool,
keep_alive: bool,
on_response: Callable[..., None],
):
self.scope = scope
self.transport = transport
Expand All @@ -363,11 +390,11 @@ def __init__(
# Response state
self.response_started = False
self.response_complete = False
self.chunked_encoding = None
self.chunked_encoding: Optional[bool] = None
self.expected_content_length = 0

# ASGI exception wrapper
async def run_asgi(self, app):
async def run_asgi(self, app: ASGI3Application) -> None:
try:
result = await app(self.scope, self.receive, self.send)
except BaseException as exc:
Expand All @@ -391,25 +418,27 @@ async def run_asgi(self, app):
self.logger.error(msg)
self.transport.close()
finally:
self.on_response = None

async def send_500_response(self):
await self.send(
{
"type": "http.response.start",
"status": 500,
"headers": [
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
],
}
)
await self.send(
{"type": "http.response.body", "body": b"Internal Server Error"}
)
self.on_response = lambda: None

async def send_500_response(self) -> None:
response_start_event: HTTPResponseStartEvent = {
"type": "http.response.start",
"status": 500,
"headers": [
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
],
}
await self.send(response_start_event)
response_body_event: HTTPResponseBodyEvent = {
"type": "http.response.body",
"body": b"Internal Server Error",
"more_body": False,
}
await self.send(response_body_event)

# ASGI interface
async def send(self, message):
async def send(self, message: ASGISendEvent) -> None:
message_type = message["type"]

if self.flow.write_paused and not self.disconnected:
Expand All @@ -423,6 +452,7 @@ async def send(self, message):
if message_type != "http.response.start":
msg = "Expected ASGI message 'http.response.start', but got '%s'."
raise RuntimeError(msg % message_type)
message = cast(HTTPResponseStartEvent, message)

self.response_started = True
self.waiting_for_100_continue = False
Expand Down Expand Up @@ -481,7 +511,7 @@ async def send(self, message):
msg = "Expected ASGI message 'http.response.body', but got '%s'."
raise RuntimeError(msg % message_type)

body = message.get("body", b"")
body = cast(bytes, message.get("body", b""))
more_body = message.get("more_body", False)

# Write response body
Expand Down Expand Up @@ -518,7 +548,7 @@ async def send(self, message):
msg = "Unexpected ASGI message '%s' sent, after response already completed."
raise RuntimeError(msg % message_type)

async def receive(self):
async def receive(self) -> ASGIReceiveEvent:
if self.waiting_for_100_continue and not self.transport.is_closing():
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
self.waiting_for_100_continue = False
Expand All @@ -528,6 +558,7 @@ async def receive(self):
await self.message_event.wait()
self.message_event.clear()

message: Union[HTTPDisconnectEvent, HTTPRequestEvent]
if self.disconnected or self.response_complete:
message = {"type": "http.disconnect"}
else:
Expand Down

0 comments on commit da9beb7

Please sign in to comment.