Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement timestamp in FileYStore, return timestamp in read #57

Merged
merged 2 commits into from
Nov 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions tests/test_ystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self):

def __call__(self):
future = asyncio.Future()
future.set_result(bytes(self.i))
future.set_result(str(self.i).encode())
self.i += 1
return future

Expand All @@ -31,6 +31,7 @@ class MyTempFileYStore(TempFileYStore):

class MySQLiteYStore(SQLiteYStore):
db_path = MY_SQLITE_YSTORE_DB_PATH
document_ttl = 1000

def __init__(self, *args, delete_db=False, **kwargs):
if delete_db:
Expand All @@ -52,16 +53,18 @@ async def test_ystore(YStore):
elif YStore == MySQLiteYStore:
assert Path(MySQLiteYStore.db_path).exists()
i = 0
async for d, m in ystore.read():
async for d, m, t in ystore.read():
assert d == data[i] # data
assert m == bytes(i) # metadata
assert m == str(i).encode() # metadata
i += 1

assert i == len(data)


@pytest.mark.asyncio
async def test_document_ttl_sqlite_ystore(test_ydoc):
store_name = "my_store"
ystore = MySQLiteYStore(store_name, metadata_callback=MetadataCallback(), delete_db=True)
ystore = MySQLiteYStore(store_name, delete_db=True)
now = time.time()

for i in range(3):
Expand Down Expand Up @@ -89,7 +92,7 @@ async def test_version(YStore, caplog):
store_name = "my_store"
prev_version = YStore.version
YStore.version = -1
ystore = YStore(store_name, metadata_callback=MetadataCallback())
ystore = YStore(store_name)
await ystore.write(b"foo")
YStore.version = prev_version
assert "YStore version mismatch" in caplog.text
37 changes: 21 additions & 16 deletions ypy_websocket/ystore.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
import struct
import tempfile
import time
from abc import ABC, abstractmethod
Expand All @@ -21,7 +22,7 @@ class YDocNotFound(Exception):
class BaseYStore(ABC):

metadata_callback: Optional[Callable] = None
version = 1
version = 2

@abstractmethod
def __init__(self, path: str, metadata_callback=None):
Expand All @@ -44,7 +45,7 @@ async def encode_state_as_update(self, ydoc: Y.YDoc):
await self.write(update)

async def apply_updates(self, ydoc: Y.YDoc):
async for update, metadata in self.read(): # type: ignore
async for update, *rest in self.read(): # type: ignore
Y.apply_update(ydoc, update) # type: ignore


Expand Down Expand Up @@ -90,7 +91,7 @@ async def check_version(self) -> int:
offset = len(version_bytes)
return offset

async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: # type: ignore
async def read(self) -> AsyncIterator[Tuple[bytes, bytes, float]]: # type: ignore
async with self.lock:
if not await aiofiles.os.path.exists(self.path):
raise YDocNotFound
Expand All @@ -100,15 +101,16 @@ async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: # type: ignore
data = await f.read()
if not data:
raise YDocNotFound
is_data = True
assert data is not None
i = 0
for d in Decoder(data).read_messages():
if is_data:
if i == 0:
update = d
elif i == 1:
metadata = d
else:
# yield data and metadata
yield update, d
is_data = not is_data
timestamp = struct.unpack("<d", d)[0]
yield update, metadata, timestamp
i = (i + 1) % 3

async def write(self, data: bytes) -> None:
parent = Path(self.path).parent
Expand All @@ -121,6 +123,9 @@ async def write(self, data: bytes) -> None:
metadata = await self.get_metadata()
metadata_len = write_var_uint(len(metadata))
await f.write(metadata_len + metadata)
timestamp = struct.pack("<d", time.time())
timestamp_len = write_var_uint(len(timestamp))
await f.write(timestamp_len + timestamp)


class TempFileYStore(FileYStore):
Expand Down Expand Up @@ -162,8 +167,8 @@ class MySQLiteYStore(SQLiteYStore):
db_path: str = "ystore.db"
# Determines the "time to live" for all documents, i.e. how recent the
# latest update of a document must be before purging document history.
# Defaults to 1 day.
document_ttl: int = 24 * 60 * 60
# Defaults to never purging document history (None).
document_ttl: Optional[int] = None
path: str
db_initialized: asyncio.Task

Expand Down Expand Up @@ -207,17 +212,17 @@ async def init_db(self):
await db.execute(f"PRAGMA user_version = {self.version}")
await db.commit()

async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: # type: ignore
async def read(self) -> AsyncIterator[Tuple[bytes, bytes, float]]: # type: ignore
await self.db_initialized
try:
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
"SELECT yupdate, metadata FROM yupdates WHERE path = ?", (self.path,)
"SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?", (self.path,)
) as cursor:
found = False
async for update, metadata in cursor:
async for update, metadata, timestamp in cursor:
found = True
yield update, metadata
yield update, metadata, timestamp
if not found:
raise YDocNotFound
except BaseException:
Expand All @@ -234,7 +239,7 @@ async def write(self, data: bytes) -> None:
row = await cursor.fetchone()
diff = (time.time() - row[0]) if row else 0

if diff > self.document_ttl:
if self.document_ttl is not None and diff > self.document_ttl:
# squash updates
ydoc = Y.YDoc()
async with db.execute(
Expand Down
8 changes: 4 additions & 4 deletions ypy_websocket/yutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ def __init__(self, stream: bytes):
self.i0 = 0

def read_var_uint(self) -> int:
if self.length == 0:
return 0
if self.length < 0:
if self.length <= 0:
raise RuntimeError("Y protocol error")
uint = 0
i = 0
Expand All @@ -73,9 +71,11 @@ def read_var_uint(self) -> int:
return uint

def read_message(self) -> Optional[bytes]:
if self.length == 0:
return None
length = self.read_var_uint()
if length == 0:
return None
return b""
i1 = self.i0 + length
message = self.stream[self.i0 : i1] # noqa
self.i0 = i1
Expand Down