diff --git a/pycrdt_websocket/yroom.py b/pycrdt_websocket/yroom.py index f0c12c8..ea88227 100644 --- a/pycrdt_websocket/yroom.py +++ b/pycrdt_websocket/yroom.py @@ -39,11 +39,16 @@ class YRoom: _update_receive_stream: MemoryObjectReceiveStream _task_group: TaskGroup | None = None _started: Event | None = None + _stopped: Event __start_lock: Lock | None = None _subscription: Subscription | None = None def __init__( - self, ready: bool = True, ystore: BaseYStore | None = None, log: Logger | None = None + self, + ready: bool = True, + ystore: BaseYStore | None = None, + exception_handler: Callable[[Exception, Logger], bool] | None = None, + log: Logger | None = None, ): """Initialize the object. @@ -63,19 +68,20 @@ def __init__( Arguments: ready: Whether the internal YDoc is ready to be synchronized right away. ystore: An optional store in which to persist document updates. + exception_handler: An optional callback to call when an exception is raised, that + returns True if the exception was handled. log: An optional logger. """ self.ydoc = Doc() self.awareness = Awareness(self.ydoc) - self._update_send_stream, self._update_receive_stream = create_memory_object_stream( - max_buffer_size=65536 - ) self.ready_event = Event() self.ready = ready self.ystore = ystore self.log = log or getLogger(__name__) self.clients = [] self._on_message = None + self.exception_handler = exception_handler + self._stopped = Event() @property def _start_lock(self) -> Lock: @@ -138,12 +144,18 @@ async def _broadcast_updates(self): # broadcast internal ydoc's update to all clients, that includes changes from the # clients and changes from the backend (out-of-band changes) for client in self.clients: - self.log.debug("Sending Y update to client with endpoint: %s", client.path) - message = create_update_message(update) - self._task_group.start_soon(client.send, message) + try: + self.log.debug("Sending Y update to client with endpoint: %s", client.path) + message = create_update_message(update) + self._task_group.start_soon(client.send, message) + except Exception as exception: + self._handle_exception(exception) if self.ystore: - self.log.debug("Writing Y update to YStore") - self._task_group.start_soon(self.ystore.write, update) + try: + self._task_group.start_soon(self.ystore.write, update) + self.log.debug("Writing Y update to YStore") + except Exception as exception: + self._handle_exception(exception) async def __aenter__(self) -> YRoom: async with self._start_lock: @@ -151,10 +163,9 @@ async def __aenter__(self) -> YRoom: raise RuntimeError("YRoom already running") async with AsyncExitStack() as exit_stack: - tg = create_task_group() - self._task_group = await exit_stack.enter_async_context(tg) + self._task_group = await exit_stack.enter_async_context(create_task_group()) self._exit_stack = exit_stack.pop_all() - await tg.start(partial(self.start, from_context_manager=True)) + await self._task_group.start(partial(self.start, from_context_manager=True)) return self @@ -162,6 +173,13 @@ async def __aexit__(self, exc_type, exc_value, exc_tb): await self.stop() return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) + def _handle_exception(self, exception: Exception) -> None: + exception_handled = False + if self.exception_handler is not None: + exception_handled = self.exception_handler(exception, self.log) + if not exception_handled: + raise exception + async def start( self, *, @@ -177,6 +195,8 @@ async def start( task_status.started() self.started.set() assert self._task_group is not None + self._task_group.start_soon(self._stopped.wait) + self._task_group.start_soon(self._watch_ready) self._task_group.start_soon(self._broadcast_updates) return @@ -184,20 +204,27 @@ async def start( if self._task_group is not None: raise RuntimeError("YRoom already running") - async with create_task_group() as self._task_group: - task_status.started() - self.started.set() - self._task_group.start_soon(self._broadcast_updates) - self._task_group.start_soon(self._watch_ready) + while True: + try: + async with create_task_group() as self._task_group: + if not self.started.is_set(): + task_status.started() + self.started.set() + self._update_send_stream, self._update_receive_stream = ( + create_memory_object_stream(max_buffer_size=65536) + ) + self._task_group.start_soon(self._stopped.wait) + self._task_group.start_soon(self._watch_ready) + self._task_group.start_soon(self._broadcast_updates) + return + except Exception as exception: + self._handle_exception(exception) async def stop(self) -> None: """Stop the room.""" if self._task_group is None: raise RuntimeError("YRoom not running") - - if self._task_group is None: - return - + self._stopped.set() self._task_group.cancel_scope.cancel() self._task_group = None if self._subscription is not None: @@ -209,10 +236,10 @@ async def serve(self, websocket: Websocket): Arguments: websocket: The WebSocket through which to serve the client. """ - async with create_task_group() as tg: - self.clients.append(websocket) - await sync(self.ydoc, websocket, self.log) - try: + try: + async with create_task_group() as tg: + self.clients.append(websocket) + await sync(self.ydoc, websocket, self.log) async for message in websocket: # filter messages (e.g. awareness) skip = False @@ -245,8 +272,7 @@ async def serve(self, websocket: Websocket): client.path, ) tg.start_soon(client.send, message) - except Exception as e: - self.log.debug("Error serving endpoint: %s", websocket.path, exc_info=e) - - # remove this client - self.clients = [c for c in self.clients if c != websocket] + # remove this client + self.clients = [c for c in self.clients if c != websocket] + except Exception as exception: + self._handle_exception(exception) diff --git a/tests/test_yroom.py b/tests/test_yroom.py new file mode 100644 index 0000000..a0edbce --- /dev/null +++ b/tests/test_yroom.py @@ -0,0 +1,33 @@ +import pytest +from anyio import TASK_STATUS_IGNORED, sleep +from anyio.abc import TaskStatus +from pycrdt import Map + +from pycrdt_websocket import exception_logger +from pycrdt_websocket.yroom import YRoom + +pytestmark = pytest.mark.anyio + + +@pytest.mark.parametrize("websocket_server_api", ["websocket_server_start_stop"], indirect=True) +@pytest.mark.parametrize("yws_server", [{"exception_handler": exception_logger}], indirect=True) +async def test_yroom_restart(yws_server, yws_provider): + port, server = yws_server + yroom = YRoom(exception_handler=exception_logger) + + async def raise_error(task_status: TaskStatus[None] = TASK_STATUS_IGNORED): + task_status.started() + raise RuntimeError("foo") + + yroom.ydoc = yws_provider + await server.start_room(yroom) + yroom.ydoc["map"] = ymap1 = Map() + ymap1["key"] = "value" + task_group_1 = yroom._task_group + await yroom._task_group.start(raise_error) + ymap1["key2"] = "value2" + await sleep(0.1) + assert yroom._task_group is not task_group_1 + assert yroom._task_group is not None + assert not yroom._task_group.cancel_scope.cancel_called + await yroom.stop()