diff --git a/channels_redis/core.py b/channels_redis/core.py index cde7962..b5f42e2 100644 --- a/channels_redis/core.py +++ b/channels_redis/core.py @@ -179,6 +179,40 @@ class UnsupportedRedis(Exception): pass +class ExpiringCache(collections.defaultdict): + def __init__(self, default, ttl=60, *args, **kw): + collections.defaultdict.__init__(self, default) + self._expires = collections.OrderedDict() + self.ttl = ttl + + def __setitem__(self, k, v): + collections.defaultdict.__setitem__(self, k, v) + self._expires[k] = time.time() + self.ttl + + def __delitem__(self, k): + try: + collections.defaultdict.__delitem__(self, k) + except KeyError: + # RedisChannelLayer itself _does_ periodically clean up this + # dictionary (e.g., when exceptions like asyncio.CancelledError + # occur) + pass + + def expire(self): + expired = [] + for k in self._expires.keys(): + if self._expires[k] < time.time(): + expired.append(k) + else: + # as this is an OrderedDict, every key after this + # was inserted *later*, so if _this_ key is *not* expired, + # the ones after it aren't either (so we can stop iterating) + break + for k in expired: + del self._expires[k] + del self[k] + + class RedisChannelLayer(BaseChannelLayer): """ Redis channel layer. @@ -226,7 +260,7 @@ def __init__( # Event loop they are trying to receive on self.receive_event_loop = None # Buffered messages by process-local channel name - self.receive_buffer = collections.defaultdict(asyncio.Queue) + self.receive_buffer = ExpiringCache(asyncio.Queue, ttl=self.expiry) # Detached channel cleanup tasks self.receive_cleaners = [] # Per-channel cleanup locks to prevent a receive starting and moving @@ -616,6 +650,7 @@ async def group_discard(self, group, channel): key = self._group_key(group) async with self.connection(self.consistent_hash(group)) as connection: await connection.zrem(key, channel) + self.receive_buffer.expire() async def group_send(self, group, message): """ diff --git a/tests/test_core.py b/tests/test_core.py index 29a4f07..0a71a1a 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,12 +1,18 @@ import asyncio import random +import time import async_timeout import pytest from async_generator import async_generator, yield_ from asgiref.sync import async_to_sync -from channels_redis.core import ChannelFull, ConnectionPool, RedisChannelLayer +from channels_redis.core import ( + ChannelFull, + ConnectionPool, + ExpiringCache, + RedisChannelLayer, +) TEST_HOSTS = [("localhost", 6379)] @@ -627,3 +633,43 @@ def test_custom_group_key_format(): channel_layer = RedisChannelLayer(prefix="test_prefix") group_name = channel_layer._group_key("test_group") assert group_name == b"test_prefix:group:test_group" + + +def test_expiring_buffer_default_value(): + buff = ExpiringCache(asyncio.Queue) + assert isinstance(buff["foo"], asyncio.Queue) + + +def test_expiring_buffer_default_ttl(): + buff = ExpiringCache(None) + assert buff.ttl == 60 + + +def test_expiring_buffer_ttl_expiration(): + past = time.time() - 60 + buff = ExpiringCache(None) + + for x in range(100): + buff[x] = "example" + assert len(buff) == 100 + buff.expire() + assert len(buff) == 100 + + for x in range(100): + buff._expires[x] = past + buff["extra"] = "extra" + buff.expire() + assert len(buff) == 1 + assert "extra" in buff + assert len(buff._expires) == 1 + + +def test_expiring_buffer_ttl_already_gone(): + past = time.time() - 60 + buff = ExpiringCache(None) + buff["delete"] = "example" + buff._expires["delete"] = past + del buff["delete"] + buff.expire() + assert len(buff) == 0 + assert len(buff._expires) == 0