diff --git a/channels_redis/core.py b/channels_redis/core.py index f1c0418..293558b 100644 --- a/channels_redis/core.py +++ b/channels_redis/core.py @@ -5,6 +5,7 @@ import hashlib import itertools import logging +import random import sys import time import types @@ -300,12 +301,18 @@ async def send(self, channel, message): else: index = next(self._send_index_generator) async with self.connection(index) as connection: + # Discard old messages based on expiry + await connection.zremrangebyscore( + channel_key, min=0, max=int(time.time()) - int(self.expiry) + ) + # Check the length of the list before send # This can allow the list to leak slightly over capacity, but that's fine. - if await connection.llen(channel_key) >= self.get_capacity(channel): + if await connection.zcount(channel_key) >= self.get_capacity(channel): raise ChannelFull() + # Push onto the list then set it to expire in case it's not consumed - await connection.lpush(channel_key, self.serialize(message)) + await connection.zadd(channel_key, time.time(), self.serialize(message)) await connection.expire(channel_key, int(self.expiry)) def _backup_channel_name(self, channel): @@ -323,9 +330,9 @@ async def _brpop_with_clean(self, index, channel, timeout): # of the main message queue in the proper order; BRPOP must *not* be called # because that would deadlock the server cleanup_script = """ - local backed_up = redis.call('LRANGE', ARGV[2], 0, -1) + local backed_up = redis.call('ZRANGE', ARGV[2], 0, -1) for i = #backed_up, 1, -1 do - redis.call('LPUSH', ARGV[1], backed_up[i]) + redis.call('ZADD', ARGV[1], backed_up[i]) end redis.call('DEL', ARGV[2]) """ @@ -335,7 +342,15 @@ async def _brpop_with_clean(self, index, channel, timeout): # and the script executes atomically... await connection.eval(cleanup_script, keys=[], args=[channel, backup_queue]) # ...and it doesn't matter here either, the message will be safe in the backup. - return await connection.brpoplpush(channel, backup_queue, timeout=timeout) + result = await connection.bzpopmin(channel, timeout=timeout) + + if result is not None: + _, member, timestamp = result + await connection.zadd(backup_queue, float(timestamp), member) + else: + member = None + + return member async def _clean_receive_backup(self, index, channel): """ @@ -343,7 +358,7 @@ async def _clean_receive_backup(self, index, channel): The result isn't interesting as it was already processed. """ async with self.connection(index) as connection: - await connection.brpop(self._backup_channel_name(channel)) + await connection.zpopmin(self._backup_channel_name(channel)) async def receive(self, channel): """ @@ -626,25 +641,30 @@ async def group_send(self, group, message): ) = self._map_channel_keys_to_connection(channel_names, message) for connection_index, channel_redis_keys in connection_to_channel_keys.items(): + # Discard old messages based on expiry + for key in channel_redis_keys: + await connection.zremrangebyscore( + key, min=0, max=int(time.time()) - int(self.expiry) + ) # Create a LUA script specific for this connection. # Make sure to use the message specific to this channel, it is # stored in channel_to_message dict and contains the # __asgi_channel__ key. - group_send_lua = ( - """ local over_capacity = 0 + group_send_lua = """ local over_capacity = 0 for i=1,#KEYS do - if redis.call('LLEN', KEYS[i]) < tonumber(ARGV[i + #KEYS]) then - redis.call('LPUSH', KEYS[i], ARGV[i]) + if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS]) then + redis.call('ZADD', KEYS[i], %f, ARGV[i]) redis.call('EXPIRE', KEYS[i], %d) else over_capacity = over_capacity + 1 end end return over_capacity - """ - % self.expiry + """ % ( + time.time(), + self.expiry, ) # We need to filter the messages to keep those related to the connection @@ -769,12 +789,18 @@ def serialize(self, message): value = msgpack.packb(message, use_bin_type=True) if self.crypter: value = self.crypter.encrypt(value) - return value + + # As we use an sorted set to expire messages we need to guarantee uniqueness, with 12 bytes. + random_prefix = random.getrandbits(8 * 12).to_bytes(12, "big") + return random_prefix + value def deserialize(self, message): """ Deserializes from a byte string. """ + # Removes the random prefix + message = message[12:] + if self.crypter: message = self.crypter.decrypt(message, self.expiry + 10) return msgpack.unpackb(message, raw=False) diff --git a/tests/test_core.py b/tests/test_core.py index 614d9e6..6198565 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -24,6 +24,36 @@ ] +async def send_three_messages_with_delay(channel_name, channel_layer, delay): + await channel_layer.send(channel_name, {"type": "test.message", "text": "First!"}) + + await asyncio.sleep(delay) + + await channel_layer.send(channel_name, {"type": "test.message", "text": "Second!"}) + + await asyncio.sleep(delay) + + await channel_layer.send(channel_name, {"type": "test.message", "text": "Third!"}) + + +async def group_send_three_messages_with_delay(group_name, channel_layer, delay): + await channel_layer.group_send( + group_name, {"type": "test.message", "text": "First!"} + ) + + await asyncio.sleep(delay) + + await channel_layer.group_send( + group_name, {"type": "test.message", "text": "Second!"} + ) + + await asyncio.sleep(delay) + + await channel_layer.group_send( + group_name, {"type": "test.message", "text": "Third!"} + ) + + @pytest.fixture() @async_generator async def channel_layer(): @@ -445,3 +475,139 @@ async def test_random_reset__client_prefix(channel_layer): random.seed(1) channel_layer_2 = RedisChannelLayer() assert channel_layer_1.client_prefix != channel_layer_2.client_prefix + + +@pytest.mark.asyncio +async def test_message_expiry__earliest_message_expires(channel_layer): + expiry = 3 + delay = 2 + channel_layer = RedisChannelLayer(expiry=expiry) + channel_name = await channel_layer.new_channel() + + task = asyncio.ensure_future( + send_three_messages_with_delay(channel_name, channel_layer, delay) + ) + await asyncio.wait_for(task, None) + + # the first message should have expired, we should only see the second message and the third + message = await channel_layer.receive(channel_name) + assert message["type"] == "test.message" + assert message["text"] == "Second!" + + message = await channel_layer.receive(channel_name) + assert message["type"] == "test.message" + assert message["text"] == "Third!" + + # Make sure there's no third message even out of order + with pytest.raises(asyncio.TimeoutError): + async with async_timeout.timeout(1): + await channel_layer.receive(channel_name) + + +@pytest.mark.asyncio +async def test_message_expiry__all_messages_under_expiration_time(channel_layer): + expiry = 3 + delay = 1 + channel_layer = RedisChannelLayer(expiry=expiry) + channel_name = await channel_layer.new_channel() + + task = asyncio.ensure_future( + send_three_messages_with_delay(channel_name, channel_layer, delay) + ) + await asyncio.wait_for(task, None) + + # expiry = 3, total delay under 3, all messages there + message = await channel_layer.receive(channel_name) + assert message["type"] == "test.message" + assert message["text"] == "First!" + + message = await channel_layer.receive(channel_name) + assert message["type"] == "test.message" + assert message["text"] == "Second!" + + message = await channel_layer.receive(channel_name) + assert message["type"] == "test.message" + assert message["text"] == "Third!" + + +@pytest.mark.asyncio +async def test_message_expiry__group_send(channel_layer): + expiry = 3 + delay = 2 + channel_layer = RedisChannelLayer(expiry=expiry) + channel_name = await channel_layer.new_channel() + + await channel_layer.group_add("test-group", channel_name) + + task = asyncio.ensure_future( + group_send_three_messages_with_delay("test-group", channel_layer, delay) + ) + await asyncio.wait_for(task, None) + + # the first message should have expired, we should only see the second message and the third + message = await channel_layer.receive(channel_name) + assert message["type"] == "test.message" + assert message["text"] == "Second!" + + message = await channel_layer.receive(channel_name) + assert message["type"] == "test.message" + assert message["text"] == "Third!" + + # Make sure there's no third message even out of order + with pytest.raises(asyncio.TimeoutError): + async with async_timeout.timeout(1): + await channel_layer.receive(channel_name) + + +@pytest.mark.asyncio +async def test_message_expiry__group_send__one_channel_expires_message(channel_layer): + expiry = 3 + delay = 1 + + channel_layer = RedisChannelLayer(expiry=expiry) + channel_1 = await channel_layer.new_channel() + channel_2 = await channel_layer.new_channel(prefix="channel_2") + + await channel_layer.group_add("test-group", channel_1) + await channel_layer.group_add("test-group", channel_2) + + # Let's give channel_1 one additional message and then sleep + await channel_layer.send(channel_1, {"type": "test.message", "text": "Zero!"}) + await asyncio.sleep(2) + + task = asyncio.ensure_future( + group_send_three_messages_with_delay("test-group", channel_layer, delay) + ) + await asyncio.wait_for(task, None) + + # message Zero! was sent about 2 + 1 + 1 seconds ago and it should have expired + message = await channel_layer.receive(channel_1) + assert message["type"] == "test.message" + assert message["text"] == "First!" + + message = await channel_layer.receive(channel_1) + assert message["type"] == "test.message" + assert message["text"] == "Second!" + + message = await channel_layer.receive(channel_1) + assert message["type"] == "test.message" + assert message["text"] == "Third!" + + # Make sure there's no fourth message even out of order + with pytest.raises(asyncio.TimeoutError): + async with async_timeout.timeout(1): + await channel_layer.receive(channel_1) + + # channel_2 should receive all three messages from group_send + message = await channel_layer.receive(channel_2) + assert message["type"] == "test.message" + assert message["text"] == "First!" + + # the first message should have expired, we should only see the second message and the third + message = await channel_layer.receive(channel_2) + assert message["type"] == "test.message" + assert message["text"] == "Second!" + + message = await channel_layer.receive(channel_2) + assert message["type"] == "test.message" + assert message["text"] == "Third!"