Skip to content

Commit

Permalink
Replace aiosqlite with sqlite-anyio
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Mar 21, 2024
1 parent 6fe12dd commit fc99a62
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 67 deletions.
137 changes: 71 additions & 66 deletions pycrdt_websocket/ystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pathlib import Path
from typing import AsyncIterator, Awaitable, Callable, cast

import aiosqlite
from sqlite_anyio import connect
import anyio
from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group
from anyio.abc import TaskGroup, TaskStatus
Expand Down Expand Up @@ -340,43 +340,49 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
self._starting = False
task_status.started()

async def stop(self) -> None:
"""Stop the store."""
await self._db.close()
await super().stop()

async def _init_db(self):
self._db = await connect(self.db_path)
create_db = False
move_db = False
if not await anyio.Path(self.db_path).exists():
create_db = True
else:
async with self.lock:
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute(
"SELECT count(name) FROM sqlite_master "
"WHERE type='table' and name='yupdates'"
)
table_exists = (await cursor.fetchone())[0]
if table_exists:
cursor = await db.execute("pragma user_version")
version = (await cursor.fetchone())[0]
if version != self.version:
move_db = True
create_db = True
else:
cursor = self._db.cursor()
await cursor.execute(
"SELECT count(name) FROM sqlite_master "
"WHERE type='table' and name='yupdates'"
)
table_exists = (await cursor.fetchone())[0]
if table_exists:
await cursor.execute("pragma user_version")
version = (await cursor.fetchone())[0]
if version != self.version:
move_db = True
create_db = True
else:
create_db = True
if move_db:
new_path = await get_new_path(self.db_path)
self.log.warning("YStore version mismatch, moving %s to %s", self.db_path, new_path)
await anyio.Path(self.db_path).rename(new_path)
if create_db:
async with self.lock:
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
"CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, "
"metadata BLOB, timestamp REAL NOT NULL)"
)
await db.execute(
"CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)"
)
await db.execute(f"PRAGMA user_version = {self.version}")
await db.commit()
cursor = self._db.cursor()
await cursor.execute(
"CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, "
"metadata BLOB, timestamp REAL NOT NULL)"
)
await cursor.execute(
"CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)"
)
await cursor.execute(f"PRAGMA user_version = {self.version}")
await self._db.commit()
self.db_initialized.set()

async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
Expand All @@ -388,17 +394,16 @@ async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
await self.db_initialized.wait()
try:
async with self.lock:
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
"SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?",
(self.path,),
) as cursor:
found = False
async for update, metadata, timestamp in cursor:
found = True
yield update, metadata, timestamp
if not found:
raise YDocNotFound
cursor = await self._db.cursor()
await cursor.execute(
"SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?",
(self.path,),
found = False
async for update, metadata, timestamp in await cursor.fetchall():
found = True
yield update, metadata, timestamp
if not found:
raise YDocNotFound
except Exception:
raise YDocNotFound

Expand All @@ -410,38 +415,38 @@ async def write(self, data: bytes) -> None:
"""
await self.db_initialized.wait()
async with self.lock:
async with aiosqlite.connect(self.db_path) as db:
# first, determine time elapsed since last update
cursor = await db.execute(
"SELECT timestamp FROM yupdates WHERE path = ? "
"ORDER BY timestamp DESC LIMIT 1",
(self.path,),
# first, determine time elapsed since last update
cursor = await self._db.cursor()
await cursor.execute(
"SELECT timestamp FROM yupdates WHERE path = ? "
"ORDER BY timestamp DESC LIMIT 1",
(self.path,),
)
row = await cursor.fetchone()
diff = (time.time() - row[0]) if row else 0

if self.document_ttl is not None and diff > self.document_ttl:
# squash updates
ydoc = Doc()
await cursor.execute(
"SELECT yupdate FROM yupdates WHERE path = ?", (self.path,)
)
row = await cursor.fetchone()
diff = (time.time() - row[0]) if row else 0

if self.document_ttl is not None and diff > self.document_ttl:
# squash updates
ydoc = Doc()
async with db.execute(
"SELECT yupdate FROM yupdates WHERE path = ?", (self.path,)
) as cursor:
async for (update,) in cursor:
ydoc.apply_update(update)
# delete history
await db.execute("DELETE FROM yupdates WHERE path = ?", (self.path,))
# insert squashed updates
squashed_update = ydoc.get_update()
metadata = await self.get_metadata()
await db.execute(
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
(self.path, squashed_update, metadata, time.time()),
)

# finally, write this update to the DB
for (update,) in await cursor.fetchall():
ydoc.apply_update(update)
# delete history
await cursor.execute("DELETE FROM yupdates WHERE path = ?", (self.path,))
# insert squashed updates
squashed_update = ydoc.get_update()
metadata = await self.get_metadata()
await db.execute(
await cursor.execute(
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
(self.path, data, metadata, time.time()),
(self.path, squashed_update, metadata, time.time()),
)
await db.commit()

# finally, write this update to the DB
metadata = await self.get_metadata()
await cursor.execute(
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
(self.path, data, metadata, time.time()),
)
await self._db.commit()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ classifiers = [
]
dependencies = [
"anyio >=3.6.2,<5",
"aiosqlite >=0.18.0,<1",
"sqlite-anyio >=0.2.0,<0.3.0",
"pycrdt >=0.8.7,<0.9.0",
]

Expand Down

0 comments on commit fc99a62

Please sign in to comment.