Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep the YStore task group alive #42

Merged
merged 1 commit into from
May 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions pycrdt_websocket/ystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class BaseYStore(ABC):
metadata_callback: Callable[[], Awaitable[bytes] | bytes] | None = None
version = 2
_started: Event | None = None
_stopped: Event | None = None
_task_group: TaskGroup | None = None
__start_lock: Lock | None = None

Expand All @@ -50,6 +51,12 @@ def started(self) -> Event:
self._started = Event()
return self._started

@property
def stopped(self) -> Event:
if self._stopped is None:
self._stopped = Event()
return self._stopped

@property
def _start_lock(self) -> Lock:
if self.__start_lock is None:
Expand Down Expand Up @@ -96,12 +103,14 @@ async def start(
async with create_task_group() as self._task_group:
task_status.started()
self.started.set()
await self.stopped.wait()

async def stop(self) -> None:
"""Stop the store."""
if self._task_group is None:
raise RuntimeError("YStore not running")

self.stopped.set()
self._task_group.cancel_scope.cancel()
self._task_group = None

Expand Down Expand Up @@ -309,7 +318,7 @@ class MySQLiteYStore(SQLiteYStore):
document_ttl: int | None = None
path: str
lock: Lock
db_initialized: Event
db_initialized: Event | None
_db: Connection

def __init__(
Expand All @@ -329,6 +338,7 @@ def __init__(
self.metadata_callback = metadata_callback
self.log = log or getLogger(__name__)
self.lock = Lock()
self.db_initialized = None

async def start(
self,
Expand Down Expand Up @@ -356,10 +366,11 @@ async def start(
self._task_group.start_soon(self._init_db)
task_status.started()
self.started.set()
await self.stopped.wait()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to make the start method blocking here? In websocketServer, we have used self._task_group.start_soon(self._stopped.wait) to keep the task group alive. Should we use this pattern?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's equivalent, in WebsocketServer the task group context manager doesn't exit as long as tasks are running.


async def stop(self) -> None:
"""Stop the store."""
if hasattr(self, "db_initialized") and self.db_initialized.is_set():
if self.db_initialized is not None and self.db_initialized.is_set():
await self._db.close()
await super().stop()

Expand Down Expand Up @@ -405,6 +416,7 @@ async def _init_db(self):
await db.commit()
await db.close()
self._db = await connect(self.db_path)
assert self.db_initialized is not None
self.db_initialized.set()

async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
Expand All @@ -413,8 +425,8 @@ async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
Returns:
A tuple of (update, metadata, timestamp) for each update.
"""
if not hasattr(self, "db_initialized"):
raise RuntimeError("ystore is not started")
if self.db_initialized is None:
raise RuntimeError("YStore not started")
await self.db_initialized.wait()
try:
async with self.lock:
Expand All @@ -438,8 +450,8 @@ async def write(self, data: bytes) -> None:
Arguments:
data: The update to store.
"""
if not hasattr(self, "db_initialized"):
raise RuntimeError("ystore is not started")
if self.db_initialized is None:
raise RuntimeError("YStore not started")
await self.db_initialized.wait()
async with self.lock:
# first, determine time elapsed since last update
Expand Down
Loading