Skip to content

Commit

Permalink
Replace aiosqlite with sqlite-anyio
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Mar 22, 2024
1 parent 6fe12dd commit 1b9ba5f
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 81 deletions.
152 changes: 81 additions & 71 deletions pycrdt_websocket/ystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from pathlib import Path
from typing import AsyncIterator, Awaitable, Callable, cast

import aiosqlite
import anyio
from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group
from anyio.abc import TaskGroup, TaskStatus
from pycrdt import Doc
from sqlite_anyio import Connection, connect

from .yutils import Decoder, get_new_path, write_var_uint

Expand Down Expand Up @@ -83,11 +83,12 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
if self._task_group is not None:
raise RuntimeError("YStore already running")

self.started.set()
self._starting = False
task_status.started()
async with create_task_group() as self._task_group:
self.started.set()
self._starting = False
task_status.started()

def stop(self) -> None:
async def stop(self) -> None:
"""Stop the store."""
if self._task_group is None:
raise RuntimeError("YStore not running")
Expand Down Expand Up @@ -300,6 +301,7 @@ class MySQLiteYStore(SQLiteYStore):
path: str
lock: Lock
db_initialized: Event
_db: Connection

def __init__(
self,
Expand Down Expand Up @@ -340,43 +342,54 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
self._starting = False
task_status.started()

async def stop(self) -> None:
"""Stop the store."""
if self.db_initialized.is_set():
await self._db.close()
await super().stop()

async def _init_db(self):
create_db = False
move_db = False
if not await anyio.Path(self.db_path).exists():
create_db = True
else:
async with self.lock:
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute(
"SELECT count(name) FROM sqlite_master "
"WHERE type='table' and name='yupdates'"
)
table_exists = (await cursor.fetchone())[0]
if table_exists:
cursor = await db.execute("pragma user_version")
version = (await cursor.fetchone())[0]
if version != self.version:
move_db = True
create_db = True
else:
db = await connect(self.db_path)
cursor = await db.cursor()
await cursor.execute(
"SELECT count(name) FROM sqlite_master "
"WHERE type='table' and name='yupdates'"
)
table_exists = (await cursor.fetchone())[0]
if table_exists:
await cursor.execute("pragma user_version")
version = (await cursor.fetchone())[0]
if version != self.version:
move_db = True
create_db = True
else:
create_db = True
await db.close()
if move_db:
new_path = await get_new_path(self.db_path)
self.log.warning("YStore version mismatch, moving %s to %s", self.db_path, new_path)
await anyio.Path(self.db_path).rename(new_path)
if create_db:
async with self.lock:
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
"CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, "
"metadata BLOB, timestamp REAL NOT NULL)"
)
await db.execute(
"CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)"
)
await db.execute(f"PRAGMA user_version = {self.version}")
await db.commit()
db = await connect(self.db_path)
cursor = await db.cursor()
await cursor.execute(
"CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, "
"metadata BLOB, timestamp REAL NOT NULL)"
)
await cursor.execute(
"CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)"
)
await cursor.execute(f"PRAGMA user_version = {self.version}")
await db.commit()
await db.close()
self._db = await connect(self.db_path)
self.db_initialized.set()

async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
Expand All @@ -388,17 +401,17 @@ async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
await self.db_initialized.wait()
try:
async with self.lock:
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
"SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?",
(self.path,),
) as cursor:
found = False
async for update, metadata, timestamp in cursor:
found = True
yield update, metadata, timestamp
if not found:
raise YDocNotFound
cursor = await self._db.cursor()
await cursor.execute(
"SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?",
(self.path,),
)
found = False
for update, metadata, timestamp in await cursor.fetchall():
found = True
yield update, metadata, timestamp
if not found:
raise YDocNotFound
except Exception:
raise YDocNotFound

Expand All @@ -410,38 +423,35 @@ async def write(self, data: bytes) -> None:
"""
await self.db_initialized.wait()
async with self.lock:
async with aiosqlite.connect(self.db_path) as db:
# first, determine time elapsed since last update
cursor = await db.execute(
"SELECT timestamp FROM yupdates WHERE path = ? "
"ORDER BY timestamp DESC LIMIT 1",
(self.path,),
)
row = await cursor.fetchone()
diff = (time.time() - row[0]) if row else 0

if self.document_ttl is not None and diff > self.document_ttl:
# squash updates
ydoc = Doc()
async with db.execute(
"SELECT yupdate FROM yupdates WHERE path = ?", (self.path,)
) as cursor:
async for (update,) in cursor:
ydoc.apply_update(update)
# delete history
await db.execute("DELETE FROM yupdates WHERE path = ?", (self.path,))
# insert squashed updates
squashed_update = ydoc.get_update()
metadata = await self.get_metadata()
await db.execute(
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
(self.path, squashed_update, metadata, time.time()),
)

# finally, write this update to the DB
# first, determine time elapsed since last update
cursor = await self._db.cursor()
await cursor.execute(
"SELECT timestamp FROM yupdates WHERE path = ? ORDER BY timestamp DESC LIMIT 1",
(self.path,),
)
row = await cursor.fetchone()
diff = (time.time() - row[0]) if row else 0

if self.document_ttl is not None and diff > self.document_ttl:
# squash updates
ydoc = Doc()
await cursor.execute("SELECT yupdate FROM yupdates WHERE path = ?", (self.path,))
for (update,) in await cursor.fetchall():
ydoc.apply_update(update)
# delete history
await cursor.execute("DELETE FROM yupdates WHERE path = ?", (self.path,))
# insert squashed updates
squashed_update = ydoc.get_update()
metadata = await self.get_metadata()
await db.execute(
await cursor.execute(
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
(self.path, data, metadata, time.time()),
(self.path, squashed_update, metadata, time.time()),
)
await db.commit()

# finally, write this update to the DB
metadata = await self.get_metadata()
await cursor.execute(
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
(self.path, data, metadata, time.time()),
)
await self._db.commit()
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ classifiers = [
]
dependencies = [
"anyio >=3.6.2,<5",
"aiosqlite >=0.18.0,<1",
"sqlite-anyio >=0.2.0,<0.3.0",
"pycrdt >=0.8.7,<0.9.0",
]

Expand Down Expand Up @@ -68,7 +68,7 @@ include = [

[tool.ruff]
line-length = 99
select = [
lint.select = [
"ASYNC", # flake8-async
"E", "F", "W", # default Flake8
"G", # flake8-logging-format
Expand Down
22 changes: 14 additions & 8 deletions tests/test_ystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from pathlib import Path
from unittest.mock import patch

import aiosqlite
import pytest
from sqlite_anyio import connect

from pycrdt_websocket.ystore import SQLiteYStore, TempFileYStore

Expand Down Expand Up @@ -59,31 +59,36 @@ async def test_ystore(YStore):

assert i == len(data)

await ystore.stop()


@pytest.mark.anyio
async def test_document_ttl_sqlite_ystore(test_ydoc):
store_name = "my_store"
ystore = MySQLiteYStore(store_name, delete_db=True)
await ystore.start()
now = time.time()
db = await connect(ystore.db_path)
cursor = await db.cursor()

for i in range(3):
# assert that adding a record before document TTL doesn't delete document history
with patch("time.time") as mock_time:
mock_time.return_value = now
await ystore.write(test_ydoc.update())
async with aiosqlite.connect(ystore.db_path) as db:
assert (await (await db.execute("SELECT count(*) FROM yupdates")).fetchone())[
0
] == i + 1
assert (await (await cursor.execute("SELECT count(*) FROM yupdates")).fetchone())[
0
] == i + 1

# assert that adding a record after document TTL deletes previous document history
with patch("time.time") as mock_time:
mock_time.return_value = now + ystore.document_ttl + 1
await ystore.write(test_ydoc.update())
async with aiosqlite.connect(ystore.db_path) as db:
# two updates in DB: one squashed update and the new update
assert (await (await db.execute("SELECT count(*) FROM yupdates")).fetchone())[0] == 2
# two updates in DB: one squashed update and the new update
assert (await (await cursor.execute("SELECT count(*) FROM yupdates")).fetchone())[0] == 2

await db.close()
await ystore.stop()


@pytest.mark.anyio
Expand All @@ -97,3 +102,4 @@ async def test_version(YStore, caplog):
await ystore.write(b"foo")
YStore.version = prev_version
assert "YStore version mismatch" in caplog.text
await ystore.stop()

0 comments on commit 1b9ba5f

Please sign in to comment.