diff --git a/CHANGES/1284.feature b/CHANGES/1284.feature new file mode 100644 index 000000000..567e90c32 --- /dev/null +++ b/CHANGES/1284.feature @@ -0,0 +1 @@ +Using coroutines as callbacks is now possible. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 3a4e5e628..36953225a 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -63,6 +63,7 @@ Sean Stewart Sergey Miletskiy SeungHyun Hwang Stanislau Arkhipenka +Stephan Meier Taku Fukada Taras Voinarovskyi Thanos Lefteris diff --git a/aioredis/client.py b/aioredis/client.py index 3b5e091d5..4f71e90cd 100644 --- a/aioredis/client.py +++ b/aioredis/client.py @@ -74,7 +74,7 @@ AnyKeyT = TypeVar("AnyKeyT", bytes, str, memoryview) AnyFieldT = TypeVar("AnyFieldT", bytes, str, memoryview) AnyChannelT = ChannelT -PubSubHandler = Callable[[Dict[str, str]], None] +PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]] SYM_EMPTY = b"" EMPTY_RESPONSE = "EMPTY_RESPONSE" @@ -4157,7 +4157,7 @@ def unsubscribe(self, *args) -> Awaitable: async def listen(self) -> AsyncIterator: """Listen for messages on channels this client has been subscribed to""" while self.subscribed: - response = self.handle_message(await self.parse_response(block=True)) + response = await self.handle_message(await self.parse_response(block=True)) if response is not None: yield response @@ -4173,7 +4173,7 @@ async def get_message( """ response = await self.parse_response(block=False, timeout=timeout) if response: - return self.handle_message(response, ignore_subscribe_messages) + return await self.handle_message(response, ignore_subscribe_messages) return None def ping(self, message=None) -> Awaitable: @@ -4183,7 +4183,7 @@ def ping(self, message=None) -> Awaitable: message = "" if message is None else message return self.execute_command("PING", message) - def handle_message(self, response, ignore_subscribe_messages=False): + async def handle_message(self, response, ignore_subscribe_messages=False): """ Parses a pub/sub message. If the channel or pattern was subscribed to with a message handler, the handler is invoked instead of a parsed @@ -4232,7 +4232,10 @@ def handle_message(self, response, ignore_subscribe_messages=False): else: handler = self.channels.get(message["channel"], None) if handler: - handler(message) + if inspect.iscoroutinefunction(handler): + await handler(message) + else: + handler(message) return None elif message_type != "pong": # this is a subscribe/unsubscribe message. ignore if we don't diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 935f9cae5..85c004659 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -270,6 +270,9 @@ def setup_method(self, method): def message_handler(self, message): self.message = message + async def async_message_handler(self, message): + self.async_message = message + async def test_published_message_to_channel(self, r): p = r.pubsub() await p.subscribe("foo") @@ -311,6 +314,25 @@ async def test_channel_message_handler(self, r): assert await wait_for_message(p) is None assert self.message == make_message("message", "foo", "test message") + async def test_channel_async_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe(foo=self.async_message_handler) + assert await wait_for_message(p) is None + assert await r.publish("foo", "test message") == 1 + assert await wait_for_message(p) is None + assert self.async_message == make_message("message", "foo", "test message") + + async def test_channel_sync_async_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe(foo=self.message_handler) + await p.subscribe(bar=self.async_message_handler) + assert await wait_for_message(p) is None + assert await r.publish("foo", "test message") == 1 + assert await r.publish("bar", "test message 2") == 1 + assert await wait_for_message(p) is None + assert self.message == make_message("message", "foo", "test message") + assert self.async_message == make_message("message", "bar", "test message 2") + async def test_pattern_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) await p.psubscribe(**{"f*": self.message_handler})