Skip to content

Commit

Permalink
adding exception handling for room start tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
Jialin Zhang committed Apr 26, 2024
1 parent fd7c0b5 commit 19c131b
Showing 1 changed file with 55 additions and 26 deletions.
81 changes: 55 additions & 26 deletions pycrdt_websocket/yroom.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,16 @@ class YRoom:
_update_receive_stream: MemoryObjectReceiveStream
_task_group: TaskGroup | None = None
_started: Event | None = None
_stopped: Event
__start_lock: Lock | None = None
_subscription: Subscription | None = None

def __init__(
self, ready: bool = True, ystore: BaseYStore | None = None, log: Logger | None = None
self,
ready: bool = True,
ystore: BaseYStore | None = None,
exception_handler: Callable[[Exception, Logger], bool] | None = None,
log: Logger | None = None,
):
"""Initialize the object.
Expand All @@ -63,6 +68,7 @@ def __init__(
Arguments:
ready: Whether the internal YDoc is ready to be synchronized right away.
ystore: An optional store in which to persist document updates.
exception_handler: An optional callback to call when an exception is raised, that returns True if the exception was handled.
log: An optional logger.
"""
self.ydoc = Doc()
Expand All @@ -76,6 +82,8 @@ def __init__(
self.log = log or getLogger(__name__)
self.clients = []
self._on_message = None
self.exception_handler = exception_handler
self._stopped = Event()

@property
def _start_lock(self) -> Lock:
Expand Down Expand Up @@ -138,30 +146,42 @@ async def _broadcast_updates(self):
# broadcast internal ydoc's update to all clients, that includes changes from the
# clients and changes from the backend (out-of-band changes)
for client in self.clients:
self.log.debug("Sending Y update to client with endpoint: %s", client.path)
message = create_update_message(update)
self._task_group.start_soon(client.send, message)
try:
self.log.debug("Sending Y update to client with endpoint: %s", client.path)
message = create_update_message(update)
self._task_group.start_soon(client.send, message)
except Exception as exception:
self._handle_exception(exception)
if self.ystore:
self.log.debug("Writing Y update to YStore")
self._task_group.start_soon(self.ystore.write, update)
try:
self._task_group.start_soon(self.ystore.write, update)
self.log.debug("Writing Y update to YStore")
except Exception as exception:
self._handle_exception(exception)

async def __aenter__(self) -> YRoom:
async with self._start_lock:
if self._task_group is not None:
raise RuntimeError("YRoom already running")

async with AsyncExitStack() as exit_stack:
tg = create_task_group()
self._task_group = await exit_stack.enter_async_context(tg)
self._task_group = await exit_stack.enter_async_context(create_task_group())
self._exit_stack = exit_stack.pop_all()
await tg.start(partial(self.start, from_context_manager=True))
await self._task_group.start(partial(self.start, from_context_manager=True))

return self

async def __aexit__(self, exc_type, exc_value, exc_tb):
await self.stop()
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)

def _handle_exception(self, exception: Exception) -> None:
exception_handled = False
if self.exception_handler is not None:
exception_handled = self.exception_handler(exception, self.log)
if not exception_handled:
raise exception

async def start(
self,
*,
Expand All @@ -177,27 +197,32 @@ async def start(
task_status.started()
self.started.set()
assert self._task_group is not None
self._task_group.start_soon(self._stopped.wait)
self._task_group.start_soon(self._broadcast_updates)
return

async with self._start_lock:
if self._task_group is not None:
raise RuntimeError("YRoom already running")

async with create_task_group() as self._task_group:
task_status.started()
self.started.set()
self._task_group.start_soon(self._broadcast_updates)
self._task_group.start_soon(self._watch_ready)
while True:
try:
async with create_task_group() as self._task_group:
if not self.started.is_set():
task_status.started()
self.started.set()
self._task_group.start_soon(self._stopped.wait)
self._task_group.start_soon(self._broadcast_updates)
self._task_group.start_soon(self._watch_ready)
return
except Exception as exception:
self._handle_exception(exception)

async def stop(self) -> None:
"""Stop the room."""
if self._task_group is None:
raise RuntimeError("YRoom not running")

if self._task_group is None:
return

self._stopped.set()
self._task_group.cancel_scope.cancel()
self._task_group = None
if self._subscription is not None:
Expand All @@ -209,10 +234,10 @@ async def serve(self, websocket: Websocket):
Arguments:
websocket: The WebSocket through which to serve the client.
"""
async with create_task_group() as tg:
self.clients.append(websocket)
await sync(self.ydoc, websocket, self.log)
try:
try:
async with create_task_group() as tg:
self.clients.append(websocket)
await sync(self.ydoc, websocket, self.log)
async for message in websocket:
# filter messages (e.g. awareness)
skip = False
Expand Down Expand Up @@ -245,8 +270,12 @@ async def serve(self, websocket: Websocket):
client.path,
)
tg.start_soon(client.send, message)
except Exception as e:
self.log.debug("Error serving endpoint: %s", websocket.path, exc_info=e)
# remove this client
self.clients = [c for c in self.clients if c != websocket]
except Exception as exception:
self._handle_exception(exception)

# remove this client
self.clients = [c for c in self.clients if c != websocket]
def exception_logger(exception: Exception, log: Logger) -> bool:
"""An exception handler that logs the exception and discards it."""
log.error("WebsocketServer exception", exc_info=exception)
return True # the exception was handled

0 comments on commit 19c131b

Please sign in to comment.