From 9a13f126c86e7ad2486231c68fce456773104fcf Mon Sep 17 00:00:00 2001 From: martins Date: Fri, 26 Jul 2019 18:17:45 +0300 Subject: [PATCH 1/5] * rewrites `RedisChannelLayer.receive` specific channel section * also properly receives specific channels of different prefixes --- channels_redis/core.py | 182 ++++++++++++++++------------------------- 1 file changed, 71 insertions(+), 111 deletions(-) diff --git a/channels_redis/core.py b/channels_redis/core.py index 669be2a..6678381 100644 --- a/channels_redis/core.py +++ b/channels_redis/core.py @@ -166,6 +166,66 @@ class UnsupportedRedis(Exception): pass +class ReceiveBuffer: + def __init__(self, channel_layer): + self.loop = None + self.channel_layer = channel_layer + self.getters = collections.defaultdict(collections.deque) + self.buffers = collections.defaultdict(lambda: collections.deque(maxlen=20)) + self.receiver = None + + def __bool__(self): + return bool(self.getters) + + def get(self, channel): + getter = self.loop.create_future() + + if channel in self.buffers: + getter.set_result(self.buffers[channel].popleft()) + if not self.buffers[channel]: + del self.buffers[channel] + else: + getter.channel = channel + getter.add_done_callback(self._getter_done_prematurely) + self.getters[channel].append(getter) + + # ensure receiver is running + if not self.receiver: + self.receiver = asyncio.ensure_future(self.receiver_factory(self.channel_layer.non_local_name(channel))) + + return getter + + def _getter_done_prematurely(self, getter): + channel = getter.channel + self.getters[channel].remove(getter) + if not self.getters[channel]: + del self.getters[channel] + if not self and self.receiver: + self.receiver.cancel() + + def put(self, channel, message): + if channel in self.getters: + getter = self.getters[channel].popleft() + getter.remove_done_callback(self._getter_done_prematurely) + if not self.getters[channel]: + del self.getters[channel] + getter.set_result(message) + else: + self.buffers[channel].append(message) + + async def receiver_factory(self, real_channel): + try: + while self: + message_channel, message = await self.channel_layer.receive_single(real_channel) + if type(message_channel) is list: + for chan in message_channel: + self.put(chan, message) + else: + self.put(message_channel, message) + finally: + self.receiver = None + + class RedisChannelLayer(BaseChannelLayer): """ Redis channel layer. @@ -209,14 +269,8 @@ def __init__( ) # Set up any encryption objects self._setup_encryption(symmetric_encryption_keys) - # Number of coroutines trying to receive right now - self.receive_count = 0 - # The receive lock - self.receive_lock = None - # Event loop they are trying to receive on - self.receive_event_loop = None # Buffered messages by process-local channel name - self.receive_buffer = collections.defaultdict(asyncio.Queue) + self.receive_buffers = collections.defaultdict(lambda: ReceiveBuffer(self)) # Detached channel cleanup tasks self.receive_cleaners = [] # Per-channel cleanup locks to prevent a receive starting and moving @@ -352,110 +406,16 @@ async def receive(self, channel): ), "Wrong client prefix" # Enter receiving section loop = asyncio.get_event_loop() - self.receive_count += 1 - try: - if self.receive_count == 1: - # If we're the first coroutine in, create the receive lock! - self.receive_lock = asyncio.Lock() - self.receive_event_loop = loop - else: - # Otherwise, check our event loop matches - if self.receive_event_loop != loop: - raise RuntimeError( - "Two event loops are trying to receive() on one channel layer at once!" - ) - - # Wait for our message to appear - message = None - while self.receive_buffer[channel].empty(): - tasks = [ - self.receive_lock.acquire(), - self.receive_buffer[channel].get(), - ] - tasks = [asyncio.ensure_future(task) for task in tasks] - try: - done, pending = await asyncio.wait( - tasks, return_when=asyncio.FIRST_COMPLETED - ) - for task in pending: - # Cancel all pending tasks. - task.cancel() - except asyncio.CancelledError: - # Ensure all tasks are cancelled if we are cancelled. - # Also see: https://bugs.python.org/issue23859 - del self.receive_buffer[channel] - for task in tasks: - if not task.cancel(): - assert task.done() - if task.result() is True: - self.receive_lock.release() - - raise - - message, token, exception = None, None, None - for task in done: - try: - result = task.result() - except Exception as error: # NOQA - # We should not propagate exceptions immediately as otherwise this may cause - # the lock to be held and never be released. - exception = error - continue - - if result is True: - token = result - else: - assert isinstance(result, dict) - message = result - - if message or exception: - if token: - # We will not be receving as we already have the message. - self.receive_lock.release() - - if exception: - raise exception - else: - break - else: - assert token - - # We hold the receive lock, receive and then release it. - try: - # There is no interruption point from when the message is - # unpacked in receive_single to when we get back here, so - # the following lines are essentially atomic. - message_channel, message = await self.receive_single( - real_channel - ) - if type(message_channel) is list: - for chan in message_channel: - self.receive_buffer[chan].put_nowait(message) - else: - self.receive_buffer[message_channel].put_nowait(message) - message = None - except: - del self.receive_buffer[channel] - raise - finally: - self.receive_lock.release() - - # We know there's a message available, because there - # couldn't have been any interruption between empty() and here - if message is None: - message = self.receive_buffer[channel].get_nowait() - - if self.receive_buffer[channel].empty(): - del self.receive_buffer[channel] - return message - - finally: - self.receive_count -= 1 - # If we were the last out, drop the receive lock - if self.receive_count == 0: - assert not self.receive_lock.locked() - self.receive_lock = None - self.receive_event_loop = None + receive_buffer = self.receive_buffers[real_channel] + + # Check our event loop matches + if receive_buffer.loop != loop and receive_buffer.receiver: + raise RuntimeError( + "Two event loops are trying to receive() on one channel layer at once!" + ) + else: + receive_buffer.loop = loop + return await receive_buffer.get(channel) else: # Do a plain direct receive return (await self.receive_single(channel))[1] From 859d53248fd1174ba094d380eda6c13d2f124522 Mon Sep 17 00:00:00 2001 From: martins Date: Fri, 26 Jul 2019 18:45:19 +0300 Subject: [PATCH 2/5] make black happy --- channels_redis/core.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/channels_redis/core.py b/channels_redis/core.py index 6678381..785f03d 100644 --- a/channels_redis/core.py +++ b/channels_redis/core.py @@ -173,13 +173,13 @@ def __init__(self, channel_layer): self.getters = collections.defaultdict(collections.deque) self.buffers = collections.defaultdict(lambda: collections.deque(maxlen=20)) self.receiver = None - + def __bool__(self): return bool(self.getters) - + def get(self, channel): getter = self.loop.create_future() - + if channel in self.buffers: getter.set_result(self.buffers[channel].popleft()) if not self.buffers[channel]: @@ -188,13 +188,15 @@ def get(self, channel): getter.channel = channel getter.add_done_callback(self._getter_done_prematurely) self.getters[channel].append(getter) - + # ensure receiver is running if not self.receiver: - self.receiver = asyncio.ensure_future(self.receiver_factory(self.channel_layer.non_local_name(channel))) - + self.receiver = asyncio.ensure_future( + self.receiver_factory(self.channel_layer.non_local_name(channel)) + ) + return getter - + def _getter_done_prematurely(self, getter): channel = getter.channel self.getters[channel].remove(getter) @@ -202,7 +204,7 @@ def _getter_done_prematurely(self, getter): del self.getters[channel] if not self and self.receiver: self.receiver.cancel() - + def put(self, channel, message): if channel in self.getters: getter = self.getters[channel].popleft() @@ -212,11 +214,13 @@ def put(self, channel, message): getter.set_result(message) else: self.buffers[channel].append(message) - + async def receiver_factory(self, real_channel): try: while self: - message_channel, message = await self.channel_layer.receive_single(real_channel) + message_channel, message = await self.channel_layer.receive_single( + real_channel + ) if type(message_channel) is list: for chan in message_channel: self.put(chan, message) From d623ec39a7388285e374edead17d01903d925a47 Mon Sep 17 00:00:00 2001 From: martins Date: Tue, 30 Jul 2019 13:41:43 +0300 Subject: [PATCH 3/5] * tweaks ReceiveBuffer * adds some tests for ReceiveBuffer --- channels_redis/core.py | 36 ++++++++++++++++------- tests/test_core.py | 67 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 91 insertions(+), 12 deletions(-) diff --git a/channels_redis/core.py b/channels_redis/core.py index 785f03d..40d0e79 100644 --- a/channels_redis/core.py +++ b/channels_redis/core.py @@ -167,9 +167,17 @@ class UnsupportedRedis(Exception): class ReceiveBuffer: - def __init__(self, channel_layer): + """ + Receive buffer + + It manages waiters and buffers messages for all specific channels under the same 'real channel' + Also manages the receive loop for the 'real channel' + """ + + def __init__(self, receive_single, real_channel): self.loop = None - self.channel_layer = channel_layer + self.real_channel = real_channel + self.receive_single = receive_single self.getters = collections.defaultdict(collections.deque) self.buffers = collections.defaultdict(lambda: collections.deque(maxlen=20)) self.receiver = None @@ -178,6 +186,13 @@ def __bool__(self): return bool(self.getters) def get(self, channel): + """ + :param channel: name of the channel + :return: Future for the next message on channel + """ + assert channel.startswith( + self.real_channel + ), "channel not managed by this buffer" getter = self.loop.create_future() if channel in self.buffers: @@ -191,10 +206,7 @@ def get(self, channel): # ensure receiver is running if not self.receiver: - self.receiver = asyncio.ensure_future( - self.receiver_factory(self.channel_layer.non_local_name(channel)) - ) - + self.receiver = asyncio.ensure_future(self.receiver_factory()) return getter def _getter_done_prematurely(self, getter): @@ -215,12 +227,10 @@ def put(self, channel, message): else: self.buffers[channel].append(message) - async def receiver_factory(self, real_channel): + async def receiver_factory(self): try: while self: - message_channel, message = await self.channel_layer.receive_single( - real_channel - ) + message_channel, message = await self.receive_single(self.real_channel) if type(message_channel) is list: for chan in message_channel: self.put(chan, message) @@ -274,7 +284,7 @@ def __init__( # Set up any encryption objects self._setup_encryption(symmetric_encryption_keys) # Buffered messages by process-local channel name - self.receive_buffers = collections.defaultdict(lambda: ReceiveBuffer(self)) + self.receive_buffers = {} # Detached channel cleanup tasks self.receive_cleaners = [] # Per-channel cleanup locks to prevent a receive starting and moving @@ -410,6 +420,10 @@ async def receive(self, channel): ), "Wrong client prefix" # Enter receiving section loop = asyncio.get_event_loop() + if real_channel not in self.receive_buffers: + self.receive_buffers[real_channel] = ReceiveBuffer( + self.receive_single, real_channel + ) receive_buffer = self.receive_buffers[real_channel] # Check our event loop matches diff --git a/tests/test_core.py b/tests/test_core.py index dbf2401..bd973ae 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -5,7 +5,7 @@ from async_generator import async_generator, yield_ from asgiref.sync import async_to_sync -from channels_redis.core import ChannelFull, RedisChannelLayer +from channels_redis.core import ChannelFull, RedisChannelLayer, ReceiveBuffer TEST_HOSTS = [("localhost", 6379)] @@ -343,3 +343,68 @@ async def test_receive_cancel(channel_layer): await asyncio.wait_for(task, None) except asyncio.CancelledError: pass + + +@pytest.mark.asyncio +async def test_receive_multiple_specific_prefixes(channel_layer): + """ + Makes sure we receive on multiple real channels + """ + channel_layer = RedisChannelLayer(capacity=10) + channel1 = await channel_layer.new_channel() + channel2 = await channel_layer.new_channel(prefix="thing") + r1, _, r2 = tasks = [ + asyncio.ensure_future(x) + for x in ( + channel_layer.receive(channel1), + channel_layer.send(channel2, {"type": "message"}), + channel_layer.receive(channel2), + ) + ] + await asyncio.wait(tasks, timeout=0.5) + + assert not r1.done() + assert r2.done() and r2.result()["type"] == "message" + r1.cancel() + + +@pytest.mark.asyncio +async def test_buffer_wrong_channel(channel_layer): + async def dummy_receive(channel): + return channel, {"type": "message"} + + buffer = ReceiveBuffer(dummy_receive, "whatever!") + buffer.loop = asyncio.get_event_loop() + with pytest.raises(AssertionError): + buffer.get("wrong!13685sjmh") + + +@pytest.mark.asyncio +async def test_buffer_receiver_stopped(channel_layer): + async def dummy_receive(channel): + return "whatever!meh", {"type": "message"} + + buffer = ReceiveBuffer(dummy_receive, "whatever!") + buffer.loop = asyncio.get_event_loop() + + await buffer.get("whatever!meh") + assert buffer.receiver is None + + +@pytest.mark.asyncio +async def test_buffer_receiver_canceled(channel_layer): + async def dummy_receive(channel): + await asyncio.sleep(2) + return "whatever!meh", {"type": "message"} + + buffer = ReceiveBuffer(dummy_receive, "whatever!") + buffer.loop = asyncio.get_event_loop() + + get1 = buffer.get("whatever!meh") + assert buffer.receiver is not None + get2 = buffer.get("whatever!meh2") + get1.cancel() + assert buffer.receiver is not None + get2.cancel() + await asyncio.sleep(0.1) + assert buffer.receiver is None From 916ee02f903784cc3c93003794bbad479ba22dbd Mon Sep 17 00:00:00 2001 From: martins Date: Tue, 30 Jul 2019 13:46:49 +0300 Subject: [PATCH 4/5] oops --- tests/test_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_core.py b/tests/test_core.py index bd973ae..9d9fbe8 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -5,7 +5,7 @@ from async_generator import async_generator, yield_ from asgiref.sync import async_to_sync -from channels_redis.core import ChannelFull, RedisChannelLayer, ReceiveBuffer +from channels_redis.core import ChannelFull, ReceiveBuffer, RedisChannelLayer TEST_HOSTS = [("localhost", 6379)] From 84de709c6781dd87f24bed244ccf6668473584c9 Mon Sep 17 00:00:00 2001 From: martins Date: Fri, 30 Aug 2019 10:48:03 +0300 Subject: [PATCH 5/5] removes size limit on ReceiveBuffer.buffers --- channels_redis/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/channels_redis/core.py b/channels_redis/core.py index 40d0e79..adca5c9 100644 --- a/channels_redis/core.py +++ b/channels_redis/core.py @@ -179,7 +179,7 @@ def __init__(self, receive_single, real_channel): self.real_channel = real_channel self.receive_single = receive_single self.getters = collections.defaultdict(collections.deque) - self.buffers = collections.defaultdict(lambda: collections.deque(maxlen=20)) + self.buffers = collections.defaultdict(collections.deque) self.receiver = None def __bool__(self):