From a016d148c0a70adec1cff6cc5adc77b9787629cc Mon Sep 17 00:00:00 2001 From: David Brochart Date: Mon, 25 Mar 2024 09:48:00 +0100 Subject: [PATCH] Replace aiosqlite with sqlite-anyio (#22) * Replace aiosqlite with sqlite-anyio * Test on Trio as well * Fix get_new_path --- pycrdt_websocket/ystore.py | 152 ++++++++++++++++++++----------------- pycrdt_websocket/yutils.py | 5 +- pyproject.toml | 11 +-- tests/conftest.py | 45 ++++++----- tests/test_asgi.py | 71 ++++++++--------- tests/test_pycrdt_yjs.py | 59 +++++++------- tests/test_ystore.py | 27 ++++--- tests/utils.py | 50 ++++++++++++ tests/yjs_client_0.js | 3 +- tests/yjs_client_1.js | 3 +- 10 files changed, 248 insertions(+), 178 deletions(-) create mode 100644 tests/utils.py diff --git a/pycrdt_websocket/ystore.py b/pycrdt_websocket/ystore.py index d746dba..39bd168 100644 --- a/pycrdt_websocket/ystore.py +++ b/pycrdt_websocket/ystore.py @@ -10,11 +10,11 @@ from pathlib import Path from typing import AsyncIterator, Awaitable, Callable, cast -import aiosqlite import anyio from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group from anyio.abc import TaskGroup, TaskStatus from pycrdt import Doc +from sqlite_anyio import Connection, connect from .yutils import Decoder, get_new_path, write_var_uint @@ -83,11 +83,12 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): if self._task_group is not None: raise RuntimeError("YStore already running") - self.started.set() - self._starting = False - task_status.started() + async with create_task_group() as self._task_group: + self.started.set() + self._starting = False + task_status.started() - def stop(self) -> None: + async def stop(self) -> None: """Stop the store.""" if self._task_group is None: raise RuntimeError("YStore not running") @@ -300,6 +301,7 @@ class MySQLiteYStore(SQLiteYStore): path: str lock: Lock db_initialized: Event + _db: Connection def __init__( self, @@ -340,6 +342,12 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): self._starting = False task_status.started() + async def stop(self) -> None: + """Stop the store.""" + if self.db_initialized.is_set(): + await self._db.close() + await super().stop() + async def _init_db(self): create_db = False move_db = False @@ -347,36 +355,41 @@ async def _init_db(self): create_db = True else: async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - cursor = await db.execute( - "SELECT count(name) FROM sqlite_master " - "WHERE type='table' and name='yupdates'" - ) - table_exists = (await cursor.fetchone())[0] - if table_exists: - cursor = await db.execute("pragma user_version") - version = (await cursor.fetchone())[0] - if version != self.version: - move_db = True - create_db = True - else: + db = await connect(self.db_path) + cursor = await db.cursor() + await cursor.execute( + "SELECT count(name) FROM sqlite_master " + "WHERE type='table' and name='yupdates'" + ) + table_exists = (await cursor.fetchone())[0] + if table_exists: + await cursor.execute("pragma user_version") + version = (await cursor.fetchone())[0] + if version != self.version: + move_db = True create_db = True + else: + create_db = True + await db.close() if move_db: new_path = await get_new_path(self.db_path) self.log.warning("YStore version mismatch, moving %s to %s", self.db_path, new_path) await anyio.Path(self.db_path).rename(new_path) if create_db: async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - await db.execute( - "CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, " - "metadata BLOB, timestamp REAL NOT NULL)" - ) - await db.execute( - "CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)" - ) - await db.execute(f"PRAGMA user_version = {self.version}") - await db.commit() + db = await connect(self.db_path) + cursor = await db.cursor() + await cursor.execute( + "CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, " + "metadata BLOB, timestamp REAL NOT NULL)" + ) + await cursor.execute( + "CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)" + ) + await cursor.execute(f"PRAGMA user_version = {self.version}") + await db.commit() + await db.close() + self._db = await connect(self.db_path) self.db_initialized.set() async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]: @@ -388,17 +401,17 @@ async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]: await self.db_initialized.wait() try: async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - async with db.execute( - "SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?", - (self.path,), - ) as cursor: - found = False - async for update, metadata, timestamp in cursor: - found = True - yield update, metadata, timestamp - if not found: - raise YDocNotFound + cursor = await self._db.cursor() + await cursor.execute( + "SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?", + (self.path,), + ) + found = False + for update, metadata, timestamp in await cursor.fetchall(): + found = True + yield update, metadata, timestamp + if not found: + raise YDocNotFound except Exception: raise YDocNotFound @@ -410,38 +423,35 @@ async def write(self, data: bytes) -> None: """ await self.db_initialized.wait() async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - # first, determine time elapsed since last update - cursor = await db.execute( - "SELECT timestamp FROM yupdates WHERE path = ? " - "ORDER BY timestamp DESC LIMIT 1", - (self.path,), - ) - row = await cursor.fetchone() - diff = (time.time() - row[0]) if row else 0 - - if self.document_ttl is not None and diff > self.document_ttl: - # squash updates - ydoc = Doc() - async with db.execute( - "SELECT yupdate FROM yupdates WHERE path = ?", (self.path,) - ) as cursor: - async for (update,) in cursor: - ydoc.apply_update(update) - # delete history - await db.execute("DELETE FROM yupdates WHERE path = ?", (self.path,)) - # insert squashed updates - squashed_update = ydoc.get_update() - metadata = await self.get_metadata() - await db.execute( - "INSERT INTO yupdates VALUES (?, ?, ?, ?)", - (self.path, squashed_update, metadata, time.time()), - ) - - # finally, write this update to the DB + # first, determine time elapsed since last update + cursor = await self._db.cursor() + await cursor.execute( + "SELECT timestamp FROM yupdates WHERE path = ? ORDER BY timestamp DESC LIMIT 1", + (self.path,), + ) + row = await cursor.fetchone() + diff = (time.time() - row[0]) if row else 0 + + if self.document_ttl is not None and diff > self.document_ttl: + # squash updates + ydoc = Doc() + await cursor.execute("SELECT yupdate FROM yupdates WHERE path = ?", (self.path,)) + for (update,) in await cursor.fetchall(): + ydoc.apply_update(update) + # delete history + await cursor.execute("DELETE FROM yupdates WHERE path = ?", (self.path,)) + # insert squashed updates + squashed_update = ydoc.get_update() metadata = await self.get_metadata() - await db.execute( + await cursor.execute( "INSERT INTO yupdates VALUES (?, ?, ?, ?)", - (self.path, data, metadata, time.time()), + (self.path, squashed_update, metadata, time.time()), ) - await db.commit() + + # finally, write this update to the DB + metadata = await self.get_metadata() + await cursor.execute( + "INSERT INTO yupdates VALUES (?, ?, ?, ?)", + (self.path, data, metadata, time.time()), + ) + await self._db.commit() diff --git a/pycrdt_websocket/yutils.py b/pycrdt_websocket/yutils.py index 476ccd9..2d363b4 100644 --- a/pycrdt_websocket/yutils.py +++ b/pycrdt_websocket/yutils.py @@ -148,10 +148,9 @@ async def get_new_path(path: str) -> str: ext = p.suffix p_noext = p.with_suffix("") i = 1 - dir_list = [p async for p in anyio.Path().iterdir()] while True: new_path = f"{p_noext}({i}){ext}" - if new_path not in dir_list: + if not await anyio.Path(new_path).exists(): break i += 1 - return str(new_path) + return new_path diff --git a/pyproject.toml b/pyproject.toml index 9f5db78..ebc5bad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ classifiers = [ ] dependencies = [ "anyio >=3.6.2,<5", - "aiosqlite >=0.18.0,<1", + "sqlite-anyio >=0.2.0,<0.3.0", "pycrdt >=0.8.7,<0.9.0", ] @@ -38,9 +38,10 @@ test = [ "mypy", "pre-commit", "pytest", - "pytest-asyncio", - "websockets >=10.0", - "uvicorn", + "httpx-ws >=0.5.2", + "hypercorn >=0.16.0", + "trio >=0.25.0", + "sniffio", ] docs = [ "mkdocs", @@ -68,7 +69,7 @@ include = [ [tool.ruff] line-length = 99 -select = [ +lint.select = [ "ASYNC", # flake8-async "E", "F", "W", # default Flake8 "G", # flake8-logging-format diff --git a/tests/conftest.py b/tests/conftest.py index d7cddd9..7f5f799 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,14 @@ -import subprocess +from functools import partial +from socket import socket import pytest +from anyio import Event, create_task_group +from hypercorn import Config from pycrdt import Array, Doc -from websockets import serve +from sniffio import current_async_library +from utils import ensure_server_running -from pycrdt_websocket import WebsocketServer +from pycrdt_websocket import ASGIServer, WebsocketServer class TestYDoc: @@ -23,32 +27,39 @@ def update(self): @pytest.fixture -async def yws_server(request): +async def yws_server(request, unused_tcp_port): try: kwargs = request.param - except Exception: + except AttributeError: kwargs = {} websocket_server = WebsocketServer(**kwargs) + app = ASGIServer(websocket_server) + config = Config() + config.bind = [f"localhost:{unused_tcp_port}"] + shutdown_event = Event() + if current_async_library() == "trio": + from hypercorn.trio import serve + else: + from hypercorn.asyncio import serve try: - async with websocket_server, serve(websocket_server.serve, "127.0.0.1", 1234): - yield websocket_server + async with create_task_group() as tg, websocket_server: + tg.start_soon( + partial(serve, app, config, shutdown_trigger=shutdown_event.wait, mode="asgi") + ) + await ensure_server_running("localhost", unused_tcp_port) + yield unused_tcp_port + shutdown_event.set() except Exception: pass -@pytest.fixture -def yjs_client(request): - client_id = request.param - p = subprocess.Popen(["node", f"tests/yjs_client_{client_id}.js"]) - yield p - p.kill() - - @pytest.fixture def test_ydoc(): return TestYDoc() @pytest.fixture -def anyio_backend(): - return "asyncio" +def unused_tcp_port() -> int: + with socket() as sock: + sock.bind(("localhost", 0)) + return sock.getsockname()[1] diff --git a/tests/test_asgi.py b/tests/test_asgi.py index 0936364..901a7cb 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -1,43 +1,32 @@ import pytest -import uvicorn -from anyio import create_task_group, sleep +from anyio import sleep +from httpx_ws import aconnect_ws from pycrdt import Doc, Map -from websockets import connect - -from pycrdt_websocket import ASGIServer, WebsocketProvider, WebsocketServer - -websocket_server = WebsocketServer(auto_clean_rooms=False) -app = ASGIServer(websocket_server) - - -@pytest.mark.anyio -async def test_asgi(unused_tcp_port): - # server - config = uvicorn.Config("test_asgi:app", port=unused_tcp_port, log_level="info") - server = uvicorn.Server(config) - async with create_task_group() as tg, websocket_server: - tg.start_soon(server.serve) - while not server.started: - await sleep(0) - - # clients - # client 1 - ydoc1 = Doc() - ydoc1["map"] = ymap1 = Map() - ymap1["key"] = "value" - async with connect( - f"ws://localhost:{unused_tcp_port}/my-roomname" - ) as websocket1, WebsocketProvider(ydoc1, websocket1): - await sleep(0.1) - - # client 2 - ydoc2 = Doc() - async with connect( - f"ws://localhost:{unused_tcp_port}/my-roomname" - ) as websocket2, WebsocketProvider(ydoc2, websocket2): - await sleep(0.1) - - ydoc2["map"] = ymap2 = Map() - assert str(ymap2) == '{"key":"value"}' - - tg.cancel_scope.cancel() +from utils import Websocket + +from pycrdt_websocket import WebsocketProvider + +pytestmark = pytest.mark.anyio + + +@pytest.mark.parametrize("yws_server", [{"auto_clean_rooms": False}], indirect=True) +async def test_asgi(yws_server): + port = yws_server + # client 1 + ydoc1 = Doc() + ydoc1["map"] = ymap1 = Map() + ymap1["key"] = "value" + async with aconnect_ws( + f"http://localhost:{port}/my-roomname" + ) as websocket1, WebsocketProvider(ydoc1, Websocket(websocket1, "my-roomname")): + await sleep(0.1) + + # client 2 + ydoc2 = Doc() + async with aconnect_ws( + f"http://localhost:{port}/my-roomname" + ) as websocket2, WebsocketProvider(ydoc2, Websocket(websocket2, "my-roomname")): + await sleep(0.1) + + ydoc2["map"] = ymap2 = Map() + assert str(ymap2) == '{"key":"value"}' diff --git a/tests/test_pycrdt_yjs.py b/tests/test_pycrdt_yjs.py index bd0e8c3..8950b53 100644 --- a/tests/test_pycrdt_yjs.py +++ b/tests/test_pycrdt_yjs.py @@ -4,11 +4,14 @@ import pytest from anyio import Event, fail_after +from httpx_ws import aconnect_ws from pycrdt import Array, Doc, Map -from websockets import connect +from utils import Websocket, yjs_client from pycrdt_websocket import WebsocketProvider +pytestmark = pytest.mark.anyio + class Change: def __init__(self, event, timeout, ydata, sid, key): @@ -38,32 +41,32 @@ def watch(ydata, key: str | None = None, timeout: float = 1.0): return Change(change_event, timeout, ydata, sid, key) -@pytest.mark.anyio -@pytest.mark.parametrize("yjs_client", "0", indirect=True) -async def test_pycrdt_yjs_0(yws_server, yjs_client): - ydoc = Doc() - async with connect("ws://127.0.0.1:1234/my-roomname") as websocket, WebsocketProvider( - ydoc, websocket - ): - ydoc["map"] = ymap = Map() - for v_in in range(10): - ymap["in"] = float(v_in) - v_out = await watch(ymap, "out").wait() - assert v_out == v_in + 1.0 +async def test_pycrdt_yjs_0(yws_server): + port = yws_server + with yjs_client(0, port): + ydoc = Doc() + async with aconnect_ws( + f"http://localhost:{port}/my-roomname" + ) as websocket, WebsocketProvider(ydoc, Websocket(websocket, "my-roomname")): + ydoc["map"] = ymap = Map() + for v_in in range(10): + ymap["in"] = float(v_in) + v_out = await watch(ymap, "out").wait() + assert v_out == v_in + 1.0 -@pytest.mark.anyio -@pytest.mark.parametrize("yjs_client", "1", indirect=True) -async def test_pycrdt_yjs_1(yws_server, yjs_client): - ydoc = Doc() - ydoc["cells"] = ycells = Array() - ydoc["state"] = ystate = Map() - ycells_change = watch(ycells) - ystate_change = watch(ystate) - async with connect("ws://127.0.0.1:1234/my-roomname") as websocket, WebsocketProvider( - ydoc, websocket - ): - await ycells_change.wait() - await ystate_change.wait() - assert ycells.to_py() == [{"metadata": {"foo": "bar"}, "source": "1 + 2"}] - assert ystate.to_py() == {"state": {"dirty": False}} +async def test_pycrdt_yjs_1(yws_server): + port = yws_server + with yjs_client(1, port): + ydoc = Doc() + ydoc["cells"] = ycells = Array() + ydoc["state"] = ystate = Map() + ycells_change = watch(ycells) + ystate_change = watch(ystate) + async with aconnect_ws( + f"http://localhost:{port}/my-roomname" + ) as websocket, WebsocketProvider(ydoc, Websocket(websocket, "my-roomname")): + await ycells_change.wait() + await ystate_change.wait() + assert ycells.to_py() == [{"metadata": {"foo": "bar"}, "source": "1 + 2"}] + assert ystate.to_py() == {"state": {"dirty": False}} diff --git a/tests/test_ystore.py b/tests/test_ystore.py index 952284d..f94c190 100644 --- a/tests/test_ystore.py +++ b/tests/test_ystore.py @@ -4,11 +4,13 @@ from pathlib import Path from unittest.mock import patch -import aiosqlite import pytest +from sqlite_anyio import connect from pycrdt_websocket.ystore import SQLiteYStore, TempFileYStore +pytestmark = pytest.mark.anyio + class MetadataCallback: def __init__(self): @@ -37,7 +39,6 @@ def __init__(self, *args, delete_db=False, **kwargs): super().__init__(*args, **kwargs) -@pytest.mark.anyio @pytest.mark.parametrize("YStore", (MyTempFileYStore, MySQLiteYStore)) async def test_ystore(YStore): store_name = "my_store" @@ -59,34 +60,37 @@ async def test_ystore(YStore): assert i == len(data) + await ystore.stop() + -@pytest.mark.anyio async def test_document_ttl_sqlite_ystore(test_ydoc): store_name = "my_store" ystore = MySQLiteYStore(store_name, delete_db=True) await ystore.start() now = time.time() + db = await connect(ystore.db_path) + cursor = await db.cursor() for i in range(3): # assert that adding a record before document TTL doesn't delete document history with patch("time.time") as mock_time: mock_time.return_value = now await ystore.write(test_ydoc.update()) - async with aiosqlite.connect(ystore.db_path) as db: - assert (await (await db.execute("SELECT count(*) FROM yupdates")).fetchone())[ - 0 - ] == i + 1 + assert (await (await cursor.execute("SELECT count(*) FROM yupdates")).fetchone())[ + 0 + ] == i + 1 # assert that adding a record after document TTL deletes previous document history with patch("time.time") as mock_time: mock_time.return_value = now + ystore.document_ttl + 1 await ystore.write(test_ydoc.update()) - async with aiosqlite.connect(ystore.db_path) as db: - # two updates in DB: one squashed update and the new update - assert (await (await db.execute("SELECT count(*) FROM yupdates")).fetchone())[0] == 2 + # two updates in DB: one squashed update and the new update + assert (await (await cursor.execute("SELECT count(*) FROM yupdates")).fetchone())[0] == 2 + + await db.close() + await ystore.stop() -@pytest.mark.anyio @pytest.mark.parametrize("YStore", (MyTempFileYStore, MySQLiteYStore)) async def test_version(YStore, caplog): store_name = "my_store" @@ -97,3 +101,4 @@ async def test_version(YStore, caplog): await ystore.write(b"foo") YStore.version = prev_version assert "YStore version mismatch" in caplog.text + await ystore.stop() diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..3815efb --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,50 @@ +import subprocess +from contextlib import contextmanager + +from anyio import Lock, connect_tcp + + +class Websocket: + def __init__(self, websocket, path: str): + self._websocket = websocket + self._path = path + self._send_lock = Lock() + + @property + def path(self) -> str: + return self._path + + def __aiter__(self): + return self + + async def __anext__(self) -> bytes: + try: + message = await self.recv() + except Exception: + raise StopAsyncIteration() + return message + + async def send(self, message: bytes): + async with self._send_lock: + await self._websocket.send_bytes(message) + + async def recv(self) -> bytes: + b = await self._websocket.receive_bytes() + return bytes(b) + + +@contextmanager +def yjs_client(client_id: int, port: int): + p = subprocess.Popen(["node", f"tests/yjs_client_{client_id}.js", str(port)]) + yield p + p.kill() + + +async def ensure_server_running(host: str, port: int) -> None: + while True: + try: + await connect_tcp(host, port) + except OSError: + pass + else: + break diff --git a/tests/yjs_client_0.js b/tests/yjs_client_0.js index 92c87c0..33285c5 100644 --- a/tests/yjs_client_0.js +++ b/tests/yjs_client_0.js @@ -2,6 +2,7 @@ const Y = require('yjs') const WebsocketProvider = require('y-websocket').WebsocketProvider const ws = require('ws') +const port = process.argv[2] const ydoc = new Y.Doc() const ymap = ydoc.getMap('map') @@ -18,7 +19,7 @@ ymap.observe(event => { }) const wsProvider = new WebsocketProvider( - 'ws://127.0.0.1:1234', 'my-roomname', + `ws://127.0.0.1:${port}`, 'my-roomname', ydoc, { WebSocketPolyfill: ws } ) diff --git a/tests/yjs_client_1.js b/tests/yjs_client_1.js index cb743ea..038c4d5 100644 --- a/tests/yjs_client_1.js +++ b/tests/yjs_client_1.js @@ -2,10 +2,11 @@ const Y = require('yjs') const WebsocketProvider = require('y-websocket').WebsocketProvider const ws = require('ws') +const port = process.argv[2] const ydoc = new Y.Doc() const wsProvider = new WebsocketProvider( - 'ws://127.0.0.1:1234', 'my-roomname', + `ws://127.0.0.1:${port}`, 'my-roomname', ydoc, { WebSocketPolyfill: ws } )