Skip to content

Commit

Permalink
Introduce lifespan state (#1818)
Browse files Browse the repository at this point in the history
Co-authored-by: Marcelo Trylesinski <[email protected]>
  • Loading branch information
adriangb and Kludex authored Mar 5, 2023
1 parent c927f7a commit 2a94a96
Show file tree
Hide file tree
Showing 12 changed files with 202 additions and 17 deletions.
47 changes: 45 additions & 2 deletions tests/protocols/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
import socket
import threading
import time
from typing import Optional, Union

import pytest

from tests.response import Response
from uvicorn import Server
from uvicorn.config import WS_PROTOCOLS, Config
from uvicorn.lifespan.off import LifespanOff
from uvicorn.lifespan.on import LifespanOn
from uvicorn.main import ServerState
from uvicorn.protocols.http.h11_impl import H11Protocol

Expand Down Expand Up @@ -184,12 +187,23 @@ def add_done_callback(self, callback):
pass


def get_connected_protocol(app, protocol_cls, **kwargs):
def get_connected_protocol(
app,
protocol_cls,
lifespan: Optional[Union[LifespanOff, LifespanOn]] = None,
**kwargs,
):
loop = MockLoop()
transport = MockTransport()
config = Config(app=app, **kwargs)
lifespan = lifespan or LifespanOff(config)
server_state = ServerState()
protocol = protocol_cls(config=config, server_state=server_state, _loop=loop)
protocol = protocol_cls(
config=config,
server_state=server_state,
app_state=lifespan.state.copy(),
_loop=loop,
)
protocol.connection_made(transport)
return protocol

Expand Down Expand Up @@ -980,3 +994,32 @@ async def app(scope, receive, send):
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert b"x-test-header: test value" in protocol.transport.buffer


@pytest.mark.anyio
@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS)
async def test_lifespan_state(protocol_cls):
expected_states = [{"a": 123, "b": [1]}, {"a": 123, "b": [1, 2]}]

async def app(scope, receive, send):
expected_state = expected_states.pop(0)
assert scope["state"] == expected_state
# modifications to keys are not preserved
scope["state"]["a"] = 456
# unless of course the value itself is mutated
scope["state"]["b"].append(2)
return await Response("Hi!")(scope, receive, send)

lifespan = LifespanOn(config=Config(app=app))
# skip over actually running the lifespan, that is tested
# in the lifespan tests
lifespan.state.update({"a": 123, "b": [1]})

for _ in range(2):
protocol = get_connected_protocol(app, protocol_cls, lifespan=lifespan)
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hi!" in protocol.transport.buffer

assert not expected_states # consumed
55 changes: 55 additions & 0 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import typing
from copy import deepcopy

import httpx
import pytest
Expand Down Expand Up @@ -1087,3 +1088,57 @@ async def open_connection(url):
async with run_server(config):
headers = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}")
assert headers.get_all("Server") == ["uvicorn", "over-ridden", "another-value"]


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_lifespan_state(ws_protocol_cls, http_protocol_cls, unused_tcp_port: int):
expected_states = [
{"a": 123, "b": [1]},
{"a": 123, "b": [1, 2]},
]

actual_states = []

async def lifespan_app(scope, receive, send):
message = await receive()
assert message["type"] == "lifespan.startup"
scope["state"]["a"] = 123
scope["state"]["b"] = [1]
await send({"type": "lifespan.startup.complete"})
message = await receive()
assert message["type"] == "lifespan.shutdown"
await send({"type": "lifespan.shutdown.complete"})

class App(WebSocketResponse):
async def websocket_connect(self, message):
actual_states.append(deepcopy(self.scope["state"]))
self.scope["state"]["a"] = 456
self.scope["state"]["b"].append(2)
await self.send({"type": "websocket.accept"})

async def open_connection(url):
async with websockets.connect(url) as websocket:
return websocket.open

async def app_wrapper(scope, receive, send):
if scope["type"] == "lifespan":
return await lifespan_app(scope, receive, send)
else:
return await App(scope, receive, send)

config = Config(
app=app_wrapper,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="on",
port=unused_tcp_port,
)
async with run_server(config):
is_open = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}")
assert is_open
is_open = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}")
assert is_open

assert expected_states == actual_states
6 changes: 4 additions & 2 deletions tests/test_auto_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_loop_auto():
async def test_http_auto():
config = Config(app=app)
server_state = ServerState()
protocol = AutoHTTPProtocol(config=config, server_state=server_state)
protocol = AutoHTTPProtocol(config=config, server_state=server_state, app_state={})
expected_http = "H11Protocol" if httptools is None else "HttpToolsProtocol"
assert type(protocol).__name__ == expected_http

Expand All @@ -54,6 +54,8 @@ async def test_http_auto():
async def test_websocket_auto():
config = Config(app=app)
server_state = ServerState()
protocol = AutoWebSocketsProtocol(config=config, server_state=server_state)
protocol = AutoWebSocketsProtocol(
config=config, server_state=server_state, app_state={}
)
expected_websockets = "WSProtocol" if websockets is None else "WebSocketProtocol"
assert type(protocol).__name__ == expected_websockets
25 changes: 25 additions & 0 deletions tests/test_lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ async def asgi3app(scope, receive, send):
assert scope == {
"type": "lifespan",
"asgi": {"version": "3.0", "spec_version": "2.0"},
"state": {},
}

async def test():
Expand All @@ -188,6 +189,7 @@ def asgi2app(scope):
assert scope == {
"type": "lifespan",
"asgi": {"version": "2.0", "spec_version": "2.0"},
"state": {},
}

async def asgi(receive, send):
Expand Down Expand Up @@ -245,3 +247,26 @@ async def test():
assert "the lifespan event failed" in error_messages.pop(0)
assert "Application shutdown failed. Exiting." in error_messages.pop(0)
loop.close()


def test_lifespan_state():
async def app(scope, receive, send):
message = await receive()
assert message["type"] == "lifespan.startup"
await send({"type": "lifespan.startup.complete"})
scope["state"]["foo"] = 123
message = await receive()
assert message["type"] == "lifespan.shutdown"
await send({"type": "lifespan.shutdown.complete"})

async def test():
config = Config(app=app, lifespan="on")
lifespan = LifespanOn(config)

await lifespan.startup()
assert lifespan.state == {"foo": 123}
await lifespan.shutdown()

loop = asyncio.new_event_loop()
loop.run_until_complete(test())
loop.close()
3 changes: 3 additions & 0 deletions uvicorn/lifespan/off.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Any, Dict

from uvicorn import Config


class LifespanOff:
def __init__(self, config: Config) -> None:
self.should_exit = False
self.state: Dict[str, Any] = {}

async def startup(self) -> None:
pass
Expand Down
6 changes: 4 additions & 2 deletions uvicorn/lifespan/on.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import logging
from asyncio import Queue
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING, Any, Dict, Union

from uvicorn import Config

Expand Down Expand Up @@ -42,6 +42,7 @@ def __init__(self, config: Config) -> None:
self.startup_failed = False
self.shutdown_failed = False
self.should_exit = False
self.state: Dict[str, Any] = {}

async def startup(self) -> None:
self.logger.info("Waiting for application startup.")
Expand Down Expand Up @@ -79,9 +80,10 @@ async def shutdown(self) -> None:
async def main(self) -> None:
try:
app = self.config.loaded_app
scope: LifespanScope = {
scope: LifespanScope = { # type: ignore[typeddict-item]
"type": "lifespan",
"asgi": {"version": self.config.asgi_version, "spec_version": "2.0"},
"state": self.state,
}
await app(scope, self.receive, self.send)
except BaseException as exc:
Expand Down
20 changes: 18 additions & 2 deletions uvicorn/protocols/http/h11_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@
import http
import logging
import sys
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Union,
cast,
)
from urllib.parse import unquote

import h11
Expand Down Expand Up @@ -42,6 +52,7 @@
HTTPScope,
)


H11Event = Union[
h11.Request,
h11.InformationalResponse,
Expand Down Expand Up @@ -69,6 +80,7 @@ def __init__(
self,
config: Config,
server_state: ServerState,
app_state: Dict[str, Any],
_loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
if not config.loaded:
Expand All @@ -89,6 +101,7 @@ def __init__(
self.ws_protocol_class = config.ws_protocol_class
self.root_path = config.root_path
self.limit_concurrency = config.limit_concurrency
self.app_state = app_state

# Timeouts
self.timeout_keep_alive_task: Optional[asyncio.TimerHandle] = None
Expand Down Expand Up @@ -229,6 +242,7 @@ def handle_events(self) -> None:
"raw_path": raw_path,
"query_string": query_string,
"headers": self.headers,
"state": self.app_state,
}

upgrade = self._get_upgrade()
Expand Down Expand Up @@ -290,7 +304,9 @@ def handle_websocket_upgrade(self, event: H11Event) -> None:
output += [name, b": ", value, b"\r\n"]
output.append(b"\r\n")
protocol = self.ws_protocol_class( # type: ignore[call-arg, misc]
config=self.config, server_state=self.server_state
config=self.config,
server_state=self.server_state,
app_state=self.app_state,
)
protocol.connection_made(self.transport)
protocol.data_received(b"".join(output))
Expand Down
21 changes: 19 additions & 2 deletions uvicorn/protocols/http/httptools_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,18 @@
import urllib
from asyncio.events import TimerHandle
from collections import deque
from typing import TYPE_CHECKING, Callable, Deque, List, Optional, Tuple, Union, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Deque,
Dict,
List,
Optional,
Tuple,
Union,
cast,
)

import httptools

Expand Down Expand Up @@ -44,6 +55,7 @@
HTTPScope,
)


HEADER_RE = re.compile(b'[\x00-\x1F\x7F()<>@,;:[]={} \t\\"]')
HEADER_VALUE_RE = re.compile(b"[\x00-\x1F\x7F]")

Expand All @@ -66,6 +78,7 @@ def __init__(
self,
config: Config,
server_state: ServerState,
app_state: Dict[str, Any],
_loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
if not config.loaded:
Expand All @@ -81,6 +94,7 @@ def __init__(
self.ws_protocol_class = config.ws_protocol_class
self.root_path = config.root_path
self.limit_concurrency = config.limit_concurrency
self.app_state = app_state

# Timeouts
self.timeout_keep_alive_task: Optional[TimerHandle] = None
Expand Down Expand Up @@ -201,7 +215,9 @@ def handle_websocket_upgrade(self) -> None:
output += [name, b": ", value, b"\r\n"]
output.append(b"\r\n")
protocol = self.ws_protocol_class( # type: ignore[call-arg, misc]
config=self.config, server_state=self.server_state
config=self.config,
server_state=self.server_state,
app_state=self.app_state,
)
protocol.connection_made(self.transport)
protocol.data_received(b"".join(output))
Expand Down Expand Up @@ -237,6 +253,7 @@ def on_message_begin(self) -> None:
"scheme": self.scheme,
"root_path": self.root_path,
"headers": self.headers,
"state": self.app_state,
}

# Parser callbacks
Expand Down
2 changes: 1 addition & 1 deletion uvicorn/protocols/websockets/auto.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import typing

AutoWebSocketsProtocol: typing.Optional[typing.Type[asyncio.Protocol]]
AutoWebSocketsProtocol: typing.Optional[typing.Callable[..., asyncio.Protocol]]
try:
import websockets # noqa
except ImportError: # pragma: no cover
Expand Down
Loading

0 comments on commit 2a94a96

Please sign in to comment.