Skip to content

Commit

Permalink
Used a sorted set to guarantee message expiration.
Browse files Browse the repository at this point in the history
  • Loading branch information
astutejoe authored and carltongibson committed Jun 28, 2020
1 parent 2ae5ad3 commit 4c88088
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 13 deletions.
52 changes: 39 additions & 13 deletions channels_redis/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import hashlib
import itertools
import logging
import random
import sys
import time
import types
Expand Down Expand Up @@ -300,12 +301,18 @@ async def send(self, channel, message):
else:
index = next(self._send_index_generator)
async with self.connection(index) as connection:
# Discard old messages based on expiry
await connection.zremrangebyscore(
channel_key, min=0, max=int(time.time()) - int(self.expiry)
)

# Check the length of the list before send
# This can allow the list to leak slightly over capacity, but that's fine.
if await connection.llen(channel_key) >= self.get_capacity(channel):
if await connection.zcount(channel_key) >= self.get_capacity(channel):
raise ChannelFull()

# Push onto the list then set it to expire in case it's not consumed
await connection.lpush(channel_key, self.serialize(message))
await connection.zadd(channel_key, time.time(), self.serialize(message))
await connection.expire(channel_key, int(self.expiry))

def _backup_channel_name(self, channel):
Expand All @@ -323,9 +330,9 @@ async def _brpop_with_clean(self, index, channel, timeout):
# of the main message queue in the proper order; BRPOP must *not* be called
# because that would deadlock the server
cleanup_script = """
local backed_up = redis.call('LRANGE', ARGV[2], 0, -1)
local backed_up = redis.call('ZRANGE', ARGV[2], 0, -1)
for i = #backed_up, 1, -1 do
redis.call('LPUSH', ARGV[1], backed_up[i])
redis.call('ZADD', ARGV[1], backed_up[i])
end
redis.call('DEL', ARGV[2])
"""
Expand All @@ -335,15 +342,23 @@ async def _brpop_with_clean(self, index, channel, timeout):
# and the script executes atomically...
await connection.eval(cleanup_script, keys=[], args=[channel, backup_queue])
# ...and it doesn't matter here either, the message will be safe in the backup.
return await connection.brpoplpush(channel, backup_queue, timeout=timeout)
result = await connection.bzpopmin(channel, timeout=timeout)

if result is not None:
_, member, timestamp = result
await connection.zadd(backup_queue, float(timestamp), member)
else:
member = None

return member

async def _clean_receive_backup(self, index, channel):
"""
Pop the oldest message off the channel backup queue.
The result isn't interesting as it was already processed.
"""
async with self.connection(index) as connection:
await connection.brpop(self._backup_channel_name(channel))
await connection.zpopmin(self._backup_channel_name(channel))

async def receive(self, channel):
"""
Expand Down Expand Up @@ -626,25 +641,30 @@ async def group_send(self, group, message):
) = self._map_channel_keys_to_connection(channel_names, message)

for connection_index, channel_redis_keys in connection_to_channel_keys.items():
# Discard old messages based on expiry
for key in channel_redis_keys:
await connection.zremrangebyscore(
key, min=0, max=int(time.time()) - int(self.expiry)
)

# Create a LUA script specific for this connection.
# Make sure to use the message specific to this channel, it is
# stored in channel_to_message dict and contains the
# __asgi_channel__ key.

group_send_lua = (
""" local over_capacity = 0
group_send_lua = """ local over_capacity = 0
for i=1,#KEYS do
if redis.call('LLEN', KEYS[i]) < tonumber(ARGV[i + #KEYS]) then
redis.call('LPUSH', KEYS[i], ARGV[i])
if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS]) then
redis.call('ZADD', KEYS[i], %f, ARGV[i])
redis.call('EXPIRE', KEYS[i], %d)
else
over_capacity = over_capacity + 1
end
end
return over_capacity
"""
% self.expiry
""" % (
time.time(),
self.expiry,
)

# We need to filter the messages to keep those related to the connection
Expand Down Expand Up @@ -769,12 +789,18 @@ def serialize(self, message):
value = msgpack.packb(message, use_bin_type=True)
if self.crypter:
value = self.crypter.encrypt(value)
return value

# As we use an sorted set to expire messages we need to guarantee uniqueness, with 12 bytes.
random_prefix = random.getrandbits(8 * 12).to_bytes(12, "big")
return random_prefix + value

def deserialize(self, message):
"""
Deserializes from a byte string.
"""
# Removes the random prefix
message = message[12:]

if self.crypter:
message = self.crypter.decrypt(message, self.expiry + 10)
return msgpack.unpackb(message, raw=False)
Expand Down
166 changes: 166 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,36 @@
]


async def send_three_messages_with_delay(channel_name, channel_layer, delay):
await channel_layer.send(channel_name, {"type": "test.message", "text": "First!"})

await asyncio.sleep(delay)

await channel_layer.send(channel_name, {"type": "test.message", "text": "Second!"})

await asyncio.sleep(delay)

await channel_layer.send(channel_name, {"type": "test.message", "text": "Third!"})


async def group_send_three_messages_with_delay(group_name, channel_layer, delay):
await channel_layer.group_send(
group_name, {"type": "test.message", "text": "First!"}
)

await asyncio.sleep(delay)

await channel_layer.group_send(
group_name, {"type": "test.message", "text": "Second!"}
)

await asyncio.sleep(delay)

await channel_layer.group_send(
group_name, {"type": "test.message", "text": "Third!"}
)


@pytest.fixture()
@async_generator
async def channel_layer():
Expand Down Expand Up @@ -445,3 +475,139 @@ async def test_random_reset__client_prefix(channel_layer):
random.seed(1)
channel_layer_2 = RedisChannelLayer()
assert channel_layer_1.client_prefix != channel_layer_2.client_prefix


@pytest.mark.asyncio
async def test_message_expiry__earliest_message_expires(channel_layer):
expiry = 3
delay = 2
channel_layer = RedisChannelLayer(expiry=expiry)
channel_name = await channel_layer.new_channel()

task = asyncio.ensure_future(
send_three_messages_with_delay(channel_name, channel_layer, delay)
)
await asyncio.wait_for(task, None)

# the first message should have expired, we should only see the second message and the third
message = await channel_layer.receive(channel_name)
assert message["type"] == "test.message"
assert message["text"] == "Second!"

message = await channel_layer.receive(channel_name)
assert message["type"] == "test.message"
assert message["text"] == "Third!"

# Make sure there's no third message even out of order
with pytest.raises(asyncio.TimeoutError):
async with async_timeout.timeout(1):
await channel_layer.receive(channel_name)


@pytest.mark.asyncio
async def test_message_expiry__all_messages_under_expiration_time(channel_layer):
expiry = 3
delay = 1
channel_layer = RedisChannelLayer(expiry=expiry)
channel_name = await channel_layer.new_channel()

task = asyncio.ensure_future(
send_three_messages_with_delay(channel_name, channel_layer, delay)
)
await asyncio.wait_for(task, None)

# expiry = 3, total delay under 3, all messages there
message = await channel_layer.receive(channel_name)
assert message["type"] == "test.message"
assert message["text"] == "First!"

message = await channel_layer.receive(channel_name)
assert message["type"] == "test.message"
assert message["text"] == "Second!"

message = await channel_layer.receive(channel_name)
assert message["type"] == "test.message"
assert message["text"] == "Third!"


@pytest.mark.asyncio
async def test_message_expiry__group_send(channel_layer):
expiry = 3
delay = 2
channel_layer = RedisChannelLayer(expiry=expiry)
channel_name = await channel_layer.new_channel()

await channel_layer.group_add("test-group", channel_name)

task = asyncio.ensure_future(
group_send_three_messages_with_delay("test-group", channel_layer, delay)
)
await asyncio.wait_for(task, None)

# the first message should have expired, we should only see the second message and the third
message = await channel_layer.receive(channel_name)
assert message["type"] == "test.message"
assert message["text"] == "Second!"

message = await channel_layer.receive(channel_name)
assert message["type"] == "test.message"
assert message["text"] == "Third!"

# Make sure there's no third message even out of order
with pytest.raises(asyncio.TimeoutError):
async with async_timeout.timeout(1):
await channel_layer.receive(channel_name)


@pytest.mark.asyncio
async def test_message_expiry__group_send__one_channel_expires_message(channel_layer):
expiry = 3
delay = 1

channel_layer = RedisChannelLayer(expiry=expiry)
channel_1 = await channel_layer.new_channel()
channel_2 = await channel_layer.new_channel(prefix="channel_2")

await channel_layer.group_add("test-group", channel_1)
await channel_layer.group_add("test-group", channel_2)

# Let's give channel_1 one additional message and then sleep
await channel_layer.send(channel_1, {"type": "test.message", "text": "Zero!"})
await asyncio.sleep(2)

task = asyncio.ensure_future(
group_send_three_messages_with_delay("test-group", channel_layer, delay)
)
await asyncio.wait_for(task, None)

# message Zero! was sent about 2 + 1 + 1 seconds ago and it should have expired
message = await channel_layer.receive(channel_1)
assert message["type"] == "test.message"
assert message["text"] == "First!"

message = await channel_layer.receive(channel_1)
assert message["type"] == "test.message"
assert message["text"] == "Second!"

message = await channel_layer.receive(channel_1)
assert message["type"] == "test.message"
assert message["text"] == "Third!"

# Make sure there's no fourth message even out of order
with pytest.raises(asyncio.TimeoutError):
async with async_timeout.timeout(1):
await channel_layer.receive(channel_1)

# channel_2 should receive all three messages from group_send
message = await channel_layer.receive(channel_2)
assert message["type"] == "test.message"
assert message["text"] == "First!"

# the first message should have expired, we should only see the second message and the third
message = await channel_layer.receive(channel_2)
assert message["type"] == "test.message"
assert message["text"] == "Second!"

message = await channel_layer.receive(channel_2)
assert message["type"] == "test.message"
assert message["text"] == "Third!"

0 comments on commit 4c88088

Please sign in to comment.