Skip to content

Commit

Permalink
Add WebsocketServer auto_restart parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Apr 26, 2024
1 parent 9f4253e commit f1edf2e
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
mypy pycrdt_websocket tests
- name: Run tests
run: |
pytest -v
pytest -v --color=yes
check_release:
runs-on: ubuntu-latest
Expand Down
32 changes: 26 additions & 6 deletions pycrdt_websocket/websocket_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ class WebsocketServer:
__start_lock: Lock | None = None

def __init__(
self, rooms_ready: bool = True, auto_clean_rooms: bool = True, log: Logger | None = None
self,
rooms_ready: bool = True,
auto_clean_rooms: bool = True,
auto_restart: bool = False,
log: Logger | None = None,
) -> None:
"""Initialize the object.
Expand All @@ -41,10 +45,12 @@ def __init__(
Arguments:
rooms_ready: Whether rooms are ready to be synchronized when opened.
auto_clean_rooms: Whether rooms should be deleted when no client is there anymore.
auto_restart: Whether to restart the server if it crashes.
log: An optional logger.
"""
self.rooms_ready = rooms_ready
self.auto_clean_rooms = auto_clean_rooms
self.auto_restart = auto_restart
self.log = log or getLogger(__name__)
self.rooms = {}

Expand Down Expand Up @@ -159,6 +165,11 @@ async def _serve(self, websocket: Websocket, tg: TaskGroup):
tg.cancel_scope.cancel()

async def __aenter__(self) -> WebsocketServer:
if self.auto_restart:
raise RuntimeError(
"WebsocketServer does not support auto-restart when used as a context manager"
)

async with self._start_lock:
if self._task_group is not None:
raise RuntimeError("WebsocketServer already running")
Expand Down Expand Up @@ -198,11 +209,20 @@ async def start(
if self._task_group is not None:
raise RuntimeError("WebsocketServer already running")

async with create_task_group() as self._task_group:
task_status.started()
self.started.set()
# wait forever
self._task_group.start_soon(Event().wait)
while True:
try:
async with create_task_group() as self._task_group:
if not self.started.is_set():
task_status.started()
self.started.set()
# wait forever
self._task_group.start_soon(Event().wait)
break
except Exception as e:
if not self.auto_restart:
raise e

self.log.error("WebsocketServer crashed, restarting...", exc_info=e)

async def stop(self) -> None:
"""Stop the WebSocket server."""
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def yws_server(request, unused_tcp_port, websocket_server_api):
)
await ensure_server_running("localhost", unused_tcp_port)
pytest.port = unused_tcp_port
yield unused_tcp_port
yield unused_tcp_port, websocket_server
shutdown_event.set()
except Exception:
pass
Expand Down
16 changes: 16 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest
from anyio import sleep

pytestmark = pytest.mark.anyio


@pytest.mark.parametrize("websocket_server_api", ["websocket_server_start_stop"], indirect=True)
@pytest.mark.parametrize("yws_server", [{"auto_restart": True}], indirect=True)
async def test_server_restart(yws_server):
port, server = yws_server

async def raise_error():
raise RuntimeError("foo")

server._task_group.start_soon(raise_error)
await sleep(0.1)

0 comments on commit f1edf2e

Please sign in to comment.