Skip to content
This repository has been archived by the owner on Feb 21, 2023. It is now read-only.

Commit

Permalink
Merge pull request #1284 from stephanm/async_callback
Browse files Browse the repository at this point in the history
Support for async callbacks
  • Loading branch information
seandstewart authored Feb 4, 2022
2 parents 8a38f98 + 359a741 commit 659a14d
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGES/1284.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Using coroutines as callbacks is now possible.
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Sean Stewart <seandstewart>
Sergey Miletskiy
SeungHyun Hwang
Stanislau Arkhipenka
Stephan Meier
Taku Fukada
Taras Voinarovskyi
Thanos Lefteris
Expand Down
13 changes: 8 additions & 5 deletions aioredis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
AnyKeyT = TypeVar("AnyKeyT", bytes, str, memoryview)
AnyFieldT = TypeVar("AnyFieldT", bytes, str, memoryview)
AnyChannelT = ChannelT
PubSubHandler = Callable[[Dict[str, str]], None]
PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]

SYM_EMPTY = b""
EMPTY_RESPONSE = "EMPTY_RESPONSE"
Expand Down Expand Up @@ -4176,7 +4176,7 @@ def unsubscribe(self, *args) -> Awaitable:
async def listen(self) -> AsyncIterator:
"""Listen for messages on channels this client has been subscribed to"""
while self.subscribed:
response = self.handle_message(await self.parse_response(block=True))
response = await self.handle_message(await self.parse_response(block=True))
if response is not None:
yield response

Expand All @@ -4192,7 +4192,7 @@ async def get_message(
"""
response = await self.parse_response(block=False, timeout=timeout)
if response:
return self.handle_message(response, ignore_subscribe_messages)
return await self.handle_message(response, ignore_subscribe_messages)
return None

def ping(self, message=None) -> Awaitable:
Expand All @@ -4202,7 +4202,7 @@ def ping(self, message=None) -> Awaitable:
message = "" if message is None else message
return self.execute_command("PING", message)

def handle_message(self, response, ignore_subscribe_messages=False):
async def handle_message(self, response, ignore_subscribe_messages=False):
"""
Parses a pub/sub message. If the channel or pattern was subscribed to
with a message handler, the handler is invoked instead of a parsed
Expand Down Expand Up @@ -4251,7 +4251,10 @@ def handle_message(self, response, ignore_subscribe_messages=False):
else:
handler = self.channels.get(message["channel"], None)
if handler:
handler(message)
if inspect.iscoroutinefunction(handler):
await handler(message)
else:
handler(message)
return None
elif message_type != "pong":
# this is a subscribe/unsubscribe message. ignore if we don't
Expand Down
22 changes: 22 additions & 0 deletions tests/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,9 @@ def setup_method(self, method):
def message_handler(self, message):
self.message = message

async def async_message_handler(self, message):
self.async_message = message

async def test_published_message_to_channel(self, r):
p = r.pubsub()
await p.subscribe("foo")
Expand Down Expand Up @@ -311,6 +314,25 @@ async def test_channel_message_handler(self, r):
assert await wait_for_message(p) is None
assert self.message == make_message("message", "foo", "test message")

async def test_channel_async_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
await p.subscribe(foo=self.async_message_handler)
assert await wait_for_message(p) is None
assert await r.publish("foo", "test message") == 1
assert await wait_for_message(p) is None
assert self.async_message == make_message("message", "foo", "test message")

async def test_channel_sync_async_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
await p.subscribe(foo=self.message_handler)
await p.subscribe(bar=self.async_message_handler)
assert await wait_for_message(p) is None
assert await r.publish("foo", "test message") == 1
assert await r.publish("bar", "test message 2") == 1
assert await wait_for_message(p) is None
assert self.message == make_message("message", "foo", "test message")
assert self.async_message == make_message("message", "bar", "test message 2")

async def test_pattern_message_handler(self, r):
p = r.pubsub(ignore_subscribe_messages=True)
await p.psubscribe(**{"f*": self.message_handler})
Expand Down

0 comments on commit 659a14d

Please sign in to comment.