Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Annotate httptools_impl.py #1484

Merged
merged 2 commits into from
May 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ files =
uvicorn/protocols/http/__init__.py,
uvicorn/protocols/websockets/__init__.py,
uvicorn/protocols/http/h11_impl.py,
uvicorn/protocols/http/httptools_impl.py,
tests/middleware/test_wsgi.py,
tests/middleware/test_proxy_headers.py,
tests/test_config.py,
Expand Down
165 changes: 98 additions & 67 deletions uvicorn/protocols/http/httptools_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,26 @@
import http
import logging
import re
import sys
import urllib
from asyncio.events import TimerHandle
from collections import deque
from typing import Callable, Deque, List, Optional, Tuple, Union, cast

import httptools
from asgiref.typing import (
ASGI3Application,
ASGIReceiveEvent,
ASGISendEvent,
HTTPDisconnectEvent,
HTTPRequestEvent,
HTTPResponseBodyEvent,
HTTPResponseStartEvent,
HTTPScope,
)

from uvicorn._logging import TRACE_LOG_LEVEL
from uvicorn.config import Config
from uvicorn.protocols.http.flow_control import (
CLOSE_HEADER,
HIGH_WATER_LIMIT,
Expand All @@ -21,12 +35,18 @@
get_remote_addr,
is_ssl,
)
from uvicorn.server import ServerState

if sys.version_info < (3, 8): # pragma: py-gte-38
from typing_extensions import Literal
else: # pragma: py-lt-38
from typing import Literal
Comment on lines +40 to +43
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have a _compat.py module or something where this sorta boilerplate can go?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No


HEADER_RE = re.compile(b'[\x00-\x1F\x7F()<>@,;:[]={} \t\\"]')
HEADER_VALUE_RE = re.compile(b"[\x00-\x1F\x7F]")


def _get_status_line(status_code):
def _get_status_line(status_code: int) -> bytes:
try:
phrase = http.HTTPStatus(status_code).phrase.encode()
except ValueError:
Expand All @@ -40,7 +60,12 @@ def _get_status_line(status_code):


class HttpToolsProtocol(asyncio.Protocol):
def __init__(self, config, server_state, _loop=None):
def __init__(
self,
config: Config,
server_state: ServerState,
_loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
if not config.loaded:
config.load()

Expand All @@ -56,7 +81,7 @@ def __init__(self, config, server_state, _loop=None):
self.limit_concurrency = config.limit_concurrency

# Timeouts
self.timeout_keep_alive_task = None
self.timeout_keep_alive_task: Optional[TimerHandle] = None
self.timeout_keep_alive = config.timeout_keep_alive

# Global state
Expand All @@ -66,21 +91,23 @@ def __init__(self, config, server_state, _loop=None):
self.default_headers = server_state.default_headers

# Per-connection state
self.transport = None
self.flow = None
self.server = None
self.client = None
self.scheme = None
self.pipeline = deque()
self.transport: asyncio.Transport = None # type: ignore[assignment]
self.flow: FlowControl = None # type: ignore[assignment]
Comment on lines +94 to +95
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you using the # type: ignore[assignment] just to avoid some assert self.transport is not Nones later?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

self.server: Optional[Tuple[str, int]] = None
self.client: Optional[Tuple[str, int]] = None
self.scheme: Optional[Literal["http", "https"]] = None
self.pipeline: Deque[Tuple[RequestResponseCycle, ASGI3Application]] = deque()

# Per-request state
self.scope = None
self.headers = None
self.scope: HTTPScope = None # type: ignore[assignment]
self.headers: List[Tuple[bytes, bytes]] = None # type: ignore[assignment]
self.expect_100_continue = False
self.cycle = None
self.cycle: RequestResponseCycle = None # type: ignore[assignment]

# Protocol interface
def connection_made(self, transport):
def connection_made( # type: ignore[override]
self, transport: asyncio.Transport
) -> None:
self.connections.add(self)

self.transport = transport
Expand All @@ -90,14 +117,14 @@ def connection_made(self, transport):
self.scheme = "https" if is_ssl(transport) else "http"

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix)

def connection_lost(self, exc):
def connection_lost(self, exc: Optional[Exception]) -> None:
self.connections.discard(self)

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection lost", prefix)

if self.cycle and not self.cycle.response_complete:
Expand All @@ -109,15 +136,15 @@ def connection_lost(self, exc):
if exc is None:
self.transport.close()

def eof_received(self):
def eof_received(self) -> None:
pass

def _unset_keepalive_if_required(self):
def _unset_keepalive_if_required(self) -> None:
if self.timeout_keep_alive_task is not None:
self.timeout_keep_alive_task.cancel()
self.timeout_keep_alive_task = None

def data_received(self, data):
def data_received(self, data: bytes) -> None:
self._unset_keepalive_if_required()

try:
Expand All @@ -130,7 +157,7 @@ def data_received(self, data):
except httptools.HttpParserUpgrade:
self.handle_upgrade()

def handle_upgrade(self):
def handle_upgrade(self) -> None:
upgrade_value = None
for name, value in self.headers:
if name == b"upgrade":
Expand All @@ -148,7 +175,7 @@ def handle_upgrade(self):
return

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix)

self.connections.discard(self)
Expand All @@ -157,14 +184,14 @@ def handle_upgrade(self):
for name, value in self.scope["headers"]:
output += [name, b": ", value, b"\r\n"]
output.append(b"\r\n")
protocol = self.ws_protocol_class(
protocol = self.ws_protocol_class( # type: ignore[call-arg]
config=self.config, server_state=self.server_state
)
protocol.connection_made(self.transport)
protocol.data_received(b"".join(output))
self.transport.set_protocol(protocol)

def send_400_response(self, msg: str):
def send_400_response(self, msg: str) -> None:

content = [STATUS_LINE[400]]
for name, value in self.default_headers:
Expand All @@ -181,11 +208,11 @@ def send_400_response(self, msg: str):
self.transport.write(b"".join(content))
self.transport.close()

def on_message_begin(self):
def on_message_begin(self) -> None:
self.url = b""
self.expect_100_continue = False
self.headers = []
self.scope = {
self.scope = { # type: ignore[typeddict-item]
"type": "http",
"asgi": {"version": self.config.asgi_version, "spec_version": "2.3"},
"http_version": "1.1",
Expand All @@ -197,16 +224,16 @@ def on_message_begin(self):
}

# Parser callbacks
def on_url(self, url):
def on_url(self, url: bytes) -> None:
self.url += url

def on_header(self, name: bytes, value: bytes):
def on_header(self, name: bytes, value: bytes) -> None:
name = name.lower()
if name == b"expect" and value.lower() == b"100-continue":
self.expect_100_continue = True
self.headers.append((name, value))

def on_headers_complete(self):
def on_headers_complete(self) -> None:
http_version = self.parser.get_http_version()
method = self.parser.get_method()
self.scope["method"] = method.decode("ascii")
Expand Down Expand Up @@ -258,21 +285,21 @@ def on_headers_complete(self):
self.flow.pause_reading()
self.pipeline.appendleft((self.cycle, app))

def on_body(self, body: bytes):
def on_body(self, body: bytes) -> None:
if self.parser.should_upgrade() or self.cycle.response_complete:
return
self.cycle.body += body
if len(self.cycle.body) > HIGH_WATER_LIMIT:
self.flow.pause_reading()
self.cycle.message_event.set()

def on_message_complete(self):
def on_message_complete(self) -> None:
if self.parser.should_upgrade() or self.cycle.response_complete:
return
self.cycle.more_body = False
self.cycle.message_event.set()

def on_response_complete(self):
def on_response_complete(self) -> None:
# Callback for pipelined HTTP requests to be started.
self.server_state.total_requests += 1

Expand All @@ -296,7 +323,7 @@ def on_response_complete(self):
task.add_done_callback(self.tasks.discard)
self.tasks.add(task)

def shutdown(self):
def shutdown(self) -> None:
"""
Called by the server to commence a graceful shutdown.
"""
Expand All @@ -305,19 +332,19 @@ def shutdown(self):
else:
self.cycle.keep_alive = False

def pause_writing(self):
def pause_writing(self) -> None:
"""
Called by the transport when the write buffer exceeds the high water mark.
"""
self.flow.pause_writing()

def resume_writing(self):
def resume_writing(self) -> None:
"""
Called by the transport when the write buffer drops below the low water mark.
"""
self.flow.resume_writing()

def timeout_keep_alive_handler(self):
def timeout_keep_alive_handler(self) -> None:
"""
Called on a keep-alive connection if no new data is received after a short
delay.
Expand All @@ -329,17 +356,17 @@ def timeout_keep_alive_handler(self):
class RequestResponseCycle:
def __init__(
self,
scope,
transport,
flow,
logger,
access_logger,
access_log,
default_headers,
message_event,
expect_100_continue,
keep_alive,
on_response,
scope: HTTPScope,
transport: asyncio.Transport,
flow: FlowControl,
logger: logging.Logger,
access_logger: logging.Logger,
access_log: bool,
default_headers: List[Tuple[bytes, bytes]],
message_event: asyncio.Event,
expect_100_continue: bool,
keep_alive: bool,
on_response: Callable[..., None],
):
self.scope = scope
self.transport = transport
Expand All @@ -363,11 +390,11 @@ def __init__(
# Response state
self.response_started = False
self.response_complete = False
self.chunked_encoding = None
self.chunked_encoding: Optional[bool] = None
self.expected_content_length = 0

# ASGI exception wrapper
async def run_asgi(self, app):
async def run_asgi(self, app: ASGI3Application) -> None:
try:
result = await app(self.scope, self.receive, self.send)
except BaseException as exc:
Expand All @@ -391,25 +418,27 @@ async def run_asgi(self, app):
self.logger.error(msg)
self.transport.close()
finally:
self.on_response = None

async def send_500_response(self):
await self.send(
{
"type": "http.response.start",
"status": 500,
"headers": [
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
],
}
)
await self.send(
{"type": "http.response.body", "body": b"Internal Server Error"}
)
self.on_response = lambda: None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This went from self.on_response = None to self.on_response = lambda: None`. The new version sounds more correct. Was this a bug?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. That was introduced on #189. I don't know if the line itself is really necessary (?)...


async def send_500_response(self) -> None:
response_start_event: HTTPResponseStartEvent = {
"type": "http.response.start",
"status": 500,
"headers": [
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
],
}
await self.send(response_start_event)
response_body_event: HTTPResponseBodyEvent = {
"type": "http.response.body",
"body": b"Internal Server Error",
"more_body": False,
}
await self.send(response_body_event)

# ASGI interface
async def send(self, message):
async def send(self, message: ASGISendEvent) -> None:
message_type = message["type"]

if self.flow.write_paused and not self.disconnected:
Expand All @@ -423,6 +452,7 @@ async def send(self, message):
if message_type != "http.response.start":
msg = "Expected ASGI message 'http.response.start', but got '%s'."
raise RuntimeError(msg % message_type)
message = cast(HTTPResponseStartEvent, message)

self.response_started = True
self.waiting_for_100_continue = False
Expand Down Expand Up @@ -481,7 +511,7 @@ async def send(self, message):
msg = "Expected ASGI message 'http.response.body', but got '%s'."
raise RuntimeError(msg % message_type)

body = message.get("body", b"")
body = cast(bytes, message.get("body", b""))
more_body = message.get("more_body", False)

# Write response body
Expand Down Expand Up @@ -518,7 +548,7 @@ async def send(self, message):
msg = "Unexpected ASGI message '%s' sent, after response already completed."
raise RuntimeError(msg % message_type)

async def receive(self):
async def receive(self) -> ASGIReceiveEvent:
if self.waiting_for_100_continue and not self.transport.is_closing():
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
self.waiting_for_100_continue = False
Expand All @@ -528,6 +558,7 @@ async def receive(self):
await self.message_event.wait()
self.message_event.clear()

message: Union[HTTPDisconnectEvent, HTTPRequestEvent]
if self.disconnected or self.response_complete:
message = {"type": "http.disconnect"}
else:
Expand Down