diff --git a/pycrdt_websocket/django_channels/storage/redis_yroom_storage.py b/pycrdt_websocket/django_channels/storage/redis_yroom_storage.py index c0bd024..3b02919 100644 --- a/pycrdt_websocket/django_channels/storage/redis_yroom_storage.py +++ b/pycrdt_websocket/django_channels/storage/redis_yroom_storage.py @@ -14,14 +14,20 @@ class RedisYRoomStorage(BaseYRoomStorage): room_name: The name of the room. """ - def __init__(self, room_name: str, save_throttle_interval: int | None = None) -> None: + def __init__( + self, + room_name: str, + save_throttle_interval: int | None = None, + redis_expiration_seconds: int | None = 60 * 10, # 10 minutes, + ): super().__init__(room_name) self.save_throttle_interval = save_throttle_interval self.last_saved_at = time.time() self.redis_key = f"document:{self.room_name}" - self.redis = self._make_redis() + self.redis = self.make_redis() + self.redis_expiration_seconds = redis_expiration_seconds async def get_document(self) -> Doc: snapshot = await self.redis.get(self.redis_key) @@ -47,7 +53,11 @@ async def update_document(self, update: bytes): while True: try: pipe.multi() - pipe.set(self.redis_key, updated_snapshot) + pipe.set( + name=self.redis_key, + value=updated_snapshot, + ex=self.redis_expiration_seconds, + ) await pipe.execute() @@ -84,6 +94,12 @@ async def throttled_save_snapshot(self) -> None: self.last_saved_at = time.time() + def make_redis(self): + """Makes a Redis client. + Defaults to a local client""" + + return redis.Redis(host="localhost", port=6379, db=0) + async def close(self): await self.save_snapshot() await self.redis.close() @@ -92,9 +108,3 @@ def _apply_update_to_document(self, document: Doc, update: bytes) -> bytes: document.apply_update(update) return document.get_update() - - def _make_redis(self): - """Makes a Redis client. - Defaults to a local client""" - - return redis.Redis(host="localhost", port=6379, db=0)