Skip to content

Commit

Permalink
Remove auto_restart, add exception_handler
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Apr 26, 2024
1 parent 0d60753 commit 2cd4944
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 59 deletions.
106 changes: 49 additions & 57 deletions pycrdt_websocket/websocket_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ class WebsocketServer:
auto_clean_rooms: bool
rooms: dict[str, YRoom]
_started: Event | None = None
_task_group: TaskGroup | None = None
_stopped: Event
_task_group0: TaskGroup | None = None
_task_group1: TaskGroup | None = None
__start_lock: Lock | None = None

def __init__(
self,
rooms_ready: bool = True,
auto_clean_rooms: bool = True,
auto_restart: bool = False,
exception_handler: Callable[[Exception], bool] | None = None,
log: Logger | None = None,
) -> None:
Expand All @@ -47,16 +48,16 @@ def __init__(
Arguments:
rooms_ready: Whether rooms are ready to be synchronized when opened.
auto_clean_rooms: Whether rooms should be deleted when no client is there anymore.
auto_restart: Whether to restart the server if it crashes.
exception_handler: An optional callback to call when an exception is raised.
exception_handler: An optional callback to call when an exception is raised, that
returns True if the exception was handled.
log: An optional logger.
"""
self.rooms_ready = rooms_ready
self.auto_clean_rooms = auto_clean_rooms
self.auto_restart = auto_restart
self.exception_handler = exception_handler
self.log = log or getLogger(__name__)
self.rooms = {}
self._stopped = Event()

@property
def started(self) -> Event:
Expand Down Expand Up @@ -92,14 +93,14 @@ async def start_room(self, room: YRoom) -> None:
Arguments:
room: The room to start.
"""
if self._task_group is None:
if self._task_group0 is None:
raise RuntimeError(
"The WebsocketServer is not running: use `async with websocket_server:` or "
"`await websocket_server.start()`"
)

if not room.started.is_set():
await self._task_group.start(room.start)
await self._task_group1.start(room.start)

def get_room_name(self, room: YRoom) -> str:
"""Get the name of a room.
Expand Down Expand Up @@ -150,53 +151,45 @@ async def serve(self, websocket: Websocket) -> None:
Arguments:
websocket: The WebSocket through which to serve the client.
"""
if self._task_group is None:
if self._task_group0 is None:
raise RuntimeError(
"The WebsocketServer is not running: use `async with websocket_server:` or "
"`await websocket_server.start()`"
)

try:
async with create_task_group() as tg:
tg.start_soon(self._serve, websocket, tg)
room = await self.get_room(websocket.path)
await self.start_room(room)
await room.serve(websocket)
if self.auto_clean_rooms and not room.clients:
await self.delete_room(room=room)
except Exception as exception:
exception_handled = False
if self.exception_handler is not None:
exception_handled = self.exception_handler(exception, self.log)
if not exception_handled:
raise exception

async def _serve(self, websocket: Websocket, tg: TaskGroup):
room = await self.get_room(websocket.path)
await self.start_room(room)
await room.serve(websocket)

if self.auto_clean_rooms and not room.clients:
await self.delete_room(room=room)
tg.cancel_scope.cancel()
self._handle_exception(exception)

async def __aenter__(self) -> WebsocketServer:
if self.auto_restart:
raise RuntimeError(
"WebsocketServer does not support auto-restart when used as a context manager"
)

async with self._start_lock:
if self._task_group is not None:
if self._task_group0 is not None:
raise RuntimeError("WebsocketServer already running")

async with AsyncExitStack() as exit_stack:
tg = create_task_group()
self._task_group = await exit_stack.enter_async_context(tg)
self._task_group0 = await exit_stack.enter_async_context(create_task_group())
self._task_group1 = await exit_stack.enter_async_context(create_task_group())
self._exit_stack = exit_stack.pop_all()
await tg.start(partial(self.start, from_context_manager=True))
await self._task_group0.start(partial(self.start, from_context_manager=True))

return self

async def __aexit__(self, exc_type, exc_value, exc_tb):
await self.stop()
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)

def _handle_exception(self, exception: Exception) -> None:
exception_handled = False
if self.exception_handler is not None:
exception_handled = self.exception_handler(exception, self.log)
if not exception_handled:
raise exception

async def start(
self,
*,
Expand All @@ -208,45 +201,44 @@ async def start(
Arguments:
task_status: The status to set when the task has started.
"""
self._stop_event = Event()

if from_context_manager:
task_status.started()
self.started.set()
assert self._task_group is not None
assert self._task_group0 is not None
# wait until stopped
self._task_group.start_soon(self._stop_event.wait)
self._task_group1.start_soon(self._stopped.wait)
return

async with self._start_lock:
if self._task_group is not None:
if self._task_group0 is not None:
raise RuntimeError("WebsocketServer already running")

while True:
try:
async with create_task_group() as self._task_group:
if not self.started.is_set():
task_status.started()
self.started.set()
# wait until stopped
self._task_group.start_soon(self._stop_event.wait)
break
except Exception as e:
if not self.auto_restart:
raise e

self.log.error("WebsocketServer crashed, restarting...", exc_info=e)
async with create_task_group() as self._task_group0:
while True:
try:
async with create_task_group() as self._task_group1:
if not self.started.is_set():
task_status.started()
self.started.set()
# wait until stopped
self._task_group1.start_soon(self._stopped.wait)
return
except Exception as exception:
self._handle_exception(exception)

async def stop(self) -> None:
"""Stop the WebSocket server."""
if self._task_group is None:
if self._task_group0 is None:
raise RuntimeError("WebsocketServer not running")

self._stop_event.set()
self._task_group.cancel_scope.cancel()
self._task_group = None
self._stopped.set()
self._task_group0.cancel_scope.cancel()
self._task_group0 = None
self._task_group1 = None


def exception_logger(exception: Exception, log: Logger) -> bool:
"""An exception handler that logs the exception and discards it.
"""
log.error("WebsocketServer exception", exc_info=exception)
return True # True means the exception was handled
return True # the exception was handled
5 changes: 3 additions & 2 deletions tests/test_server.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import pytest
from anyio import sleep
from pycrdt_websocket import exception_logger

pytestmark = pytest.mark.anyio


@pytest.mark.parametrize("websocket_server_api", ["websocket_server_start_stop"], indirect=True)
@pytest.mark.parametrize("yws_server", [{"auto_restart": True}], indirect=True)
@pytest.mark.parametrize("yws_server", [{"exception_handler": exception_logger}], indirect=True)
async def test_server_restart(yws_server):
port, server = yws_server

async def raise_error():
raise RuntimeError("foo")

server._task_group.start_soon(raise_error)
server._task_group1.start_soon(raise_error)
await sleep(0.1)

0 comments on commit 2cd4944

Please sign in to comment.