Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement a TTL for RedisChannelLayer.receive_buffer (to avoid a memory leak) #213

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion channels_redis/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,40 @@ class UnsupportedRedis(Exception):
pass


class ExpiringCache(collections.defaultdict):
def __init__(self, default, ttl=60, *args, **kw):
collections.defaultdict.__init__(self, default)
self._expires = collections.OrderedDict()
self.ttl = ttl

def __setitem__(self, k, v):
collections.defaultdict.__setitem__(self, k, v)
self._expires[k] = time.time() + self.ttl

def __delitem__(self, k):
try:
collections.defaultdict.__delitem__(self, k)
except KeyError:
# RedisChannelLayer itself _does_ periodically clean up this
# dictionary (e.g., when exceptions like asyncio.CancelledError
# occur)
pass

def expire(self):
expired = []
for k in self._expires.keys():
if self._expires[k] < time.time():
expired.append(k)
else:
# as this is an OrderedDict, every key after this
# was inserted *later*, so if _this_ key is *not* expired,
# the ones after it aren't either (so we can stop iterating)
break
for k in expired:
del self._expires[k]
del self[k]


class RedisChannelLayer(BaseChannelLayer):
"""
Redis channel layer.
Expand Down Expand Up @@ -226,7 +260,7 @@ def __init__(
# Event loop they are trying to receive on
self.receive_event_loop = None
# Buffered messages by process-local channel name
self.receive_buffer = collections.defaultdict(asyncio.Queue)
self.receive_buffer = ExpiringCache(asyncio.Queue, ttl=self.expiry)
# Detached channel cleanup tasks
self.receive_cleaners = []
# Per-channel cleanup locks to prevent a receive starting and moving
Expand Down Expand Up @@ -616,6 +650,7 @@ async def group_discard(self, group, channel):
key = self._group_key(group)
async with self.connection(self.consistent_hash(group)) as connection:
await connection.zrem(key, channel)
self.receive_buffer.expire()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can't rely on this only happening on group_discard (since some may not be using the group extension?).

Doing this sort of thing once per receive() call seems too expensive, though.


async def group_send(self, group, message):
"""
Expand Down
48 changes: 47 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import asyncio
import random
import time

import async_timeout
import pytest
from async_generator import async_generator, yield_

from asgiref.sync import async_to_sync
from channels_redis.core import ChannelFull, ConnectionPool, RedisChannelLayer
from channels_redis.core import (
ChannelFull,
ConnectionPool,
ExpiringCache,
RedisChannelLayer,
)

TEST_HOSTS = [("localhost", 6379)]

Expand Down Expand Up @@ -627,3 +633,43 @@ def test_custom_group_key_format():
channel_layer = RedisChannelLayer(prefix="test_prefix")
group_name = channel_layer._group_key("test_group")
assert group_name == b"test_prefix:group:test_group"


def test_expiring_buffer_default_value():
buff = ExpiringCache(asyncio.Queue)
assert isinstance(buff["foo"], asyncio.Queue)


def test_expiring_buffer_default_ttl():
buff = ExpiringCache(None)
assert buff.ttl == 60


def test_expiring_buffer_ttl_expiration():
past = time.time() - 60
buff = ExpiringCache(None)

for x in range(100):
buff[x] = "example"
assert len(buff) == 100
buff.expire()
assert len(buff) == 100

for x in range(100):
buff._expires[x] = past
buff["extra"] = "extra"
buff.expire()
assert len(buff) == 1
assert "extra" in buff
assert len(buff._expires) == 1


def test_expiring_buffer_ttl_already_gone():
past = time.time() - 60
buff = ExpiringCache(None)
buff["delete"] = "example"
buff._expires["delete"] = past
del buff["delete"]
buff.expire()
assert len(buff) == 0
assert len(buff._expires) == 0