Skip to content

Commit

Permalink
Revert stream interface (#1355)
Browse files Browse the repository at this point in the history
* Revert "fix: move all data handle to protocol & ensure connection is closed (#1332)"

This reverts commit fc6e056.

* Revert stream interface
  • Loading branch information
Kludex authored Feb 3, 2022
1 parent 38720cf commit 9722ca4
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 148 deletions.
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ files =
uvicorn/supervisors/watchgodreload.py,
uvicorn/logging.py,
uvicorn/middleware/asgi2.py,
uvicorn/_handlers,
uvicorn/server.py,
uvicorn/__init__.py,
uvicorn/__main__.py,
Expand Down
Empty file removed uvicorn/_handlers/__init__.py
Empty file.
88 changes: 0 additions & 88 deletions uvicorn/_handlers/http.py

This file was deleted.

13 changes: 2 additions & 11 deletions uvicorn/protocols/http/h11_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import http
import logging
from typing import Callable
from urllib.parse import unquote

import h11
Expand Down Expand Up @@ -35,15 +34,12 @@ def _get_status_phrase(status_code):


class H11Protocol(asyncio.Protocol):
def __init__(
self, config, server_state, on_connection_lost: Callable = None, _loop=None
):
def __init__(self, config, server_state, _loop=None):
if not config.loaded:
config.load()

self.config = config
self.app = config.loaded_app
self.on_connection_lost = on_connection_lost
self.loop = _loop or asyncio.get_event_loop()
self.logger = logging.getLogger("uvicorn.error")
self.access_logger = logging.getLogger("uvicorn.access")
Expand Down Expand Up @@ -113,9 +109,6 @@ def connection_lost(self, exc):
if exc is None:
self.transport.close()

if self.on_connection_lost is not None:
self.on_connection_lost()

def eof_received(self):
pass

Expand Down Expand Up @@ -266,9 +259,7 @@ def handle_upgrade(self, event):
output += [name, b": ", value, b"\r\n"]
output.append(b"\r\n")
protocol = self.ws_protocol_class(
config=self.config,
server_state=self.server_state,
on_connection_lost=self.on_connection_lost,
config=self.config, server_state=self.server_state
)
protocol.connection_made(self.transport)
protocol.data_received(b"".join(output))
Expand Down
13 changes: 2 additions & 11 deletions uvicorn/protocols/http/httptools_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import re
import urllib
from collections import deque
from typing import Callable

import httptools

Expand Down Expand Up @@ -41,15 +40,12 @@ def _get_status_line(status_code):


class HttpToolsProtocol(asyncio.Protocol):
def __init__(
self, config, server_state, on_connection_lost: Callable = None, _loop=None
):
def __init__(self, config, server_state, _loop=None):
if not config.loaded:
config.load()

self.config = config
self.app = config.loaded_app
self.on_connection_lost = on_connection_lost
self.loop = _loop or asyncio.get_event_loop()
self.logger = logging.getLogger("uvicorn.error")
self.access_logger = logging.getLogger("uvicorn.access")
Expand Down Expand Up @@ -114,9 +110,6 @@ def connection_lost(self, exc):
if exc is None:
self.transport.close()

if self.on_connection_lost is not None:
self.on_connection_lost()

def eof_received(self):
pass

Expand Down Expand Up @@ -180,9 +173,7 @@ def handle_upgrade(self):
output += [name, b": ", value, b"\r\n"]
output.append(b"\r\n")
protocol = self.ws_protocol_class(
config=self.config,
server_state=self.server_state,
on_connection_lost=self.on_connection_lost,
config=self.config, server_state=self.server_state
)
protocol.connection_made(self.transport)
protocol.data_received(b"".join(output))
Expand Down
8 changes: 1 addition & 7 deletions uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import http
import logging
from typing import Callable
from urllib.parse import unquote

import websockets
Expand All @@ -25,15 +24,12 @@ def is_serving(self):


class WebSocketProtocol(websockets.WebSocketServerProtocol):
def __init__(
self, config, server_state, on_connection_lost: Callable = None, _loop=None
):
def __init__(self, config, server_state, _loop=None):
if not config.loaded:
config.load()

self.config = config
self.app = config.loaded_app
self.on_connection_lost = on_connection_lost
self.loop = _loop or asyncio.get_event_loop()
self.root_path = config.root_path

Expand Down Expand Up @@ -96,8 +92,6 @@ def connection_lost(self, exc):

self.handshake_completed_event.set()
super().connection_lost(exc)
if self.on_connection_lost is not None:
self.on_connection_lost()
if exc is None:
self.transport.close()

Expand Down
8 changes: 1 addition & 7 deletions uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import logging
from typing import Callable
from urllib.parse import unquote

import h11
Expand All @@ -20,15 +19,12 @@


class WSProtocol(asyncio.Protocol):
def __init__(
self, config, server_state, on_connection_lost: Callable = None, _loop=None
):
def __init__(self, config, server_state, _loop=None):
if not config.loaded:
config.load()

self.config = config
self.app = config.loaded_app
self.on_connection_lost = on_connection_lost
self.loop = _loop or asyncio.get_event_loop()
self.logger = logging.getLogger("uvicorn.error")
self.root_path = config.root_path
Expand Down Expand Up @@ -81,8 +77,6 @@ def connection_lost(self, exc):
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)

if self.on_connection_lost is not None:
self.on_connection_lost()
if exc is None:
self.transport.close()

Expand Down
37 changes: 14 additions & 23 deletions uvicorn/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import functools
import logging
import os
import platform
Expand All @@ -9,11 +10,10 @@
import time
from email.utils import formatdate
from types import FrameType
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union

import click

from uvicorn._handlers.http import handle_http
from uvicorn.config import Config

if TYPE_CHECKING:
Expand All @@ -24,13 +24,6 @@

Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol]

if sys.platform != "win32":
from asyncio import start_unix_server as _start_unix_server
else:

async def _start_unix_server(*args: Any, **kwargs: Any) -> Any:
raise NotImplementedError("Cannot start a unix server on win32")


HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
Expand Down Expand Up @@ -99,12 +92,10 @@ async def startup(self, sockets: list = None) -> None:

config = self.config

async def handler(
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
await handle_http(
reader, writer, server_state=self.server_state, config=config
)
create_protocol = functools.partial(
config.http_protocol_class, config=config, server_state=self.server_state
)
loop = asyncio.get_running_loop()

if sockets is not None:
# Explicitly passed a list of open sockets.
Expand All @@ -122,17 +113,17 @@ def _share_socket(sock: socket.SocketType) -> socket.SocketType:
for sock in sockets:
if config.workers > 1 and platform.system() == "Windows":
sock = _share_socket(sock)
server = await asyncio.start_server(
handler, sock=sock, ssl=config.ssl, backlog=config.backlog
server = await loop.create_server(
create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog
)
self.servers.append(server)
listeners = sockets

elif config.fd is not None:
# Use an existing socket, from a file descriptor.
sock = socket.fromfd(config.fd, socket.AF_UNIX, socket.SOCK_STREAM)
server = await asyncio.start_server(
handler, sock=sock, ssl=config.ssl, backlog=config.backlog
server = await loop.create_server(
create_protocol, sock=sock, ssl=config.ssl, backlog=config.backlog
)
assert server.sockets is not None # mypy
listeners = server.sockets
Expand All @@ -143,8 +134,8 @@ def _share_socket(sock: socket.SocketType) -> socket.SocketType:
uds_perms = 0o666
if os.path.exists(config.uds):
uds_perms = os.stat(config.uds).st_mode
server = await _start_unix_server(
handler, path=config.uds, ssl=config.ssl, backlog=config.backlog
server = await loop.create_server( # type: ignore[call-overload]
create_protocol, path=config.uds, ssl=config.ssl, backlog=config.backlog
)
os.chmod(config.uds, uds_perms)
assert server.sockets is not None # mypy
Expand All @@ -154,8 +145,8 @@ def _share_socket(sock: socket.SocketType) -> socket.SocketType:
else:
# Standard case. Create a socket from a host/port pair.
try:
server = await asyncio.start_server(
handler,
server = await loop.create_server(
create_protocol,
host=config.host,
port=config.port,
ssl=config.ssl,
Expand Down

0 comments on commit 9722ca4

Please sign in to comment.