Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WebsocketServer exception handler #31

Merged
merged 3 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
mypy pycrdt_websocket tests
- name: Run tests
run: |
pytest -v
pytest -v --color=yes

check_release:
runs-on: ubuntu-latest
Expand Down
3 changes: 2 additions & 1 deletion pycrdt_websocket/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .asgi_server import ASGIServer as ASGIServer
from .websocket_provider import WebsocketProvider as WebsocketProvider
from .websocket_server import WebsocketServer as WebsocketServer
from .websocket_server import YRoom as YRoom
from .websocket_server import exception_logger as exception_logger
from .yroom import YRoom as YRoom
from .yutils import YMessageType as YMessageType

__version__ = "0.13.0"
71 changes: 49 additions & 22 deletions pycrdt_websocket/websocket_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from contextlib import AsyncExitStack
from functools import partial
from logging import Logger, getLogger
from typing import Callable

from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group
from anyio.abc import TaskGroup, TaskStatus
Expand All @@ -17,11 +18,16 @@ class WebsocketServer:
auto_clean_rooms: bool
rooms: dict[str, YRoom]
_started: Event | None = None
_stopped: Event
_task_group: TaskGroup | None = None
__start_lock: Lock | None = None

def __init__(
self, rooms_ready: bool = True, auto_clean_rooms: bool = True, log: Logger | None = None
self,
rooms_ready: bool = True,
auto_clean_rooms: bool = True,
exception_handler: Callable[[Exception, Logger], bool] | None = None,
log: Logger | None = None,
) -> None:
"""Initialize the object.

Expand All @@ -41,12 +47,16 @@ def __init__(
Arguments:
rooms_ready: Whether rooms are ready to be synchronized when opened.
auto_clean_rooms: Whether rooms should be deleted when no client is there anymore.
exception_handler: An optional callback to call when an exception is raised, that
returns True if the exception was handled.
log: An optional logger.
"""
self.rooms_ready = rooms_ready
self.auto_clean_rooms = auto_clean_rooms
self.exception_handler = exception_handler
self.log = log or getLogger(__name__)
self.rooms = {}
self._stopped = Event()

@property
def started(self) -> Event:
Expand Down Expand Up @@ -146,35 +156,39 @@ async def serve(self, websocket: Websocket) -> None:
"`await websocket_server.start()`"
)

async with create_task_group() as tg:
tg.start_soon(self._serve, websocket, tg)

async def _serve(self, websocket: Websocket, tg: TaskGroup):
room = await self.get_room(websocket.path)
await self.start_room(room)
await room.serve(websocket)

if self.auto_clean_rooms and not room.clients:
await self.delete_room(room=room)
tg.cancel_scope.cancel()
try:
async with create_task_group():
room = await self.get_room(websocket.path)
await self.start_room(room)
await room.serve(websocket)
if self.auto_clean_rooms and not room.clients:
await self.delete_room(room=room)
except Exception as exception:
self._handle_exception(exception)

async def __aenter__(self) -> WebsocketServer:
async with self._start_lock:
if self._task_group is not None:
raise RuntimeError("WebsocketServer already running")

async with AsyncExitStack() as exit_stack:
tg = create_task_group()
self._task_group = await exit_stack.enter_async_context(tg)
self._task_group = await exit_stack.enter_async_context(create_task_group())
self._exit_stack = exit_stack.pop_all()
await tg.start(partial(self.start, from_context_manager=True))
await self._task_group.start(partial(self.start, from_context_manager=True))

return self

async def __aexit__(self, exc_type, exc_value, exc_tb):
await self.stop()
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)

def _handle_exception(self, exception: Exception) -> None:
exception_handled = False
if self.exception_handler is not None:
exception_handled = self.exception_handler(exception, self.log)
if not exception_handled:
raise exception

async def start(
self,
*,
Expand All @@ -190,24 +204,37 @@ async def start(
task_status.started()
self.started.set()
assert self._task_group is not None
# wait forever
self._task_group.start_soon(Event().wait)
# wait until stopped
self._task_group.start_soon(self._stopped.wait)
return

async with self._start_lock:
if self._task_group is not None:
raise RuntimeError("WebsocketServer already running")

async with create_task_group() as self._task_group:
task_status.started()
self.started.set()
# wait forever
self._task_group.start_soon(Event().wait)
while True:
try:
async with create_task_group() as self._task_group:
if not self.started.is_set():
task_status.started()
self.started.set()
# wait until stopped
self._task_group.start_soon(self._stopped.wait)
return
except Exception as exception:
self._handle_exception(exception)

async def stop(self) -> None:
"""Stop the WebSocket server."""
if self._task_group is None:
raise RuntimeError("WebsocketServer not running")

self._stopped.set()
self._task_group.cancel_scope.cancel()
self._task_group = None


def exception_logger(exception: Exception, log: Logger) -> bool:
"""An exception handler that logs the exception and discards it."""
log.error("WebsocketServer exception", exc_info=exception)
return True # the exception was handled
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def yws_server(request, unused_tcp_port, websocket_server_api):
)
await ensure_server_running("localhost", unused_tcp_port)
pytest.port = unused_tcp_port
yield unused_tcp_port
yield unused_tcp_port, websocket_server
shutdown_event.set()
except Exception:
pass
Expand Down
18 changes: 18 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
from anyio import sleep

from pycrdt_websocket import exception_logger

pytestmark = pytest.mark.anyio


@pytest.mark.parametrize("websocket_server_api", ["websocket_server_start_stop"], indirect=True)
@pytest.mark.parametrize("yws_server", [{"exception_handler": exception_logger}], indirect=True)
async def test_server_restart(yws_server):
port, server = yws_server

async def raise_error():
raise RuntimeError("foo")

server._task_group.start_soon(raise_error)
await sleep(0.1)
Loading