Skip to content

Commit

Permalink
Use sqlite-anyio's Connection async context manager and exception logger
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Jun 11, 2024
1 parent 5900b31 commit 8b967bc
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 62 deletions.
140 changes: 79 additions & 61 deletions pycrdt_websocket/ystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group
from anyio.abc import TaskGroup, TaskStatus
from pycrdt import Doc
from sqlite_anyio import Connection, connect
from sqlite_anyio import Connection, connect, exception_logger

from .yutils import Decoder, get_new_path, write_var_uint

Expand Down Expand Up @@ -388,41 +388,54 @@ async def _init_db(self):
create_db = True
else:
async with self.lock:
db = await connect(self.db_path)
cursor = await db.cursor()
await cursor.execute(
"SELECT count(name) FROM sqlite_master "
"WHERE type='table' and name='yupdates'"
db = await connect(
self.db_path,
exception_handler=exception_logger,
log=self.log,
)
table_exists = (await cursor.fetchone())[0]
if table_exists:
await cursor.execute("pragma user_version")
version = (await cursor.fetchone())[0]
if version != self.version:
move_db = True
async with db:
cursor = await db.cursor()
await cursor.execute(
"SELECT count(name) FROM sqlite_master "
"WHERE type='table' and name='yupdates'"
)
table_exists = (await cursor.fetchone())[0]
if table_exists:
await cursor.execute("pragma user_version")
version = (await cursor.fetchone())[0]
if version != self.version:
move_db = True
create_db = True
else:
create_db = True
else:
create_db = True
await db.close()
if move_db:
new_path = await get_new_path(self.db_path)
self.log.warning("YStore version mismatch, moving %s to %s", self.db_path, new_path)
await anyio.Path(self.db_path).rename(new_path)
if create_db:
async with self.lock:
db = await connect(self.db_path)
cursor = await db.cursor()
await cursor.execute(
"CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, "
"metadata BLOB, timestamp REAL NOT NULL)"
db = await connect(
self.db_path,
exception_handler=exception_logger,
log=self.log,
)
await cursor.execute(
"CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)"
)
await cursor.execute(f"PRAGMA user_version = {self.version}")
await db.commit()
async with db:
cursor = await db.cursor()
await cursor.execute(
"CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, "
"metadata BLOB, timestamp REAL NOT NULL)"
)
await cursor.execute(
"CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)"
)
await cursor.execute(f"PRAGMA user_version = {self.version}")
await db.close()
self._db = await connect(self.db_path)
self._db = await connect(
self.db_path,
exception_handler=exception_logger,
log=self.log,
)
assert self.db_initialized is not None
self.db_initialized.set()

Expand All @@ -437,15 +450,16 @@ async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
await self.db_initialized.wait()
try:
async with self.lock:
cursor = await self._db.cursor()
await cursor.execute(
"SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?",
(self.path,),
)
found = False
for update, metadata, timestamp in await cursor.fetchall():
found = True
yield update, metadata, timestamp
async with self._db:
cursor = await self._db.cursor()
await cursor.execute(
"SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?",
(self.path,),
)
for update, metadata, timestamp in await cursor.fetchall():
found = True
yield update, metadata, timestamp
if not found:
raise YDocNotFound
except Exception:
Expand All @@ -461,35 +475,39 @@ async def write(self, data: bytes) -> None:
raise RuntimeError("YStore not started")
await self.db_initialized.wait()
async with self.lock:
# first, determine time elapsed since last update
cursor = await self._db.cursor()
await cursor.execute(
"SELECT timestamp FROM yupdates WHERE path = ? ORDER BY timestamp DESC LIMIT 1",
(self.path,),
)
row = await cursor.fetchone()
diff = (time.time() - row[0]) if row else 0

if self.document_ttl is not None and diff > self.document_ttl:
# squash updates
ydoc = Doc()
await cursor.execute("SELECT yupdate FROM yupdates WHERE path = ?", (self.path,))
for (update,) in await cursor.fetchall():
ydoc.apply_update(update)
# delete history
await cursor.execute("DELETE FROM yupdates WHERE path = ?", (self.path,))
# insert squashed updates
squashed_update = ydoc.get_update()
async with self._db:
# first, determine time elapsed since last update
cursor = await self._db.cursor()
await cursor.execute(
"SELECT timestamp FROM yupdates WHERE path = ? "
"ORDER BY timestamp DESC LIMIT 1",
(self.path,),
)
row = await cursor.fetchone()
diff = (time.time() - row[0]) if row else 0

if self.document_ttl is not None and diff > self.document_ttl:
# squash updates
ydoc = Doc()
await cursor.execute(
"SELECT yupdate FROM yupdates WHERE path = ?",
(self.path,),
)
for (update,) in await cursor.fetchall():
ydoc.apply_update(update)
# delete history
await cursor.execute("DELETE FROM yupdates WHERE path = ?", (self.path,))
# insert squashed updates
squashed_update = ydoc.get_update()
metadata = await self.get_metadata()
await cursor.execute(
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
(self.path, squashed_update, metadata, time.time()),
)

# finally, write this update to the DB
metadata = await self.get_metadata()
await cursor.execute(
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
(self.path, squashed_update, metadata, time.time()),
(self.path, data, metadata, time.time()),
)

# finally, write this update to the DB
metadata = await self.get_metadata()
await cursor.execute(
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
(self.path, data, metadata, time.time()),
)
await self._db.commit()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ classifiers = [
]
dependencies = [
"anyio >=3.6.2,<5",
"sqlite-anyio >=0.2.0,<0.3.0",
"sqlite-anyio >=0.2.1,<0.3.0",
"pycrdt >=0.8.16,<0.9.0",
]

Expand Down

0 comments on commit 8b967bc

Please sign in to comment.