Skip to content

Commit

Permalink
Merge pull request #31 from davidbrochart/auto-restart
Browse files Browse the repository at this point in the history
Add WebsocketServer exception handler
  • Loading branch information
Zsailer authored Apr 26, 2024
2 parents 9f4253e + 61cf47f commit fd7c0b5
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
mypy pycrdt_websocket tests
- name: Run tests
run: |
pytest -v
pytest -v --color=yes
check_release:
runs-on: ubuntu-latest
Expand Down
3 changes: 2 additions & 1 deletion pycrdt_websocket/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .asgi_server import ASGIServer as ASGIServer
from .websocket_provider import WebsocketProvider as WebsocketProvider
from .websocket_server import WebsocketServer as WebsocketServer
from .websocket_server import YRoom as YRoom
from .websocket_server import exception_logger as exception_logger
from .yroom import YRoom as YRoom
from .yutils import YMessageType as YMessageType

__version__ = "0.13.0"
71 changes: 49 additions & 22 deletions pycrdt_websocket/websocket_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from contextlib import AsyncExitStack
from functools import partial
from logging import Logger, getLogger
from typing import Callable

from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group
from anyio.abc import TaskGroup, TaskStatus
Expand All @@ -17,11 +18,16 @@ class WebsocketServer:
auto_clean_rooms: bool
rooms: dict[str, YRoom]
_started: Event | None = None
_stopped: Event
_task_group: TaskGroup | None = None
__start_lock: Lock | None = None

def __init__(
self, rooms_ready: bool = True, auto_clean_rooms: bool = True, log: Logger | None = None
self,
rooms_ready: bool = True,
auto_clean_rooms: bool = True,
exception_handler: Callable[[Exception, Logger], bool] | None = None,
log: Logger | None = None,
) -> None:
"""Initialize the object.
Expand All @@ -41,12 +47,16 @@ def __init__(
Arguments:
rooms_ready: Whether rooms are ready to be synchronized when opened.
auto_clean_rooms: Whether rooms should be deleted when no client is there anymore.
exception_handler: An optional callback to call when an exception is raised, that
returns True if the exception was handled.
log: An optional logger.
"""
self.rooms_ready = rooms_ready
self.auto_clean_rooms = auto_clean_rooms
self.exception_handler = exception_handler
self.log = log or getLogger(__name__)
self.rooms = {}
self._stopped = Event()

@property
def started(self) -> Event:
Expand Down Expand Up @@ -146,35 +156,39 @@ async def serve(self, websocket: Websocket) -> None:
"`await websocket_server.start()`"
)

async with create_task_group() as tg:
tg.start_soon(self._serve, websocket, tg)

async def _serve(self, websocket: Websocket, tg: TaskGroup):
room = await self.get_room(websocket.path)
await self.start_room(room)
await room.serve(websocket)

if self.auto_clean_rooms and not room.clients:
await self.delete_room(room=room)
tg.cancel_scope.cancel()
try:
async with create_task_group():
room = await self.get_room(websocket.path)
await self.start_room(room)
await room.serve(websocket)
if self.auto_clean_rooms and not room.clients:
await self.delete_room(room=room)
except Exception as exception:
self._handle_exception(exception)

async def __aenter__(self) -> WebsocketServer:
async with self._start_lock:
if self._task_group is not None:
raise RuntimeError("WebsocketServer already running")

async with AsyncExitStack() as exit_stack:
tg = create_task_group()
self._task_group = await exit_stack.enter_async_context(tg)
self._task_group = await exit_stack.enter_async_context(create_task_group())
self._exit_stack = exit_stack.pop_all()
await tg.start(partial(self.start, from_context_manager=True))
await self._task_group.start(partial(self.start, from_context_manager=True))

return self

async def __aexit__(self, exc_type, exc_value, exc_tb):
await self.stop()
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)

def _handle_exception(self, exception: Exception) -> None:
exception_handled = False
if self.exception_handler is not None:
exception_handled = self.exception_handler(exception, self.log)
if not exception_handled:
raise exception

async def start(
self,
*,
Expand All @@ -190,24 +204,37 @@ async def start(
task_status.started()
self.started.set()
assert self._task_group is not None
# wait forever
self._task_group.start_soon(Event().wait)
# wait until stopped
self._task_group.start_soon(self._stopped.wait)
return

async with self._start_lock:
if self._task_group is not None:
raise RuntimeError("WebsocketServer already running")

async with create_task_group() as self._task_group:
task_status.started()
self.started.set()
# wait forever
self._task_group.start_soon(Event().wait)
while True:
try:
async with create_task_group() as self._task_group:
if not self.started.is_set():
task_status.started()
self.started.set()
# wait until stopped
self._task_group.start_soon(self._stopped.wait)
return
except Exception as exception:
self._handle_exception(exception)

async def stop(self) -> None:
"""Stop the WebSocket server."""
if self._task_group is None:
raise RuntimeError("WebsocketServer not running")

self._stopped.set()
self._task_group.cancel_scope.cancel()
self._task_group = None


def exception_logger(exception: Exception, log: Logger) -> bool:
"""An exception handler that logs the exception and discards it."""
log.error("WebsocketServer exception", exc_info=exception)
return True # the exception was handled
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def yws_server(request, unused_tcp_port, websocket_server_api):
)
await ensure_server_running("localhost", unused_tcp_port)
pytest.port = unused_tcp_port
yield unused_tcp_port
yield unused_tcp_port, websocket_server
shutdown_event.set()
except Exception:
pass
Expand Down
18 changes: 18 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
from anyio import sleep

from pycrdt_websocket import exception_logger

pytestmark = pytest.mark.anyio


@pytest.mark.parametrize("websocket_server_api", ["websocket_server_start_stop"], indirect=True)
@pytest.mark.parametrize("yws_server", [{"exception_handler": exception_logger}], indirect=True)
async def test_server_restart(yws_server):
port, server = yws_server

async def raise_error():
raise RuntimeError("foo")

server._task_group.start_soon(raise_error)
await sleep(0.1)

0 comments on commit fd7c0b5

Please sign in to comment.