diff --git a/jupyter_collaboration/app.py b/jupyter_collaboration/app.py index 2c35982e..136c3c0d 100644 --- a/jupyter_collaboration/app.py +++ b/jupyter_collaboration/app.py @@ -12,7 +12,7 @@ from .loaders import FileLoaderMapping from .stores import SQLiteYStore from .utils import AWARENESS_EVENTS_SCHEMA_PATH, EVENTS_SCHEMA_PATH -from .websocketserver import JupyterWebsocketServer +from .websocketserver import JupyterWebsocketServer, exception_logger class YDocExtension(ExtensionApp): @@ -85,6 +85,9 @@ def initialize_handlers(self): rooms_ready=False, auto_clean_rooms=False, ystore_class=self.ystore_class, + # Log exceptions, because we don't want the websocket server + # to _ever_ crash permanently in a live jupyter_server. + exception_handler=exception_logger, log=self.log, ) diff --git a/jupyter_collaboration/handlers.py b/jupyter_collaboration/handlers.py index d042287d..2ba132e2 100644 --- a/jupyter_collaboration/handlers.py +++ b/jupyter_collaboration/handlers.py @@ -7,6 +7,7 @@ import json import time import uuid +from logging import Logger from typing import Any from jupyter_server.auth import authorized @@ -80,6 +81,20 @@ async def prepare(self): if self._websocket_server.room_exists(self._room_id): self.room: YRoom = await self._websocket_server.get_room(self._room_id) else: + # Logging exceptions, instead of raising them here to ensure + # that the y-rooms stay alive even after an exception is seen. + def exception_logger(exception: Exception, log: Logger) -> bool: + """A function that catches any exceptions raised in the websocket + server and logs them. + The protects the y-room's task group from cancelling + anytime an exception is raised. + """ + log.error( + f"Document Room Exception, (room_id={self._room_id or 'unknown'}): ", + exc_info=exception, + ) + return True + if self._room_id.count(":") >= 2: # DocumentRoom file_format, file_type, file_id = decode_file_path(self._room_id) @@ -101,13 +116,18 @@ async def prepare(self): self.event_logger, ystore, self.log, - self._document_save_delay, + exception_handler=exception_logger, + save_delay=self._document_save_delay, ) else: # TransientRoom # it is a transient document (e.g. awareness) - self.room = TransientRoom(self._room_id, self.log) + self.room = TransientRoom( + self._room_id, + log=self.log, + exception_handler=exception_logger, + ) await self._websocket_server.start_room(self.room) self._websocket_server.add_room(self._room_id, self.room) diff --git a/jupyter_collaboration/rooms.py b/jupyter_collaboration/rooms.py index 691943c5..9a97e36f 100644 --- a/jupyter_collaboration/rooms.py +++ b/jupyter_collaboration/rooms.py @@ -5,7 +5,7 @@ import asyncio from logging import Logger -from typing import Any +from typing import Any, Callable from jupyter_events import EventLogger from jupyter_ydoc import ydocs as YDOCS @@ -31,8 +31,9 @@ def __init__( ystore: BaseYStore | None, log: Logger | None, save_delay: float | None = None, + exception_handler: Callable[[Exception, Logger], bool] | None = None, ): - super().__init__(ready=False, ystore=ystore, log=log) + super().__init__(ready=False, ystore=ystore, exception_handler=exception_handler, log=log) self._room_id: str = room_id self._file_format: str = file_format @@ -281,8 +282,13 @@ async def _maybe_save_document(self, saving_document: asyncio.Task | None) -> No class TransientRoom(YRoom): """A Y room for sharing state (e.g. awareness).""" - def __init__(self, room_id: str, log: Logger | None): - super().__init__(log=log) + def __init__( + self, + room_id: str, + log: Logger | None = None, + exception_handler: Callable[[Exception, Logger], bool] | None = None, + ): + super().__init__(log=log, exception_handler=exception_handler) self._room_id = room_id diff --git a/jupyter_collaboration/websocketserver.py b/jupyter_collaboration/websocketserver.py index 365da078..3031dff7 100644 --- a/jupyter_collaboration/websocketserver.py +++ b/jupyter_collaboration/websocketserver.py @@ -5,7 +5,7 @@ import asyncio from logging import Logger -from typing import Any +from typing import Any, Callable from pycrdt_websocket.websocket_server import WebsocketServer, YRoom from pycrdt_websocket.ystore import BaseYStore @@ -16,6 +16,16 @@ class RoomNotFound(LookupError): pass +def exception_logger(exception: Exception, log: Logger) -> bool: + """A function that catches any exceptions raised in the websocket + server and logs them. + This protects the websocket server's task group from cancelling + anytime an exception is raised. + """ + log.error("Jupyter Websocket Server: ", exc_info=exception) + return True + + class JupyterWebsocketServer(WebsocketServer): """Ypy websocket server. @@ -30,9 +40,15 @@ def __init__( ystore_class: BaseYStore, rooms_ready: bool = True, auto_clean_rooms: bool = True, + exception_handler: Callable[[Exception, Logger], bool] | None = None, log: Logger | None = None, ): - super().__init__(rooms_ready, auto_clean_rooms, log) + super().__init__( + rooms_ready=rooms_ready, + auto_clean_rooms=auto_clean_rooms, + exception_handler=exception_handler, + log=log, + ) self.ystore_class = ystore_class self.ypatch_nb = 0 self.connected_users: dict[Any, Any] = {} diff --git a/pyproject.toml b/pyproject.toml index 43bd496e..5bb2aeaf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ classifiers = [ dependencies = [ "jupyter_server>=2.0.0,<3.0.0", "jupyter_ydoc>=2.0.0,<3.0.0", - "pycrdt-websocket>=0.13.0,<0.14.0", + "pycrdt-websocket>=0.13.1,<0.14.0", "jupyter_events>=0.10.0", "jupyter_server_fileid>=0.7.0,<1", "jsonschema>=4.18.0"