From a249b75f71c0b2aadb5a0028976069e19935faee Mon Sep 17 00:00:00 2001 From: David Brochart Date: Fri, 3 May 2024 09:38:22 +0200 Subject: [PATCH] Keep the YStore task group alive --- pycrdt_websocket/ystore.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/pycrdt_websocket/ystore.py b/pycrdt_websocket/ystore.py index 5772189..1bce41e 100644 --- a/pycrdt_websocket/ystore.py +++ b/pycrdt_websocket/ystore.py @@ -28,6 +28,7 @@ class BaseYStore(ABC): metadata_callback: Callable[[], Awaitable[bytes] | bytes] | None = None version = 2 _started: Event | None = None + _stopped: Event | None = None _task_group: TaskGroup | None = None __start_lock: Lock | None = None @@ -50,6 +51,12 @@ def started(self) -> Event: self._started = Event() return self._started + @property + def stopped(self) -> Event: + if self._stopped is None: + self._stopped = Event() + return self._stopped + @property def _start_lock(self) -> Lock: if self.__start_lock is None: @@ -96,12 +103,14 @@ async def start( async with create_task_group() as self._task_group: task_status.started() self.started.set() + await self.stopped.wait() async def stop(self) -> None: """Stop the store.""" if self._task_group is None: raise RuntimeError("YStore not running") + self.stopped.set() self._task_group.cancel_scope.cancel() self._task_group = None @@ -309,7 +318,7 @@ class MySQLiteYStore(SQLiteYStore): document_ttl: int | None = None path: str lock: Lock - db_initialized: Event + db_initialized: Event | None _db: Connection def __init__( @@ -329,6 +338,7 @@ def __init__( self.metadata_callback = metadata_callback self.log = log or getLogger(__name__) self.lock = Lock() + self.db_initialized = None async def start( self, @@ -356,10 +366,11 @@ async def start( self._task_group.start_soon(self._init_db) task_status.started() self.started.set() + await self.stopped.wait() async def stop(self) -> None: """Stop the store.""" - if hasattr(self, "db_initialized") and self.db_initialized.is_set(): + if self.db_initialized is not None and self.db_initialized.is_set(): await self._db.close() await super().stop() @@ -405,6 +416,7 @@ async def _init_db(self): await db.commit() await db.close() self._db = await connect(self.db_path) + assert self.db_initialized is not None self.db_initialized.set() async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]: @@ -413,8 +425,8 @@ async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]: Returns: A tuple of (update, metadata, timestamp) for each update. """ - if not hasattr(self, "db_initialized"): - raise RuntimeError("ystore is not started") + if self.db_initialized is None: + raise RuntimeError("YStore not started") await self.db_initialized.wait() try: async with self.lock: @@ -438,8 +450,8 @@ async def write(self, data: bytes) -> None: Arguments: data: The update to store. """ - if not hasattr(self, "db_initialized"): - raise RuntimeError("ystore is not started") + if self.db_initialized is None: + raise RuntimeError("YStore not started") await self.db_initialized.wait() async with self.lock: # first, determine time elapsed since last update