From 0724304319706af3ea6aa5ed06e84baf65b680b1 Mon Sep 17 00:00:00 2001 From: Stephan Meier Date: Wed, 2 Feb 2022 12:05:24 +0100 Subject: [PATCH 1/6] add test for async callback --- tests/test_pubsub.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 935f9cae5..c782ac2fa 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -270,6 +270,9 @@ def setup_method(self, method): def message_handler(self, message): self.message = message + async def async_message_handler(self, message): + self.message = message + async def test_published_message_to_channel(self, r): p = r.pubsub() await p.subscribe("foo") @@ -311,6 +314,14 @@ async def test_channel_message_handler(self, r): assert await wait_for_message(p) is None assert self.message == make_message("message", "foo", "test message") + async def test_channel_async_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe(foo=self.async_message_handler) + assert await wait_for_message(p) is None + assert await r.publish("foo", "test message") == 1 + assert await wait_for_message(p) is None + assert self.message == make_message("message", "foo", "test message") + async def test_pattern_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) await p.psubscribe(**{"f*": self.message_handler}) From d83cb29089b069971d31bc065946668f17959ac3 Mon Sep 17 00:00:00 2001 From: Stephan Meier Date: Wed, 2 Feb 2022 12:06:16 +0100 Subject: [PATCH 2/6] add support for async callbacks --- aioredis/client.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/aioredis/client.py b/aioredis/client.py index 3b5e091d5..7e0284e09 100644 --- a/aioredis/client.py +++ b/aioredis/client.py @@ -4157,7 +4157,7 @@ def unsubscribe(self, *args) -> Awaitable: async def listen(self) -> AsyncIterator: """Listen for messages on channels this client has been subscribed to""" while self.subscribed: - response = self.handle_message(await self.parse_response(block=True)) + response = await self.handle_message(await self.parse_response(block=True)) if response is not None: yield response @@ -4173,7 +4173,7 @@ async def get_message( """ response = await self.parse_response(block=False, timeout=timeout) if response: - return self.handle_message(response, ignore_subscribe_messages) + return await self.handle_message(response, ignore_subscribe_messages) return None def ping(self, message=None) -> Awaitable: @@ -4183,7 +4183,7 @@ def ping(self, message=None) -> Awaitable: message = "" if message is None else message return self.execute_command("PING", message) - def handle_message(self, response, ignore_subscribe_messages=False): + async def handle_message(self, response, ignore_subscribe_messages=False): """ Parses a pub/sub message. If the channel or pattern was subscribed to with a message handler, the handler is invoked instead of a parsed @@ -4232,7 +4232,10 @@ def handle_message(self, response, ignore_subscribe_messages=False): else: handler = self.channels.get(message["channel"], None) if handler: - handler(message) + if inspect.iscoroutinefunction(handler): + await handler(message) + else: + handler(message) return None elif message_type != "pong": # this is a subscribe/unsubscribe message. ignore if we don't From eab542282e4ead99d334dabc59717a1809537372 Mon Sep 17 00:00:00 2001 From: Stephan Meier Date: Wed, 2 Feb 2022 12:15:54 +0100 Subject: [PATCH 3/6] update contributors --- CONTRIBUTORS.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 3a4e5e628..36953225a 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -63,6 +63,7 @@ Sean Stewart Sergey Miletskiy SeungHyun Hwang Stanislau Arkhipenka +Stephan Meier Taku Fukada Taras Voinarovskyi Thanos Lefteris From 2a53f06b0744586d3fb2b98a70789eaff03ce940 Mon Sep 17 00:00:00 2001 From: Stephan Meier Date: Wed, 2 Feb 2022 12:27:42 +0100 Subject: [PATCH 4/6] add news fragment --- CHANGES/1284.feature | 1 + 1 file changed, 1 insertion(+) create mode 100644 CHANGES/1284.feature diff --git a/CHANGES/1284.feature b/CHANGES/1284.feature new file mode 100644 index 000000000..567e90c32 --- /dev/null +++ b/CHANGES/1284.feature @@ -0,0 +1 @@ +Using coroutines as callbacks is now possible. From de40317e03fb9f2bcd0f91be6cf1eee658fd7274 Mon Sep 17 00:00:00 2001 From: Stephan Meier Date: Thu, 3 Feb 2022 14:03:48 +0100 Subject: [PATCH 5/6] add combined sync + async callback test --- tests/test_pubsub.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index c782ac2fa..85c004659 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -271,7 +271,7 @@ def message_handler(self, message): self.message = message async def async_message_handler(self, message): - self.message = message + self.async_message = message async def test_published_message_to_channel(self, r): p = r.pubsub() @@ -320,7 +320,18 @@ async def test_channel_async_message_handler(self, r): assert await wait_for_message(p) is None assert await r.publish("foo", "test message") == 1 assert await wait_for_message(p) is None + assert self.async_message == make_message("message", "foo", "test message") + + async def test_channel_sync_async_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe(foo=self.message_handler) + await p.subscribe(bar=self.async_message_handler) + assert await wait_for_message(p) is None + assert await r.publish("foo", "test message") == 1 + assert await r.publish("bar", "test message 2") == 1 + assert await wait_for_message(p) is None assert self.message == make_message("message", "foo", "test message") + assert self.async_message == make_message("message", "bar", "test message 2") async def test_pattern_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) From 359a741286017a3defaf982545c255d277b6cb5c Mon Sep 17 00:00:00 2001 From: Stephan Meier Date: Thu, 3 Feb 2022 17:09:41 +0100 Subject: [PATCH 6/6] fix linter error --- aioredis/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aioredis/client.py b/aioredis/client.py index 7e0284e09..4f71e90cd 100644 --- a/aioredis/client.py +++ b/aioredis/client.py @@ -74,7 +74,7 @@ AnyKeyT = TypeVar("AnyKeyT", bytes, str, memoryview) AnyFieldT = TypeVar("AnyFieldT", bytes, str, memoryview) AnyChannelT = ChannelT -PubSubHandler = Callable[[Dict[str, str]], None] +PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]] SYM_EMPTY = b"" EMPTY_RESPONSE = "EMPTY_RESPONSE"