Skip to content

Commit

Permalink
implement a TTL for RedisChannelLayer.receive_buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanpetrello committed Jul 27, 2020
1 parent 2075071 commit 2e8097f
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 2 deletions.
37 changes: 36 additions & 1 deletion channels_redis/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,40 @@ class UnsupportedRedis(Exception):
pass


class ExpiringCache(collections.defaultdict):
def __init__(self, default, ttl=60, *args, **kw):
collections.defaultdict.__init__(self, default)
self._expires = collections.OrderedDict()
self.ttl = ttl

def __setitem__(self, k, v):
collections.defaultdict.__setitem__(self, k, v)
self._expires[k] = time.time() + self.ttl

def __delitem__(self, k):
try:
collections.defaultdict.__delitem__(self, k)
except KeyError:
# RedisChannelLayer itself _does_ periodically clean up this
# dictionary (e.g., when exceptions like asyncio.CancelledError
# occur)
pass

def expire(self):
expired = []
for k in self._expires.keys():
if self._expires[k] < time.time():
expired.append(k)
else:
# as this is an OrderedDict, every key after this
# was inserted *later*, so if _this_ key is *not* expired,
# the ones after it aren't either (so we can stop iterating)
break
for k in expired:
del self._expires[k]
del self[k]


class RedisChannelLayer(BaseChannelLayer):
"""
Redis channel layer.
Expand Down Expand Up @@ -226,7 +260,7 @@ def __init__(
# Event loop they are trying to receive on
self.receive_event_loop = None
# Buffered messages by process-local channel name
self.receive_buffer = collections.defaultdict(asyncio.Queue)
self.receive_buffer = ExpiringCache(asyncio.Queue, ttl=self.expiry)
# Detached channel cleanup tasks
self.receive_cleaners = []
# Per-channel cleanup locks to prevent a receive starting and moving
Expand Down Expand Up @@ -616,6 +650,7 @@ async def group_discard(self, group, channel):
key = self._group_key(group)
async with self.connection(self.consistent_hash(group)) as connection:
await connection.zrem(key, channel)
self.receive_buffer.expire()

async def group_send(self, group, message):
"""
Expand Down
48 changes: 47 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import asyncio
import random
import time

import async_timeout
import pytest
from async_generator import async_generator, yield_

from asgiref.sync import async_to_sync
from channels_redis.core import ChannelFull, ConnectionPool, RedisChannelLayer
from channels_redis.core import (
ChannelFull,
ConnectionPool,
RedisChannelLayer,
ExpiringCache,
)

TEST_HOSTS = [("localhost", 6379)]

Expand Down Expand Up @@ -627,3 +633,43 @@ def test_custom_group_key_format():
channel_layer = RedisChannelLayer(prefix="test_prefix")
group_name = channel_layer._group_key("test_group")
assert group_name == b"test_prefix:group:test_group"


def test_expiring_buffer_default_value():
buff = ExpiringCache(asyncio.Queue)
assert isinstance(buff["foo"], asyncio.Queue)


def test_expiring_buffer_default_ttl():
buff = ExpiringCache(None)
assert buff.ttl == 60


def test_expiring_buffer_ttl_expiration():
past = time.time() - 60
buff = ExpiringCache(None)

for x in range(100):
buff[x] = "example"
assert len(buff) == 100
buff.expire()
assert len(buff) == 100

for x in range(100):
buff._expires[x] = past
buff["extra"] = "extra"
buff.expire()
assert len(buff) == 1
assert "extra" in buff
assert len(buff._expires) == 1


def test_expiring_buffer_ttl_already_gone():
past = time.time() - 60
buff = ExpiringCache(None)
buff["delete"] = "example"
buff._expires["delete"] = past
del buff["delete"]
buff.expire()
assert len(buff) == 0
assert len(buff._expires) == 0

0 comments on commit 2e8097f

Please sign in to comment.