Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support lifespan_scope["state"] #110

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/hypercorn/asyncio/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, Callable

from ..config import Config
from ..typing import AppWrapper, ASGIReceiveEvent, ASGISendEvent, LifespanScope
from ..typing import AppWrapper, ASGIReceiveEvent, ASGISendEvent, LifespanScope, LifespanState
from ..utils import LifespanFailureError, LifespanTimeoutError


Expand All @@ -14,14 +14,21 @@ class UnexpectedMessageError(Exception):


class Lifespan:
def __init__(self, app: AppWrapper, config: Config, loop: asyncio.AbstractEventLoop) -> None:
def __init__(
self,
app: AppWrapper,
config: Config,
loop: asyncio.AbstractEventLoop,
lifespan_state: LifespanState,
) -> None:
self.app = app
self.config = config
self.startup = asyncio.Event()
self.shutdown = asyncio.Event()
self.app_queue: asyncio.Queue = asyncio.Queue(config.max_app_queue_size)
self.supported = True
self.loop = loop
self.state = lifespan_state

# This mimics the Trio nursery.start task_status and is
# required to ensure the support has been checked before
Expand All @@ -33,6 +40,7 @@ async def handle_lifespan(self) -> None:
scope: LifespanScope = {
"type": "lifespan",
"asgi": {"spec_version": "2.0", "version": "3.0"},
"state": self.state,
}

def _call_soon(func: Callable, *args: Any) -> Any:
Expand Down
9 changes: 5 additions & 4 deletions src/hypercorn/asyncio/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .udp_server import UDPServer
from .worker_context import WorkerContext
from ..config import Config, Sockets
from ..typing import AppWrapper
from ..typing import AppWrapper, LifespanState
from ..utils import (
check_multiprocess_shutdown_event,
load_application,
Expand Down Expand Up @@ -71,7 +71,8 @@ def _signal_handler(*_: Any) -> None: # noqa: N803

shutdown_trigger = signal_event.wait # type: ignore

lifespan = Lifespan(app, config, loop)
lifespan_state: LifespanState = {}
lifespan = Lifespan(app, config, loop, lifespan_state)

lifespan_task = loop.create_task(lifespan.handle_lifespan())
await lifespan.wait_for_startup()
Expand All @@ -93,7 +94,7 @@ def _signal_handler(*_: Any) -> None: # noqa: N803

async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
server_tasks.add(asyncio.current_task(loop))
await TCPServer(app, loop, config, context, reader, writer)
await TCPServer(app, loop, config, context, lifespan_state, reader, writer)

servers = []
for sock in sockets.secure_sockets:
Expand Down Expand Up @@ -127,7 +128,7 @@ async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamW
sock = _share_socket(sock)

_, protocol = await loop.create_datagram_endpoint(
lambda: UDPServer(app, loop, config, context), sock=sock
lambda: UDPServer(app, loop, config, context, lifespan_state), sock=sock
)
server_tasks.add(loop.create_task(protocol.run()))
bind = repr_socket_addr(sock.family, sock.getsockname())
Expand Down
5 changes: 4 additions & 1 deletion src/hypercorn/asyncio/tcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..config import Config
from ..events import Closed, Event, RawData, Updated
from ..protocol import ProtocolWrapper
from ..typing import AppWrapper
from ..typing import AppWrapper, ConnectionState, LifespanState
from ..utils import parse_socket_addr

MAX_RECV = 2**16
Expand All @@ -22,6 +22,7 @@ def __init__(
loop: asyncio.AbstractEventLoop,
config: Config,
context: WorkerContext,
state: LifespanState,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
) -> None:
Expand All @@ -34,6 +35,7 @@ def __init__(
self.writer = writer
self.send_lock = asyncio.Lock()
self.idle_lock = asyncio.Lock()
self.state = state

self._idle_handle: Optional[asyncio.Task] = None

Expand All @@ -59,6 +61,7 @@ async def run(self) -> None:
self.config,
self.context,
task_group,
ConnectionState(self.state.copy()),
ssl,
client,
server,
Expand Down
12 changes: 10 additions & 2 deletions src/hypercorn/asyncio/udp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .worker_context import WorkerContext
from ..config import Config
from ..events import Event, RawData
from ..typing import AppWrapper
from ..typing import AppWrapper, ConnectionState, LifespanState
from ..utils import parse_socket_addr

if TYPE_CHECKING:
Expand All @@ -22,6 +22,7 @@ def __init__(
loop: asyncio.AbstractEventLoop,
config: Config,
context: WorkerContext,
state: LifespanState,
) -> None:
self.app = app
self.config = config
Expand All @@ -30,6 +31,7 @@ def __init__(
self.protocol: "QuicProtocol"
self.protocol_queue: asyncio.Queue = asyncio.Queue(10)
self.transport: Optional[asyncio.DatagramTransport] = None
self.state = state

def connection_made(self, transport: asyncio.DatagramTransport) -> None: # type: ignore
self.transport = transport
Expand All @@ -48,7 +50,13 @@ async def run(self) -> None:
server = parse_socket_addr(socket.family, socket.getsockname())
async with TaskGroup(self.loop) as task_group:
self.protocol = QuicProtocol(
self.app, self.config, self.context, task_group, server, self.protocol_send
self.app,
self.config,
self.context,
task_group,
ConnectionState(self.state.copy()),
server,
self.protocol_send,
)

while not self.context.terminated.is_set() or not self.protocol.idle:
Expand Down
8 changes: 7 additions & 1 deletion src/hypercorn/protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .h11 import H2CProtocolRequiredError, H2ProtocolAssumedError, H11Protocol
from ..config import Config
from ..events import Event, RawData
from ..typing import AppWrapper, TaskGroup, WorkerContext
from ..typing import AppWrapper, ConnectionState, TaskGroup, WorkerContext


class ProtocolWrapper:
Expand All @@ -16,6 +16,7 @@ def __init__(
config: Config,
context: WorkerContext,
task_group: TaskGroup,
state: ConnectionState,
ssl: bool,
client: Optional[Tuple[str, int]],
server: Optional[Tuple[str, int]],
Expand All @@ -30,13 +31,15 @@ def __init__(
self.client = client
self.server = server
self.send = send
self.state = state
self.protocol: Union[H11Protocol, H2Protocol]
if alpn_protocol == "h2":
self.protocol = H2Protocol(
self.app,
self.config,
self.context,
self.task_group,
self.state,
self.ssl,
self.client,
self.server,
Expand All @@ -48,6 +51,7 @@ def __init__(
self.config,
self.context,
self.task_group,
self.state,
self.ssl,
self.client,
self.server,
Expand All @@ -66,6 +70,7 @@ async def handle(self, event: Event) -> None:
self.config,
self.context,
self.task_group,
self.state,
self.ssl,
self.client,
self.server,
Expand All @@ -80,6 +85,7 @@ async def handle(self, event: Event) -> None:
self.config,
self.context,
self.task_group,
self.state,
self.ssl,
self.client,
self.server,
Expand Down
3 changes: 3 additions & 0 deletions src/hypercorn/protocol/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dataclasses import dataclass
from typing import List, Tuple

from hypercorn.typing import ConnectionState


@dataclass(frozen=True)
class Event:
Expand All @@ -15,6 +17,7 @@ class Request(Event):
http_version: str
method: str
raw_path: bytes
state: ConnectionState


@dataclass(frozen=True)
Expand Down
5 changes: 4 additions & 1 deletion src/hypercorn/protocol/h11.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .ws_stream import WSStream
from ..config import Config
from ..events import Closed, Event, RawData, Updated
from ..typing import AppWrapper, H11SendableEvent, TaskGroup, WorkerContext
from ..typing import AppWrapper, ConnectionState, H11SendableEvent, TaskGroup, WorkerContext

STREAM_ID = 1

Expand Down Expand Up @@ -84,6 +84,7 @@ def __init__(
config: Config,
context: WorkerContext,
task_group: TaskGroup,
connection_state: ConnectionState,
ssl: bool,
client: Optional[Tuple[str, int]],
server: Optional[Tuple[str, int]],
Expand All @@ -102,6 +103,7 @@ def __init__(
self.ssl = ssl
self.stream: Optional[Union[HTTPStream, WSStream]] = None
self.task_group = task_group
self.connection_state = connection_state

async def initiate(self) -> None:
pass
Expand Down Expand Up @@ -226,6 +228,7 @@ async def _create_stream(self, request: h11.Request) -> None:
http_version=request.http_version.decode(),
method=request.method.decode("ascii").upper(),
raw_path=request.target,
state=self.connection_state,
)
)

Expand Down
5 changes: 4 additions & 1 deletion src/hypercorn/protocol/h2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .ws_stream import WSStream
from ..config import Config
from ..events import Closed, Event, RawData, Updated
from ..typing import AppWrapper, Event as IOEvent, TaskGroup, WorkerContext
from ..typing import AppWrapper, ConnectionState, Event as IOEvent, TaskGroup, WorkerContext
from ..utils import filter_pseudo_headers

BUFFER_HIGH_WATER = 2 * 2**14 # Twice the default max frame size (two frames worth)
Expand Down Expand Up @@ -84,6 +84,7 @@ def __init__(
config: Config,
context: WorkerContext,
task_group: TaskGroup,
connection_state: ConnectionState,
ssl: bool,
client: Optional[Tuple[str, int]],
server: Optional[Tuple[str, int]],
Expand All @@ -95,6 +96,7 @@ def __init__(
self.config = config
self.context = context
self.task_group = task_group
self.connection_state = connection_state

self.connection = h2.connection.H2Connection(
config=h2.config.H2Configuration(client_side=False, header_encoding=None)
Expand Down Expand Up @@ -347,6 +349,7 @@ async def _create_stream(self, request: h2.events.RequestReceived) -> None:
http_version="2",
method=method,
raw_path=raw_path,
state=self.connection_state,
)
)

Expand Down
5 changes: 4 additions & 1 deletion src/hypercorn/protocol/h3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .http_stream import HTTPStream
from .ws_stream import WSStream
from ..config import Config
from ..typing import AppWrapper, TaskGroup, WorkerContext
from ..typing import AppWrapper, ConnectionState, TaskGroup, WorkerContext
from ..utils import filter_pseudo_headers


Expand All @@ -33,6 +33,7 @@ def __init__(
config: Config,
context: WorkerContext,
task_group: TaskGroup,
state: ConnectionState,
client: Optional[Tuple[str, int]],
server: Optional[Tuple[str, int]],
quic: QuicConnection,
Expand All @@ -47,6 +48,7 @@ def __init__(
self.server = server
self.streams: Dict[int, Union[HTTPStream, WSStream]] = {}
self.task_group = task_group
self.state = state

async def handle(self, quic_event: QuicEvent) -> None:
for event in self.connection.handle_event(quic_event):
Expand Down Expand Up @@ -123,6 +125,7 @@ async def _create_stream(self, request: HeadersReceived) -> None:
http_version="3",
method=method,
raw_path=raw_path,
state=self.state,
)
)

Expand Down
2 changes: 2 additions & 0 deletions src/hypercorn/protocol/http_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ async def handle(self, event: Event) -> None:
"headers": event.headers,
"client": self.client,
"server": self.server,
"state": event.state,
"extensions": {},
}
if event.http_version in PUSH_VERSIONS:
Expand Down Expand Up @@ -144,6 +145,7 @@ async def app_send(self, message: Optional[ASGISendEvent]) -> None:
http_version=self.scope["http_version"],
method="GET",
raw_path=message["path"].encode(),
state=self.scope["state"],
)
)
elif (
Expand Down
5 changes: 4 additions & 1 deletion src/hypercorn/protocol/quic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .h3 import H3Protocol
from ..config import Config
from ..events import Closed, Event, RawData
from ..typing import AppWrapper, TaskGroup, WorkerContext
from ..typing import AppWrapper, ConnectionState, TaskGroup, WorkerContext


class QuicProtocol:
Expand All @@ -32,6 +32,7 @@ def __init__(
config: Config,
context: WorkerContext,
task_group: TaskGroup,
state: ConnectionState,
server: Optional[Tuple[str, int]],
send: Callable[[Event], Awaitable[None]],
) -> None:
Expand All @@ -43,6 +44,7 @@ def __init__(
self.send = send
self.server = server
self.task_group = task_group
self.state = state

self.quic_config = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=False)
self.quic_config.load_cert_chain(certfile=config.certfile, keyfile=config.keyfile)
Expand Down Expand Up @@ -106,6 +108,7 @@ async def _handle_events(
self.config,
self.context,
self.task_group,
self.state,
client,
self.server,
connection,
Expand Down
1 change: 1 addition & 0 deletions src/hypercorn/protocol/ws_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ async def handle(self, event: Event) -> None:
"headers": event.headers,
"client": self.client,
"server": self.server,
"state": event.state,
"subprotocols": self.handshake.subprotocols or [],
"extensions": {"websocket.http.response": {}},
}
Expand Down
Loading