Skip to content
This repository has been archived by the owner on Feb 21, 2023. It is now read-only.

Support for async callbacks #1284

Merged
merged 6 commits into from
Feb 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/1284.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Using coroutines as callbacks is now possible.
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Sean Stewart <seandstewart>
Sergey Miletskiy
SeungHyun Hwang
Stanislau Arkhipenka
Stephan Meier
Taku Fukada
Taras Voinarovskyi
Thanos Lefteris
Expand Down
13 changes: 8 additions & 5 deletions aioredis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
AnyKeyT = TypeVar("AnyKeyT", bytes, str, memoryview)
AnyFieldT = TypeVar("AnyFieldT", bytes, str, memoryview)
AnyChannelT = ChannelT
PubSubHandler = Callable[[Dict[str, str]], None]
PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]

SYM_EMPTY = b""
EMPTY_RESPONSE = "EMPTY_RESPONSE"
Expand Down Expand Up @@ -4157,7 +4157,7 @@ def unsubscribe(self, *args) -> Awaitable:
async def listen(self) -> AsyncIterator:
"""Listen for messages on channels this client has been subscribed to"""
while self.subscribed:
response = self.handle_message(await self.parse_response(block=True))
response = await self.handle_message(await self.parse_response(block=True))
if response is not None:
yield response

Expand All @@ -4173,7 +4173,7 @@ async def get_message(
"""
response = await self.parse_response(block=False, timeout=timeout)
if response:
return self.handle_message(response, ignore_subscribe_messages)
return await self.handle_message(response, ignore_subscribe_messages)
return None

def ping(self, message=None) -> Awaitable:
Expand All @@ -4183,7 +4183,7 @@ def ping(self, message=None) -> Awaitable:
message = "" if message is None else message
return self.execute_command("PING", message)

def handle_message(self, response, ignore_subscribe_messages=False):
async def handle_message(self, response, ignore_subscribe_messages=False):
"""
Parses a pub/sub message. If the channel or pattern was subscribed to
with a message handler, the handler is invoked instead of a parsed
Expand Down Expand Up @@ -4232,7 +4232,10 @@ def handle_message(self, response, ignore_subscribe_messages=False):
else:
handler = self.channels.get(message["channel"], None)
if handler:
handler(message)
if inspect.iscoroutinefunction(handler):
await handler(message)
else:
handler(message)
return None
elif message_type != "pong":
# this is a subscribe/unsubscribe message. ignore if we don't
Expand Down
22 changes: 22 additions & 0 deletions tests/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,9 @@ def setup_method(self, method):
def message_handler(self, message):
self.message = message

async def async_message_handler(self, message):
self.async_message = message

async def test_published_message_to_channel(self, r):
p = r.pubsub()
await p.subscribe("foo")
Expand Down Expand Up @@ -311,6 +314,25 @@ async def test_channel_message_handler(self, r):
assert await wait_for_message(p) is None
assert self.message == make_message("message", "foo", "test message")

async def test_channel_async_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
await p.subscribe(foo=self.async_message_handler)
assert await wait_for_message(p) is None
assert await r.publish("foo", "test message") == 1
assert await wait_for_message(p) is None
assert self.async_message == make_message("message", "foo", "test message")

async def test_channel_sync_async_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
await p.subscribe(foo=self.message_handler)
await p.subscribe(bar=self.async_message_handler)
assert await wait_for_message(p) is None
assert await r.publish("foo", "test message") == 1
assert await r.publish("bar", "test message 2") == 1
assert await wait_for_message(p) is None
assert self.message == make_message("message", "foo", "test message")
assert self.async_message == make_message("message", "bar", "test message 2")

async def test_pattern_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
await p.psubscribe(**{"f*": self.message_handler})
Expand Down