Skip to content

Commit

Permalink
adding exception handling for room start tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
Jialin Zhang committed Apr 26, 2024
1 parent fd7c0b5 commit 1f2a11c
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 30 deletions.
86 changes: 56 additions & 30 deletions pycrdt_websocket/yroom.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,16 @@ class YRoom:
_update_receive_stream: MemoryObjectReceiveStream
_task_group: TaskGroup | None = None
_started: Event | None = None
_stopped: Event
__start_lock: Lock | None = None
_subscription: Subscription | None = None

def __init__(
self, ready: bool = True, ystore: BaseYStore | None = None, log: Logger | None = None
self,
ready: bool = True,
ystore: BaseYStore | None = None,
exception_handler: Callable[[Exception, Logger], bool] | None = None,
log: Logger | None = None,
):
"""Initialize the object.
Expand All @@ -63,19 +68,20 @@ def __init__(
Arguments:
ready: Whether the internal YDoc is ready to be synchronized right away.
ystore: An optional store in which to persist document updates.
exception_handler: An optional callback to call when an exception is raised, that
returns True if the exception was handled.
log: An optional logger.
"""
self.ydoc = Doc()
self.awareness = Awareness(self.ydoc)
self._update_send_stream, self._update_receive_stream = create_memory_object_stream(
max_buffer_size=65536
)
self.ready_event = Event()
self.ready = ready
self.ystore = ystore
self.log = log or getLogger(__name__)
self.clients = []
self._on_message = None
self.exception_handler = exception_handler
self._stopped = Event()

@property
def _start_lock(self) -> Lock:
Expand Down Expand Up @@ -138,30 +144,42 @@ async def _broadcast_updates(self):
# broadcast internal ydoc's update to all clients, that includes changes from the
# clients and changes from the backend (out-of-band changes)
for client in self.clients:
self.log.debug("Sending Y update to client with endpoint: %s", client.path)
message = create_update_message(update)
self._task_group.start_soon(client.send, message)
try:
self.log.debug("Sending Y update to client with endpoint: %s", client.path)
message = create_update_message(update)
self._task_group.start_soon(client.send, message)
except Exception as exception:
self._handle_exception(exception)
if self.ystore:
self.log.debug("Writing Y update to YStore")
self._task_group.start_soon(self.ystore.write, update)
try:
self._task_group.start_soon(self.ystore.write, update)
self.log.debug("Writing Y update to YStore")
except Exception as exception:
self._handle_exception(exception)

async def __aenter__(self) -> YRoom:
async with self._start_lock:
if self._task_group is not None:
raise RuntimeError("YRoom already running")

async with AsyncExitStack() as exit_stack:
tg = create_task_group()
self._task_group = await exit_stack.enter_async_context(tg)
self._task_group = await exit_stack.enter_async_context(create_task_group())
self._exit_stack = exit_stack.pop_all()
await tg.start(partial(self.start, from_context_manager=True))
await self._task_group.start(partial(self.start, from_context_manager=True))

return self

async def __aexit__(self, exc_type, exc_value, exc_tb):
await self.stop()
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)

def _handle_exception(self, exception: Exception) -> None:
exception_handled = False
if self.exception_handler is not None:
exception_handled = self.exception_handler(exception, self.log)
if not exception_handled:
raise exception

async def start(
self,
*,
Expand All @@ -177,27 +195,36 @@ async def start(
task_status.started()
self.started.set()
assert self._task_group is not None
self._task_group.start_soon(self._stopped.wait)
self._task_group.start_soon(self._watch_ready)
self._task_group.start_soon(self._broadcast_updates)
return

async with self._start_lock:
if self._task_group is not None:
raise RuntimeError("YRoom already running")

async with create_task_group() as self._task_group:
task_status.started()
self.started.set()
self._task_group.start_soon(self._broadcast_updates)
self._task_group.start_soon(self._watch_ready)
while True:
try:
async with create_task_group() as self._task_group:
if not self.started.is_set():
task_status.started()
self.started.set()
self._update_send_stream, self._update_receive_stream = (
create_memory_object_stream(max_buffer_size=65536)
)
self._task_group.start_soon(self._stopped.wait)
self._task_group.start_soon(self._watch_ready)
self._task_group.start_soon(self._broadcast_updates)
return
except Exception as exception:
self._handle_exception(exception)

async def stop(self) -> None:
"""Stop the room."""
if self._task_group is None:
raise RuntimeError("YRoom not running")

if self._task_group is None:
return

self._stopped.set()
self._task_group.cancel_scope.cancel()
self._task_group = None
if self._subscription is not None:
Expand All @@ -209,10 +236,10 @@ async def serve(self, websocket: Websocket):
Arguments:
websocket: The WebSocket through which to serve the client.
"""
async with create_task_group() as tg:
self.clients.append(websocket)
await sync(self.ydoc, websocket, self.log)
try:
try:
async with create_task_group() as tg:
self.clients.append(websocket)
await sync(self.ydoc, websocket, self.log)
async for message in websocket:
# filter messages (e.g. awareness)
skip = False
Expand Down Expand Up @@ -245,8 +272,7 @@ async def serve(self, websocket: Websocket):
client.path,
)
tg.start_soon(client.send, message)
except Exception as e:
self.log.debug("Error serving endpoint: %s", websocket.path, exc_info=e)

# remove this client
self.clients = [c for c in self.clients if c != websocket]
# remove this client
self.clients = [c for c in self.clients if c != websocket]
except Exception as exception:
self._handle_exception(exception)
33 changes: 33 additions & 0 deletions tests/test_yroom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest
from anyio import TASK_STATUS_IGNORED, sleep
from anyio.abc import TaskStatus
from pycrdt import Map

from pycrdt_websocket import exception_logger
from pycrdt_websocket.yroom import YRoom

pytestmark = pytest.mark.anyio


@pytest.mark.parametrize("websocket_server_api", ["websocket_server_start_stop"], indirect=True)
@pytest.mark.parametrize("yws_server", [{"exception_handler": exception_logger}], indirect=True)
async def test_yroom_restart(yws_server, yws_provider):
port, server = yws_server
yroom = YRoom(exception_handler=exception_logger)

async def raise_error(task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
task_status.started()
raise RuntimeError("foo")

yroom.ydoc = yws_provider
await server.start_room(yroom)
yroom.ydoc["map"] = ymap1 = Map()
ymap1["key"] = "value"
task_group_1 = yroom._task_group
await yroom._task_group.start(raise_error)
ymap1["key2"] = "value2"
await sleep(0.1)
assert yroom._task_group is not task_group_1
assert yroom._task_group is not None
assert not yroom._task_group.cancel_scope.cancel_called
await yroom.stop()

0 comments on commit 1f2a11c

Please sign in to comment.