From 19c131bec0995f9984f654aec3378b5706ffef25 Mon Sep 17 00:00:00 2001 From: Jialin Zhang Date: Tue, 23 Apr 2024 13:48:59 -0700 Subject: [PATCH] adding exception handling for room start tasks --- pycrdt_websocket/yroom.py | 81 ++++++++++++++++++++++++++------------- 1 file changed, 55 insertions(+), 26 deletions(-) diff --git a/pycrdt_websocket/yroom.py b/pycrdt_websocket/yroom.py index f0c12c8..542e272 100644 --- a/pycrdt_websocket/yroom.py +++ b/pycrdt_websocket/yroom.py @@ -39,11 +39,16 @@ class YRoom: _update_receive_stream: MemoryObjectReceiveStream _task_group: TaskGroup | None = None _started: Event | None = None + _stopped: Event __start_lock: Lock | None = None _subscription: Subscription | None = None def __init__( - self, ready: bool = True, ystore: BaseYStore | None = None, log: Logger | None = None + self, + ready: bool = True, + ystore: BaseYStore | None = None, + exception_handler: Callable[[Exception, Logger], bool] | None = None, + log: Logger | None = None, ): """Initialize the object. @@ -63,6 +68,7 @@ def __init__( Arguments: ready: Whether the internal YDoc is ready to be synchronized right away. ystore: An optional store in which to persist document updates. + exception_handler: An optional callback to call when an exception is raised, that returns True if the exception was handled. log: An optional logger. """ self.ydoc = Doc() @@ -76,6 +82,8 @@ def __init__( self.log = log or getLogger(__name__) self.clients = [] self._on_message = None + self.exception_handler = exception_handler + self._stopped = Event() @property def _start_lock(self) -> Lock: @@ -138,12 +146,18 @@ async def _broadcast_updates(self): # broadcast internal ydoc's update to all clients, that includes changes from the # clients and changes from the backend (out-of-band changes) for client in self.clients: - self.log.debug("Sending Y update to client with endpoint: %s", client.path) - message = create_update_message(update) - self._task_group.start_soon(client.send, message) + try: + self.log.debug("Sending Y update to client with endpoint: %s", client.path) + message = create_update_message(update) + self._task_group.start_soon(client.send, message) + except Exception as exception: + self._handle_exception(exception) if self.ystore: - self.log.debug("Writing Y update to YStore") - self._task_group.start_soon(self.ystore.write, update) + try: + self._task_group.start_soon(self.ystore.write, update) + self.log.debug("Writing Y update to YStore") + except Exception as exception: + self._handle_exception(exception) async def __aenter__(self) -> YRoom: async with self._start_lock: @@ -151,10 +165,9 @@ async def __aenter__(self) -> YRoom: raise RuntimeError("YRoom already running") async with AsyncExitStack() as exit_stack: - tg = create_task_group() - self._task_group = await exit_stack.enter_async_context(tg) + self._task_group = await exit_stack.enter_async_context(create_task_group()) self._exit_stack = exit_stack.pop_all() - await tg.start(partial(self.start, from_context_manager=True)) + await self._task_group.start(partial(self.start, from_context_manager=True)) return self @@ -162,6 +175,13 @@ async def __aexit__(self, exc_type, exc_value, exc_tb): await self.stop() return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) + def _handle_exception(self, exception: Exception) -> None: + exception_handled = False + if self.exception_handler is not None: + exception_handled = self.exception_handler(exception, self.log) + if not exception_handled: + raise exception + async def start( self, *, @@ -177,6 +197,7 @@ async def start( task_status.started() self.started.set() assert self._task_group is not None + self._task_group.start_soon(self._stopped.wait) self._task_group.start_soon(self._broadcast_updates) return @@ -184,20 +205,24 @@ async def start( if self._task_group is not None: raise RuntimeError("YRoom already running") - async with create_task_group() as self._task_group: - task_status.started() - self.started.set() - self._task_group.start_soon(self._broadcast_updates) - self._task_group.start_soon(self._watch_ready) + while True: + try: + async with create_task_group() as self._task_group: + if not self.started.is_set(): + task_status.started() + self.started.set() + self._task_group.start_soon(self._stopped.wait) + self._task_group.start_soon(self._broadcast_updates) + self._task_group.start_soon(self._watch_ready) + return + except Exception as exception: + self._handle_exception(exception) async def stop(self) -> None: """Stop the room.""" if self._task_group is None: raise RuntimeError("YRoom not running") - - if self._task_group is None: - return - + self._stopped.set() self._task_group.cancel_scope.cancel() self._task_group = None if self._subscription is not None: @@ -209,10 +234,10 @@ async def serve(self, websocket: Websocket): Arguments: websocket: The WebSocket through which to serve the client. """ - async with create_task_group() as tg: - self.clients.append(websocket) - await sync(self.ydoc, websocket, self.log) - try: + try: + async with create_task_group() as tg: + self.clients.append(websocket) + await sync(self.ydoc, websocket, self.log) async for message in websocket: # filter messages (e.g. awareness) skip = False @@ -245,8 +270,12 @@ async def serve(self, websocket: Websocket): client.path, ) tg.start_soon(client.send, message) - except Exception as e: - self.log.debug("Error serving endpoint: %s", websocket.path, exc_info=e) + # remove this client + self.clients = [c for c in self.clients if c != websocket] + except Exception as exception: + self._handle_exception(exception) - # remove this client - self.clients = [c for c in self.clients if c != websocket] + def exception_logger(exception: Exception, log: Logger) -> bool: + """An exception handler that logs the exception and discards it.""" + log.error("WebsocketServer exception", exc_info=exception) + return True # the exception was handled