diff --git a/pycrdt_websocket/ystore.py b/pycrdt_websocket/ystore.py index 8660c77..0145757 100644 --- a/pycrdt_websocket/ystore.py +++ b/pycrdt_websocket/ystore.py @@ -15,7 +15,7 @@ from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group from anyio.abc import TaskGroup, TaskStatus from pycrdt import Doc -from sqlite_anyio import Connection, connect +from sqlite_anyio import Connection, connect, exception_logger from .yutils import Decoder, get_new_path, write_var_uint @@ -388,21 +388,26 @@ async def _init_db(self): create_db = True else: async with self.lock: - db = await connect(self.db_path) - cursor = await db.cursor() - await cursor.execute( - "SELECT count(name) FROM sqlite_master " - "WHERE type='table' and name='yupdates'" + db = await connect( + self.db_path, + exception_handler=exception_logger, + log=self.log, ) - table_exists = (await cursor.fetchone())[0] - if table_exists: - await cursor.execute("pragma user_version") - version = (await cursor.fetchone())[0] - if version != self.version: - move_db = True + async with db: + cursor = await db.cursor() + await cursor.execute( + "SELECT count(name) FROM sqlite_master " + "WHERE type='table' and name='yupdates'" + ) + table_exists = (await cursor.fetchone())[0] + if table_exists: + await cursor.execute("pragma user_version") + version = (await cursor.fetchone())[0] + if version != self.version: + move_db = True + create_db = True + else: create_db = True - else: - create_db = True await db.close() if move_db: new_path = await get_new_path(self.db_path) @@ -410,19 +415,27 @@ async def _init_db(self): await anyio.Path(self.db_path).rename(new_path) if create_db: async with self.lock: - db = await connect(self.db_path) - cursor = await db.cursor() - await cursor.execute( - "CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, " - "metadata BLOB, timestamp REAL NOT NULL)" + db = await connect( + self.db_path, + exception_handler=exception_logger, + log=self.log, ) - await cursor.execute( - "CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)" - ) - await cursor.execute(f"PRAGMA user_version = {self.version}") - await db.commit() + async with db: + cursor = await db.cursor() + await cursor.execute( + "CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, " + "metadata BLOB, timestamp REAL NOT NULL)" + ) + await cursor.execute( + "CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)" + ) + await cursor.execute(f"PRAGMA user_version = {self.version}") await db.close() - self._db = await connect(self.db_path) + self._db = await connect( + self.db_path, + exception_handler=exception_logger, + log=self.log, + ) assert self.db_initialized is not None self.db_initialized.set() @@ -437,15 +450,16 @@ async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]: await self.db_initialized.wait() try: async with self.lock: - cursor = await self._db.cursor() - await cursor.execute( - "SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?", - (self.path,), - ) found = False - for update, metadata, timestamp in await cursor.fetchall(): - found = True - yield update, metadata, timestamp + async with self._db: + cursor = await self._db.cursor() + await cursor.execute( + "SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?", + (self.path,), + ) + for update, metadata, timestamp in await cursor.fetchall(): + found = True + yield update, metadata, timestamp if not found: raise YDocNotFound except Exception: @@ -461,35 +475,39 @@ async def write(self, data: bytes) -> None: raise RuntimeError("YStore not started") await self.db_initialized.wait() async with self.lock: - # first, determine time elapsed since last update - cursor = await self._db.cursor() - await cursor.execute( - "SELECT timestamp FROM yupdates WHERE path = ? ORDER BY timestamp DESC LIMIT 1", - (self.path,), - ) - row = await cursor.fetchone() - diff = (time.time() - row[0]) if row else 0 - - if self.document_ttl is not None and diff > self.document_ttl: - # squash updates - ydoc = Doc() - await cursor.execute("SELECT yupdate FROM yupdates WHERE path = ?", (self.path,)) - for (update,) in await cursor.fetchall(): - ydoc.apply_update(update) - # delete history - await cursor.execute("DELETE FROM yupdates WHERE path = ?", (self.path,)) - # insert squashed updates - squashed_update = ydoc.get_update() + async with self._db: + # first, determine time elapsed since last update + cursor = await self._db.cursor() + await cursor.execute( + "SELECT timestamp FROM yupdates WHERE path = ? " + "ORDER BY timestamp DESC LIMIT 1", + (self.path,), + ) + row = await cursor.fetchone() + diff = (time.time() - row[0]) if row else 0 + + if self.document_ttl is not None and diff > self.document_ttl: + # squash updates + ydoc = Doc() + await cursor.execute( + "SELECT yupdate FROM yupdates WHERE path = ?", + (self.path,), + ) + for (update,) in await cursor.fetchall(): + ydoc.apply_update(update) + # delete history + await cursor.execute("DELETE FROM yupdates WHERE path = ?", (self.path,)) + # insert squashed updates + squashed_update = ydoc.get_update() + metadata = await self.get_metadata() + await cursor.execute( + "INSERT INTO yupdates VALUES (?, ?, ?, ?)", + (self.path, squashed_update, metadata, time.time()), + ) + + # finally, write this update to the DB metadata = await self.get_metadata() await cursor.execute( "INSERT INTO yupdates VALUES (?, ?, ?, ?)", - (self.path, squashed_update, metadata, time.time()), + (self.path, data, metadata, time.time()), ) - - # finally, write this update to the DB - metadata = await self.get_metadata() - await cursor.execute( - "INSERT INTO yupdates VALUES (?, ?, ?, ?)", - (self.path, data, metadata, time.time()), - ) - await self._db.commit() diff --git a/pyproject.toml b/pyproject.toml index 7878e16..91623e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ classifiers = [ ] dependencies = [ "anyio >=3.6.2,<5", - "sqlite-anyio >=0.2.0,<0.3.0", + "sqlite-anyio >=0.2.1,<0.3.0", "pycrdt >=0.8.16,<0.9.0", ]