From f1edf2e5e636a400b3bd1d8eaf4506fe6472d4b0 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Thu, 25 Apr 2024 11:42:45 +0200 Subject: [PATCH] Add WebsocketServer auto_restart parameter --- .github/workflows/test.yml | 2 +- pycrdt_websocket/websocket_server.py | 32 ++++++++++++++++++++++------ tests/conftest.py | 2 +- tests/test_server.py | 16 ++++++++++++++ 4 files changed, 44 insertions(+), 8 deletions(-) create mode 100644 tests/test_server.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 446f797..e307b15 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -55,7 +55,7 @@ jobs: mypy pycrdt_websocket tests - name: Run tests run: | - pytest -v + pytest -v --color=yes check_release: runs-on: ubuntu-latest diff --git a/pycrdt_websocket/websocket_server.py b/pycrdt_websocket/websocket_server.py index 0c5f0bd..600a45f 100644 --- a/pycrdt_websocket/websocket_server.py +++ b/pycrdt_websocket/websocket_server.py @@ -21,7 +21,11 @@ class WebsocketServer: __start_lock: Lock | None = None def __init__( - self, rooms_ready: bool = True, auto_clean_rooms: bool = True, log: Logger | None = None + self, + rooms_ready: bool = True, + auto_clean_rooms: bool = True, + auto_restart: bool = False, + log: Logger | None = None, ) -> None: """Initialize the object. @@ -41,10 +45,12 @@ def __init__( Arguments: rooms_ready: Whether rooms are ready to be synchronized when opened. auto_clean_rooms: Whether rooms should be deleted when no client is there anymore. + auto_restart: Whether to restart the server if it crashes. log: An optional logger. """ self.rooms_ready = rooms_ready self.auto_clean_rooms = auto_clean_rooms + self.auto_restart = auto_restart self.log = log or getLogger(__name__) self.rooms = {} @@ -159,6 +165,11 @@ async def _serve(self, websocket: Websocket, tg: TaskGroup): tg.cancel_scope.cancel() async def __aenter__(self) -> WebsocketServer: + if self.auto_restart: + raise RuntimeError( + "WebsocketServer does not support auto-restart when used as a context manager" + ) + async with self._start_lock: if self._task_group is not None: raise RuntimeError("WebsocketServer already running") @@ -198,11 +209,20 @@ async def start( if self._task_group is not None: raise RuntimeError("WebsocketServer already running") - async with create_task_group() as self._task_group: - task_status.started() - self.started.set() - # wait forever - self._task_group.start_soon(Event().wait) + while True: + try: + async with create_task_group() as self._task_group: + if not self.started.is_set(): + task_status.started() + self.started.set() + # wait forever + self._task_group.start_soon(Event().wait) + break + except Exception as e: + if not self.auto_restart: + raise e + + self.log.error("WebsocketServer crashed, restarting...", exc_info=e) async def stop(self) -> None: """Stop the WebSocket server.""" diff --git a/tests/conftest.py b/tests/conftest.py index caba90e..561a186 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,7 +44,7 @@ async def yws_server(request, unused_tcp_port, websocket_server_api): ) await ensure_server_running("localhost", unused_tcp_port) pytest.port = unused_tcp_port - yield unused_tcp_port + yield unused_tcp_port, websocket_server shutdown_event.set() except Exception: pass diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000..6fd2fb1 --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,16 @@ +import pytest +from anyio import sleep + +pytestmark = pytest.mark.anyio + + +@pytest.mark.parametrize("websocket_server_api", ["websocket_server_start_stop"], indirect=True) +@pytest.mark.parametrize("yws_server", [{"auto_restart": True}], indirect=True) +async def test_server_restart(yws_server): + port, server = yws_server + + async def raise_error(): + raise RuntimeError("foo") + + server._task_group.start_soon(raise_error) + await sleep(0.1)