Skip to content

Commit

Permalink
Add type annotation to h11_impl.py (encode#1397)
Browse files Browse the repository at this point in the history
* Add type annotation to `h11_impl.py`

* Finish

* Use typing.Literal only for Python 3.8+

* Use 3.8 instead

* Don't use TypedDict instantiation at runtime

* Fix setup.cfg
  • Loading branch information
Kludex committed Oct 29, 2022
1 parent 6c9e7f7 commit cf46fd1
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 60 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ files =
uvicorn/protocols/__init__.py,
uvicorn/protocols/http/__init__.py,
uvicorn/protocols/websockets/__init__.py,
uvicorn/protocols/http/h11_impl.py,
tests/middleware/test_wsgi.py,
tests/middleware/test_proxy_headers.py,
tests/test_config.py,
Expand Down
163 changes: 103 additions & 60 deletions uvicorn/protocols/http/h11_impl.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
import asyncio
import http
import logging
import sys
from typing import Callable, List, Optional, Tuple, Union, cast
from urllib.parse import unquote

import h11
from asgiref.typing import (
ASGI3Application,
ASGIReceiveEvent,
ASGISendEvent,
HTTPDisconnectEvent,
HTTPRequestEvent,
HTTPResponseBodyEvent,
HTTPResponseStartEvent,
HTTPScope,
)

from uvicorn._logging import TRACE_LOG_LEVEL
from uvicorn.config import Config
from uvicorn.protocols.http.flow_control import (
CLOSE_HEADER,
HIGH_WATER_LIMIT,
Expand All @@ -19,9 +32,24 @@
get_remote_addr,
is_ssl,
)
from uvicorn.server import ServerState

if sys.version_info < (3, 8): # pragma: py-gte-38
from typing_extensions import Literal
else: # pragma: py-lt-38
from typing import Literal

H11Event = Union[
h11.Request,
h11.InformationalResponse,
h11.Response,
h11.Data,
h11.EndOfMessage,
h11.ConnectionClosed,
]


def _get_status_phrase(status_code):
def _get_status_phrase(status_code: int) -> bytes:
try:
return http.HTTPStatus(status_code).phrase.encode()
except ValueError:
Expand All @@ -34,7 +62,12 @@ def _get_status_phrase(status_code):


class H11Protocol(asyncio.Protocol):
def __init__(self, config, server_state, _loop=None):
def __init__(
self,
config: Config,
server_state: ServerState,
_loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
if not config.loaded:
config.load()

Expand All @@ -50,7 +83,7 @@ def __init__(self, config, server_state, _loop=None):
self.limit_concurrency = config.limit_concurrency

# Timeouts
self.timeout_keep_alive_task = None
self.timeout_keep_alive_task: Optional[asyncio.TimerHandle] = None
self.timeout_keep_alive = config.timeout_keep_alive

# Shared server state
Expand All @@ -60,19 +93,21 @@ def __init__(self, config, server_state, _loop=None):
self.default_headers = server_state.default_headers

# Per-connection state
self.transport = None
self.flow = None
self.server = None
self.client = None
self.scheme = None
self.transport: asyncio.Transport = None # type: ignore[assignment]
self.flow: FlowControl = None # type: ignore[assignment]
self.server: Optional[Tuple[str, int]] = None
self.client: Optional[Tuple[str, int]] = None
self.scheme: Optional[Literal["http", "https"]] = None

# Per-request state
self.scope = None
self.headers = None
self.cycle = None
self.scope: HTTPScope = None # type: ignore[assignment]
self.headers: List[Tuple[bytes, bytes]] = None # type: ignore[assignment]
self.cycle: RequestResponseCycle = None # type: ignore[assignment]

# Protocol interface
def connection_made(self, transport):
def connection_made( # type: ignore[override]
self, transport: asyncio.Transport
) -> None:
self.connections.add(self)

self.transport = transport
Expand All @@ -82,14 +117,14 @@ def connection_made(self, transport):
self.scheme = "https" if is_ssl(transport) else "http"

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix)

def connection_lost(self, exc):
def connection_lost(self, exc: Optional[Exception]) -> None:
self.connections.discard(self)

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection lost", prefix)

if self.cycle and not self.cycle.response_complete:
Expand All @@ -109,21 +144,21 @@ def connection_lost(self, exc):
if exc is None:
self.transport.close()

def eof_received(self):
def eof_received(self) -> None:
pass

def _unset_keepalive_if_required(self):
def _unset_keepalive_if_required(self) -> None:
if self.timeout_keep_alive_task is not None:
self.timeout_keep_alive_task.cancel()
self.timeout_keep_alive_task = None

def data_received(self, data):
def data_received(self, data: bytes) -> None:
self._unset_keepalive_if_required()

self.conn.receive_data(data)
self.handle_events()

def handle_events(self):
def handle_events(self) -> None:
while True:
try:
event = self.conn.next_event()
Expand All @@ -148,7 +183,7 @@ def handle_events(self):
elif event_type is h11.Request:
self.headers = [(key.lower(), value) for key, value in event.headers]
raw_path, _, query_string = event.target.partition(b"?")
self.scope = {
self.scope = { # type: ignore[typeddict-item]
"type": "http",
"asgi": {
"version": self.config.asgi_version,
Expand Down Expand Up @@ -216,7 +251,7 @@ def handle_events(self):
self.cycle.more_body = False
self.cycle.message_event.set()

def handle_upgrade(self, event):
def handle_upgrade(self, event: H11Event) -> None:
upgrade_value = None
for name, value in self.headers:
if name == b"upgrade":
Expand All @@ -234,22 +269,22 @@ def handle_upgrade(self, event):
return

if self.logger.level <= TRACE_LOG_LEVEL:
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix)

self.connections.discard(self)
output = [event.method, b" ", event.target, b" HTTP/1.1\r\n"]
for name, value in self.headers:
output += [name, b": ", value, b"\r\n"]
output.append(b"\r\n")
protocol = self.ws_protocol_class(
protocol = self.ws_protocol_class( # type: ignore[call-arg]
config=self.config, server_state=self.server_state
)
protocol.connection_made(self.transport)
protocol.data_received(b"".join(output))
self.transport.set_protocol(protocol)

def send_400_response(self, msg: str):
def send_400_response(self, msg: str) -> None:

reason = STATUS_PHRASES[400]
headers = [
Expand All @@ -267,7 +302,7 @@ def send_400_response(self, msg: str):
self.transport.write(output)
self.transport.close()

def on_response_complete(self):
def on_response_complete(self) -> None:
self.server_state.total_requests += 1

if self.transport.is_closing():
Expand All @@ -288,7 +323,7 @@ def on_response_complete(self):
self.conn.start_next_cycle()
self.handle_events()

def shutdown(self):
def shutdown(self) -> None:
"""
Called by the server to commence a graceful shutdown.
"""
Expand All @@ -299,19 +334,19 @@ def shutdown(self):
else:
self.cycle.keep_alive = False

def pause_writing(self):
def pause_writing(self) -> None:
"""
Called by the transport when the write buffer exceeds the high water mark.
"""
self.flow.pause_writing()

def resume_writing(self):
def resume_writing(self) -> None:
"""
Called by the transport when the write buffer drops below the low water mark.
"""
self.flow.resume_writing()

def timeout_keep_alive_handler(self):
def timeout_keep_alive_handler(self) -> None:
"""
Called on a keep-alive connection if no new data is received after a short
delay.
Expand All @@ -325,17 +360,17 @@ def timeout_keep_alive_handler(self):
class RequestResponseCycle:
def __init__(
self,
scope,
conn,
transport,
flow,
logger,
access_logger,
access_log,
default_headers,
message_event,
on_response,
):
scope: HTTPScope,
conn: h11.Connection,
transport: asyncio.Transport,
flow: FlowControl,
logger: logging.Logger,
access_logger: logging.Logger,
access_log: bool,
default_headers: List[Tuple[bytes, bytes]],
message_event: asyncio.Event,
on_response: Callable[..., None],
) -> None:
self.scope = scope
self.conn = conn
self.transport = transport
Expand All @@ -361,7 +396,7 @@ def __init__(
self.response_complete = False

# ASGI exception wrapper
async def run_asgi(self, app):
async def run_asgi(self, app: ASGI3Application) -> None:
try:
result = await app(self.scope, self.receive, self.send)
except BaseException as exc:
Expand All @@ -385,25 +420,27 @@ async def run_asgi(self, app):
self.logger.error(msg)
self.transport.close()
finally:
self.on_response = None

async def send_500_response(self):
await self.send(
{
"type": "http.response.start",
"status": 500,
"headers": [
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
],
}
)
await self.send(
{"type": "http.response.body", "body": b"Internal Server Error"}
)
self.on_response = lambda: None

async def send_500_response(self) -> None:
response_start_event: HTTPResponseStartEvent = {
"type": "http.response.start",
"status": 500,
"headers": [
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
],
}
await self.send(response_start_event)
response_body_event: HTTPResponseBodyEvent = {
"type": "http.response.body",
"body": b"Internal Server Error",
"more_body": False,
}
await self.send(response_body_event)

# ASGI interface
async def send(self, message):
async def send(self, message: ASGISendEvent) -> None:
message_type = message["type"]

if self.flow.write_paused and not self.disconnected:
Expand All @@ -417,12 +454,16 @@ async def send(self, message):
if message_type != "http.response.start":
msg = "Expected ASGI message 'http.response.start', but got '%s'."
raise RuntimeError(msg % message_type)
message = cast(HTTPResponseStartEvent, message)

self.response_started = True
self.waiting_for_100_continue = False

status_code = message["status"]
headers = self.default_headers + message.get("headers", [])
message_headers = cast(
List[Tuple[bytes, bytes]], message.get("headers", [])
)
headers = self.default_headers + message_headers

if CLOSE_HEADER in self.scope["headers"] and CLOSE_HEADER not in headers:
headers = headers + [CLOSE_HEADER]
Expand Down Expand Up @@ -450,6 +491,7 @@ async def send(self, message):
if message_type != "http.response.body":
msg = "Expected ASGI message 'http.response.body', but got '%s'."
raise RuntimeError(msg % message_type)
message = cast(HTTPResponseBodyEvent, message)

body = message.get("body", b"")
more_body = message.get("more_body", False)
Expand Down Expand Up @@ -482,7 +524,7 @@ async def send(self, message):
self.transport.close()
self.on_response()

async def receive(self):
async def receive(self) -> ASGIReceiveEvent:
if self.waiting_for_100_continue and not self.transport.is_closing():
event = h11.InformationalResponse(
status_code=100, headers=[], reason="Continue"
Expand All @@ -496,6 +538,7 @@ async def receive(self):
await self.message_event.wait()
self.message_event.clear()

message: Union[HTTPDisconnectEvent, HTTPRequestEvent]
if self.disconnected or self.response_complete:
message = {"type": "http.disconnect"}
else:
Expand Down

0 comments on commit cf46fd1

Please sign in to comment.