Skip to content

Commit

Permalink
Switch to asyncio streams API (#869)
Browse files Browse the repository at this point in the history
* Switch to asyncio streams API

* Tweak buffer swapping

* Properly handle exceptions

* More explanatory comments, 3.6 compatibility

* Drop unused MAX_RECV
  • Loading branch information
florimondmanca authored May 29, 2021
1 parent b72c386 commit 960d465
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 23 deletions.
Empty file added uvicorn/_handlers/__init__.py
Empty file.
99 changes: 99 additions & 0 deletions uvicorn/_handlers/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import asyncio
from typing import TYPE_CHECKING

from uvicorn.config import Config

if TYPE_CHECKING: # pragma: no cover
from uvicorn.server import ServerState


async def handle_http(
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
server_state: "ServerState",
config: Config,
) -> None:
# Run transport/protocol session from streams.
#
# This is a bit fiddly, so let me explain why we do this in the first place.
#
# This was introduced to switch to the asyncio streams API while retaining our
# existing protocols-based code.
#
# The aim was to:
# * Make it easier to support alternative async libaries (all of which expose
# a streams API, rather than anything similar to asyncio's transports and
# protocols) while keeping the change footprint (and risk) at a minimum.
# * Keep a "fast track" for asyncio that's as efficient as possible, by reusing
# our asyncio-optimized protocols-based implementation.
#
# See: https://github.com/encode/uvicorn/issues/169
# See: https://github.com/encode/uvicorn/pull/869

# Use a future to coordinate between the protocol and this handler task.
# https://docs.python.org/3/library/asyncio-protocol.html#connecting-existing-sockets
loop = asyncio.get_event_loop()
connection_lost = loop.create_future()

# Switch the protocol from the stream reader to our own HTTP protocol class.
protocol = config.http_protocol_class(
config=config,
server_state=server_state,
on_connection_lost=lambda: connection_lost.set_result(True),
)
transport = writer.transport
transport.set_protocol(protocol)

# Asyncio stream servers don't `await` handler tasks (like the one we're currently
# running), so we must make sure exceptions that occur in protocols but outside the
# ASGI cycle (e.g. bugs) are properly retrieved and logged.
# Vanilla asyncio handles exceptions properly out-of-the-box, but uvloop doesn't.
# So we need to attach a callback to handle exceptions ourselves for that case.
# (It's not easy to know which loop we're effectively running on, so we attach the
# callback in all cases. In practice it won't be called on vanilla asyncio.)
task = _get_current_task()

@task.add_done_callback
def retrieve_exception(task: asyncio.Task) -> None:
exc = task.exception()

if exc is None:
return

loop.call_exception_handler(
{
"message": "Fatal error in server handler",
"exception": exc,
"transport": transport,
"protocol": protocol,
}
)
# Hang up the connection so the client doesn't wait forever.
transport.close()

# Kick off the HTTP protocol.
protocol.connection_made(transport)

# Pass any data already in the read buffer.
# The assumption here is that we haven't read any data off the stream reader
# yet: all data that the client might have already sent since the connection has
# been established is in the `_buffer`.
data = reader._buffer # type: ignore
if data:
protocol.data_received(data)

# Let the transport run in the background. When closed, this future will complete
# and we'll exit here.
await connection_lost


def _get_current_task() -> asyncio.Task:
try:
current_task = asyncio.current_task
except AttributeError: # pragma: no cover
# Python 3.6.
current_task = asyncio.Task.current_task

task = current_task()
assert task is not None
return task
13 changes: 11 additions & 2 deletions uvicorn/protocols/http/h11_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import http
import logging
from typing import Callable
from urllib.parse import unquote

import h11
Expand Down Expand Up @@ -34,12 +35,15 @@ def _get_status_phrase(status_code):


class H11Protocol(asyncio.Protocol):
def __init__(self, config, server_state, _loop=None):
def __init__(
self, config, server_state, on_connection_lost: Callable = None, _loop=None
):
if not config.loaded:
config.load()

self.config = config
self.app = config.loaded_app
self.on_connection_lost = on_connection_lost
self.loop = _loop or asyncio.get_event_loop()
self.logger = logging.getLogger("uvicorn.error")
self.access_logger = logging.getLogger("uvicorn.access")
Expand Down Expand Up @@ -107,6 +111,9 @@ def connection_lost(self, exc):
if self.flow is not None:
self.flow.resume_writing()

if self.on_connection_lost is not None:
self.on_connection_lost()

def eof_received(self):
pass

Expand Down Expand Up @@ -253,7 +260,9 @@ def handle_upgrade(self, event):
output += [name, b": ", value, b"\r\n"]
output.append(b"\r\n")
protocol = self.ws_protocol_class(
config=self.config, server_state=self.server_state
config=self.config,
server_state=self.server_state,
on_connection_lost=self.on_connection_lost,
)
protocol.connection_made(self.transport)
protocol.data_received(b"".join(output))
Expand Down
13 changes: 11 additions & 2 deletions uvicorn/protocols/http/httptools_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import re
import urllib
from typing import Callable

import httptools

Expand Down Expand Up @@ -39,12 +40,15 @@ def _get_status_line(status_code):


class HttpToolsProtocol(asyncio.Protocol):
def __init__(self, config, server_state, _loop=None):
def __init__(
self, config, server_state, on_connection_lost: Callable = None, _loop=None
):
if not config.loaded:
config.load()

self.config = config
self.app = config.loaded_app
self.on_connection_lost = on_connection_lost
self.loop = _loop or asyncio.get_event_loop()
self.logger = logging.getLogger("uvicorn.error")
self.access_logger = logging.getLogger("uvicorn.access")
Expand Down Expand Up @@ -107,6 +111,9 @@ def connection_lost(self, exc):
if self.flow is not None:
self.flow.resume_writing()

if self.on_connection_lost is not None:
self.on_connection_lost()

def eof_received(self):
pass

Expand Down Expand Up @@ -166,7 +173,9 @@ def handle_upgrade(self):
output += [name, b": ", value, b"\r\n"]
output.append(b"\r\n")
protocol = self.ws_protocol_class(
config=self.config, server_state=self.server_state
config=self.config,
server_state=self.server_state,
on_connection_lost=self.on_connection_lost,
)
protocol.connection_made(self.transport)
protocol.data_received(b"".join(output))
Expand Down
8 changes: 7 additions & 1 deletion uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import http
import logging
from typing import Callable
from urllib.parse import unquote

import websockets
Expand All @@ -23,12 +24,15 @@ def is_serving(self):


class WebSocketProtocol(websockets.WebSocketServerProtocol):
def __init__(self, config, server_state, _loop=None):
def __init__(
self, config, server_state, on_connection_lost: Callable = None, _loop=None
):
if not config.loaded:
config.load()

self.config = config
self.app = config.loaded_app
self.on_connection_lost = on_connection_lost
self.loop = _loop or asyncio.get_event_loop()
self.logger = logging.getLogger("uvicorn.error")
self.root_path = config.root_path
Expand Down Expand Up @@ -74,6 +78,8 @@ def connection_lost(self, exc):
self.connections.remove(self)
self.handshake_completed_event.set()
super().connection_lost(exc)
if self.on_connection_lost is not None:
self.on_connection_lost()

def shutdown(self):
self.ws_server.closing = True
Expand Down
8 changes: 7 additions & 1 deletion uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
from typing import Callable
from urllib.parse import unquote

import h11
Expand All @@ -16,12 +17,15 @@


class WSProtocol(asyncio.Protocol):
def __init__(self, config, server_state, _loop=None):
def __init__(
self, config, server_state, on_connection_lost: Callable = None, _loop=None
):
if not config.loaded:
config.load()

self.config = config
self.app = config.loaded_app
self.on_connection_lost = on_connection_lost
self.loop = _loop or asyncio.get_event_loop()
self.logger = logging.getLogger("uvicorn.error")
self.root_path = config.root_path
Expand Down Expand Up @@ -65,6 +69,8 @@ def connection_lost(self, exc):
if exc is not None:
self.queue.put_nowait({"type": "websocket.disconnect"})
self.connections.remove(self)
if self.on_connection_lost is not None:
self.on_connection_lost()

def eof_received(self):
pass
Expand Down
37 changes: 20 additions & 17 deletions uvicorn/server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import functools
import logging
import os
import platform
Expand All @@ -13,6 +12,8 @@

import click

from ._handlers.http import handle_http

HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
Expand Down Expand Up @@ -77,25 +78,26 @@ async def serve(self, sockets=None):
extra={"color_message": color_message},
)

async def startup(self, sockets=None):
async def startup(self, sockets: list = None) -> None:
await self.lifespan.startup()
if self.lifespan.should_exit:
self.should_exit = True
return

config = self.config

create_protocol = functools.partial(
config.http_protocol_class, config=config, server_state=self.server_state
)

loop = asyncio.get_event_loop()
async def handler(
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
await handle_http(
reader, writer, server_state=self.server_state, config=config
)

if sockets is not None:
# Explicitly passed a list of open sockets.
# We use this when the server is run from a Gunicorn worker.

def _share_socket(sock: socket) -> socket:
def _share_socket(sock: socket.SocketType) -> socket.SocketType:
# Windows requires the socket be explicitly shared across
# multiple workers (processes).
from socket import fromshare # type: ignore
Expand All @@ -107,17 +109,17 @@ def _share_socket(sock: socket) -> socket:
for sock in sockets:
if config.workers > 1 and platform.system() == "Windows":
sock = _share_socket(sock)
server = await loop.create_server(
create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog
server = await asyncio.start_server(
handler, sock=sock, ssl=config.ssl, backlog=config.backlog
)
self.servers.append(server)
listeners = sockets

elif config.fd is not None:
# Use an existing socket, from a file descriptor.
sock = socket.fromfd(config.fd, socket.AF_UNIX, socket.SOCK_STREAM)
server = await loop.create_server(
create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog
server = await asyncio.start_server(
handler, sock=sock, ssl=config.ssl, backlog=config.backlog
)
assert server.sockets is not None # mypy
listeners = server.sockets
Expand All @@ -128,8 +130,8 @@ def _share_socket(sock: socket) -> socket:
uds_perms = 0o666
if os.path.exists(config.uds):
uds_perms = os.stat(config.uds).st_mode
server = await loop.create_unix_server(
create_protocol, path=config.uds, ssl=config.ssl, backlog=config.backlog
server = await asyncio.start_unix_server(
handler, path=config.uds, ssl=config.ssl, backlog=config.backlog
)
os.chmod(config.uds, uds_perms)
assert server.sockets is not None # mypy
Expand All @@ -139,8 +141,8 @@ def _share_socket(sock: socket) -> socket:
else:
# Standard case. Create a socket from a host/port pair.
try:
server = await loop.create_server(
create_protocol,
server = await asyncio.start_server(
handler,
host=config.host,
port=config.port,
ssl=config.ssl,
Expand All @@ -150,7 +152,8 @@ def _share_socket(sock: socket) -> socket:
logger.error(exc)
await self.lifespan.shutdown()
sys.exit(1)
assert server.sockets is not None # mypy

assert server.sockets is not None
listeners = server.sockets
self.servers = [server]

Expand Down

0 comments on commit 960d465

Please sign in to comment.