Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

squash doc history in separate task #58

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 55 additions & 14 deletions tests/test_ystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class MyTempFileYStore(TempFileYStore):

class MySQLiteYStore(SQLiteYStore):
db_path = MY_SQLITE_YSTORE_DB_PATH
document_ttl = 1000
document_ttl = 1

def __init__(self, *args, delete_db=False, **kwargs):
if delete_db:
Expand Down Expand Up @@ -61,29 +61,70 @@ async def test_ystore(YStore):
assert i == len(data)


async def count_yupdates(db):
"""Returns number of yupdates in a SQLite DB given a connection."""
return (await (await db.execute("SELECT count(*) FROM yupdates")).fetchone())[0]


@pytest.mark.asyncio
async def test_document_ttl_sqlite_ystore(test_ydoc):
"""Assert that document history is squashed after the document TTL."""
store_name = "my_store"
ystore = MySQLiteYStore(store_name, delete_db=True)
now = time.time()

for i in range(3):
# assert that adding a record before document TTL doesn't delete document history
with patch("time.time") as mock_time:
mock_time.return_value = now
await ystore.write(test_ydoc.update())
async with aiosqlite.connect(ystore.db_path) as db:
assert (await (await db.execute("SELECT count(*) FROM yupdates")).fetchone())[
0
] == i + 1
await ystore.write(test_ydoc.update())
async with aiosqlite.connect(ystore.db_path) as db:
assert (await count_yupdates(db)) == i + 1

# assert that adding a record after document TTL deletes previous document history
with patch("time.time") as mock_time:
mock_time.return_value = now + ystore.document_ttl + 1
await ystore._squash_task

async with aiosqlite.connect(ystore.db_path) as db:
assert (await count_yupdates(db)) == 1


@pytest.mark.asyncio
async def test_document_ttl_simultaneous_write_sqlite_ystore(test_ydoc):
"""Assert that document history is squashed after the document TTL, and a
write that happens at the same time is also squashed."""
store_name = "my_store"
ystore = MySQLiteYStore(store_name, delete_db=True)

for i in range(3):
await ystore.write(test_ydoc.update())
async with aiosqlite.connect(ystore.db_path) as db:
# two updates in DB: one squashed update and the new update
assert (await (await db.execute("SELECT count(*) FROM yupdates")).fetchone())[0] == 2
assert (await count_yupdates(db)) == i + 1

await asyncio.sleep(ystore.document_ttl)
await ystore.write(test_ydoc.update())
await ystore._squash_task

async with aiosqlite.connect(ystore.db_path) as db:
assert (await count_yupdates(db)) == 1


@pytest.mark.asyncio
async def test_document_ttl_init_sqlite_ystore(test_ydoc):
"""Assert that document history is squashed on init if the document TTL has
already elapsed since last update."""
store_name = "my_store"
ystore = MySQLiteYStore(store_name, delete_db=True)
now = time.time()

with patch("time.time") as mock_time:
mock_time.return_value = now - ystore.document_ttl - 1
for i in range(3):
await ystore.write(test_ydoc.update())
async with aiosqlite.connect(ystore.db_path) as db:
assert (await count_yupdates(db)) == i + 1

del ystore
ystore = MySQLiteYStore(store_name)
await ystore.db_initialized

async with aiosqlite.connect(ystore.db_path) as db:
assert (await count_yupdates(db)) == 1


@pytest.mark.asyncio
Expand Down
76 changes: 50 additions & 26 deletions ypy_websocket/ystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def __init__(self, path: str, metadata_callback: Optional[Callable] = None, log=
self.metadata_callback = metadata_callback
self.log = log or logging.getLogger(__name__)
self.db_initialized = asyncio.create_task(self.init_db())
self._squash_task: Optional[asyncio.Task] = None

async def init_db(self):
create_db = False
Expand Down Expand Up @@ -212,6 +213,17 @@ async def init_db(self):
await db.execute(f"PRAGMA user_version = {self.version}")
await db.commit()

# squash updates if document TTL already elapsed
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute(
"SELECT timestamp FROM yupdates WHERE path = ? ORDER BY timestamp DESC LIMIT 1",
(self.path,),
)
row = await cursor.fetchone()
diff = (time.time() - row[0]) if row else 0
if self.document_ttl is not None and diff > self.document_ttl:
await self._squash()

async def read(self) -> AsyncIterator[Tuple[bytes, bytes, float]]: # type: ignore
await self.db_initialized
try:
Expand All @@ -231,36 +243,48 @@ async def read(self) -> AsyncIterator[Tuple[bytes, bytes, float]]: # type: igno
async def write(self, data: bytes) -> None:
await self.db_initialized
async with aiosqlite.connect(self.db_path) as db:
# first, determine time elapsed since last update
cursor = await db.execute(
"SELECT timestamp FROM yupdates WHERE path = ? ORDER BY timestamp DESC LIMIT 1",
(self.path,),
# write this update to the DB
metadata = await self.get_metadata()
await db.execute(
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
(self.path, data, metadata, time.time()),
)
row = await cursor.fetchone()
diff = (time.time() - row[0]) if row else 0

if self.document_ttl is not None and diff > self.document_ttl:
# squash updates
ydoc = Y.YDoc()
async with db.execute(
"SELECT yupdate FROM yupdates WHERE path = ?", (self.path,)
) as cursor:
async for update, in cursor:
Y.apply_update(ydoc, update)
# delete history
await db.execute("DELETE FROM yupdates WHERE path = ?", (self.path,))
# insert squashed updates
squashed_update = Y.encode_state_as_update(ydoc)
metadata = await self.get_metadata()
await db.execute(
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
(self.path, squashed_update, metadata, time.time()),
)
await db.commit()
# create task that squashes document history after document_ttl
self._create_squash_task()

# finally, write this update to the DB
async def _squash(self):
"""Squashes document history into a single Y update."""
async with aiosqlite.connect(self.db_path) as db:
# squash updates
ydoc = Y.YDoc()
async with db.execute(
"SELECT yupdate FROM yupdates WHERE path = ?", (self.path,)
) as cursor:
async for update, in cursor:
Y.apply_update(ydoc, update)
# delete history
await db.execute("DELETE FROM yupdates WHERE path = ?", (self.path,))
# insert squashed updates
squashed_update = Y.encode_state_as_update(ydoc)
metadata = await self.get_metadata()
await db.execute(
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
(self.path, data, metadata, time.time()),
(self.path, squashed_update, metadata, time.time()),
)
await db.commit()

async def _squash_later(self):
await asyncio.sleep(self.document_ttl)
await self._squash()

def _create_squash_task(self) -> None:
"""Creates a task that squashes document history after self.document_ttl
and binds it to the _squash_task attribute. If a task already exists,
this cancels the existing task."""
if self.document_ttl is None:
return
if self._squash_task is not None:
self._squash_task.cancel()

self._squash_task = asyncio.create_task(self._squash_later())