diff --git a/pycrdt_websocket/ystore.py b/pycrdt_websocket/ystore.py index d746dba..092d115 100644 --- a/pycrdt_websocket/ystore.py +++ b/pycrdt_websocket/ystore.py @@ -10,11 +10,11 @@ from pathlib import Path from typing import AsyncIterator, Awaitable, Callable, cast -import aiosqlite import anyio from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group from anyio.abc import TaskGroup, TaskStatus from pycrdt import Doc +from sqlite_anyio import Connection, connect from .yutils import Decoder, get_new_path, write_var_uint @@ -83,11 +83,12 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): if self._task_group is not None: raise RuntimeError("YStore already running") - self.started.set() - self._starting = False - task_status.started() + async with create_task_group() as self._task_group: + self.started.set() + self._starting = False + task_status.started() - def stop(self) -> None: + async def stop(self) -> None: """Stop the store.""" if self._task_group is None: raise RuntimeError("YStore not running") @@ -300,6 +301,7 @@ class MySQLiteYStore(SQLiteYStore): path: str lock: Lock db_initialized: Event + _db: Connection | None def __init__( self, @@ -319,6 +321,7 @@ def __init__( self.log = log or getLogger(__name__) self.lock = Lock() self.db_initialized = Event() + self._db = None async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): """Start the SQLiteYStore. @@ -340,6 +343,12 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): self._starting = False task_status.started() + async def stop(self) -> None: + """Stop the store.""" + if self._db is not None: + await self._db.close() + await super().stop() + async def _init_db(self): create_db = False move_db = False @@ -347,36 +356,41 @@ async def _init_db(self): create_db = True else: async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - cursor = await db.execute( - "SELECT count(name) FROM sqlite_master " - "WHERE type='table' and name='yupdates'" - ) - table_exists = (await cursor.fetchone())[0] - if table_exists: - cursor = await db.execute("pragma user_version") - version = (await cursor.fetchone())[0] - if version != self.version: - move_db = True - create_db = True - else: + db = await connect(self.db_path) + cursor = await db.cursor() + await cursor.execute( + "SELECT count(name) FROM sqlite_master " + "WHERE type='table' and name='yupdates'" + ) + table_exists = (await cursor.fetchone())[0] + if table_exists: + await cursor.execute("pragma user_version") + version = (await cursor.fetchone())[0] + if version != self.version: + move_db = True create_db = True + else: + create_db = True + await db.close() if move_db: new_path = await get_new_path(self.db_path) self.log.warning("YStore version mismatch, moving %s to %s", self.db_path, new_path) await anyio.Path(self.db_path).rename(new_path) if create_db: async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - await db.execute( - "CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, " - "metadata BLOB, timestamp REAL NOT NULL)" - ) - await db.execute( - "CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)" - ) - await db.execute(f"PRAGMA user_version = {self.version}") - await db.commit() + db = await connect(self.db_path) + cursor = await db.cursor() + await cursor.execute( + "CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, " + "metadata BLOB, timestamp REAL NOT NULL)" + ) + await cursor.execute( + "CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)" + ) + await cursor.execute(f"PRAGMA user_version = {self.version}") + await db.commit() + await db.close() + self._db = await connect(self.db_path) self.db_initialized.set() async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]: @@ -388,17 +402,17 @@ async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]: await self.db_initialized.wait() try: async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - async with db.execute( - "SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?", - (self.path,), - ) as cursor: - found = False - async for update, metadata, timestamp in cursor: - found = True - yield update, metadata, timestamp - if not found: - raise YDocNotFound + cursor = await self._db.cursor() + await cursor.execute( + "SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?", + (self.path,), + ) + found = False + for update, metadata, timestamp in await cursor.fetchall(): + found = True + yield update, metadata, timestamp + if not found: + raise YDocNotFound except Exception: raise YDocNotFound @@ -410,38 +424,35 @@ async def write(self, data: bytes) -> None: """ await self.db_initialized.wait() async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - # first, determine time elapsed since last update - cursor = await db.execute( - "SELECT timestamp FROM yupdates WHERE path = ? " - "ORDER BY timestamp DESC LIMIT 1", - (self.path,), - ) - row = await cursor.fetchone() - diff = (time.time() - row[0]) if row else 0 - - if self.document_ttl is not None and diff > self.document_ttl: - # squash updates - ydoc = Doc() - async with db.execute( - "SELECT yupdate FROM yupdates WHERE path = ?", (self.path,) - ) as cursor: - async for (update,) in cursor: - ydoc.apply_update(update) - # delete history - await db.execute("DELETE FROM yupdates WHERE path = ?", (self.path,)) - # insert squashed updates - squashed_update = ydoc.get_update() - metadata = await self.get_metadata() - await db.execute( - "INSERT INTO yupdates VALUES (?, ?, ?, ?)", - (self.path, squashed_update, metadata, time.time()), - ) - - # finally, write this update to the DB + # first, determine time elapsed since last update + cursor = await self._db.cursor() + await cursor.execute( + "SELECT timestamp FROM yupdates WHERE path = ? ORDER BY timestamp DESC LIMIT 1", + (self.path,), + ) + row = await cursor.fetchone() + diff = (time.time() - row[0]) if row else 0 + + if self.document_ttl is not None and diff > self.document_ttl: + # squash updates + ydoc = Doc() + await cursor.execute("SELECT yupdate FROM yupdates WHERE path = ?", (self.path,)) + for (update,) in await cursor.fetchall(): + ydoc.apply_update(update) + # delete history + await cursor.execute("DELETE FROM yupdates WHERE path = ?", (self.path,)) + # insert squashed updates + squashed_update = ydoc.get_update() metadata = await self.get_metadata() - await db.execute( + await cursor.execute( "INSERT INTO yupdates VALUES (?, ?, ?, ?)", - (self.path, data, metadata, time.time()), + (self.path, squashed_update, metadata, time.time()), ) - await db.commit() + + # finally, write this update to the DB + metadata = await self.get_metadata() + await cursor.execute( + "INSERT INTO yupdates VALUES (?, ?, ?, ?)", + (self.path, data, metadata, time.time()), + ) + await self._db.commit() diff --git a/pyproject.toml b/pyproject.toml index 9f5db78..cc18acb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ classifiers = [ ] dependencies = [ "anyio >=3.6.2,<5", - "aiosqlite >=0.18.0,<1", + "sqlite-anyio >=0.2.0,<0.3.0", "pycrdt >=0.8.7,<0.9.0", ] @@ -68,7 +68,7 @@ include = [ [tool.ruff] line-length = 99 -select = [ +lint.select = [ "ASYNC", # flake8-async "E", "F", "W", # default Flake8 "G", # flake8-logging-format diff --git a/tests/test_ystore.py b/tests/test_ystore.py index 952284d..d374e3a 100644 --- a/tests/test_ystore.py +++ b/tests/test_ystore.py @@ -4,8 +4,8 @@ from pathlib import Path from unittest.mock import patch -import aiosqlite import pytest +from sqlite_anyio import connect from pycrdt_websocket.ystore import SQLiteYStore, TempFileYStore @@ -59,6 +59,8 @@ async def test_ystore(YStore): assert i == len(data) + await ystore.stop() + @pytest.mark.anyio async def test_document_ttl_sqlite_ystore(test_ydoc): @@ -66,24 +68,27 @@ async def test_document_ttl_sqlite_ystore(test_ydoc): ystore = MySQLiteYStore(store_name, delete_db=True) await ystore.start() now = time.time() + db = await connect(ystore.db_path) + cursor = await db.cursor() for i in range(3): # assert that adding a record before document TTL doesn't delete document history with patch("time.time") as mock_time: mock_time.return_value = now await ystore.write(test_ydoc.update()) - async with aiosqlite.connect(ystore.db_path) as db: - assert (await (await db.execute("SELECT count(*) FROM yupdates")).fetchone())[ - 0 - ] == i + 1 + assert (await (await cursor.execute("SELECT count(*) FROM yupdates")).fetchone())[ + 0 + ] == i + 1 # assert that adding a record after document TTL deletes previous document history with patch("time.time") as mock_time: mock_time.return_value = now + ystore.document_ttl + 1 await ystore.write(test_ydoc.update()) - async with aiosqlite.connect(ystore.db_path) as db: - # two updates in DB: one squashed update and the new update - assert (await (await db.execute("SELECT count(*) FROM yupdates")).fetchone())[0] == 2 + # two updates in DB: one squashed update and the new update + assert (await (await cursor.execute("SELECT count(*) FROM yupdates")).fetchone())[0] == 2 + + await db.close() + await ystore.stop() @pytest.mark.anyio @@ -97,3 +102,4 @@ async def test_version(YStore, caplog): await ystore.write(b"foo") YStore.version = prev_version assert "YStore version mismatch" in caplog.text + await ystore.stop()