diff --git a/pycrdt_websocket/websocket_server.py b/pycrdt_websocket/websocket_server.py index 6e26e35..2213634 100644 --- a/pycrdt_websocket/websocket_server.py +++ b/pycrdt_websocket/websocket_server.py @@ -18,14 +18,15 @@ class WebsocketServer: auto_clean_rooms: bool rooms: dict[str, YRoom] _started: Event | None = None - _task_group: TaskGroup | None = None + _stopped: Event + _task_group0: TaskGroup | None = None + _task_group1: TaskGroup | None = None __start_lock: Lock | None = None def __init__( self, rooms_ready: bool = True, auto_clean_rooms: bool = True, - auto_restart: bool = False, exception_handler: Callable[[Exception], bool] | None = None, log: Logger | None = None, ) -> None: @@ -47,16 +48,16 @@ def __init__( Arguments: rooms_ready: Whether rooms are ready to be synchronized when opened. auto_clean_rooms: Whether rooms should be deleted when no client is there anymore. - auto_restart: Whether to restart the server if it crashes. - exception_handler: An optional callback to call when an exception is raised. + exception_handler: An optional callback to call when an exception is raised, that + returns True if the exception was handled. log: An optional logger. """ self.rooms_ready = rooms_ready self.auto_clean_rooms = auto_clean_rooms - self.auto_restart = auto_restart self.exception_handler = exception_handler self.log = log or getLogger(__name__) self.rooms = {} + self._stopped = Event() @property def started(self) -> Event: @@ -92,14 +93,14 @@ async def start_room(self, room: YRoom) -> None: Arguments: room: The room to start. """ - if self._task_group is None: + if self._task_group0 is None: raise RuntimeError( "The WebsocketServer is not running: use `async with websocket_server:` or " "`await websocket_server.start()`" ) if not room.started.is_set(): - await self._task_group.start(room.start) + await self._task_group1.start(room.start) def get_room_name(self, room: YRoom) -> str: """Get the name of a room. @@ -150,46 +151,31 @@ async def serve(self, websocket: Websocket) -> None: Arguments: websocket: The WebSocket through which to serve the client. """ - if self._task_group is None: + if self._task_group0 is None: raise RuntimeError( "The WebsocketServer is not running: use `async with websocket_server:` or " "`await websocket_server.start()`" ) try: - async with create_task_group() as tg: - tg.start_soon(self._serve, websocket, tg) + room = await self.get_room(websocket.path) + await self.start_room(room) + await room.serve(websocket) + if self.auto_clean_rooms and not room.clients: + await self.delete_room(room=room) except Exception as exception: - exception_handled = False - if self.exception_handler is not None: - exception_handled = self.exception_handler(exception, self.log) - if not exception_handled: - raise exception - - async def _serve(self, websocket: Websocket, tg: TaskGroup): - room = await self.get_room(websocket.path) - await self.start_room(room) - await room.serve(websocket) - - if self.auto_clean_rooms and not room.clients: - await self.delete_room(room=room) - tg.cancel_scope.cancel() + self._handle_exception(exception) async def __aenter__(self) -> WebsocketServer: - if self.auto_restart: - raise RuntimeError( - "WebsocketServer does not support auto-restart when used as a context manager" - ) - async with self._start_lock: - if self._task_group is not None: + if self._task_group0 is not None: raise RuntimeError("WebsocketServer already running") async with AsyncExitStack() as exit_stack: - tg = create_task_group() - self._task_group = await exit_stack.enter_async_context(tg) + self._task_group0 = await exit_stack.enter_async_context(create_task_group()) + self._task_group1 = await exit_stack.enter_async_context(create_task_group()) self._exit_stack = exit_stack.pop_all() - await tg.start(partial(self.start, from_context_manager=True)) + await self._task_group0.start(partial(self.start, from_context_manager=True)) return self @@ -197,6 +183,13 @@ async def __aexit__(self, exc_type, exc_value, exc_tb): await self.stop() return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) + def _handle_exception(self, exception: Exception) -> None: + exception_handled = False + if self.exception_handler is not None: + exception_handled = self.exception_handler(exception, self.log) + if not exception_handled: + raise exception + async def start( self, *, @@ -208,45 +201,44 @@ async def start( Arguments: task_status: The status to set when the task has started. """ - self._stop_event = Event() - if from_context_manager: task_status.started() self.started.set() - assert self._task_group is not None + assert self._task_group0 is not None # wait until stopped - self._task_group.start_soon(self._stop_event.wait) + self._task_group1.start_soon(self._stopped.wait) return async with self._start_lock: - if self._task_group is not None: + if self._task_group0 is not None: raise RuntimeError("WebsocketServer already running") - while True: - try: - async with create_task_group() as self._task_group: - if not self.started.is_set(): - task_status.started() - self.started.set() - # wait until stopped - self._task_group.start_soon(self._stop_event.wait) - break - except Exception as e: - if not self.auto_restart: - raise e - - self.log.error("WebsocketServer crashed, restarting...", exc_info=e) + async with create_task_group() as self._task_group0: + while True: + try: + async with create_task_group() as self._task_group1: + if not self.started.is_set(): + task_status.started() + self.started.set() + # wait until stopped + self._task_group1.start_soon(self._stopped.wait) + return + except Exception as exception: + self._handle_exception(exception) async def stop(self) -> None: """Stop the WebSocket server.""" - if self._task_group is None: + if self._task_group0 is None: raise RuntimeError("WebsocketServer not running") - self._stop_event.set() - self._task_group.cancel_scope.cancel() - self._task_group = None + self._stopped.set() + self._task_group0.cancel_scope.cancel() + self._task_group0 = None + self._task_group1 = None def exception_logger(exception: Exception, log: Logger) -> bool: + """An exception handler that logs the exception and discards it. + """ log.error("WebsocketServer exception", exc_info=exception) - return True # True means the exception was handled + return True # the exception was handled diff --git a/tests/test_server.py b/tests/test_server.py index 6fd2fb1..94ef48c 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,16 +1,17 @@ import pytest from anyio import sleep +from pycrdt_websocket import exception_logger pytestmark = pytest.mark.anyio @pytest.mark.parametrize("websocket_server_api", ["websocket_server_start_stop"], indirect=True) -@pytest.mark.parametrize("yws_server", [{"auto_restart": True}], indirect=True) +@pytest.mark.parametrize("yws_server", [{"exception_handler": exception_logger}], indirect=True) async def test_server_restart(yws_server): port, server = yws_server async def raise_error(): raise RuntimeError("foo") - server._task_group.start_soon(raise_error) + server._task_group1.start_soon(raise_error) await sleep(0.1)