Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use sqlite-anyio's Connection async context manager and exception logger #51

Merged
merged 1 commit into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 79 additions & 61 deletions pycrdt_websocket/ystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group
from anyio.abc import TaskGroup, TaskStatus
from pycrdt import Doc
from sqlite_anyio import Connection, connect
from sqlite_anyio import Connection, connect, exception_logger

from .yutils import Decoder, get_new_path, write_var_uint

Expand Down Expand Up @@ -388,41 +388,54 @@ async def _init_db(self):
create_db = True
else:
async with self.lock:
db = await connect(self.db_path)
cursor = await db.cursor()
await cursor.execute(
"SELECT count(name) FROM sqlite_master "
"WHERE type='table' and name='yupdates'"
db = await connect(
self.db_path,
exception_handler=exception_logger,
log=self.log,
)
table_exists = (await cursor.fetchone())[0]
if table_exists:
await cursor.execute("pragma user_version")
version = (await cursor.fetchone())[0]
if version != self.version:
move_db = True
async with db:
cursor = await db.cursor()
await cursor.execute(
"SELECT count(name) FROM sqlite_master "
"WHERE type='table' and name='yupdates'"
)
table_exists = (await cursor.fetchone())[0]
if table_exists:
await cursor.execute("pragma user_version")
version = (await cursor.fetchone())[0]
if version != self.version:
move_db = True
create_db = True
else:
create_db = True
else:
create_db = True
await db.close()
if move_db:
new_path = await get_new_path(self.db_path)
self.log.warning("YStore version mismatch, moving %s to %s", self.db_path, new_path)
await anyio.Path(self.db_path).rename(new_path)
if create_db:
async with self.lock:
db = await connect(self.db_path)
cursor = await db.cursor()
await cursor.execute(
"CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, "
"metadata BLOB, timestamp REAL NOT NULL)"
db = await connect(
self.db_path,
exception_handler=exception_logger,
log=self.log,
)
await cursor.execute(
"CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)"
)
await cursor.execute(f"PRAGMA user_version = {self.version}")
await db.commit()
async with db:
cursor = await db.cursor()
await cursor.execute(
"CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, "
"metadata BLOB, timestamp REAL NOT NULL)"
)
await cursor.execute(
"CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)"
)
await cursor.execute(f"PRAGMA user_version = {self.version}")
await db.close()
self._db = await connect(self.db_path)
self._db = await connect(
self.db_path,
exception_handler=exception_logger,
log=self.log,
)
assert self.db_initialized is not None
self.db_initialized.set()

Expand All @@ -437,15 +450,16 @@ async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
await self.db_initialized.wait()
try:
async with self.lock:
cursor = await self._db.cursor()
await cursor.execute(
"SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?",
(self.path,),
)
found = False
for update, metadata, timestamp in await cursor.fetchall():
found = True
yield update, metadata, timestamp
async with self._db:
cursor = await self._db.cursor()
await cursor.execute(
"SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?",
(self.path,),
)
for update, metadata, timestamp in await cursor.fetchall():
found = True
yield update, metadata, timestamp
if not found:
raise YDocNotFound
except Exception:
Expand All @@ -461,35 +475,39 @@ async def write(self, data: bytes) -> None:
raise RuntimeError("YStore not started")
await self.db_initialized.wait()
async with self.lock:
# first, determine time elapsed since last update
cursor = await self._db.cursor()
await cursor.execute(
"SELECT timestamp FROM yupdates WHERE path = ? ORDER BY timestamp DESC LIMIT 1",
(self.path,),
)
row = await cursor.fetchone()
diff = (time.time() - row[0]) if row else 0

if self.document_ttl is not None and diff > self.document_ttl:
# squash updates
ydoc = Doc()
await cursor.execute("SELECT yupdate FROM yupdates WHERE path = ?", (self.path,))
for (update,) in await cursor.fetchall():
ydoc.apply_update(update)
# delete history
await cursor.execute("DELETE FROM yupdates WHERE path = ?", (self.path,))
# insert squashed updates
squashed_update = ydoc.get_update()
async with self._db:
# first, determine time elapsed since last update
cursor = await self._db.cursor()
await cursor.execute(
"SELECT timestamp FROM yupdates WHERE path = ? "
"ORDER BY timestamp DESC LIMIT 1",
(self.path,),
)
row = await cursor.fetchone()
diff = (time.time() - row[0]) if row else 0

if self.document_ttl is not None and diff > self.document_ttl:
# squash updates
ydoc = Doc()
await cursor.execute(
"SELECT yupdate FROM yupdates WHERE path = ?",
(self.path,),
)
for (update,) in await cursor.fetchall():
ydoc.apply_update(update)
# delete history
await cursor.execute("DELETE FROM yupdates WHERE path = ?", (self.path,))
# insert squashed updates
squashed_update = ydoc.get_update()
metadata = await self.get_metadata()
await cursor.execute(
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
(self.path, squashed_update, metadata, time.time()),
)

# finally, write this update to the DB
metadata = await self.get_metadata()
await cursor.execute(
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
(self.path, squashed_update, metadata, time.time()),
(self.path, data, metadata, time.time()),
)

# finally, write this update to the DB
metadata = await self.get_metadata()
await cursor.execute(
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
(self.path, data, metadata, time.time()),
)
await self._db.commit()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ classifiers = [
]
dependencies = [
"anyio >=3.6.2,<5",
"sqlite-anyio >=0.2.0,<0.3.0",
"sqlite-anyio >=0.2.1,<0.3.0",
"pycrdt >=0.8.16,<0.9.0",
]

Expand Down
Loading