-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Async/Await Prototype #70
base: master
Are you sure you want to change the base?
Changes from all commits
24af4e6
27e77c3
69b3b7e
27a4e4e
a53d567
6dd3e73
f109600
09af3ec
5fe1f1f
56e01b4
ca1c613
f6a0169
a6a7c48
1776fa9
5786fee
da1dee7
9c8624c
73e9e82
da05359
cbf0062
995d3d8
d63353f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,8 +7,10 @@ dist/ | |
# virtualenvs | ||
env/ | ||
pyenv/ | ||
.venv/ | ||
|
||
# pytest | ||
.coverage | ||
.pytest_cache/ | ||
htmlcov/ | ||
.mypy_cache/ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import asyncio | ||
import collections | ||
from typing import Optional, Tuple, Callable | ||
|
||
from .client import Client, Message | ||
from .message_queue import SubscriptionManager | ||
|
||
|
||
class MqttClientWrapper: | ||
""" | ||
Wraps the callback-based client in a async/await API | ||
""" | ||
|
||
def __init__( | ||
self, | ||
inner_client: Client, | ||
receive_maximum: Optional[int] = None, | ||
loop: Optional[asyncio.AbstractEventLoop] = None, | ||
): | ||
if loop is None: | ||
loop = asyncio.get_event_loop() | ||
self.loop = loop | ||
|
||
self.client = inner_client | ||
|
||
receive_maximum = receive_maximum or 65665 # FIXME: Sane default? | ||
|
||
self.subscription_manager = SubscriptionManager(receive_maximum) | ||
# self.message_queue = asyncio.Queue(maxsize=receive_maximum or 0) | ||
# JKelf._subscriptions = collections.defaultdict(set) | ||
self._init_client() | ||
|
||
def _init_client(self): | ||
"""Set up client so messages are forwarded to registered subscriptions:""" | ||
|
||
def _on_message(client, topic, payload, qos, properties): | ||
liamdiprose marked this conversation as resolved.
Show resolved
Hide resolved
|
||
message = Message(client=client, topic=topic, payload=payload, **properties) | ||
self.subscription_manager.on_message(message) | ||
|
||
self.client.on_message = _on_message | ||
|
||
async def publish(self, topic: str, message: Message, qos=0): | ||
"""Publish a message to the MQTT topic""" | ||
self.client.publish(topic, message, qos) | ||
|
||
class Subscribe(collections.abc.Awaitable): | ||
def __init__(self, client_wrapper, topic, qos=0): | ||
self.topic = topic | ||
self.qos = qos | ||
self.client_wrapper = client_wrapper | ||
|
||
def __await__(self): | ||
return self.__await_impl__().__await__() | ||
|
||
async def __await_impl__(self): | ||
subscription = await self.client_wrapper.subscription_manager.add_subscription( | ||
self.topic | ||
) | ||
self.client_wrapper.client.subscribe(self.topic, qos=self.qos) | ||
|
||
return subscription | ||
|
||
async def _unsubscribe(self): | ||
self.client_wrapper.client.unsubscribe(self.topic) | ||
# TODO: Await unsubscribe callback | ||
|
||
async def __aenter__(self): | ||
return await self | ||
|
||
async def __aexit__(self, exc_type, exc_value, traceback): | ||
# TODO: Wait for unsubscribe callback (future) | ||
await self._unsubscribe() | ||
|
||
def subscribe(self, topic, qos=0) -> collections.abc.Awaitable: | ||
"""Subscribe the client to a topic""" | ||
|
||
# Developers Notes: | ||
# This returns an awaitable object (`Subscribe`) that sets up the `Subscription`. | ||
|
||
client_wrapper = self | ||
return self.Subscribe(client_wrapper, topic=topic, qos=qos) | ||
|
||
|
||
class Connect: | ||
""" | ||
An async context manager that provides a connected MQTT Client. | ||
Responsible for setting up and tearing down the client. | ||
|
||
>>> async with connect('iot.eclipse.org') as client: | ||
>>> await client.publish('test/message', 'hello world', qos=1) | ||
|
||
>>> client = await connect('iot.eclipse.org') | ||
>>> await client.publish('test/message', 'hello world', qos=1) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
broker_host: str, | ||
broker_port: int = 1883, | ||
client_id: Optional[str] = None, | ||
clean_session=True, | ||
loop: Optional[asyncio.AbstractEventLoop] = None, | ||
receive_maximum: Optional[int] = None, | ||
): | ||
|
||
self.loop = loop or asyncio.get_event_loop() | ||
|
||
client_args = {} | ||
if receive_maximum: | ||
client_args["receive_maximum"] = receive_maximum | ||
|
||
self.client = Client(client_id, clean_session=clean_session, **client_args) | ||
self.broker_host = broker_host | ||
self.broker_port = broker_port | ||
self._connect_future = self.loop.create_future() | ||
self._disconnect_future = self.loop.create_future() | ||
self._receive_maximum = receive_maximum | ||
|
||
async def _connect(self) -> MqttClientWrapper: | ||
def _on_connect(client, flags, rc, properties): | ||
self._connect_future.set_result(client) | ||
|
||
self.client.on_connect = _on_connect | ||
|
||
await self.client.connect(self.broker_host, self.broker_port) | ||
return await self._connect_future | ||
|
||
async def _disconnect(self): | ||
def _on_disconnect(client, packet, exc=None): | ||
self._disconnect_future.set_result(packet) | ||
|
||
self.client.on_disconnect = _on_disconnect | ||
await self.client.disconnect() | ||
return await self._disconnect_future | ||
|
||
def __await__(self): | ||
return self.__await_impl__().__await__() | ||
|
||
async def __await_impl__(self): | ||
client = await self._connect() | ||
return MqttClientWrapper( | ||
client, loop=self.loop, receive_maximum=self._receive_maximum | ||
) | ||
|
||
async def __aenter__(self) -> MqttClientWrapper: | ||
return await self | ||
|
||
async def __aexit__(self, exc_type, exc_value, traceback): | ||
await self._disconnect() | ||
|
||
|
||
# Make the context manager look like a function | ||
connect = Connect |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import asyncio | ||
import enum | ||
import functools | ||
import collections | ||
|
||
# from .aioclient import Subscription | ||
from .client import Message | ||
|
||
from typing import List, Tuple, Dict, Callable | ||
|
||
|
||
class TopicFilter: | ||
def __init__(self, topic_filter: str): | ||
self.levels = topic_filter.split("/") | ||
TopicFilter.validate_topic(self.levels) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you have some reason to call It will generate some unexpected behavior if you will inherit from it (if you will inherit and will want to overwrite validate_topic, also you will need to overwrite
|
||
|
||
@staticmethod | ||
def validate_topic(filter_levels: List[str]): | ||
if filter_levels[-1][-1] == "#": | ||
if len(filter_levels[-1]) > 1: | ||
raise ValueError("Multi-level wildcard must be on its own level") | ||
|
||
if "#" in "".join(filter_levels[:-1]): | ||
raise ValueError( | ||
"Multi-level wildcard must be at the end of the topic filter" | ||
) | ||
|
||
for level in filter_levels: | ||
if len(level) > 1: | ||
if "+" in level: | ||
raise ValueError("Single-level wildcard (+) only allowed by itself") | ||
|
||
def match(self, topic: str) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you think about using regexp ? Replacements: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I am considering switching this to a regexp - I wonder how much faster it'll be. I'm also considering using a tree structure to store the subscriptions, which I think will find all matching subscriptions with less work. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tree structure will give us a speed boost if only you will have hundred of subscriptions. For a less then ten (in my opinion) it will be only overhead for iterating and managing this structure in memory. But it's interesting and you may implement it and compare with |
||
topic_levels = topic.split("/") | ||
|
||
for filter_level, topic_level in zip(self.levels, topic_levels): | ||
if filter_level == "+": | ||
continue | ||
elif filter_level == "#": | ||
return True | ||
else: | ||
if filter_level != topic_level: | ||
return False | ||
return True | ||
|
||
def __hash__(self): | ||
return hash("/".join(self.levels)) | ||
|
||
|
||
class Subscription: | ||
def __init__(self, message_queue: asyncio.Queue, on_unsubscribe: Callable): | ||
self._incoming_messages = message_queue | ||
self._on_unsubscribe = on_unsubscribe | ||
|
||
async def recv(self): | ||
"""Receive the next message published to this subscription""" | ||
message = await self._incoming_messages.get() | ||
# TODO: Hold off sending PUBACK for `message` until this point | ||
return message | ||
|
||
def __aiter__(self): | ||
return self | ||
|
||
async def __anext__(self): | ||
return await self.recv() | ||
|
||
async def unsubscribe(self): | ||
await self._on_unsubscribe() | ||
|
||
def _add_message(self, message: Message): | ||
self._incoming_messages.put_nowait(message) | ||
|
||
|
||
class DropPolicy(enum.Enum): | ||
OLDEST_FIRST = enum.auto() | ||
|
||
|
||
class SubscriptionManager: | ||
""" | ||
Manages incoming messages and the downstream subscription objects. | ||
|
||
Handles: | ||
- Maximum Queue size and message drop policy | ||
- Routing messages to subscriptions based on topic (possibily with wildcards) | ||
""" | ||
|
||
def __init__( | ||
self, receive_maximum: int, drop_policy: DropPolicy = DropPolicy.OLDEST_FIRST | ||
): | ||
self.receive_maximum = receive_maximum | ||
self.drop_policy = drop_policy | ||
self.subs: Dict[TopicFilter, List["asyncio.Queue"]] = collections.defaultdict( | ||
list | ||
) | ||
self.size = 0 | ||
|
||
async def add_subscription(self, topic_filter_str: str) -> Subscription: | ||
topic_filter = TopicFilter(topic_filter_str) | ||
|
||
subscribed_messages = asyncio.Queue(self.receive_maximum) | ||
sub_id = (topic_filter, len(self.subs[topic_filter])) | ||
self.subs[topic_filter].append(subscribed_messages) | ||
|
||
return Subscription( | ||
message_queue=subscribed_messages, | ||
on_unsubscribe=functools.partial(self.remove_subscription, sub_id=sub_id), | ||
) | ||
|
||
async def remove_subscription(self, sub_id: Tuple[TopicFilter, int]): | ||
topic_filter, idx = sub_id | ||
|
||
# TODO: Unsubscribe on underlying client | ||
# Note: Don't remove so indexes are preserved | ||
self.subs[topic_filter][idx] = None | ||
|
||
def on_message(self, message: Message): | ||
# TODO: check self.size | ||
|
||
# if over, attempt to drop qos=0 packet using `drop_policy` | ||
|
||
# if under, add to appropiate queues: | ||
for subscription_topic in self.subs: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's all good, but we can reduce count of indents and make the code more beautiful. for topic in filter(lambda x: x.match(message.topic)):
for sub in self.subs[topic]:
subs.put_nowait(message) or for subscription_topic in self.subs:
if not subscription_topic.match(message.topic):
continue
for subscription in self.subs[subscription_topic]:
subscription.put_nowait(message) |
||
if subscription_topic.match(message.topic): | ||
for subscription in self.subs[subscription_topic]: | ||
subscription.put_nowait(message) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import pytest | ||
from unittest.mock import MagicMock | ||
|
||
import gmqtt | ||
from gmqtt import aioclient, message_queue | ||
from gmqtt.message_queue import TopicFilter | ||
from gmqtt.aioclient import MqttClientWrapper | ||
from gmqtt.client import Message | ||
|
||
# TODO: Fixtures | ||
# Mock client | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_plain_subscription(): | ||
sm = message_queue.SubscriptionManager(999) | ||
sub = await sm.add_subscription("topic/TEST") | ||
|
||
message = gmqtt.Message(topic="topic/TEST", payload="payload") | ||
sm.on_message(message) | ||
|
||
received = await sub.recv() | ||
assert received.topic == "topic/TEST" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_wildcard_subscription(): | ||
sm = message_queue.SubscriptionManager(999) | ||
sub = await sm.add_subscription("topic/+") | ||
|
||
message = gmqtt.Message(topic="topic/TEST", payload="payload") | ||
sm.on_message(message) | ||
|
||
received = await sub.recv() | ||
assert received.topic == "topic/TEST" | ||
|
||
|
||
def test_match_topic_filter(): | ||
tf = TopicFilter("topic/TEST") | ||
assert tf.match("topic/TEST") | ||
assert tf.match("topic/FOO") == False | ||
|
||
|
||
def test_match_topic_filter_multilevel_wildcard(): | ||
tf = TopicFilter("topic/#") | ||
|
||
assert tf.match("topic/TEST/1") | ||
assert tf.match("topic/1") | ||
assert tf.match("topic/") | ||
assert tf.match("topic") | ||
|
||
|
||
def test_invalid_multilevel_wildcard(): | ||
|
||
with pytest.raises(ValueError) as exc: | ||
TopicFilter("sport/tennis/#/ranking") | ||
|
||
with pytest.raises(ValueError): | ||
TopicFilter("#/tailing") | ||
|
||
with pytest.raises(ValueError): | ||
TopicFilter("sport/tennis#") | ||
|
||
|
||
@pytest.fixture | ||
def client(): | ||
mock_inner = MagicMock() | ||
mock_inner.subscribe.return_value = None | ||
mock_inner.on_message = None | ||
|
||
wrapper_client = MqttClientWrapper(mock_inner) | ||
|
||
return wrapper_client, mock_inner | ||
|
||
@pytest.mark.asyncio | ||
async def test_multiple_subs_duplicate_messages(client): | ||
client, mocked_inner = client | ||
|
||
sub1 = await client.subscribe("test/test") | ||
sub2 = await client.subscribe("test/test") | ||
|
||
message = Message(topic="test/test", payload="payload") | ||
mocked_inner.on_message(client=mocked_inner, topic=message.topic, payload=message.payload, qos=message.qos, properties={}) | ||
|
||
msg1 = await sub1.recv() | ||
msg2 = await sub2.recv() | ||
|
||
assert msg1.topic == "test/test" | ||
assert msg2.topic == "test/test" | ||
assert msg1.payload == b"payload" | ||
assert msg2.payload == b"payload" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's remove all old comments