Skip to content
This repository has been archived by the owner on Feb 21, 2023. It is now read-only.

Support for async callbacks #1284

Merged
merged 6 commits into from
Feb 4, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/1284.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Using coroutines as callbacks is now possible.
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Sean Stewart <seandstewart>
Sergey Miletskiy
SeungHyun Hwang
Stanislau Arkhipenka
Stephan Meier
Taku Fukada
Taras Voinarovskyi
Thanos Lefteris
Expand Down
11 changes: 7 additions & 4 deletions aioredis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4157,7 +4157,7 @@ def unsubscribe(self, *args) -> Awaitable:
async def listen(self) -> AsyncIterator:
"""Listen for messages on channels this client has been subscribed to"""
while self.subscribed:
response = self.handle_message(await self.parse_response(block=True))
response = await self.handle_message(await self.parse_response(block=True))
if response is not None:
yield response

Expand All @@ -4173,7 +4173,7 @@ async def get_message(
"""
response = await self.parse_response(block=False, timeout=timeout)
if response:
return self.handle_message(response, ignore_subscribe_messages)
return await self.handle_message(response, ignore_subscribe_messages)
return None

def ping(self, message=None) -> Awaitable:
Expand All @@ -4183,7 +4183,7 @@ def ping(self, message=None) -> Awaitable:
message = "" if message is None else message
return self.execute_command("PING", message)

def handle_message(self, response, ignore_subscribe_messages=False):
async def handle_message(self, response, ignore_subscribe_messages=False):
"""
Parses a pub/sub message. If the channel or pattern was subscribed to
with a message handler, the handler is invoked instead of a parsed
Expand Down Expand Up @@ -4232,7 +4232,10 @@ def handle_message(self, response, ignore_subscribe_messages=False):
else:
handler = self.channels.get(message["channel"], None)
if handler:
handler(message)
if inspect.iscoroutinefunction(handler):
await handler(message)
else:
handler(message)
return None
elif message_type != "pong":
# this is a subscribe/unsubscribe message. ignore if we don't
Expand Down
11 changes: 11 additions & 0 deletions tests/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,9 @@ def setup_method(self, method):
def message_handler(self, message):
self.message = message

async def async_message_handler(self, message):
self.message = message

async def test_published_message_to_channel(self, r):
p = r.pubsub()
await p.subscribe("foo")
Expand Down Expand Up @@ -311,6 +314,14 @@ async def test_channel_message_handler(self, r):
assert await wait_for_message(p) is None
assert self.message == make_message("message", "foo", "test message")

async def test_channel_async_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
await p.subscribe(foo=self.async_message_handler)
assert await wait_for_message(p) is None
assert await r.publish("foo", "test message") == 1
assert await wait_for_message(p) is None
assert self.message == make_message("message", "foo", "test message")

async def test_pattern_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
await p.psubscribe(**{"f*": self.message_handler})
Expand Down