-
-
Notifications
You must be signed in to change notification settings - Fork 754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Annotate httptools_impl.py
#1484
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,12 +2,26 @@ | |
import http | ||
import logging | ||
import re | ||
import sys | ||
import urllib | ||
from asyncio.events import TimerHandle | ||
from collections import deque | ||
from typing import Callable, Deque, List, Optional, Tuple, Union, cast | ||
|
||
import httptools | ||
from asgiref.typing import ( | ||
ASGI3Application, | ||
ASGIReceiveEvent, | ||
ASGISendEvent, | ||
HTTPDisconnectEvent, | ||
HTTPRequestEvent, | ||
HTTPResponseBodyEvent, | ||
HTTPResponseStartEvent, | ||
HTTPScope, | ||
) | ||
|
||
from uvicorn._logging import TRACE_LOG_LEVEL | ||
from uvicorn.config import Config | ||
from uvicorn.protocols.http.flow_control import ( | ||
CLOSE_HEADER, | ||
HIGH_WATER_LIMIT, | ||
|
@@ -21,12 +35,18 @@ | |
get_remote_addr, | ||
is_ssl, | ||
) | ||
from uvicorn.server import ServerState | ||
|
||
if sys.version_info < (3, 8): # pragma: py-gte-38 | ||
from typing_extensions import Literal | ||
else: # pragma: py-lt-38 | ||
from typing import Literal | ||
|
||
HEADER_RE = re.compile(b'[\x00-\x1F\x7F()<>@,;:[]={} \t\\"]') | ||
HEADER_VALUE_RE = re.compile(b"[\x00-\x1F\x7F]") | ||
|
||
|
||
def _get_status_line(status_code): | ||
def _get_status_line(status_code: int) -> bytes: | ||
try: | ||
phrase = http.HTTPStatus(status_code).phrase.encode() | ||
except ValueError: | ||
|
@@ -40,7 +60,12 @@ def _get_status_line(status_code): | |
|
||
|
||
class HttpToolsProtocol(asyncio.Protocol): | ||
def __init__(self, config, server_state, _loop=None): | ||
def __init__( | ||
self, | ||
config: Config, | ||
server_state: ServerState, | ||
_loop: Optional[asyncio.AbstractEventLoop] = None, | ||
) -> None: | ||
if not config.loaded: | ||
config.load() | ||
|
||
|
@@ -56,7 +81,7 @@ def __init__(self, config, server_state, _loop=None): | |
self.limit_concurrency = config.limit_concurrency | ||
|
||
# Timeouts | ||
self.timeout_keep_alive_task = None | ||
self.timeout_keep_alive_task: Optional[TimerHandle] = None | ||
self.timeout_keep_alive = config.timeout_keep_alive | ||
|
||
# Global state | ||
|
@@ -66,21 +91,24 @@ def __init__(self, config, server_state, _loop=None): | |
self.default_headers = server_state.default_headers | ||
|
||
# Per-connection state | ||
self.transport = None | ||
self.flow = None | ||
self.server = None | ||
self.client = None | ||
self.scheme = None | ||
self.pipeline = deque() | ||
# Per-connection state | ||
Kludex marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.transport: asyncio.Transport = None # type: ignore[assignment] | ||
self.flow: FlowControl = None # type: ignore[assignment] | ||
Comment on lines
+94
to
+95
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you using the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes |
||
self.server: Optional[Tuple[str, int]] = None | ||
self.client: Optional[Tuple[str, int]] = None | ||
self.scheme: Optional[Literal["http", "https"]] = None | ||
self.pipeline: Deque[Tuple[RequestResponseCycle, ASGI3Application]] = deque() | ||
|
||
# Per-request state | ||
self.scope = None | ||
self.headers = None | ||
self.scope: HTTPScope = None # type: ignore[assignment] | ||
self.headers: List[Tuple[bytes, bytes]] = None # type: ignore[assignment] | ||
self.expect_100_continue = False | ||
self.cycle = None | ||
self.cycle: RequestResponseCycle = None # type: ignore[assignment] | ||
|
||
# Protocol interface | ||
def connection_made(self, transport): | ||
def connection_made( # type: ignore[override] | ||
self, transport: asyncio.Transport | ||
) -> None: | ||
self.connections.add(self) | ||
|
||
self.transport = transport | ||
|
@@ -90,14 +118,14 @@ def connection_made(self, transport): | |
self.scheme = "https" if is_ssl(transport) else "http" | ||
|
||
if self.logger.level <= TRACE_LOG_LEVEL: | ||
prefix = "%s:%d - " % tuple(self.client) if self.client else "" | ||
prefix = "%s:%d - " % self.client if self.client else "" | ||
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix) | ||
|
||
def connection_lost(self, exc): | ||
def connection_lost(self, exc: Optional[Exception]) -> None: | ||
self.connections.discard(self) | ||
|
||
if self.logger.level <= TRACE_LOG_LEVEL: | ||
prefix = "%s:%d - " % tuple(self.client) if self.client else "" | ||
prefix = "%s:%d - " % self.client if self.client else "" | ||
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection lost", prefix) | ||
|
||
if self.cycle and not self.cycle.response_complete: | ||
|
@@ -109,15 +137,15 @@ def connection_lost(self, exc): | |
if exc is None: | ||
self.transport.close() | ||
|
||
def eof_received(self): | ||
def eof_received(self) -> None: | ||
pass | ||
|
||
def _unset_keepalive_if_required(self): | ||
def _unset_keepalive_if_required(self) -> None: | ||
if self.timeout_keep_alive_task is not None: | ||
self.timeout_keep_alive_task.cancel() | ||
self.timeout_keep_alive_task = None | ||
|
||
def data_received(self, data): | ||
def data_received(self, data: bytes) -> None: | ||
self._unset_keepalive_if_required() | ||
|
||
try: | ||
|
@@ -130,7 +158,7 @@ def data_received(self, data): | |
except httptools.HttpParserUpgrade: | ||
self.handle_upgrade() | ||
|
||
def handle_upgrade(self): | ||
def handle_upgrade(self) -> None: | ||
upgrade_value = None | ||
for name, value in self.headers: | ||
if name == b"upgrade": | ||
|
@@ -148,7 +176,7 @@ def handle_upgrade(self): | |
return | ||
|
||
if self.logger.level <= TRACE_LOG_LEVEL: | ||
prefix = "%s:%d - " % tuple(self.client) if self.client else "" | ||
prefix = "%s:%d - " % self.client if self.client else "" | ||
self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix) | ||
|
||
self.connections.discard(self) | ||
|
@@ -157,14 +185,14 @@ def handle_upgrade(self): | |
for name, value in self.scope["headers"]: | ||
output += [name, b": ", value, b"\r\n"] | ||
output.append(b"\r\n") | ||
protocol = self.ws_protocol_class( | ||
protocol = self.ws_protocol_class( # type: ignore[call-arg] | ||
config=self.config, server_state=self.server_state | ||
) | ||
protocol.connection_made(self.transport) | ||
protocol.data_received(b"".join(output)) | ||
self.transport.set_protocol(protocol) | ||
|
||
def send_400_response(self, msg: str): | ||
def send_400_response(self, msg: str) -> None: | ||
|
||
content = [STATUS_LINE[400]] | ||
for name, value in self.default_headers: | ||
|
@@ -181,11 +209,11 @@ def send_400_response(self, msg: str): | |
self.transport.write(b"".join(content)) | ||
self.transport.close() | ||
|
||
def on_message_begin(self): | ||
def on_message_begin(self) -> None: | ||
self.url = b"" | ||
self.expect_100_continue = False | ||
self.headers = [] | ||
self.scope = { | ||
self.scope = { # type: ignore[typeddict-item] | ||
"type": "http", | ||
"asgi": {"version": self.config.asgi_version, "spec_version": "2.3"}, | ||
"http_version": "1.1", | ||
|
@@ -197,16 +225,16 @@ def on_message_begin(self): | |
} | ||
|
||
# Parser callbacks | ||
def on_url(self, url): | ||
def on_url(self, url: bytes) -> None: | ||
self.url += url | ||
|
||
def on_header(self, name: bytes, value: bytes): | ||
def on_header(self, name: bytes, value: bytes) -> None: | ||
name = name.lower() | ||
if name == b"expect" and value.lower() == b"100-continue": | ||
self.expect_100_continue = True | ||
self.headers.append((name, value)) | ||
|
||
def on_headers_complete(self): | ||
def on_headers_complete(self) -> None: | ||
http_version = self.parser.get_http_version() | ||
method = self.parser.get_method() | ||
self.scope["method"] = method.decode("ascii") | ||
|
@@ -258,21 +286,21 @@ def on_headers_complete(self): | |
self.flow.pause_reading() | ||
self.pipeline.appendleft((self.cycle, app)) | ||
|
||
def on_body(self, body: bytes): | ||
def on_body(self, body: bytes) -> None: | ||
if self.parser.should_upgrade() or self.cycle.response_complete: | ||
return | ||
self.cycle.body += body | ||
if len(self.cycle.body) > HIGH_WATER_LIMIT: | ||
self.flow.pause_reading() | ||
self.cycle.message_event.set() | ||
|
||
def on_message_complete(self): | ||
def on_message_complete(self) -> None: | ||
if self.parser.should_upgrade() or self.cycle.response_complete: | ||
return | ||
self.cycle.more_body = False | ||
self.cycle.message_event.set() | ||
|
||
def on_response_complete(self): | ||
def on_response_complete(self) -> None: | ||
# Callback for pipelined HTTP requests to be started. | ||
self.server_state.total_requests += 1 | ||
|
||
|
@@ -296,7 +324,7 @@ def on_response_complete(self): | |
task.add_done_callback(self.tasks.discard) | ||
self.tasks.add(task) | ||
|
||
def shutdown(self): | ||
def shutdown(self) -> None: | ||
""" | ||
Called by the server to commence a graceful shutdown. | ||
""" | ||
|
@@ -305,19 +333,19 @@ def shutdown(self): | |
else: | ||
self.cycle.keep_alive = False | ||
|
||
def pause_writing(self): | ||
def pause_writing(self) -> None: | ||
""" | ||
Called by the transport when the write buffer exceeds the high water mark. | ||
""" | ||
self.flow.pause_writing() | ||
|
||
def resume_writing(self): | ||
def resume_writing(self) -> None: | ||
""" | ||
Called by the transport when the write buffer drops below the low water mark. | ||
""" | ||
self.flow.resume_writing() | ||
|
||
def timeout_keep_alive_handler(self): | ||
def timeout_keep_alive_handler(self) -> None: | ||
""" | ||
Called on a keep-alive connection if no new data is received after a short | ||
delay. | ||
|
@@ -329,17 +357,17 @@ def timeout_keep_alive_handler(self): | |
class RequestResponseCycle: | ||
def __init__( | ||
self, | ||
scope, | ||
transport, | ||
flow, | ||
logger, | ||
access_logger, | ||
access_log, | ||
default_headers, | ||
message_event, | ||
expect_100_continue, | ||
keep_alive, | ||
on_response, | ||
scope: HTTPScope, | ||
transport: asyncio.Transport, | ||
flow: FlowControl, | ||
logger: logging.Logger, | ||
access_logger: logging.Logger, | ||
access_log: bool, | ||
default_headers: List[Tuple[bytes, bytes]], | ||
message_event: asyncio.Event, | ||
expect_100_continue: bool, | ||
keep_alive: bool, | ||
on_response: Callable[..., None], | ||
): | ||
self.scope = scope | ||
self.transport = transport | ||
|
@@ -363,11 +391,11 @@ def __init__( | |
# Response state | ||
self.response_started = False | ||
self.response_complete = False | ||
self.chunked_encoding = None | ||
self.chunked_encoding: Optional[bool] = None | ||
self.expected_content_length = 0 | ||
|
||
# ASGI exception wrapper | ||
async def run_asgi(self, app): | ||
async def run_asgi(self, app: ASGI3Application) -> None: | ||
try: | ||
result = await app(self.scope, self.receive, self.send) | ||
except BaseException as exc: | ||
|
@@ -391,25 +419,27 @@ async def run_asgi(self, app): | |
self.logger.error(msg) | ||
self.transport.close() | ||
finally: | ||
self.on_response = None | ||
|
||
async def send_500_response(self): | ||
await self.send( | ||
{ | ||
"type": "http.response.start", | ||
"status": 500, | ||
"headers": [ | ||
(b"content-type", b"text/plain; charset=utf-8"), | ||
(b"connection", b"close"), | ||
], | ||
} | ||
) | ||
await self.send( | ||
{"type": "http.response.body", "body": b"Internal Server Error"} | ||
) | ||
self.on_response = lambda: None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This went from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question. That was introduced on #189. I don't know if the line itself is really necessary (?)... |
||
|
||
async def send_500_response(self) -> None: | ||
response_start_event: HTTPResponseStartEvent = { | ||
"type": "http.response.start", | ||
"status": 500, | ||
"headers": [ | ||
(b"content-type", b"text/plain; charset=utf-8"), | ||
(b"connection", b"close"), | ||
], | ||
} | ||
await self.send(response_start_event) | ||
response_body_event: HTTPResponseBodyEvent = { | ||
"type": "http.response.body", | ||
"body": b"Internal Server Error", | ||
"more_body": False, | ||
} | ||
await self.send(response_body_event) | ||
|
||
# ASGI interface | ||
async def send(self, message): | ||
async def send(self, message: ASGISendEvent) -> None: | ||
message_type = message["type"] | ||
|
||
if self.flow.write_paused and not self.disconnected: | ||
|
@@ -423,6 +453,7 @@ async def send(self, message): | |
if message_type != "http.response.start": | ||
msg = "Expected ASGI message 'http.response.start', but got '%s'." | ||
raise RuntimeError(msg % message_type) | ||
message = cast(HTTPResponseStartEvent, message) | ||
|
||
self.response_started = True | ||
self.waiting_for_100_continue = False | ||
|
@@ -481,7 +512,7 @@ async def send(self, message): | |
msg = "Expected ASGI message 'http.response.body', but got '%s'." | ||
raise RuntimeError(msg % message_type) | ||
|
||
body = message.get("body", b"") | ||
body = cast(bytes, message.get("body", b"")) | ||
more_body = message.get("more_body", False) | ||
|
||
# Write response body | ||
|
@@ -518,7 +549,7 @@ async def send(self, message): | |
msg = "Unexpected ASGI message '%s' sent, after response already completed." | ||
raise RuntimeError(msg % message_type) | ||
|
||
async def receive(self): | ||
async def receive(self) -> ASGIReceiveEvent: | ||
if self.waiting_for_100_continue and not self.transport.is_closing(): | ||
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n") | ||
self.waiting_for_100_continue = False | ||
|
@@ -528,6 +559,7 @@ async def receive(self): | |
await self.message_event.wait() | ||
self.message_event.clear() | ||
|
||
message: Union[HTTPDisconnectEvent, HTTPRequestEvent] | ||
if self.disconnected or self.response_complete: | ||
message = {"type": "http.disconnect"} | ||
else: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have a
_compat.py
module or something where this sorta boilerplate can go?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No