Skip to content

Commit

Permalink
Implement YStore versioning (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart authored Nov 23, 2022
1 parent d626b1d commit 87e6186
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 37 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ keywords = [
"yjs",
]
dependencies = [
"aiofiles >=0.8.0,<1",
"aiofiles >=22.1.0,<23",
"aiosqlite >=0.17.0,<1",
"y-py >=0.5.3,<0.6.0",
]
Expand Down
27 changes: 22 additions & 5 deletions tests/test_ystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,21 @@ class MyTempFileYStore(TempFileYStore):
prefix_dir = "test_temp_"


MY_SQLITE_YSTORE_DB_PATH = str(Path(tempfile.mkdtemp(prefix="test_sql_")) / "ystore.db")


class MySQLiteYStore(SQLiteYStore):
db_path = str(Path(tempfile.mkdtemp(prefix="test_sql_")) / "ystore.db")
db_path = MY_SQLITE_YSTORE_DB_PATH

def __del__(self):
os.remove(self.db_path)
def __init__(self, *args, delete_db=False, **kwargs):
if delete_db:
os.remove(self.db_path)
super().__init__(*args, **kwargs)


@pytest.mark.asyncio
@pytest.mark.parametrize("YStore", (MyTempFileYStore, MySQLiteYStore))
async def test_file_ystore(YStore):
async def test_ystore(YStore):
store_name = "my_store"
ystore = YStore(store_name, metadata_callback=MetadataCallback())
data = [b"foo", b"bar", b"baz"]
Expand All @@ -56,7 +61,7 @@ async def test_file_ystore(YStore):
@pytest.mark.asyncio
async def test_document_ttl_sqlite_ystore():
store_name = "my_store"
ystore = MySQLiteYStore(store_name, metadata_callback=MetadataCallback())
ystore = MySQLiteYStore(store_name, metadata_callback=MetadataCallback(), delete_db=True)

await ystore.write(b"a")
async with aiosqlite.connect(ystore.db_path) as db:
Expand All @@ -77,3 +82,15 @@ async def test_document_ttl_sqlite_ystore():
await ystore.write(b"c")
async with aiosqlite.connect(ystore.db_path) as db:
assert (await (await db.execute("SELECT count(*) FROM yupdates")).fetchone())[0] == 1


@pytest.mark.asyncio
@pytest.mark.parametrize("YStore", (MyTempFileYStore, MySQLiteYStore))
async def test_version(YStore, caplog):
store_name = "my_store"
prev_version = YStore.version
YStore.version = -1
ystore = YStore(store_name, metadata_callback=MetadataCallback())
await ystore.write(b"foo")
YStore.version = prev_version
assert "YStore version mismatch" in caplog.text
119 changes: 88 additions & 31 deletions ypy_websocket/ystore.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import asyncio
import logging
import tempfile
import time
from abc import ABC, abstractmethod
from pathlib import Path
from typing import AsyncIterator, Callable, Optional, Tuple

import aiofiles # type: ignore
import aiofiles.os # type: ignore
import aiosqlite # type: ignore
import y_py as Y

from .yutils import Decoder, write_var_uint
from .yutils import Decoder, get_new_path, write_var_uint


class YDocNotFound(Exception):
Expand All @@ -19,6 +21,7 @@ class YDocNotFound(Exception):
class BaseYStore(ABC):

metadata_callback: Optional[Callable] = None
version = 1

@abstractmethod
def __init__(self, path: str, metadata_callback=None):
Expand Down Expand Up @@ -52,19 +55,53 @@ class FileYStore(BaseYStore):
metadata_callback: Optional[Callable]
lock: asyncio.Lock

def __init__(self, path: str, metadata_callback: Optional[Callable] = None):
def __init__(self, path: str, metadata_callback: Optional[Callable] = None, log=None):
self.path = path
self.metadata_callback = metadata_callback
self.log = log or logging.getLogger(__name__)
self.lock = asyncio.Lock()

async def check_version(self) -> int:
if not await aiofiles.os.path.exists(self.path):
version_mismatch = True
else:
version_mismatch = False
move_file = False
async with aiofiles.open(self.path, "rb") as f:
header = await f.read(8)
if header == b"VERSION:":
version = int(await f.readline())
if version == self.version:
offset = await f.tell()
else:
version_mismatch = True
else:
version_mismatch = True
if version_mismatch:
move_file = True
if move_file:
new_path = await get_new_path(self.path)
self.log.warning(f"YStore version mismatch, moving {self.path} to {new_path}")
await aiofiles.os.rename(self.path, new_path)
if version_mismatch:
async with aiofiles.open(self.path, "wb") as f:
version_bytes = f"VERSION:{self.version}\n".encode()
await f.write(version_bytes)
offset = len(version_bytes)
return offset

async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: # type: ignore
async with self.lock:
try:
async with aiofiles.open(self.path, "rb") as f:
data = await f.read()
except BaseException:
if not await aiofiles.os.path.exists(self.path):
raise YDocNotFound
offset = await self.check_version()
async with aiofiles.open(self.path, "rb") as f:
await f.seek(offset)
data = await f.read()
if not data:
raise YDocNotFound
is_data = True
assert data is not None
for d in Decoder(data).read_messages():
if is_data:
update = d
Expand All @@ -75,13 +112,10 @@ async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: # type: ignore

async def write(self, data: bytes) -> None:
parent = Path(self.path).parent
if not parent.exists():
parent.mkdir(parents=True)
mode = "wb"
else:
mode = "ab"
async with self.lock:
async with aiofiles.open(self.path, mode) as f:
await aiofiles.os.makedirs(parent, exist_ok=True)
await self.check_version()
async with aiofiles.open(self.path, "ab") as f:
data_len = write_var_uint(len(data))
await f.write(data_len + data)
metadata = await self.get_metadata()
Expand All @@ -101,9 +135,9 @@ class PrefixTempFileYStore(TempFileYStore):
prefix_dir: Optional[str] = None
base_dir: Optional[str] = None

def __init__(self, path: str, metadata_callback: Optional[Callable] = None):
def __init__(self, path: str, metadata_callback: Optional[Callable] = None, log=None):
full_path = str(Path(self.get_base_dir()) / path)
super().__init__(full_path, metadata_callback=metadata_callback)
super().__init__(full_path, metadata_callback=metadata_callback, log=log)

def get_base_dir(self) -> str:
if self.base_dir is None:
Expand Down Expand Up @@ -131,27 +165,50 @@ class MySQLiteYStore(SQLiteYStore):
# Defaults to 1 day.
document_ttl: int = 24 * 60 * 60
path: str
db_created: asyncio.Event
db_initialized: asyncio.Task

def __init__(self, path: str, metadata_callback: Optional[Callable] = None):
def __init__(self, path: str, metadata_callback: Optional[Callable] = None, log=None):
self.path = path
self.metadata_callback = metadata_callback
self.db_created = asyncio.Event()
asyncio.create_task(self.create_db())

async def create_db(self):
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
"CREATE TABLE IF NOT EXISTS yupdates (path TEXT NOT NULL, yupdate BLOB, metadata BLOB, timestamp REAL NOT NULL)"
)
await db.execute(
"CREATE INDEX IF NOT EXISTS idx_yupdates_path_timestamp ON yupdates (path, timestamp)"
)
await db.commit()
self.db_created.set()
self.log = log or logging.getLogger(__name__)
self.db_initialized = asyncio.create_task(self.init_db())

async def init_db(self):
create_db = False
move_db = False
if not await aiofiles.os.path.exists(self.db_path):
create_db = True
else:
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute(
"SELECT count(name) FROM sqlite_master WHERE type='table' and name='yupdates'"
)
table_exists = (await cursor.fetchone())[0]
if table_exists:
cursor = await db.execute("pragma user_version")
version = (await cursor.fetchone())[0]
if version != self.version:
move_db = True
create_db = True
else:
create_db = True
if move_db:
new_path = await get_new_path(self.db_path)
self.log.warning(f"YStore version mismatch, moving {self.db_path} to {new_path}")
await aiofiles.os.rename(self.db_path, new_path)
if create_db:
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
"CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, metadata BLOB, timestamp REAL NOT NULL)"
)
await db.execute(
"CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)"
)
await db.execute(f"PRAGMA user_version = {self.version}")
await db.commit()

async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: # type: ignore
await self.db_created.wait()
await self.db_initialized
try:
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
Expand All @@ -167,8 +224,8 @@ async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: # type: ignore
raise YDocNotFound

async def write(self, data: bytes) -> None:
await self.db_created.wait()
metadata = await self.get_metadata()
await self.db_initialized
async with aiosqlite.connect(self.db_path) as db:
# first, determine time elapsed since last update
cursor = await db.execute(
Expand Down
16 changes: 16 additions & 0 deletions ypy_websocket/yutils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
from enum import IntEnum
from pathlib import Path
from typing import Optional

import aiofiles.os # type: ignore
import y_py as Y


Expand Down Expand Up @@ -133,3 +135,17 @@ async def sync(ydoc: Y.YDoc, websocket, log):
websocket.path,
)
await websocket.send(msg)


async def get_new_path(path: str) -> str:
p = Path(path)
ext = p.suffix
p_noext = p.with_suffix("")
i = 1
dir_list = await aiofiles.os.listdir()
while True:
new_path = f"{p_noext}({i}){ext}"
if new_path not in dir_list:
break
i += 1
return str(new_path)

0 comments on commit 87e6186

Please sign in to comment.