diff --git a/pycrdt_websocket/__init__.py b/pycrdt_websocket/__init__.py index 34e6235..51a2a90 100644 --- a/pycrdt_websocket/__init__.py +++ b/pycrdt_websocket/__init__.py @@ -1,7 +1,8 @@ from .asgi_server import ASGIServer as ASGIServer from .websocket_provider import WebsocketProvider as WebsocketProvider from .websocket_server import WebsocketServer as WebsocketServer -from .websocket_server import YRoom as YRoom +from .websocket_server import exception_logger as exception_logger +from .yroom import YRoom as YRoom from .yutils import YMessageType as YMessageType __version__ = "0.13.0" diff --git a/pycrdt_websocket/websocket_server.py b/pycrdt_websocket/websocket_server.py index 600a45f..6e26e35 100644 --- a/pycrdt_websocket/websocket_server.py +++ b/pycrdt_websocket/websocket_server.py @@ -3,6 +3,7 @@ from contextlib import AsyncExitStack from functools import partial from logging import Logger, getLogger +from typing import Callable from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group from anyio.abc import TaskGroup, TaskStatus @@ -25,6 +26,7 @@ def __init__( rooms_ready: bool = True, auto_clean_rooms: bool = True, auto_restart: bool = False, + exception_handler: Callable[[Exception], bool] | None = None, log: Logger | None = None, ) -> None: """Initialize the object. @@ -46,11 +48,13 @@ def __init__( rooms_ready: Whether rooms are ready to be synchronized when opened. auto_clean_rooms: Whether rooms should be deleted when no client is there anymore. auto_restart: Whether to restart the server if it crashes. + exception_handler: An optional callback to call when an exception is raised. log: An optional logger. """ self.rooms_ready = rooms_ready self.auto_clean_rooms = auto_clean_rooms self.auto_restart = auto_restart + self.exception_handler = exception_handler self.log = log or getLogger(__name__) self.rooms = {} @@ -152,8 +156,15 @@ async def serve(self, websocket: Websocket) -> None: "`await websocket_server.start()`" ) - async with create_task_group() as tg: - tg.start_soon(self._serve, websocket, tg) + try: + async with create_task_group() as tg: + tg.start_soon(self._serve, websocket, tg) + except Exception as exception: + exception_handled = False + if self.exception_handler is not None: + exception_handled = self.exception_handler(exception, self.log) + if not exception_handled: + raise exception async def _serve(self, websocket: Websocket, tg: TaskGroup): room = await self.get_room(websocket.path) @@ -197,12 +208,14 @@ async def start( Arguments: task_status: The status to set when the task has started. """ + self._stop_event = Event() + if from_context_manager: task_status.started() self.started.set() assert self._task_group is not None - # wait forever - self._task_group.start_soon(Event().wait) + # wait until stopped + self._task_group.start_soon(self._stop_event.wait) return async with self._start_lock: @@ -215,8 +228,8 @@ async def start( if not self.started.is_set(): task_status.started() self.started.set() - # wait forever - self._task_group.start_soon(Event().wait) + # wait until stopped + self._task_group.start_soon(self._stop_event.wait) break except Exception as e: if not self.auto_restart: @@ -229,5 +242,11 @@ async def stop(self) -> None: if self._task_group is None: raise RuntimeError("WebsocketServer not running") + self._stop_event.set() self._task_group.cancel_scope.cancel() self._task_group = None + + +def exception_logger(exception: Exception, log: Logger) -> bool: + log.error("WebsocketServer exception", exc_info=exception) + return True # True means the exception was handled