Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Async/Await Prototype #70

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
24af4e6
Initial async-with and await API prototype
Nov 16, 2019
27e77c3
Rename wrapper class with postfix that is used elsewhere in codebase
liamdiprose Nov 18, 2019
69b3b7e
Allow maximum queue size to be set to prevent a possible memory overflow
liamdiprose Nov 18, 2019
27a4e4e
Received messages are now the entire message object, not just the pay…
liamdiprose Nov 18, 2019
a53d567
Finish renaming ClientWrapper class
liamdiprose Nov 18, 2019
6dd3e73
Initial implementation of Subscription object-based API
liamdiprose Nov 19, 2019
f109600
Implement async iterator for subscription
liamdiprose Nov 19, 2019
09af3ec
Make connect context-manager awaitable (see doctest)
Nov 19, 2019
5fe1f1f
Make client_id optional
liamdiprose Nov 26, 2019
56e01b4
Simplify filling in optional parameter
liamdiprose Nov 27, 2019
ca1c613
Add more parameters to connect function
liamdiprose Nov 27, 2019
f6a0169
Git-ignore mypy cache and hidden '.venv' directories
liamdiprose Nov 27, 2019
a6a7c48
Use named argument
liamdiprose Nov 29, 2019
1776fa9
Initial implimentation of subscription manager that conforms to MQTT …
liamdiprose Nov 29, 2019
5786fee
Add subscription manager for handling incoming messages
liamdiprose Dec 4, 2019
da1dee7
Add tests for alternative async/await client API
liamdiprose Dec 4, 2019
9c8624c
Enable typechecking
liamdiprose Dec 4, 2019
73e9e82
Format async/await tests
liamdiprose Dec 4, 2019
da05359
Add tests for expect-duplicated messges
liamdiprose Dec 4, 2019
cbf0062
Finish renaming recv() method
liamdiprose Dec 4, 2019
995d3d8
Add FIXME note
liamdiprose Dec 4, 2019
d63353f
Work with port
liamdiprose Mar 23, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ dist/
# virtualenvs
env/
pyenv/
.venv/

# pytest
.coverage
.pytest_cache/
htmlcov/
.mypy_cache/
153 changes: 153 additions & 0 deletions gmqtt/aioclient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import asyncio
import collections
from typing import Optional, Tuple, Callable

from .client import Client, Message
from .message_queue import SubscriptionManager


class MqttClientWrapper:
"""
Wraps the callback-based client in a async/await API
"""

def __init__(
self,
inner_client: Client,
receive_maximum: Optional[int] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
):
if loop is None:
loop = asyncio.get_event_loop()
self.loop = loop

self.client = inner_client

receive_maximum = receive_maximum or 65665 # FIXME: Sane default?

self.subscription_manager = SubscriptionManager(receive_maximum)
# self.message_queue = asyncio.Queue(maxsize=receive_maximum or 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove all old comments

# JKelf._subscriptions = collections.defaultdict(set)
self._init_client()

def _init_client(self):
"""Set up client so messages are forwarded to registered subscriptions:"""

def _on_message(client, topic, payload, qos, properties):
liamdiprose marked this conversation as resolved.
Show resolved Hide resolved
message = Message(client=client, topic=topic, payload=payload, **properties)
self.subscription_manager.on_message(message)

self.client.on_message = _on_message

async def publish(self, topic: str, message: Message, qos=0):
"""Publish a message to the MQTT topic"""
self.client.publish(topic, message, qos)

class Subscribe(collections.abc.Awaitable):
def __init__(self, client_wrapper, topic, qos=0):
self.topic = topic
self.qos = qos
self.client_wrapper = client_wrapper

def __await__(self):
return self.__await_impl__().__await__()

async def __await_impl__(self):
subscription = await self.client_wrapper.subscription_manager.add_subscription(
self.topic
)
self.client_wrapper.client.subscribe(self.topic, qos=self.qos)

return subscription

async def _unsubscribe(self):
self.client_wrapper.client.unsubscribe(self.topic)
# TODO: Await unsubscribe callback

async def __aenter__(self):
return await self

async def __aexit__(self, exc_type, exc_value, traceback):
# TODO: Wait for unsubscribe callback (future)
await self._unsubscribe()

def subscribe(self, topic, qos=0) -> collections.abc.Awaitable:
"""Subscribe the client to a topic"""

# Developers Notes:
# This returns an awaitable object (`Subscribe`) that sets up the `Subscription`.

client_wrapper = self
return self.Subscribe(client_wrapper, topic=topic, qos=qos)


class Connect:
"""
An async context manager that provides a connected MQTT Client.
Responsible for setting up and tearing down the client.

>>> async with connect('iot.eclipse.org') as client:
>>> await client.publish('test/message', 'hello world', qos=1)

>>> client = await connect('iot.eclipse.org')
>>> await client.publish('test/message', 'hello world', qos=1)
"""

def __init__(
self,
broker_host: str,
broker_port: int = 1883,
client_id: Optional[str] = None,
clean_session=True,
loop: Optional[asyncio.AbstractEventLoop] = None,
receive_maximum: Optional[int] = None,
):

self.loop = loop or asyncio.get_event_loop()

client_args = {}
if receive_maximum:
client_args["receive_maximum"] = receive_maximum

self.client = Client(client_id, clean_session=clean_session, **client_args)
self.broker_host = broker_host
self.broker_port = broker_port
self._connect_future = self.loop.create_future()
self._disconnect_future = self.loop.create_future()
self._receive_maximum = receive_maximum

async def _connect(self) -> MqttClientWrapper:
def _on_connect(client, flags, rc, properties):
self._connect_future.set_result(client)

self.client.on_connect = _on_connect

await self.client.connect(self.broker_host, self.broker_port)
return await self._connect_future

async def _disconnect(self):
def _on_disconnect(client, packet, exc=None):
self._disconnect_future.set_result(packet)

self.client.on_disconnect = _on_disconnect
await self.client.disconnect()
return await self._disconnect_future

def __await__(self):
return self.__await_impl__().__await__()

async def __await_impl__(self):
client = await self._connect()
return MqttClientWrapper(
client, loop=self.loop, receive_maximum=self._receive_maximum
)

async def __aenter__(self) -> MqttClientWrapper:
return await self

async def __aexit__(self, exc_type, exc_value, traceback):
await self._disconnect()


# Make the context manager look like a function
connect = Connect
125 changes: 125 additions & 0 deletions gmqtt/message_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import asyncio
import enum
import functools
import collections

# from .aioclient import Subscription
from .client import Message

from typing import List, Tuple, Dict, Callable


class TopicFilter:
def __init__(self, topic_filter: str):
self.levels = topic_filter.split("/")
TopicFilter.validate_topic(self.levels)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have some reason to call validate_topic by naming TopicFilter?

It will generate some unexpected behavior if you will inherit from it (if you will inherit and will want to overwrite validate_topic, also you will need to overwrite __init__ and call super and call validate_topic manually):

class B(TopicFilter):
    @staticmethod
    def validate_topic(*args):
        print("Never happens")
        return super().validate_topic(*args)

B("my-awesome-topic/bla")


@staticmethod
def validate_topic(filter_levels: List[str]):
if filter_levels[-1][-1] == "#":
if len(filter_levels[-1]) > 1:
raise ValueError("Multi-level wildcard must be on its own level")

if "#" in "".join(filter_levels[:-1]):
raise ValueError(
"Multi-level wildcard must be at the end of the topic filter"
)

for level in filter_levels:
if len(level) > 1:
if "+" in level:
raise ValueError("Single-level wildcard (+) only allowed by itself")

def match(self, topic: str) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about using regexp ?
For example if a user passes a topic ("root/sub1/+/sub3") we can transform it in regexp ("root/sub1/([^/]+)/sub3") and code will be more clear.

Replacements: + -> ([^/]+), # ->(.+)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I am considering switching this to a regexp - I wonder how much faster it'll be.

I'm also considering using a tree structure to store the subscriptions, which I think will find all matching subscriptions with less work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tree structure will give us a speed boost if only you will have hundred of subscriptions. For a less then ten (in my opinion) it will be only overhead for iterating and managing this structure in memory.

But it's interesting and you may implement it and compare with for-loop and matching by regexp implementation.

topic_levels = topic.split("/")

for filter_level, topic_level in zip(self.levels, topic_levels):
if filter_level == "+":
continue
elif filter_level == "#":
return True
else:
if filter_level != topic_level:
return False
return True

def __hash__(self):
return hash("/".join(self.levels))


class Subscription:
def __init__(self, message_queue: asyncio.Queue, on_unsubscribe: Callable):
self._incoming_messages = message_queue
self._on_unsubscribe = on_unsubscribe

async def recv(self):
"""Receive the next message published to this subscription"""
message = await self._incoming_messages.get()
# TODO: Hold off sending PUBACK for `message` until this point
return message

def __aiter__(self):
return self

async def __anext__(self):
return await self.recv()

async def unsubscribe(self):
await self._on_unsubscribe()

def _add_message(self, message: Message):
self._incoming_messages.put_nowait(message)


class DropPolicy(enum.Enum):
OLDEST_FIRST = enum.auto()


class SubscriptionManager:
"""
Manages incoming messages and the downstream subscription objects.

Handles:
- Maximum Queue size and message drop policy
- Routing messages to subscriptions based on topic (possibily with wildcards)
"""

def __init__(
self, receive_maximum: int, drop_policy: DropPolicy = DropPolicy.OLDEST_FIRST
):
self.receive_maximum = receive_maximum
self.drop_policy = drop_policy
self.subs: Dict[TopicFilter, List["asyncio.Queue"]] = collections.defaultdict(
list
)
self.size = 0

async def add_subscription(self, topic_filter_str: str) -> Subscription:
topic_filter = TopicFilter(topic_filter_str)

subscribed_messages = asyncio.Queue(self.receive_maximum)
sub_id = (topic_filter, len(self.subs[topic_filter]))
self.subs[topic_filter].append(subscribed_messages)

return Subscription(
message_queue=subscribed_messages,
on_unsubscribe=functools.partial(self.remove_subscription, sub_id=sub_id),
)

async def remove_subscription(self, sub_id: Tuple[TopicFilter, int]):
topic_filter, idx = sub_id

# TODO: Unsubscribe on underlying client
# Note: Don't remove so indexes are preserved
self.subs[topic_filter][idx] = None

def on_message(self, message: Message):
# TODO: check self.size

# if over, attempt to drop qos=0 packet using `drop_policy`

# if under, add to appropiate queues:
for subscription_topic in self.subs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's all good, but we can reduce count of indents and make the code more beautiful.

for topic in filter(lambda x: x.match(message.topic)):
    for sub in self.subs[topic]:
        subs.put_nowait(message)

or

for subscription_topic in self.subs:
    if not subscription_topic.match(message.topic):
        continue
    for subscription in self.subs[subscription_topic]:
        subscription.put_nowait(message)

if subscription_topic.match(message.topic):
for subscription in self.subs[subscription_topic]:
subscription.put_nowait(message)
Empty file added gmqtt/py.typed
Empty file.
91 changes: 91 additions & 0 deletions tests/test_aioclient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import pytest
from unittest.mock import MagicMock

import gmqtt
from gmqtt import aioclient, message_queue
from gmqtt.message_queue import TopicFilter
from gmqtt.aioclient import MqttClientWrapper
from gmqtt.client import Message

# TODO: Fixtures
# Mock client


@pytest.mark.asyncio
async def test_plain_subscription():
sm = message_queue.SubscriptionManager(999)
sub = await sm.add_subscription("topic/TEST")

message = gmqtt.Message(topic="topic/TEST", payload="payload")
sm.on_message(message)

received = await sub.recv()
assert received.topic == "topic/TEST"


@pytest.mark.asyncio
async def test_wildcard_subscription():
sm = message_queue.SubscriptionManager(999)
sub = await sm.add_subscription("topic/+")

message = gmqtt.Message(topic="topic/TEST", payload="payload")
sm.on_message(message)

received = await sub.recv()
assert received.topic == "topic/TEST"


def test_match_topic_filter():
tf = TopicFilter("topic/TEST")
assert tf.match("topic/TEST")
assert tf.match("topic/FOO") == False


def test_match_topic_filter_multilevel_wildcard():
tf = TopicFilter("topic/#")

assert tf.match("topic/TEST/1")
assert tf.match("topic/1")
assert tf.match("topic/")
assert tf.match("topic")


def test_invalid_multilevel_wildcard():

with pytest.raises(ValueError) as exc:
TopicFilter("sport/tennis/#/ranking")

with pytest.raises(ValueError):
TopicFilter("#/tailing")

with pytest.raises(ValueError):
TopicFilter("sport/tennis#")


@pytest.fixture
def client():
mock_inner = MagicMock()
mock_inner.subscribe.return_value = None
mock_inner.on_message = None

wrapper_client = MqttClientWrapper(mock_inner)

return wrapper_client, mock_inner

@pytest.mark.asyncio
async def test_multiple_subs_duplicate_messages(client):
client, mocked_inner = client

sub1 = await client.subscribe("test/test")
sub2 = await client.subscribe("test/test")

message = Message(topic="test/test", payload="payload")
mocked_inner.on_message(client=mocked_inner, topic=message.topic, payload=message.payload, qos=message.qos, properties={})

msg1 = await sub1.recv()
msg2 = await sub2.recv()

assert msg1.topic == "test/test"
assert msg2.topic == "test/test"
assert msg1.payload == b"payload"
assert msg2.payload == b"payload"