Skip to content

Commit

Permalink
Remove auto_restart, add exception_handler
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Apr 26, 2024
1 parent 0d60753 commit 61cf47f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 42 deletions.
70 changes: 29 additions & 41 deletions pycrdt_websocket/websocket_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ class WebsocketServer:
auto_clean_rooms: bool
rooms: dict[str, YRoom]
_started: Event | None = None
_stopped: Event
_task_group: TaskGroup | None = None
__start_lock: Lock | None = None

def __init__(
self,
rooms_ready: bool = True,
auto_clean_rooms: bool = True,
auto_restart: bool = False,
exception_handler: Callable[[Exception], bool] | None = None,
exception_handler: Callable[[Exception, Logger], bool] | None = None,
log: Logger | None = None,
) -> None:
"""Initialize the object.
Expand All @@ -47,16 +47,16 @@ def __init__(
Arguments:
rooms_ready: Whether rooms are ready to be synchronized when opened.
auto_clean_rooms: Whether rooms should be deleted when no client is there anymore.
auto_restart: Whether to restart the server if it crashes.
exception_handler: An optional callback to call when an exception is raised.
exception_handler: An optional callback to call when an exception is raised, that
returns True if the exception was handled.
log: An optional logger.
"""
self.rooms_ready = rooms_ready
self.auto_clean_rooms = auto_clean_rooms
self.auto_restart = auto_restart
self.exception_handler = exception_handler
self.log = log or getLogger(__name__)
self.rooms = {}
self._stopped = Event()

@property
def started(self) -> Event:
Expand Down Expand Up @@ -157,46 +157,38 @@ async def serve(self, websocket: Websocket) -> None:
)

try:
async with create_task_group() as tg:
tg.start_soon(self._serve, websocket, tg)
async with create_task_group():
room = await self.get_room(websocket.path)
await self.start_room(room)
await room.serve(websocket)
if self.auto_clean_rooms and not room.clients:
await self.delete_room(room=room)
except Exception as exception:
exception_handled = False
if self.exception_handler is not None:
exception_handled = self.exception_handler(exception, self.log)
if not exception_handled:
raise exception

async def _serve(self, websocket: Websocket, tg: TaskGroup):
room = await self.get_room(websocket.path)
await self.start_room(room)
await room.serve(websocket)

if self.auto_clean_rooms and not room.clients:
await self.delete_room(room=room)
tg.cancel_scope.cancel()
self._handle_exception(exception)

async def __aenter__(self) -> WebsocketServer:
if self.auto_restart:
raise RuntimeError(
"WebsocketServer does not support auto-restart when used as a context manager"
)

async with self._start_lock:
if self._task_group is not None:
raise RuntimeError("WebsocketServer already running")

async with AsyncExitStack() as exit_stack:
tg = create_task_group()
self._task_group = await exit_stack.enter_async_context(tg)
self._task_group = await exit_stack.enter_async_context(create_task_group())
self._exit_stack = exit_stack.pop_all()
await tg.start(partial(self.start, from_context_manager=True))
await self._task_group.start(partial(self.start, from_context_manager=True))

return self

async def __aexit__(self, exc_type, exc_value, exc_tb):
await self.stop()
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)

def _handle_exception(self, exception: Exception) -> None:
exception_handled = False
if self.exception_handler is not None:
exception_handled = self.exception_handler(exception, self.log)
if not exception_handled:
raise exception

async def start(
self,
*,
Expand All @@ -208,14 +200,12 @@ async def start(
Arguments:
task_status: The status to set when the task has started.
"""
self._stop_event = Event()

if from_context_manager:
task_status.started()
self.started.set()
assert self._task_group is not None
# wait until stopped
self._task_group.start_soon(self._stop_event.wait)
self._task_group.start_soon(self._stopped.wait)
return

async with self._start_lock:
Expand All @@ -229,24 +219,22 @@ async def start(
task_status.started()
self.started.set()
# wait until stopped
self._task_group.start_soon(self._stop_event.wait)
break
except Exception as e:
if not self.auto_restart:
raise e

self.log.error("WebsocketServer crashed, restarting...", exc_info=e)
self._task_group.start_soon(self._stopped.wait)
return
except Exception as exception:
self._handle_exception(exception)

async def stop(self) -> None:
"""Stop the WebSocket server."""
if self._task_group is None:
raise RuntimeError("WebsocketServer not running")

self._stop_event.set()
self._stopped.set()
self._task_group.cancel_scope.cancel()
self._task_group = None


def exception_logger(exception: Exception, log: Logger) -> bool:
"""An exception handler that logs the exception and discards it."""
log.error("WebsocketServer exception", exc_info=exception)
return True # True means the exception was handled
return True # the exception was handled
4 changes: 3 additions & 1 deletion tests/test_server.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import pytest
from anyio import sleep

from pycrdt_websocket import exception_logger

pytestmark = pytest.mark.anyio


@pytest.mark.parametrize("websocket_server_api", ["websocket_server_start_stop"], indirect=True)
@pytest.mark.parametrize("yws_server", [{"auto_restart": True}], indirect=True)
@pytest.mark.parametrize("yws_server", [{"exception_handler": exception_logger}], indirect=True)
async def test_server_restart(yws_server):
port, server = yws_server

Expand Down

0 comments on commit 61cf47f

Please sign in to comment.