From 51db8972c24af24ebe9d4d4560f7582e0043fe47 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sat, 11 Mar 2023 12:48:29 -0600 Subject: [PATCH] Support lifespan_scope["state"] --- src/hypercorn/asyncio/lifespan.py | 12 +++++++-- src/hypercorn/asyncio/run.py | 9 ++++--- src/hypercorn/asyncio/tcp_server.py | 5 +++- src/hypercorn/asyncio/udp_server.py | 12 +++++++-- src/hypercorn/protocol/__init__.py | 8 +++++- src/hypercorn/protocol/events.py | 3 +++ src/hypercorn/protocol/h11.py | 5 +++- src/hypercorn/protocol/h2.py | 5 +++- src/hypercorn/protocol/h3.py | 5 +++- src/hypercorn/protocol/http_stream.py | 2 ++ src/hypercorn/protocol/quic.py | 5 +++- src/hypercorn/protocol/ws_stream.py | 1 + src/hypercorn/trio/lifespan.py | 6 +++-- src/hypercorn/trio/run.py | 20 +++++++++++--- src/hypercorn/trio/tcp_server.py | 11 ++++++-- src/hypercorn/trio/udp_server.py | 12 +++++++-- src/hypercorn/typing.py | 9 ++++++- tests/asyncio/test_lifespan.py | 10 ++++--- tests/asyncio/test_sanity.py | 4 +++ tests/asyncio/test_tcp_server.py | 2 ++ tests/conftest.py | 3 ++- tests/helpers.py | 2 +- tests/middleware/test_dispatcher.py | 4 +-- tests/middleware/test_http_to_https.py | 7 ++++- tests/protocol/test_h11.py | 26 ++++++++++++++++--- tests/protocol/test_h2.py | 11 +++++++- tests/protocol/test_http_stream.py | 36 +++++++++++++++++++++++--- tests/protocol/test_ws_stream.py | 11 ++++++++ tests/test_app_wrappers.py | 7 ++++- tests/trio/test_keep_alive.py | 2 +- tests/trio/test_lifespan.py | 4 +-- tests/trio/test_sanity.py | 8 +++--- 32 files changed, 217 insertions(+), 50 deletions(-) diff --git a/src/hypercorn/asyncio/lifespan.py b/src/hypercorn/asyncio/lifespan.py index 244950c6..eaef9068 100644 --- a/src/hypercorn/asyncio/lifespan.py +++ b/src/hypercorn/asyncio/lifespan.py @@ -5,7 +5,7 @@ from typing import Any, Callable from ..config import Config -from ..typing import AppWrapper, ASGIReceiveEvent, ASGISendEvent, LifespanScope +from ..typing import AppWrapper, ASGIReceiveEvent, ASGISendEvent, LifespanScope, LifespanState from ..utils import LifespanFailureError, LifespanTimeoutError @@ -14,7 +14,13 @@ class UnexpectedMessageError(Exception): class Lifespan: - def __init__(self, app: AppWrapper, config: Config, loop: asyncio.AbstractEventLoop) -> None: + def __init__( + self, + app: AppWrapper, + config: Config, + loop: asyncio.AbstractEventLoop, + lifespan_state: LifespanState, + ) -> None: self.app = app self.config = config self.startup = asyncio.Event() @@ -22,6 +28,7 @@ def __init__(self, app: AppWrapper, config: Config, loop: asyncio.AbstractEventL self.app_queue: asyncio.Queue = asyncio.Queue(config.max_app_queue_size) self.supported = True self.loop = loop + self.state = lifespan_state # This mimics the Trio nursery.start task_status and is # required to ensure the support has been checked before @@ -33,6 +40,7 @@ async def handle_lifespan(self) -> None: scope: LifespanScope = { "type": "lifespan", "asgi": {"spec_version": "2.0", "version": "3.0"}, + "state": self.state, } def _call_soon(func: Callable, *args: Any) -> Any: diff --git a/src/hypercorn/asyncio/run.py b/src/hypercorn/asyncio/run.py index b50b9c65..1388b62b 100644 --- a/src/hypercorn/asyncio/run.py +++ b/src/hypercorn/asyncio/run.py @@ -17,7 +17,7 @@ from .udp_server import UDPServer from .worker_context import WorkerContext from ..config import Config, Sockets -from ..typing import AppWrapper +from ..typing import AppWrapper, LifespanState from ..utils import ( check_multiprocess_shutdown_event, load_application, @@ -71,7 +71,8 @@ def _signal_handler(*_: Any) -> None: # noqa: N803 shutdown_trigger = signal_event.wait # type: ignore - lifespan = Lifespan(app, config, loop) + lifespan_state: LifespanState = {} + lifespan = Lifespan(app, config, loop, lifespan_state) lifespan_task = loop.create_task(lifespan.handle_lifespan()) await lifespan.wait_for_startup() @@ -93,7 +94,7 @@ def _signal_handler(*_: Any) -> None: # noqa: N803 async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: server_tasks.add(asyncio.current_task(loop)) - await TCPServer(app, loop, config, context, reader, writer) + await TCPServer(app, loop, config, context, lifespan_state, reader, writer) servers = [] for sock in sockets.secure_sockets: @@ -127,7 +128,7 @@ async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamW sock = _share_socket(sock) _, protocol = await loop.create_datagram_endpoint( - lambda: UDPServer(app, loop, config, context), sock=sock + lambda: UDPServer(app, loop, config, context, lifespan_state), sock=sock ) server_tasks.add(loop.create_task(protocol.run())) bind = repr_socket_addr(sock.family, sock.getsockname()) diff --git a/src/hypercorn/asyncio/tcp_server.py b/src/hypercorn/asyncio/tcp_server.py index 91b3c050..f179170d 100644 --- a/src/hypercorn/asyncio/tcp_server.py +++ b/src/hypercorn/asyncio/tcp_server.py @@ -9,7 +9,7 @@ from ..config import Config from ..events import Closed, Event, RawData, Updated from ..protocol import ProtocolWrapper -from ..typing import AppWrapper +from ..typing import AppWrapper, ConnectionState, LifespanState from ..utils import parse_socket_addr MAX_RECV = 2**16 @@ -22,6 +22,7 @@ def __init__( loop: asyncio.AbstractEventLoop, config: Config, context: WorkerContext, + state: LifespanState, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, ) -> None: @@ -34,6 +35,7 @@ def __init__( self.writer = writer self.send_lock = asyncio.Lock() self.idle_lock = asyncio.Lock() + self.state = state self._idle_handle: Optional[asyncio.Task] = None @@ -59,6 +61,7 @@ async def run(self) -> None: self.config, self.context, task_group, + ConnectionState(self.state.copy()), ssl, client, server, diff --git a/src/hypercorn/asyncio/udp_server.py b/src/hypercorn/asyncio/udp_server.py index 629ab9f4..32857cc1 100644 --- a/src/hypercorn/asyncio/udp_server.py +++ b/src/hypercorn/asyncio/udp_server.py @@ -7,7 +7,7 @@ from .worker_context import WorkerContext from ..config import Config from ..events import Event, RawData -from ..typing import AppWrapper +from ..typing import AppWrapper, ConnectionState, LifespanState from ..utils import parse_socket_addr if TYPE_CHECKING: @@ -22,6 +22,7 @@ def __init__( loop: asyncio.AbstractEventLoop, config: Config, context: WorkerContext, + state: LifespanState, ) -> None: self.app = app self.config = config @@ -30,6 +31,7 @@ def __init__( self.protocol: "QuicProtocol" self.protocol_queue: asyncio.Queue = asyncio.Queue(10) self.transport: Optional[asyncio.DatagramTransport] = None + self.state = state def connection_made(self, transport: asyncio.DatagramTransport) -> None: # type: ignore self.transport = transport @@ -48,7 +50,13 @@ async def run(self) -> None: server = parse_socket_addr(socket.family, socket.getsockname()) async with TaskGroup(self.loop) as task_group: self.protocol = QuicProtocol( - self.app, self.config, self.context, task_group, server, self.protocol_send + self.app, + self.config, + self.context, + task_group, + ConnectionState(self.state.copy()), + server, + self.protocol_send, ) while not self.context.terminated.is_set() or not self.protocol.idle: diff --git a/src/hypercorn/protocol/__init__.py b/src/hypercorn/protocol/__init__.py index 39385681..4e8feae8 100755 --- a/src/hypercorn/protocol/__init__.py +++ b/src/hypercorn/protocol/__init__.py @@ -6,7 +6,7 @@ from .h11 import H2CProtocolRequiredError, H2ProtocolAssumedError, H11Protocol from ..config import Config from ..events import Event, RawData -from ..typing import AppWrapper, TaskGroup, WorkerContext +from ..typing import AppWrapper, ConnectionState, TaskGroup, WorkerContext class ProtocolWrapper: @@ -16,6 +16,7 @@ def __init__( config: Config, context: WorkerContext, task_group: TaskGroup, + state: ConnectionState, ssl: bool, client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], @@ -30,6 +31,7 @@ def __init__( self.client = client self.server = server self.send = send + self.state = state self.protocol: Union[H11Protocol, H2Protocol] if alpn_protocol == "h2": self.protocol = H2Protocol( @@ -37,6 +39,7 @@ def __init__( self.config, self.context, self.task_group, + self.state, self.ssl, self.client, self.server, @@ -48,6 +51,7 @@ def __init__( self.config, self.context, self.task_group, + self.state, self.ssl, self.client, self.server, @@ -66,6 +70,7 @@ async def handle(self, event: Event) -> None: self.config, self.context, self.task_group, + self.state, self.ssl, self.client, self.server, @@ -80,6 +85,7 @@ async def handle(self, event: Event) -> None: self.config, self.context, self.task_group, + self.state, self.ssl, self.client, self.server, diff --git a/src/hypercorn/protocol/events.py b/src/hypercorn/protocol/events.py index d91d203f..126f5f53 100644 --- a/src/hypercorn/protocol/events.py +++ b/src/hypercorn/protocol/events.py @@ -3,6 +3,8 @@ from dataclasses import dataclass from typing import List, Tuple +from hypercorn.typing import ConnectionState + @dataclass(frozen=True) class Event: @@ -15,6 +17,7 @@ class Request(Event): http_version: str method: str raw_path: bytes + state: ConnectionState @dataclass(frozen=True) diff --git a/src/hypercorn/protocol/h11.py b/src/hypercorn/protocol/h11.py index e18d4884..b47a53f7 100755 --- a/src/hypercorn/protocol/h11.py +++ b/src/hypercorn/protocol/h11.py @@ -20,7 +20,7 @@ from .ws_stream import WSStream from ..config import Config from ..events import Closed, Event, RawData, Updated -from ..typing import AppWrapper, H11SendableEvent, TaskGroup, WorkerContext +from ..typing import AppWrapper, ConnectionState, H11SendableEvent, TaskGroup, WorkerContext STREAM_ID = 1 @@ -84,6 +84,7 @@ def __init__( config: Config, context: WorkerContext, task_group: TaskGroup, + connection_state: ConnectionState, ssl: bool, client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], @@ -102,6 +103,7 @@ def __init__( self.ssl = ssl self.stream: Optional[Union[HTTPStream, WSStream]] = None self.task_group = task_group + self.connection_state = connection_state async def initiate(self) -> None: pass @@ -226,6 +228,7 @@ async def _create_stream(self, request: h11.Request) -> None: http_version=request.http_version.decode(), method=request.method.decode("ascii").upper(), raw_path=request.target, + state=self.connection_state, ) ) diff --git a/src/hypercorn/protocol/h2.py b/src/hypercorn/protocol/h2.py index 6e76d493..80cb6a59 100755 --- a/src/hypercorn/protocol/h2.py +++ b/src/hypercorn/protocol/h2.py @@ -23,7 +23,7 @@ from .ws_stream import WSStream from ..config import Config from ..events import Closed, Event, RawData, Updated -from ..typing import AppWrapper, Event as IOEvent, TaskGroup, WorkerContext +from ..typing import AppWrapper, ConnectionState, Event as IOEvent, TaskGroup, WorkerContext from ..utils import filter_pseudo_headers BUFFER_HIGH_WATER = 2 * 2**14 # Twice the default max frame size (two frames worth) @@ -84,6 +84,7 @@ def __init__( config: Config, context: WorkerContext, task_group: TaskGroup, + connection_state: ConnectionState, ssl: bool, client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], @@ -95,6 +96,7 @@ def __init__( self.config = config self.context = context self.task_group = task_group + self.connection_state = connection_state self.connection = h2.connection.H2Connection( config=h2.config.H2Configuration(client_side=False, header_encoding=None) @@ -347,6 +349,7 @@ async def _create_stream(self, request: h2.events.RequestReceived) -> None: http_version="2", method=method, raw_path=raw_path, + state=self.connection_state, ) ) diff --git a/src/hypercorn/protocol/h3.py b/src/hypercorn/protocol/h3.py index 88d9a4d3..29f95df2 100644 --- a/src/hypercorn/protocol/h3.py +++ b/src/hypercorn/protocol/h3.py @@ -22,7 +22,7 @@ from .http_stream import HTTPStream from .ws_stream import WSStream from ..config import Config -from ..typing import AppWrapper, TaskGroup, WorkerContext +from ..typing import AppWrapper, ConnectionState, TaskGroup, WorkerContext from ..utils import filter_pseudo_headers @@ -33,6 +33,7 @@ def __init__( config: Config, context: WorkerContext, task_group: TaskGroup, + state: ConnectionState, client: Optional[Tuple[str, int]], server: Optional[Tuple[str, int]], quic: QuicConnection, @@ -47,6 +48,7 @@ def __init__( self.server = server self.streams: Dict[int, Union[HTTPStream, WSStream]] = {} self.task_group = task_group + self.state = state async def handle(self, quic_event: QuicEvent) -> None: for event in self.connection.handle_event(quic_event): @@ -123,6 +125,7 @@ async def _create_stream(self, request: HeadersReceived) -> None: http_version="3", method=method, raw_path=raw_path, + state=self.state, ) ) diff --git a/src/hypercorn/protocol/http_stream.py b/src/hypercorn/protocol/http_stream.py index 6cd9beea..2cd8c9a5 100644 --- a/src/hypercorn/protocol/http_stream.py +++ b/src/hypercorn/protocol/http_stream.py @@ -86,6 +86,7 @@ async def handle(self, event: Event) -> None: "headers": event.headers, "client": self.client, "server": self.server, + "state": event.state, "extensions": {}, } if event.http_version in PUSH_VERSIONS: @@ -144,6 +145,7 @@ async def app_send(self, message: Optional[ASGISendEvent]) -> None: http_version=self.scope["http_version"], method="GET", raw_path=message["path"].encode(), + state=self.scope["state"], ) ) elif ( diff --git a/src/hypercorn/protocol/quic.py b/src/hypercorn/protocol/quic.py index 3d16e54d..0a1eb761 100644 --- a/src/hypercorn/protocol/quic.py +++ b/src/hypercorn/protocol/quic.py @@ -22,7 +22,7 @@ from .h3 import H3Protocol from ..config import Config from ..events import Closed, Event, RawData -from ..typing import AppWrapper, TaskGroup, WorkerContext +from ..typing import AppWrapper, ConnectionState, TaskGroup, WorkerContext class QuicProtocol: @@ -32,6 +32,7 @@ def __init__( config: Config, context: WorkerContext, task_group: TaskGroup, + state: ConnectionState, server: Optional[Tuple[str, int]], send: Callable[[Event], Awaitable[None]], ) -> None: @@ -43,6 +44,7 @@ def __init__( self.send = send self.server = server self.task_group = task_group + self.state = state self.quic_config = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=False) self.quic_config.load_cert_chain(certfile=config.certfile, keyfile=config.keyfile) @@ -106,6 +108,7 @@ async def _handle_events( self.config, self.context, self.task_group, + self.state, client, self.server, connection, diff --git a/src/hypercorn/protocol/ws_stream.py b/src/hypercorn/protocol/ws_stream.py index cebbf89b..e7c53f4b 100644 --- a/src/hypercorn/protocol/ws_stream.py +++ b/src/hypercorn/protocol/ws_stream.py @@ -217,6 +217,7 @@ async def handle(self, event: Event) -> None: "headers": event.headers, "client": self.client, "server": self.server, + "state": event.state, "subprotocols": self.handshake.subprotocols or [], "extensions": {"websocket.http.response": {}}, } diff --git a/src/hypercorn/trio/lifespan.py b/src/hypercorn/trio/lifespan.py index a45fc528..21f4dd26 100644 --- a/src/hypercorn/trio/lifespan.py +++ b/src/hypercorn/trio/lifespan.py @@ -3,7 +3,7 @@ import trio from ..config import Config -from ..typing import AppWrapper, ASGIReceiveEvent, ASGISendEvent, LifespanScope +from ..typing import AppWrapper, ASGIReceiveEvent, ASGISendEvent, LifespanScope, LifespanState from ..utils import LifespanFailureError, LifespanTimeoutError @@ -12,7 +12,7 @@ class UnexpectedMessageError(Exception): class Lifespan: - def __init__(self, app: AppWrapper, config: Config) -> None: + def __init__(self, app: AppWrapper, config: Config, state: LifespanState) -> None: self.app = app self.config = config self.startup = trio.Event() @@ -20,6 +20,7 @@ def __init__(self, app: AppWrapper, config: Config) -> None: self.app_send_channel, self.app_receive_channel = trio.open_memory_channel( config.max_app_queue_size ) + self.state = state self.supported = True async def handle_lifespan( @@ -29,6 +30,7 @@ async def handle_lifespan( scope: LifespanScope = { "type": "lifespan", "asgi": {"spec_version": "2.0", "version": "3.0"}, + "state": self.state, } try: await self.app( diff --git a/src/hypercorn/trio/run.py b/src/hypercorn/trio/run.py index 5dfbf91f..b7cdbaac 100644 --- a/src/hypercorn/trio/run.py +++ b/src/hypercorn/trio/run.py @@ -12,7 +12,7 @@ from .udp_server import UDPServer from .worker_context import WorkerContext from ..config import Config, Sockets -from ..typing import AppWrapper +from ..typing import AppWrapper, ConnectionState, LifespanState from ..utils import ( check_multiprocess_shutdown_event, load_application, @@ -32,7 +32,9 @@ async def worker_serve( ) -> None: config.set_statsd_logger_class(StatsdLogger) - lifespan = Lifespan(app, config) + lifespan_state: LifespanState = {} + + lifespan = Lifespan(app, config, lifespan_state) context = WorkerContext() async with trio.open_nursery() as lifespan_nursery: @@ -69,7 +71,11 @@ async def worker_serve( await config.log.info(f"Running on http://{bind} (CTRL + C to quit)") for sock in sockets.quic_sockets: - await server_nursery.start(UDPServer(app, config, context, sock).run) + await server_nursery.start( + UDPServer( + app, config, context, ConnectionState(lifespan_state.copy()), sock + ).run + ) bind = repr_socket_addr(sock.family, sock.getsockname()) await config.log.info(f"Running on https://{bind} (QUIC) (CTRL + C to quit)") @@ -82,7 +88,13 @@ async def worker_serve( nursery.start_soon( partial( trio.serve_listeners, - partial(TCPServer, app, config, context), + partial( + TCPServer, + app, + config, + context, + ConnectionState(lifespan_state.copy()), + ), listeners, handler_nursery=server_nursery, ), diff --git a/src/hypercorn/trio/tcp_server.py b/src/hypercorn/trio/tcp_server.py index 3419440f..c66d878c 100644 --- a/src/hypercorn/trio/tcp_server.py +++ b/src/hypercorn/trio/tcp_server.py @@ -10,7 +10,7 @@ from ..config import Config from ..events import Closed, Event, RawData, Updated from ..protocol import ProtocolWrapper -from ..typing import AppWrapper +from ..typing import AppWrapper, ConnectionState, LifespanState from ..utils import parse_socket_addr MAX_RECV = 2**16 @@ -18,7 +18,12 @@ class TCPServer: def __init__( - self, app: AppWrapper, config: Config, context: WorkerContext, stream: trio.abc.Stream + self, + app: AppWrapper, + config: Config, + context: WorkerContext, + stream: trio.abc.Stream, + state: LifespanState, ) -> None: self.app = app self.config = config @@ -27,6 +32,7 @@ def __init__( self.send_lock = trio.Lock() self.idle_lock = trio.Lock() self.stream = stream + self.state = state self._idle_handle: Optional[trio.CancelScope] = None @@ -59,6 +65,7 @@ async def run(self) -> None: self.config, self.context, task_group, + ConnectionState(self.state.copy()), ssl, client, server, diff --git a/src/hypercorn/trio/udp_server.py b/src/hypercorn/trio/udp_server.py index b8d4530b..d66b0378 100644 --- a/src/hypercorn/trio/udp_server.py +++ b/src/hypercorn/trio/udp_server.py @@ -6,7 +6,7 @@ from .worker_context import WorkerContext from ..config import Config from ..events import Event, RawData -from ..typing import AppWrapper +from ..typing import AppWrapper, ConnectionState, LifespanState from ..utils import parse_socket_addr MAX_RECV = 2**16 @@ -18,12 +18,14 @@ def __init__( app: AppWrapper, config: Config, context: WorkerContext, + state: LifespanState, socket: trio.socket.socket, ) -> None: self.app = app self.config = config self.context = context self.socket = trio.socket.from_stdlib_socket(socket) + self.state = state async def run( self, task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED @@ -34,7 +36,13 @@ async def run( server = parse_socket_addr(self.socket.family, self.socket.getsockname()) async with TaskGroup() as task_group: self.protocol = QuicProtocol( - self.app, self.config, self.context, task_group, server, self.protocol_send + self.app, + self.config, + self.context, + task_group, + ConnectionState(self.state.copy()), + server, + self.protocol_send, ) while not self.context.terminated.is_set() or not self.protocol.idle: diff --git a/src/hypercorn/typing.py b/src/hypercorn/typing.py index 206415c0..396b28bf 100644 --- a/src/hypercorn/typing.py +++ b/src/hypercorn/typing.py @@ -2,7 +2,7 @@ from multiprocessing.synchronize import Event as EventType from types import TracebackType -from typing import Any, Awaitable, Callable, Dict, Iterable, Optional, Tuple, Type, Union +from typing import Any, Awaitable, Callable, Dict, Iterable, NewType, Optional, Tuple, Type, Union import h2.events import h11 @@ -19,6 +19,10 @@ WorkerFunc = Callable[[Config, Optional[Sockets], Optional[EventType]], None] +LifespanState = Dict[str, Any] + +ConnectionState = NewType("ConnectionState", Dict[str, Any]) + class ASGIVersions(TypedDict, total=False): spec_version: str @@ -38,6 +42,7 @@ class HTTPScope(TypedDict): headers: Iterable[Tuple[bytes, bytes]] client: Optional[Tuple[str, int]] server: Optional[Tuple[str, Optional[int]]] + state: ConnectionState extensions: Dict[str, dict] @@ -54,12 +59,14 @@ class WebsocketScope(TypedDict): client: Optional[Tuple[str, int]] server: Optional[Tuple[str, Optional[int]]] subprotocols: Iterable[str] + state: ConnectionState extensions: Dict[str, dict] class LifespanScope(TypedDict): type: Literal["lifespan"] asgi: ASGIVersions + state: LifespanState WWWScope = Union[HTTPScope, WebsocketScope] diff --git a/tests/asyncio/test_lifespan.py b/tests/asyncio/test_lifespan.py index c59a395b..0bb02590 100644 --- a/tests/asyncio/test_lifespan.py +++ b/tests/asyncio/test_lifespan.py @@ -23,7 +23,7 @@ async def no_lifespan_app(scope: Scope, receive: Callable, send: Callable) -> No async def test_ensure_no_race_condition(event_loop: asyncio.AbstractEventLoop) -> None: config = Config() config.startup_timeout = 0.2 - lifespan = Lifespan(ASGIWrapper(no_lifespan_app), config, event_loop) + lifespan = Lifespan(ASGIWrapper(no_lifespan_app), config, event_loop, {}) task = event_loop.create_task(lifespan.handle_lifespan()) await lifespan.wait_for_startup() # Raises if there is a race condition await task @@ -33,7 +33,9 @@ async def test_ensure_no_race_condition(event_loop: asyncio.AbstractEventLoop) - async def test_startup_timeout_error(event_loop: asyncio.AbstractEventLoop) -> None: config = Config() config.startup_timeout = 0.01 - lifespan = Lifespan(ASGIWrapper(SlowLifespanFramework(0.02, asyncio.sleep)), config, event_loop) + lifespan = Lifespan( + ASGIWrapper(SlowLifespanFramework(0.02, asyncio.sleep)), config, event_loop, {} + ) task = event_loop.create_task(lifespan.handle_lifespan()) with pytest.raises(LifespanTimeoutError) as exc_info: await lifespan.wait_for_startup() @@ -43,7 +45,7 @@ async def test_startup_timeout_error(event_loop: asyncio.AbstractEventLoop) -> N @pytest.mark.asyncio async def test_startup_failure(event_loop: asyncio.AbstractEventLoop) -> None: - lifespan = Lifespan(ASGIWrapper(lifespan_failure), Config(), event_loop) + lifespan = Lifespan(ASGIWrapper(lifespan_failure), Config(), event_loop, {}) lifespan_task = event_loop.create_task(lifespan.handle_lifespan()) await lifespan.wait_for_startup() assert lifespan_task.done() @@ -58,7 +60,7 @@ async def return_app(scope: Scope, receive: Callable, send: Callable) -> None: @pytest.mark.asyncio async def test_lifespan_return(event_loop: asyncio.AbstractEventLoop) -> None: - lifespan = Lifespan(ASGIWrapper(return_app), Config(), event_loop) + lifespan = Lifespan(ASGIWrapper(return_app), Config(), event_loop, {}) lifespan_task = event_loop.create_task(lifespan.handle_lifespan()) await lifespan.wait_for_startup() await lifespan.wait_for_shutdown() diff --git a/tests/asyncio/test_sanity.py b/tests/asyncio/test_sanity.py index 287cd06d..222a6dd0 100644 --- a/tests/asyncio/test_sanity.py +++ b/tests/asyncio/test_sanity.py @@ -22,6 +22,7 @@ async def test_http1_request(event_loop: asyncio.AbstractEventLoop) -> None: event_loop, Config(), WorkerContext(), + {}, MemoryReader(), # type: ignore MemoryWriter(), # type: ignore ) @@ -79,6 +80,7 @@ async def test_http1_websocket(event_loop: asyncio.AbstractEventLoop) -> None: event_loop, Config(), WorkerContext(), + {}, MemoryReader(), # type: ignore MemoryWriter(), # type: ignore ) @@ -116,6 +118,7 @@ async def test_http2_request(event_loop: asyncio.AbstractEventLoop) -> None: event_loop, Config(), WorkerContext(), + {}, MemoryReader(), # type: ignore MemoryWriter(http2=True), # type: ignore ) @@ -179,6 +182,7 @@ async def test_http2_websocket(event_loop: asyncio.AbstractEventLoop) -> None: event_loop, Config(), WorkerContext(), + {}, MemoryReader(), # type: ignore MemoryWriter(http2=True), # type: ignore ) diff --git a/tests/asyncio/test_tcp_server.py b/tests/asyncio/test_tcp_server.py index f4915de0..b0f411d2 100644 --- a/tests/asyncio/test_tcp_server.py +++ b/tests/asyncio/test_tcp_server.py @@ -19,6 +19,7 @@ async def test_completes_on_closed(event_loop: asyncio.AbstractEventLoop) -> Non event_loop, Config(), WorkerContext(), + {}, MemoryReader(), # type: ignore MemoryWriter(), # type: ignore ) @@ -35,6 +36,7 @@ async def test_complets_on_half_close(event_loop: asyncio.AbstractEventLoop) -> event_loop, Config(), WorkerContext(), + {}, MemoryReader(), # type: ignore MemoryWriter(), # type: ignore ) diff --git a/tests/conftest.py b/tests/conftest.py index f25c3f1a..be84f59a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ from _pytest.monkeypatch import MonkeyPatch import hypercorn.config -from hypercorn.typing import HTTPScope +from hypercorn.typing import ConnectionState, HTTPScope @pytest.fixture(autouse=True) @@ -32,4 +32,5 @@ def _http_scope() -> HTTPScope: "client": ("127.0.0.1", 80), "server": None, "extensions": {}, + "state": ConnectionState({}), } diff --git a/tests/helpers.py b/tests/helpers.py index cdac68c3..9162e5be 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -11,7 +11,6 @@ class MockSocket: - family = AF_INET def getsockname(self) -> Tuple[str, int]: @@ -94,6 +93,7 @@ async def sanity_framework( if event["type"] in {"http.disconnect", "websocket.disconnect"}: break elif event["type"] == "lifespan.startup": + assert "state" in scope await send({"type": "lifspan.startup.complete"}) # type: ignore elif event["type"] == "lifespan.shutdown": await send({"type": "lifspan.shutdown.complete"}) # type: ignore diff --git a/tests/middleware/test_dispatcher.py b/tests/middleware/test_dispatcher.py index dbb3f43e..2d0b9e37 100644 --- a/tests/middleware/test_dispatcher.py +++ b/tests/middleware/test_dispatcher.py @@ -72,7 +72,7 @@ async def send(message: dict) -> None: async def receive() -> dict: return {"type": "lifespan.shutdown"} - await app({"type": "lifespan", "asgi": {"version": "3.0"}}, receive, send) + await app({"type": "lifespan", "asgi": {"version": "3.0"}, "state": {}}, receive, send) assert sent_events == [{"type": "lifespan.startup.complete"}] @@ -89,5 +89,5 @@ async def send(message: dict) -> None: async def receive() -> dict: return {"type": "lifespan.shutdown"} - await app({"type": "lifespan", "asgi": {"version": "3.0"}}, receive, send) + await app({"type": "lifespan", "asgi": {"version": "3.0"}, "state": {}}, receive, send) assert sent_events == [{"type": "lifespan.startup.complete"}] diff --git a/tests/middleware/test_http_to_https.py b/tests/middleware/test_http_to_https.py index a4880c07..01583e26 100644 --- a/tests/middleware/test_http_to_https.py +++ b/tests/middleware/test_http_to_https.py @@ -3,7 +3,7 @@ import pytest from hypercorn.middleware import HTTPToHTTPSRedirectMiddleware -from hypercorn.typing import HTTPScope, WebsocketScope +from hypercorn.typing import ConnectionState, HTTPScope, WebsocketScope from ..helpers import empty_framework @@ -31,6 +31,7 @@ async def send(message: dict) -> None: "client": ("127.0.0.1", 80), "server": None, "extensions": {}, + "state": ConnectionState({}), } await app(scope, None, send) @@ -69,6 +70,7 @@ async def send(message: dict) -> None: "server": None, "subprotocols": [], "extensions": {"websocket.http.response": {}}, + "state": ConnectionState({}), } await app(scope, None, send) @@ -105,6 +107,7 @@ async def send(message: dict) -> None: "server": None, "subprotocols": [], "extensions": {"websocket.http.response": {}}, + "state": ConnectionState({}), } await app(scope, None, send) @@ -141,6 +144,7 @@ async def send(message: dict) -> None: "server": None, "subprotocols": [], "extensions": {}, + "state": ConnectionState({}), } await app(scope, None, send) @@ -165,6 +169,7 @@ def test_http_to_https_redirect_new_url_header() -> None: "client": None, "server": None, "extensions": {}, + "state": ConnectionState({}), }, ) assert new_url == "https://localhost/" diff --git a/tests/protocol/test_h11.py b/tests/protocol/test_h11.py index 20e00917..0c8ce210 100755 --- a/tests/protocol/test_h11.py +++ b/tests/protocol/test_h11.py @@ -16,7 +16,7 @@ from hypercorn.protocol.events import Body, Data, EndBody, EndData, Request, Response, StreamClosed from hypercorn.protocol.h11 import H2CProtocolRequiredError, H2ProtocolAssumedError, H11Protocol from hypercorn.protocol.http_stream import HTTPStream -from hypercorn.typing import Event as IOEvent +from hypercorn.typing import ConnectionState, Event as IOEvent try: from unittest.mock import AsyncMock @@ -37,7 +37,17 @@ async def _protocol(monkeypatch: MonkeyPatch) -> H11Protocol: context.event_class.return_value = AsyncMock(spec=IOEvent) context.terminated = context.event_class() context.terminated.is_set.return_value = False - return H11Protocol(AsyncMock(), Config(), context, AsyncMock(), False, None, None, AsyncMock()) + return H11Protocol( + AsyncMock(), + Config(), + context, + AsyncMock(), + ConnectionState({}), + False, + None, + None, + AsyncMock(), + ) @pytest.mark.asyncio @@ -169,6 +179,7 @@ async def test_protocol_handle_closed(protocol: H11Protocol) -> None: http_version="1.1", method="GET", raw_path=b"/", + state=ConnectionState({}), ) ), call(EndBody(stream_id=1)), @@ -191,6 +202,7 @@ async def test_protocol_handle_request(protocol: H11Protocol) -> None: http_version="1.1", method="GET", raw_path=b"/?a=b", + state=ConnectionState({}), ) ), call(EndBody(stream_id=1)), @@ -268,7 +280,15 @@ async def test_protocol_handle_max_incomplete(monkeypatch: MonkeyPatch) -> None: context = Mock() context.event_class.return_value = AsyncMock(spec=IOEvent) protocol = H11Protocol( - AsyncMock(), config, context, AsyncMock(), False, None, None, AsyncMock() + AsyncMock(), + config, + context, + AsyncMock(), + ConnectionState({}), + False, + None, + None, + AsyncMock(), ) await protocol.handle(RawData(data=b"GET / HTTP/1.1\r\nHost: hypercorn\r\n")) protocol.send.assert_called() # type: ignore diff --git a/tests/protocol/test_h2.py b/tests/protocol/test_h2.py index c44f39ae..d6bcf8bd 100644 --- a/tests/protocol/test_h2.py +++ b/tests/protocol/test_h2.py @@ -9,6 +9,7 @@ from hypercorn.config import Config from hypercorn.events import Closed, RawData from hypercorn.protocol.h2 import BUFFER_HIGH_WATER, BufferCompleteError, H2Protocol, StreamBuffer +from hypercorn.typing import ConnectionState try: from unittest.mock import AsyncMock @@ -73,7 +74,15 @@ async def test_stream_buffer_complete(event_loop: asyncio.AbstractEventLoop) -> @pytest.mark.asyncio async def test_protocol_handle_protocol_error() -> None: protocol = H2Protocol( - Mock(), Config(), WorkerContext(), AsyncMock(), False, None, None, AsyncMock() + Mock(), + Config(), + WorkerContext(), + AsyncMock(), + ConnectionState({}), + False, + None, + None, + AsyncMock(), ) await protocol.handle(RawData(data=b"broken nonsense\r\n\r\n")) protocol.send.assert_awaited() # type: ignore diff --git a/tests/protocol/test_http_stream.py b/tests/protocol/test_http_stream.py index 3cb2ad79..6e0357e5 100644 --- a/tests/protocol/test_http_stream.py +++ b/tests/protocol/test_http_stream.py @@ -18,7 +18,12 @@ StreamClosed, ) from hypercorn.protocol.http_stream import ASGIHTTPState, HTTPStream -from hypercorn.typing import HTTPResponseBodyEvent, HTTPResponseStartEvent, HTTPScope +from hypercorn.typing import ( + ConnectionState, + HTTPResponseBodyEvent, + HTTPResponseStartEvent, + HTTPScope, +) from hypercorn.utils import UnexpectedMessageError try: @@ -42,7 +47,14 @@ async def _stream() -> HTTPStream: @pytest.mark.asyncio async def test_handle_request_http_1(stream: HTTPStream, http_version: str) -> None: await stream.handle( - Request(stream_id=1, http_version=http_version, headers=[], raw_path=b"/?a=b", method="GET") + Request( + stream_id=1, + http_version=http_version, + headers=[], + raw_path=b"/?a=b", + method="GET", + state=ConnectionState({}), + ) ) stream.task_group.spawn_app.assert_called() # type: ignore scope = stream.task_group.spawn_app.call_args[0][2] # type: ignore @@ -66,7 +78,14 @@ async def test_handle_request_http_1(stream: HTTPStream, http_version: str) -> N @pytest.mark.asyncio async def test_handle_request_http_2(stream: HTTPStream) -> None: await stream.handle( - Request(stream_id=1, http_version="2", headers=[], raw_path=b"/?a=b", method="GET") + Request( + stream_id=1, + http_version="2", + headers=[], + raw_path=b"/?a=b", + method="GET", + state=ConnectionState({}), + ) ) stream.task_group.spawn_app.assert_called() # type: ignore scope = stream.task_group.spawn_app.call_args[0][2] # type: ignore @@ -116,7 +135,14 @@ async def test_handle_closed(stream: HTTPStream) -> None: @pytest.mark.asyncio async def test_send_response(stream: HTTPStream) -> None: await stream.handle( - Request(stream_id=1, http_version="2", headers=[], raw_path=b"/?a=b", method="GET") + Request( + stream_id=1, + http_version="2", + headers=[], + raw_path=b"/?a=b", + method="GET", + state=ConnectionState({}), + ) ) await stream.app_send( cast(HTTPResponseStartEvent, {"type": "http.response.start", "status": 200, "headers": []}) @@ -148,6 +174,7 @@ async def test_invalid_server_name(stream: HTTPStream) -> None: headers=[(b"host", b"example.com")], raw_path=b"/", method="GET", + state=ConnectionState({}), ) ) assert stream.send.call_args_list == [ # type: ignore @@ -177,6 +204,7 @@ async def test_send_push(stream: HTTPStream, http_scope: HTTPScope) -> None: http_version="2", method="GET", raw_path=b"/push", + state=ConnectionState({}), ) ) ] diff --git a/tests/protocol/test_ws_stream.py b/tests/protocol/test_ws_stream.py index f927cf59..0dd0aed1 100644 --- a/tests/protocol/test_ws_stream.py +++ b/tests/protocol/test_ws_stream.py @@ -21,6 +21,7 @@ WSStream, ) from hypercorn.typing import ( + ConnectionState, WebsocketAcceptEvent, WebsocketCloseEvent, WebsocketResponseBodyEvent, @@ -182,6 +183,7 @@ async def test_handle_request(stream: WSStream) -> None: headers=[(b"sec-websocket-version", b"13")], raw_path=b"/?a=b", method="GET", + state=ConnectionState({}), ) ) stream.task_group.spawn_app.assert_called() # type: ignore @@ -212,6 +214,7 @@ async def test_handle_connection(stream: WSStream) -> None: headers=[(b"sec-websocket-version", b"13")], raw_path=b"/?a=b", method="GET", + state=ConnectionState({}), ) ) await stream.app_send(cast(WebsocketAcceptEvent, {"type": "websocket.accept"})) @@ -241,6 +244,7 @@ async def test_send_accept(stream: WSStream) -> None: headers=[(b"sec-websocket-version", b"13")], raw_path=b"/", method="GET", + state=ConnectionState({}), ) ) await stream.app_send(cast(WebsocketAcceptEvent, {"type": "websocket.accept"})) @@ -260,6 +264,7 @@ async def test_send_accept_with_additional_headers(stream: WSStream) -> None: headers=[(b"sec-websocket-version", b"13")], raw_path=b"/", method="GET", + state=ConnectionState({}), ) ) await stream.app_send( @@ -284,6 +289,7 @@ async def test_send_reject(stream: WSStream) -> None: headers=[(b"sec-websocket-version", b"13")], raw_path=b"/", method="GET", + state=ConnectionState({}), ) ) await stream.app_send( @@ -318,6 +324,7 @@ async def test_invalid_server_name(stream: WSStream) -> None: headers=[(b"host", b"example.com"), (b"sec-websocket-version", b"13")], raw_path=b"/", method="GET", + state=ConnectionState({}), ) ) assert stream.send.call_args_list == [ # type: ignore @@ -343,6 +350,7 @@ async def test_send_app_error_handshake(stream: WSStream) -> None: headers=[(b"sec-websocket-version", b"13")], raw_path=b"/", method="GET", + state=ConnectionState({}), ) ) await stream.app_send(None) @@ -370,6 +378,7 @@ async def test_send_app_error_connected(stream: WSStream) -> None: headers=[(b"sec-websocket-version", b"13")], raw_path=b"/", method="GET", + state=ConnectionState({}), ) ) await stream.app_send(cast(WebsocketAcceptEvent, {"type": "websocket.accept"})) @@ -392,6 +401,7 @@ async def test_send_connection(stream: WSStream) -> None: headers=[(b"sec-websocket-version", b"13")], raw_path=b"/", method="GET", + state=ConnectionState({}), ) ) await stream.app_send(cast(WebsocketAcceptEvent, {"type": "websocket.accept"})) @@ -416,6 +426,7 @@ async def test_pings(stream: WSStream, event_loop: asyncio.AbstractEventLoop) -> headers=[(b"sec-websocket-version", b"13")], raw_path=b"/?a=b", method="GET", + state=ConnectionState({}), ) ) async with TaskGroup(event_loop) as task_group: diff --git a/tests/test_app_wrappers.py b/tests/test_app_wrappers.py index bb7b5897..3aefebce 100644 --- a/tests/test_app_wrappers.py +++ b/tests/test_app_wrappers.py @@ -8,7 +8,7 @@ import trio from hypercorn.app_wrappers import _build_environ, InvalidPathError, WSGIWrapper -from hypercorn.typing import ASGISendEvent, HTTPScope +from hypercorn.typing import ASGISendEvent, ConnectionState, HTTPScope def echo_body(environ: dict, start_response: Callable) -> List[bytes]: @@ -39,6 +39,7 @@ async def test_wsgi_trio() -> None: "client": ("localhost", 80), "server": None, "extensions": {}, + "state": ConnectionState({}), } send_channel, receive_channel = trio.open_memory_channel(1) await send_channel.send({"type": "http.request"}) @@ -78,6 +79,7 @@ async def test_wsgi_asyncio(event_loop: asyncio.AbstractEventLoop) -> None: "client": ("localhost", 80), "server": None, "extensions": {}, + "state": ConnectionState({}), } queue: asyncio.Queue = asyncio.Queue() await queue.put({"type": "http.request"}) @@ -121,6 +123,7 @@ async def test_max_body_size(event_loop: asyncio.AbstractEventLoop) -> None: "client": ("localhost", 80), "server": None, "extensions": {}, + "state": ConnectionState({}), } queue: asyncio.Queue = asyncio.Queue() await queue.put({"type": "http.request", "body": b"abcde"}) @@ -156,6 +159,7 @@ def test_build_environ_encoding() -> None: "client": ("localhost", 80), "server": None, "extensions": {}, + "state": ConnectionState({}), } environ = _build_environ(scope, b"") assert environ["SCRIPT_NAME"] == "/δΈ­".encode("utf8").decode("latin-1") @@ -177,6 +181,7 @@ def test_build_environ_root_path() -> None: "client": ("localhost", 80), "server": None, "extensions": {}, + "state": ConnectionState({}), } with pytest.raises(InvalidPathError): _build_environ(scope, b"") diff --git a/tests/trio/test_keep_alive.py b/tests/trio/test_keep_alive.py index d30d82db..0fbdc066 100644 --- a/tests/trio/test_keep_alive.py +++ b/tests/trio/test_keep_alive.py @@ -47,7 +47,7 @@ def _client_stream( config.keep_alive_timeout = KEEP_ALIVE_TIMEOUT client_stream, server_stream = trio.testing.memory_stream_pair() server_stream.socket = MockSocket() - server = TCPServer(ASGIWrapper(slow_framework), config, WorkerContext(), server_stream) + server = TCPServer(ASGIWrapper(slow_framework), config, WorkerContext(), {}, server_stream) nursery.start_soon(server.run) yield client_stream diff --git a/tests/trio/test_lifespan.py b/tests/trio/test_lifespan.py index dd8ab775..161e82fa 100644 --- a/tests/trio/test_lifespan.py +++ b/tests/trio/test_lifespan.py @@ -14,7 +14,7 @@ async def test_startup_timeout_error(nursery: trio._core._run.Nursery) -> None: config = Config() config.startup_timeout = 0.01 - lifespan = Lifespan(ASGIWrapper(SlowLifespanFramework(0.02, trio.sleep)), config) + lifespan = Lifespan(ASGIWrapper(SlowLifespanFramework(0.02, trio.sleep)), config, {}) nursery.start_soon(lifespan.handle_lifespan) with pytest.raises(LifespanTimeoutError) as exc_info: await lifespan.wait_for_startup() @@ -23,7 +23,7 @@ async def test_startup_timeout_error(nursery: trio._core._run.Nursery) -> None: @pytest.mark.trio async def test_startup_failure() -> None: - lifespan = Lifespan(ASGIWrapper(lifespan_failure), Config()) + lifespan = Lifespan(ASGIWrapper(lifespan_failure), Config(), {}) with pytest.raises(LifespanFailureError) as exc_info: async with trio.open_nursery() as lifespan_nursery: await lifespan_nursery.start(lifespan.handle_lifespan) diff --git a/tests/trio/test_sanity.py b/tests/trio/test_sanity.py index 3828e37b..2dcd42bb 100644 --- a/tests/trio/test_sanity.py +++ b/tests/trio/test_sanity.py @@ -25,7 +25,7 @@ async def test_http1_request(nursery: trio._core._run.Nursery) -> None: client_stream, server_stream = trio.testing.memory_stream_pair() server_stream.socket = MockSocket() - server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(), server_stream) + server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(), {}, server_stream) nursery.start_soon(server.run) client = h11.Connection(h11.CLIENT) await client_stream.send_all( @@ -76,7 +76,7 @@ async def test_http1_request(nursery: trio._core._run.Nursery) -> None: async def test_http1_websocket(nursery: trio._core._run.Nursery) -> None: client_stream, server_stream = trio.testing.memory_stream_pair() server_stream.socket = MockSocket() - server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(), server_stream) + server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(), {}, server_stream) nursery.start_soon(server.run) client = wsproto.WSConnection(wsproto.ConnectionType.CLIENT) await client_stream.send_all(client.send(wsproto.events.Request(host="hypercorn", target="/"))) @@ -103,7 +103,7 @@ async def test_http2_request(nursery: trio._core._run.Nursery) -> None: server_stream.transport_stream = Mock(return_value=PropertyMock(return_value=MockSocket())) server_stream.do_handshake = AsyncMock() server_stream.selected_alpn_protocol = Mock(return_value="h2") - server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(), server_stream) + server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(), {}, server_stream) nursery.start_soon(server.run) client = h2.connection.H2Connection() client.initiate_connection() @@ -158,7 +158,7 @@ async def test_http2_websocket(nursery: trio._core._run.Nursery) -> None: server_stream.transport_stream = Mock(return_value=PropertyMock(return_value=MockSocket())) server_stream.do_handshake = AsyncMock() server_stream.selected_alpn_protocol = Mock(return_value="h2") - server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(), server_stream) + server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(), {}, server_stream) nursery.start_soon(server.run) h2_client = h2.connection.H2Connection() h2_client.initiate_connection()