diff --git a/tests/test_ystore.py b/tests/test_ystore.py index bdd208f..a669d47 100644 --- a/tests/test_ystore.py +++ b/tests/test_ystore.py @@ -17,7 +17,7 @@ def __init__(self): def __call__(self): future = asyncio.Future() - future.set_result(bytes(self.i)) + future.set_result(str(self.i).encode()) self.i += 1 return future @@ -52,16 +52,18 @@ async def test_ystore(YStore): elif YStore == MySQLiteYStore: assert Path(MySQLiteYStore.db_path).exists() i = 0 - async for d, m in ystore.read(): + async for d, m, t in ystore.read(): assert d == data[i] # data - assert m == bytes(i) # metadata + assert m == str(i).encode() # metadata i += 1 + assert i == len(data) + @pytest.mark.asyncio async def test_document_ttl_sqlite_ystore(test_ydoc): store_name = "my_store" - ystore = MySQLiteYStore(store_name, metadata_callback=MetadataCallback(), delete_db=True) + ystore = MySQLiteYStore(store_name, delete_db=True) now = time.time() for i in range(3): @@ -89,7 +91,7 @@ async def test_version(YStore, caplog): store_name = "my_store" prev_version = YStore.version YStore.version = -1 - ystore = YStore(store_name, metadata_callback=MetadataCallback()) + ystore = YStore(store_name) await ystore.write(b"foo") YStore.version = prev_version assert "YStore version mismatch" in caplog.text diff --git a/ypy_websocket/ystore.py b/ypy_websocket/ystore.py index 793bd8e..374f534 100644 --- a/ypy_websocket/ystore.py +++ b/ypy_websocket/ystore.py @@ -1,5 +1,6 @@ import asyncio import logging +import struct import tempfile import time from abc import ABC, abstractmethod @@ -21,7 +22,7 @@ class YDocNotFound(Exception): class BaseYStore(ABC): metadata_callback: Optional[Callable] = None - version = 1 + version = 2 @abstractmethod def __init__(self, path: str, metadata_callback=None): @@ -44,7 +45,7 @@ async def encode_state_as_update(self, ydoc: Y.YDoc): await self.write(update) async def apply_updates(self, ydoc: Y.YDoc): - async for update, metadata in self.read(): # type: ignore + async for update, *rest in self.read(): # type: ignore Y.apply_update(ydoc, update) # type: ignore @@ -90,7 +91,7 @@ async def check_version(self) -> int: offset = len(version_bytes) return offset - async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: # type: ignore + async def read(self) -> AsyncIterator[Tuple[bytes, bytes, float]]: # type: ignore async with self.lock: if not await aiofiles.os.path.exists(self.path): raise YDocNotFound @@ -100,15 +101,16 @@ async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: # type: ignore data = await f.read() if not data: raise YDocNotFound - is_data = True - assert data is not None + i = 0 for d in Decoder(data).read_messages(): - if is_data: + if i == 0: update = d + elif i == 1: + metadata = d else: - # yield data and metadata - yield update, d - is_data = not is_data + timestamp = struct.unpack(" None: parent = Path(self.path).parent @@ -121,6 +123,9 @@ async def write(self, data: bytes) -> None: metadata = await self.get_metadata() metadata_len = write_var_uint(len(metadata)) await f.write(metadata_len + metadata) + timestamp = struct.pack(" AsyncIterator[Tuple[bytes, bytes]]: # type: ignore + async def read(self) -> AsyncIterator[Tuple[bytes, bytes, float]]: # type: ignore await self.db_initialized try: async with aiosqlite.connect(self.db_path) as db: async with db.execute( - "SELECT yupdate, metadata FROM yupdates WHERE path = ?", (self.path,) + "SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?", (self.path,) ) as cursor: found = False - async for update, metadata in cursor: + async for update, metadata, timestamp in cursor: found = True - yield update, metadata + yield update, metadata, timestamp if not found: raise YDocNotFound except BaseException: diff --git a/ypy_websocket/yutils.py b/ypy_websocket/yutils.py index 4378cf5..a9cb0f7 100644 --- a/ypy_websocket/yutils.py +++ b/ypy_websocket/yutils.py @@ -56,9 +56,7 @@ def __init__(self, stream: bytes): self.i0 = 0 def read_var_uint(self) -> int: - if self.length == 0: - return 0 - if self.length < 0: + if self.length <= 0: raise RuntimeError("Y protocol error") uint = 0 i = 0 @@ -73,9 +71,11 @@ def read_var_uint(self) -> int: return uint def read_message(self) -> Optional[bytes]: + if self.length == 0: + return None length = self.read_var_uint() if length == 0: - return None + return b"" i1 = self.i0 + length message = self.stream[self.i0 : i1] # noqa self.i0 = i1