diff --git a/src/hypercorn/__main__.py b/src/hypercorn/__main__.py index 769a5e09..aed33b12 100644 --- a/src/hypercorn/__main__.py +++ b/src/hypercorn/__main__.py @@ -89,6 +89,19 @@ def main(sys_args: Optional[List[str]] = None) -> int: default=sentinel, type=int, ) + parser.add_argument( + "--max-requests", + help="""Maximum number of requests a worker will process before restarting""", + default=sentinel, + type=int, + ) + parser.add_argument( + "--max-requests-jitter", + help="This jitter causes the max-requests per worker to be " + "randomized by randint(0, max_requests_jitter)", + default=sentinel, + type=int, + ) parser.add_argument( "-g", "--group", help="Group to own any unix sockets.", default=sentinel, type=int ) @@ -252,6 +265,10 @@ def _convert_verify_mode(value: str) -> ssl.VerifyMode: config.keyfile_password = args.keyfile_password if args.log_config is not sentinel: config.logconfig = args.log_config + if args.max_requests is not sentinel: + config.max_requests = args.max_requests + if args.max_requests_jitter is not sentinel: + config.max_requests_jitter = args.max_requests if args.pid is not sentinel: config.pid_path = args.pid if args.root_path is not sentinel: diff --git a/src/hypercorn/asyncio/run.py b/src/hypercorn/asyncio/run.py index 7c0982d9..c633c5bd 100644 --- a/src/hypercorn/asyncio/run.py +++ b/src/hypercorn/asyncio/run.py @@ -4,9 +4,11 @@ import platform import signal import ssl +import sys from functools import partial from multiprocessing.synchronize import Event as EventType from os import getpid +from random import randint from socket import socket from typing import Any, Awaitable, Callable, Optional, Set @@ -30,6 +32,14 @@ except ImportError: from taskgroup import Runner # type: ignore +try: + from asyncio import TaskGroup +except ImportError: + from taskgroup import TaskGroup # type: ignore + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + def _share_socket(sock: socket) -> socket: # Windows requires the socket be explicitly shared across @@ -84,7 +94,10 @@ def _signal_handler(*_: Any) -> None: # noqa: N803 ssl_context = config.create_ssl_context() ssl_handshake_timeout = config.ssl_handshake_timeout - context = WorkerContext() + max_requests = None + if config.max_requests is not None: + max_requests = config.max_requests + randint(0, config.max_requests_jitter) + context = WorkerContext(max_requests) server_tasks: Set[asyncio.Task] = set() async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: @@ -136,7 +149,13 @@ async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamW await config.log.info(f"Running on https://{bind} (QUIC) (CTRL + C to quit)") try: - await raise_shutdown(shutdown_trigger) + async with TaskGroup() as task_group: + task_group.create_task(raise_shutdown(shutdown_trigger)) + task_group.create_task(raise_shutdown(context.terminate.wait)) + except BaseExceptionGroup as error: + _, other_errors = error.split((ShutdownError, KeyboardInterrupt)) + if other_errors is not None: + raise other_errors except (ShutdownError, KeyboardInterrupt): pass finally: diff --git a/src/hypercorn/asyncio/worker_context.py b/src/hypercorn/asyncio/worker_context.py index fe9ad1c7..d16f76ba 100644 --- a/src/hypercorn/asyncio/worker_context.py +++ b/src/hypercorn/asyncio/worker_context.py @@ -1,7 +1,7 @@ from __future__ import annotations import asyncio -from typing import Type, Union +from typing import Optional, Type, Union from ..typing import Event @@ -26,9 +26,20 @@ def is_set(self) -> bool: class WorkerContext: event_class: Type[Event] = EventWrapper - def __init__(self) -> None: + def __init__(self, max_requests: Optional[int]) -> None: + self.max_requests = max_requests + self.requests = 0 + self.terminate = self.event_class() self.terminated = self.event_class() + async def mark_request(self) -> None: + if self.max_requests is None: + return + + self.requests += 1 + if self.requests > self.max_requests: + await self.terminate.set() + @staticmethod async def sleep(wait: Union[float, int]) -> None: return await asyncio.sleep(wait) diff --git a/src/hypercorn/config.py b/src/hypercorn/config.py index fdc7a413..f00c7d5e 100644 --- a/src/hypercorn/config.py +++ b/src/hypercorn/config.py @@ -92,6 +92,8 @@ class Config: logger_class = Logger loglevel: str = "INFO" max_app_queue_size: int = 10 + max_requests: Optional[int] = None + max_requests_jitter: int = 0 pid_path: Optional[str] = None server_names: List[str] = [] shutdown_timeout = 60 * SECONDS diff --git a/src/hypercorn/protocol/h11.py b/src/hypercorn/protocol/h11.py index ec04593a..a33ad4ad 100755 --- a/src/hypercorn/protocol/h11.py +++ b/src/hypercorn/protocol/h11.py @@ -236,6 +236,7 @@ async def _create_stream(self, request: h11.Request) -> None: ) ) self.keep_alive_requests += 1 + await self.context.mark_request() async def _send_h11_event(self, event: H11SendableEvent) -> None: try: diff --git a/src/hypercorn/protocol/h2.py b/src/hypercorn/protocol/h2.py index 26048780..9c92ab3d 100755 --- a/src/hypercorn/protocol/h2.py +++ b/src/hypercorn/protocol/h2.py @@ -354,6 +354,7 @@ async def _create_stream(self, request: h2.events.RequestReceived) -> None: ) ) self.keep_alive_requests += 1 + await self.context.mark_request() async def _create_server_push( self, stream_id: int, path: bytes, headers: List[Tuple[bytes, bytes]] diff --git a/src/hypercorn/protocol/h3.py b/src/hypercorn/protocol/h3.py index 88d9a4d3..151c0667 100644 --- a/src/hypercorn/protocol/h3.py +++ b/src/hypercorn/protocol/h3.py @@ -125,6 +125,7 @@ async def _create_stream(self, request: HeadersReceived) -> None: raw_path=raw_path, ) ) + await self.context.mark_request() async def _create_server_push( self, stream_id: int, path: bytes, headers: List[Tuple[bytes, bytes]] diff --git a/src/hypercorn/run.py b/src/hypercorn/run.py index 05ab2391..cfe801aa 100644 --- a/src/hypercorn/run.py +++ b/src/hypercorn/run.py @@ -4,6 +4,7 @@ import signal import time from multiprocessing import get_context +from multiprocessing.connection import wait from multiprocessing.context import BaseContext from multiprocessing.process import BaseProcess from multiprocessing.synchronize import Event as EventType @@ -12,12 +13,10 @@ from .config import Config, Sockets from .typing import WorkerFunc -from .utils import load_application, wait_for_changes, write_pid_file +from .utils import check_for_updates, files_to_watch, load_application, write_pid_file def run(config: Config) -> int: - exit_code = 0 - if config.pid_path is not None: write_pid_file(config.pid_path) @@ -42,67 +41,82 @@ def run(config: Config) -> int: if config.use_reloader and config.workers == 0: raise RuntimeError("Cannot reload without workers") - if config.use_reloader or config.workers == 0: - # Load the application so that the correct paths are checked for - # changes, but only when the reloader is being used. - load_application(config.application_path, config.wsgi_max_body_size) - + exitcode = 0 if config.workers == 0: worker_func(config, sockets) else: + if config.use_reloader: + # Load the application so that the correct paths are checked for + # changes, but only when the reloader is being used. + load_application(config.application_path, config.wsgi_max_body_size) + ctx = get_context("spawn") active = True + shutdown_event = ctx.Event() + + def shutdown(*args: Any) -> None: + nonlocal active, shutdown_event + shutdown_event.set() + active = False + + processes: List[BaseProcess] = [] while active: # Ignore SIGINT before creating the processes, so that they # inherit the signal handling. This means that the shutdown # function controls the shutdown. signal.signal(signal.SIGINT, signal.SIG_IGN) - shutdown_event = ctx.Event() - processes = start_processes(config, worker_func, sockets, shutdown_event, ctx) - - def shutdown(*args: Any) -> None: - nonlocal active, shutdown_event - shutdown_event.set() - active = False + _populate(processes, config, worker_func, sockets, shutdown_event, ctx) for signal_name in {"SIGINT", "SIGTERM", "SIGBREAK"}: if hasattr(signal, signal_name): signal.signal(getattr(signal, signal_name), shutdown) if config.use_reloader: - wait_for_changes(shutdown_event) - shutdown_event.set() + files = files_to_watch() + while True: + finished = wait((process.sentinel for process in processes), timeout=1) + updated = check_for_updates(files) + if updated: + shutdown_event.set() + for process in processes: + process.join() + shutdown_event.clear() + break + if len(finished) > 0: + break else: - active = False + wait(process.sentinel for process in processes) - for process in processes: - process.join() - if process.exitcode != 0: - exit_code = process.exitcode + exitcode = _join_exited(processes) + if exitcode != 0: + shutdown_event.set() + active = False for process in processes: process.terminate() + exitcode = _join_exited(processes) if exitcode != 0 else exitcode + for sock in sockets.secure_sockets: sock.close() for sock in sockets.insecure_sockets: sock.close() - return exit_code + return exitcode -def start_processes( +def _populate( + processes: List[BaseProcess], config: Config, worker_func: WorkerFunc, sockets: Sockets, shutdown_event: EventType, ctx: BaseContext, -) -> List[BaseProcess]: - processes = [] - for _ in range(config.workers): +) -> None: + for _ in range(config.workers - len(processes)): process = ctx.Process( # type: ignore target=worker_func, kwargs={"config": config, "shutdown_event": shutdown_event, "sockets": sockets}, @@ -117,4 +131,15 @@ def start_processes( processes.append(process) if platform.system() == "Windows": time.sleep(0.1) - return processes + + +def _join_exited(processes: List[BaseProcess]) -> int: + exitcode = 0 + for index in reversed(range(len(processes))): + worker = processes[index] + if worker.exitcode is not None: + worker.join() + exitcode = worker.exitcode if exitcode == 0 else exitcode + del processes[index] + + return exitcode diff --git a/src/hypercorn/trio/run.py b/src/hypercorn/trio/run.py index d8721bbb..2cfe5db4 100644 --- a/src/hypercorn/trio/run.py +++ b/src/hypercorn/trio/run.py @@ -3,6 +3,7 @@ import sys from functools import partial from multiprocessing.synchronize import Event as EventType +from random import randint from typing import Awaitable, Callable, Optional import trio @@ -37,7 +38,10 @@ async def worker_serve( config.set_statsd_logger_class(StatsdLogger) lifespan = Lifespan(app, config) - context = WorkerContext() + max_requests = None + if config.max_requests is not None: + max_requests = config.max_requests + randint(0, config.max_requests_jitter) + context = WorkerContext(max_requests) async with trio.open_nursery() as lifespan_nursery: await lifespan_nursery.start(lifespan.handle_lifespan) @@ -82,6 +86,7 @@ async def worker_serve( async with trio.open_nursery(strict_exception_groups=True) as nursery: if shutdown_trigger is not None: nursery.start_soon(raise_shutdown, shutdown_trigger) + nursery.start_soon(raise_shutdown, context.terminate.wait) nursery.start_soon( partial( diff --git a/src/hypercorn/trio/worker_context.py b/src/hypercorn/trio/worker_context.py index bcfa1a51..c09c4fb6 100644 --- a/src/hypercorn/trio/worker_context.py +++ b/src/hypercorn/trio/worker_context.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Type, Union +from typing import Optional, Type, Union import trio @@ -27,9 +27,20 @@ def is_set(self) -> bool: class WorkerContext: event_class: Type[Event] = EventWrapper - def __init__(self) -> None: + def __init__(self, max_requests: Optional[int]) -> None: + self.max_requests = max_requests + self.requests = 0 + self.terminate = self.event_class() self.terminated = self.event_class() + async def mark_request(self) -> None: + if self.max_requests is None: + return + + self.requests += 1 + if self.requests > self.max_requests: + await self.terminate.set() + @staticmethod async def sleep(wait: Union[float, int]) -> None: return await trio.sleep(wait) diff --git a/src/hypercorn/typing.py b/src/hypercorn/typing.py index 1299a776..2ebb711d 100644 --- a/src/hypercorn/typing.py +++ b/src/hypercorn/typing.py @@ -290,8 +290,12 @@ def is_set(self) -> bool: class WorkerContext(Protocol): event_class: Type[Event] + terminate: Event terminated: Event + async def mark_request(self) -> None: + ... + @staticmethod async def sleep(wait: Union[float, int]) -> None: ... diff --git a/src/hypercorn/utils.py b/src/hypercorn/utils.py index 9e3520d7..39249c53 100644 --- a/src/hypercorn/utils.py +++ b/src/hypercorn/utils.py @@ -4,7 +4,6 @@ import os import socket import sys -import time from enum import Enum from importlib import import_module from multiprocessing.synchronize import Event as EventType @@ -133,7 +132,7 @@ def wrap_app( return WSGIWrapper(cast(WSGIFramework, app), wsgi_max_body_size) -def wait_for_changes(shutdown_event: EventType) -> None: +def files_to_watch() -> Dict[Path, float]: last_updates: Dict[Path, float] = {} for module in list(sys.modules.values()): filename = getattr(module, "__file__", None) @@ -144,24 +143,21 @@ def wait_for_changes(shutdown_event: EventType) -> None: last_updates[Path(filename)] = path.stat().st_mtime except (FileNotFoundError, NotADirectoryError): pass + return last_updates - while not shutdown_event.is_set(): - time.sleep(1) - for index, (path, last_mtime) in enumerate(last_updates.items()): - if index % 10 == 0: - # Yield to the event loop - time.sleep(0) - - try: - mtime = path.stat().st_mtime - except FileNotFoundError: - return +def check_for_updates(files: Dict[Path, float]) -> bool: + for path, last_mtime in files.items(): + try: + mtime = path.stat().st_mtime + except FileNotFoundError: + return True + else: + if mtime > last_mtime: + return True else: - if mtime > last_mtime: - return - else: - last_updates[path] = mtime + files[path] = mtime + return False async def raise_shutdown(shutdown_event: Callable[..., Awaitable]) -> None: diff --git a/tests/asyncio/test_keep_alive.py b/tests/asyncio/test_keep_alive.py index 6b357f8f..a46f4cfd 100644 --- a/tests/asyncio/test_keep_alive.py +++ b/tests/asyncio/test_keep_alive.py @@ -50,7 +50,7 @@ async def _server(event_loop: asyncio.AbstractEventLoop) -> AsyncGenerator[TCPSe ASGIWrapper(slow_framework), event_loop, config, - WorkerContext(), + WorkerContext(None), MemoryReader(), # type: ignore MemoryWriter(), # type: ignore ) diff --git a/tests/asyncio/test_sanity.py b/tests/asyncio/test_sanity.py index 2d7cb0bc..cde29297 100644 --- a/tests/asyncio/test_sanity.py +++ b/tests/asyncio/test_sanity.py @@ -21,7 +21,7 @@ async def test_http1_request(event_loop: asyncio.AbstractEventLoop) -> None: ASGIWrapper(sanity_framework), event_loop, Config(), - WorkerContext(), + WorkerContext(None), MemoryReader(), # type: ignore MemoryWriter(), # type: ignore ) @@ -78,7 +78,7 @@ async def test_http1_websocket(event_loop: asyncio.AbstractEventLoop) -> None: ASGIWrapper(sanity_framework), event_loop, Config(), - WorkerContext(), + WorkerContext(None), MemoryReader(), # type: ignore MemoryWriter(), # type: ignore ) @@ -115,7 +115,7 @@ async def test_http2_request(event_loop: asyncio.AbstractEventLoop) -> None: ASGIWrapper(sanity_framework), event_loop, Config(), - WorkerContext(), + WorkerContext(None), MemoryReader(), # type: ignore MemoryWriter(http2=True), # type: ignore ) @@ -178,7 +178,7 @@ async def test_http2_websocket(event_loop: asyncio.AbstractEventLoop) -> None: ASGIWrapper(sanity_framework), event_loop, Config(), - WorkerContext(), + WorkerContext(None), MemoryReader(), # type: ignore MemoryWriter(http2=True), # type: ignore ) diff --git a/tests/asyncio/test_tcp_server.py b/tests/asyncio/test_tcp_server.py index afe00c20..1dfd4212 100644 --- a/tests/asyncio/test_tcp_server.py +++ b/tests/asyncio/test_tcp_server.py @@ -18,7 +18,7 @@ async def test_completes_on_closed(event_loop: asyncio.AbstractEventLoop) -> Non ASGIWrapper(echo_framework), event_loop, Config(), - WorkerContext(), + WorkerContext(None), MemoryReader(), # type: ignore MemoryWriter(), # type: ignore ) @@ -34,7 +34,7 @@ async def test_complets_on_half_close(event_loop: asyncio.AbstractEventLoop) -> ASGIWrapper(echo_framework), event_loop, Config(), - WorkerContext(), + WorkerContext(None), MemoryReader(), # type: ignore MemoryWriter(), # type: ignore ) diff --git a/tests/protocol/test_h11.py b/tests/protocol/test_h11.py index 27c80b56..09e85b78 100755 --- a/tests/protocol/test_h11.py +++ b/tests/protocol/test_h11.py @@ -35,6 +35,8 @@ async def _protocol(monkeypatch: MonkeyPatch) -> H11Protocol: monkeypatch.setattr(hypercorn.protocol.h11, "HTTPStream", MockHTTPStream) context = Mock() context.event_class.return_value = AsyncMock(spec=IOEvent) + context.mark_request = AsyncMock() + context.terminate = context.event_class() context.terminated = context.event_class() context.terminated.is_set.return_value = False return H11Protocol(AsyncMock(), Config(), context, AsyncMock(), False, None, None, AsyncMock()) diff --git a/tests/protocol/test_h2.py b/tests/protocol/test_h2.py index 77bcf515..cec6c263 100644 --- a/tests/protocol/test_h2.py +++ b/tests/protocol/test_h2.py @@ -75,7 +75,7 @@ async def test_stream_buffer_complete(event_loop: asyncio.AbstractEventLoop) -> @pytest.mark.asyncio async def test_protocol_handle_protocol_error() -> None: protocol = H2Protocol( - Mock(), Config(), WorkerContext(), AsyncMock(), False, None, None, AsyncMock() + Mock(), Config(), WorkerContext(None), AsyncMock(), False, None, None, AsyncMock() ) await protocol.handle(RawData(data=b"broken nonsense\r\n\r\n")) protocol.send.assert_awaited() # type: ignore @@ -85,7 +85,7 @@ async def test_protocol_handle_protocol_error() -> None: @pytest.mark.asyncio async def test_protocol_keep_alive_max_requests() -> None: protocol = H2Protocol( - Mock(), Config(), WorkerContext(), AsyncMock(), False, None, None, AsyncMock() + Mock(), Config(), WorkerContext(None), AsyncMock(), False, None, None, AsyncMock() ) protocol.config.keep_alive_max_requests = 0 client = H2Connection() diff --git a/tests/protocol/test_http_stream.py b/tests/protocol/test_http_stream.py index 24af5969..6f656de0 100644 --- a/tests/protocol/test_http_stream.py +++ b/tests/protocol/test_http_stream.py @@ -31,7 +31,7 @@ @pytest_asyncio.fixture(name="stream") # type: ignore[misc] async def _stream() -> HTTPStream: stream = HTTPStream( - AsyncMock(), Config(), WorkerContext(), AsyncMock(), False, None, None, AsyncMock(), 1 + AsyncMock(), Config(), WorkerContext(None), AsyncMock(), False, None, None, AsyncMock(), 1 ) stream.app_put = AsyncMock() stream.config._log = AsyncMock(spec=Logger) diff --git a/tests/protocol/test_ws_stream.py b/tests/protocol/test_ws_stream.py index 5f595828..05403130 100644 --- a/tests/protocol/test_ws_stream.py +++ b/tests/protocol/test_ws_stream.py @@ -165,7 +165,7 @@ def test_handshake_accept_additional_headers() -> None: @pytest_asyncio.fixture(name="stream") # type: ignore[misc] async def _stream() -> WSStream: stream = WSStream( - AsyncMock(), Config(), WorkerContext(), AsyncMock(), False, None, None, AsyncMock(), 1 + AsyncMock(), Config(), WorkerContext(None), AsyncMock(), False, None, None, AsyncMock(), 1 ) stream.task_group.spawn_app.return_value = AsyncMock() # type: ignore stream.app_put = AsyncMock() diff --git a/tests/trio/test_keep_alive.py b/tests/trio/test_keep_alive.py index d30d82db..6bed437f 100644 --- a/tests/trio/test_keep_alive.py +++ b/tests/trio/test_keep_alive.py @@ -47,7 +47,7 @@ def _client_stream( config.keep_alive_timeout = KEEP_ALIVE_TIMEOUT client_stream, server_stream = trio.testing.memory_stream_pair() server_stream.socket = MockSocket() - server = TCPServer(ASGIWrapper(slow_framework), config, WorkerContext(), server_stream) + server = TCPServer(ASGIWrapper(slow_framework), config, WorkerContext(None), server_stream) nursery.start_soon(server.run) yield client_stream diff --git a/tests/trio/test_sanity.py b/tests/trio/test_sanity.py index 6d4be8c5..b5bf75ba 100644 --- a/tests/trio/test_sanity.py +++ b/tests/trio/test_sanity.py @@ -25,7 +25,7 @@ async def test_http1_request(nursery: trio._core._run.Nursery) -> None: client_stream, server_stream = trio.testing.memory_stream_pair() server_stream.socket = MockSocket() - server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(), server_stream) + server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(None), server_stream) nursery.start_soon(server.run) client = h11.Connection(h11.CLIENT) await client_stream.send_all( @@ -76,7 +76,7 @@ async def test_http1_request(nursery: trio._core._run.Nursery) -> None: async def test_http1_websocket(nursery: trio._core._run.Nursery) -> None: client_stream, server_stream = trio.testing.memory_stream_pair() server_stream.socket = MockSocket() - server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(), server_stream) + server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(None), server_stream) nursery.start_soon(server.run) client = wsproto.WSConnection(wsproto.ConnectionType.CLIENT) await client_stream.send_all(client.send(wsproto.events.Request(host="hypercorn", target="/"))) @@ -103,7 +103,7 @@ async def test_http2_request(nursery: trio._core._run.Nursery) -> None: server_stream.transport_stream = Mock(return_value=PropertyMock(return_value=MockSocket())) server_stream.do_handshake = AsyncMock() server_stream.selected_alpn_protocol = Mock(return_value="h2") - server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(), server_stream) + server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(None), server_stream) nursery.start_soon(server.run) client = h2.connection.H2Connection() client.initiate_connection() @@ -158,7 +158,7 @@ async def test_http2_websocket(nursery: trio._core._run.Nursery) -> None: server_stream.transport_stream = Mock(return_value=PropertyMock(return_value=MockSocket())) server_stream.do_handshake = AsyncMock() server_stream.selected_alpn_protocol = Mock(return_value="h2") - server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(), server_stream) + server = TCPServer(ASGIWrapper(sanity_framework), Config(), WorkerContext(None), server_stream) nursery.start_soon(server.run) h2_client = h2.connection.H2Connection() h2_client.initiate_connection()