Skip to content

Commit

Permalink
Replace aiosqlite with sqlite-anyio (#22)
Browse files Browse the repository at this point in the history
* Replace aiosqlite with sqlite-anyio

* Test on Trio as well

* Fix get_new_path
  • Loading branch information
davidbrochart authored Mar 25, 2024
1 parent 6fe12dd commit a016d14
Show file tree
Hide file tree
Showing 10 changed files with 248 additions and 178 deletions.
152 changes: 81 additions & 71 deletions pycrdt_websocket/ystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from pathlib import Path
from typing import AsyncIterator, Awaitable, Callable, cast

import aiosqlite
import anyio
from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group
from anyio.abc import TaskGroup, TaskStatus
from pycrdt import Doc
from sqlite_anyio import Connection, connect

from .yutils import Decoder, get_new_path, write_var_uint

Expand Down Expand Up @@ -83,11 +83,12 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
if self._task_group is not None:
raise RuntimeError("YStore already running")

self.started.set()
self._starting = False
task_status.started()
async with create_task_group() as self._task_group:
self.started.set()
self._starting = False
task_status.started()

def stop(self) -> None:
async def stop(self) -> None:
"""Stop the store."""
if self._task_group is None:
raise RuntimeError("YStore not running")
Expand Down Expand Up @@ -300,6 +301,7 @@ class MySQLiteYStore(SQLiteYStore):
path: str
lock: Lock
db_initialized: Event
_db: Connection

def __init__(
self,
Expand Down Expand Up @@ -340,43 +342,54 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
self._starting = False
task_status.started()

async def stop(self) -> None:
"""Stop the store."""
if self.db_initialized.is_set():
await self._db.close()
await super().stop()

async def _init_db(self):
create_db = False
move_db = False
if not await anyio.Path(self.db_path).exists():
create_db = True
else:
async with self.lock:
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute(
"SELECT count(name) FROM sqlite_master "
"WHERE type='table' and name='yupdates'"
)
table_exists = (await cursor.fetchone())[0]
if table_exists:
cursor = await db.execute("pragma user_version")
version = (await cursor.fetchone())[0]
if version != self.version:
move_db = True
create_db = True
else:
db = await connect(self.db_path)
cursor = await db.cursor()
await cursor.execute(
"SELECT count(name) FROM sqlite_master "
"WHERE type='table' and name='yupdates'"
)
table_exists = (await cursor.fetchone())[0]
if table_exists:
await cursor.execute("pragma user_version")
version = (await cursor.fetchone())[0]
if version != self.version:
move_db = True
create_db = True
else:
create_db = True
await db.close()
if move_db:
new_path = await get_new_path(self.db_path)
self.log.warning("YStore version mismatch, moving %s to %s", self.db_path, new_path)
await anyio.Path(self.db_path).rename(new_path)
if create_db:
async with self.lock:
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
"CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, "
"metadata BLOB, timestamp REAL NOT NULL)"
)
await db.execute(
"CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)"
)
await db.execute(f"PRAGMA user_version = {self.version}")
await db.commit()
db = await connect(self.db_path)
cursor = await db.cursor()
await cursor.execute(
"CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, "
"metadata BLOB, timestamp REAL NOT NULL)"
)
await cursor.execute(
"CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)"
)
await cursor.execute(f"PRAGMA user_version = {self.version}")
await db.commit()
await db.close()
self._db = await connect(self.db_path)
self.db_initialized.set()

async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
Expand All @@ -388,17 +401,17 @@ async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
await self.db_initialized.wait()
try:
async with self.lock:
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
"SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?",
(self.path,),
) as cursor:
found = False
async for update, metadata, timestamp in cursor:
found = True
yield update, metadata, timestamp
if not found:
raise YDocNotFound
cursor = await self._db.cursor()
await cursor.execute(
"SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?",
(self.path,),
)
found = False
for update, metadata, timestamp in await cursor.fetchall():
found = True
yield update, metadata, timestamp
if not found:
raise YDocNotFound
except Exception:
raise YDocNotFound

Expand All @@ -410,38 +423,35 @@ async def write(self, data: bytes) -> None:
"""
await self.db_initialized.wait()
async with self.lock:
async with aiosqlite.connect(self.db_path) as db:
# first, determine time elapsed since last update
cursor = await db.execute(
"SELECT timestamp FROM yupdates WHERE path = ? "
"ORDER BY timestamp DESC LIMIT 1",
(self.path,),
)
row = await cursor.fetchone()
diff = (time.time() - row[0]) if row else 0

if self.document_ttl is not None and diff > self.document_ttl:
# squash updates
ydoc = Doc()
async with db.execute(
"SELECT yupdate FROM yupdates WHERE path = ?", (self.path,)
) as cursor:
async for (update,) in cursor:
ydoc.apply_update(update)
# delete history
await db.execute("DELETE FROM yupdates WHERE path = ?", (self.path,))
# insert squashed updates
squashed_update = ydoc.get_update()
metadata = await self.get_metadata()
await db.execute(
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
(self.path, squashed_update, metadata, time.time()),
)

# finally, write this update to the DB
# first, determine time elapsed since last update
cursor = await self._db.cursor()
await cursor.execute(
"SELECT timestamp FROM yupdates WHERE path = ? ORDER BY timestamp DESC LIMIT 1",
(self.path,),
)
row = await cursor.fetchone()
diff = (time.time() - row[0]) if row else 0

if self.document_ttl is not None and diff > self.document_ttl:
# squash updates
ydoc = Doc()
await cursor.execute("SELECT yupdate FROM yupdates WHERE path = ?", (self.path,))
for (update,) in await cursor.fetchall():
ydoc.apply_update(update)
# delete history
await cursor.execute("DELETE FROM yupdates WHERE path = ?", (self.path,))
# insert squashed updates
squashed_update = ydoc.get_update()
metadata = await self.get_metadata()
await db.execute(
await cursor.execute(
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
(self.path, data, metadata, time.time()),
(self.path, squashed_update, metadata, time.time()),
)
await db.commit()

# finally, write this update to the DB
metadata = await self.get_metadata()
await cursor.execute(
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
(self.path, data, metadata, time.time()),
)
await self._db.commit()
5 changes: 2 additions & 3 deletions pycrdt_websocket/yutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,9 @@ async def get_new_path(path: str) -> str:
ext = p.suffix
p_noext = p.with_suffix("")
i = 1
dir_list = [p async for p in anyio.Path().iterdir()]
while True:
new_path = f"{p_noext}({i}){ext}"
if new_path not in dir_list:
if not await anyio.Path(new_path).exists():
break
i += 1
return str(new_path)
return new_path
11 changes: 6 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ classifiers = [
]
dependencies = [
"anyio >=3.6.2,<5",
"aiosqlite >=0.18.0,<1",
"sqlite-anyio >=0.2.0,<0.3.0",
"pycrdt >=0.8.7,<0.9.0",
]

Expand All @@ -38,9 +38,10 @@ test = [
"mypy",
"pre-commit",
"pytest",
"pytest-asyncio",
"websockets >=10.0",
"uvicorn",
"httpx-ws >=0.5.2",
"hypercorn >=0.16.0",
"trio >=0.25.0",
"sniffio",
]
docs = [
"mkdocs",
Expand Down Expand Up @@ -68,7 +69,7 @@ include = [

[tool.ruff]
line-length = 99
select = [
lint.select = [
"ASYNC", # flake8-async
"E", "F", "W", # default Flake8
"G", # flake8-logging-format
Expand Down
45 changes: 28 additions & 17 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import subprocess
from functools import partial
from socket import socket

import pytest
from anyio import Event, create_task_group
from hypercorn import Config
from pycrdt import Array, Doc
from websockets import serve
from sniffio import current_async_library
from utils import ensure_server_running

from pycrdt_websocket import WebsocketServer
from pycrdt_websocket import ASGIServer, WebsocketServer


class TestYDoc:
Expand All @@ -23,32 +27,39 @@ def update(self):


@pytest.fixture
async def yws_server(request):
async def yws_server(request, unused_tcp_port):
try:
kwargs = request.param
except Exception:
except AttributeError:
kwargs = {}
websocket_server = WebsocketServer(**kwargs)
app = ASGIServer(websocket_server)
config = Config()
config.bind = [f"localhost:{unused_tcp_port}"]
shutdown_event = Event()
if current_async_library() == "trio":
from hypercorn.trio import serve
else:
from hypercorn.asyncio import serve
try:
async with websocket_server, serve(websocket_server.serve, "127.0.0.1", 1234):
yield websocket_server
async with create_task_group() as tg, websocket_server:
tg.start_soon(
partial(serve, app, config, shutdown_trigger=shutdown_event.wait, mode="asgi")
)
await ensure_server_running("localhost", unused_tcp_port)
yield unused_tcp_port
shutdown_event.set()
except Exception:
pass


@pytest.fixture
def yjs_client(request):
client_id = request.param
p = subprocess.Popen(["node", f"tests/yjs_client_{client_id}.js"])
yield p
p.kill()


@pytest.fixture
def test_ydoc():
return TestYDoc()


@pytest.fixture
def anyio_backend():
return "asyncio"
def unused_tcp_port() -> int:
with socket() as sock:
sock.bind(("localhost", 0))
return sock.getsockname()[1]
Loading

0 comments on commit a016d14

Please sign in to comment.