diff --git a/bench/marshall.py b/bench/marshall.py index f376d84d..23b45665 100644 --- a/bench/marshall.py +++ b/bench/marshall.py @@ -11,7 +11,7 @@ def marhsall_bluez_get_managed_objects_message(): - message._marshall() + message._marshall(False) count = 1000000 diff --git a/poetry.lock b/poetry.lock index a324f9bd..7231741b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -357,6 +357,17 @@ pytest = ">=4.6" [package.extras] testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"] +[[package]] +name = "pytest-timeout" +version = "2.1.0" +description = "pytest plugin to abort hanging tests" +category = "dev" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +pytest = ">=5.0.0" + [[package]] name = "pytz" version = "2022.4" @@ -596,7 +607,7 @@ docs = ["myst-parser", "Sphinx", "sphinx-rtd-theme", "sphinxcontrib-asyncio", "s [metadata] lock-version = "1.1" python-versions = "^3.7" -content-hash = "96d521d1e66777febd43aad81ec451dbf5a15873e1434052283b6ecdf3095c07" +content-hash = "381552380ec2e3115cbc867021d088eeac5b7e7a494070586fa3bc40ad01f6a3" [metadata.files] alabaster = [ @@ -849,6 +860,10 @@ pytest-cov = [ {file = "pytest-cov-3.0.0.tar.gz", hash = "sha256:e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470"}, {file = "pytest_cov-3.0.0-py3-none-any.whl", hash = "sha256:578d5d15ac4a25e5f961c938b85a05b09fdaae9deef3bb6de9a6e766622ca7a6"}, ] +pytest-timeout = [ + {file = "pytest-timeout-2.1.0.tar.gz", hash = "sha256:c07ca07404c612f8abbe22294b23c368e2e5104b521c1790195561f37e1ac3d9"}, + {file = "pytest_timeout-2.1.0-py3-none-any.whl", hash = "sha256:f6f50101443ce70ad325ceb4473c4255e9d74e3c7cd0ef827309dfa4c0d975c6"}, +] pytz = [ {file = "pytz-2022.4-py2.py3-none-any.whl", hash = "sha256:2c0784747071402c6e99f0bafdb7da0fa22645f06554c7ae06bf6358897e9c91"}, {file = "pytz-2022.4.tar.gz", hash = "sha256:48ce799d83b6f8aab2020e369b627446696619e79645419610b9facd909b3174"}, diff --git a/pyproject.toml b/pyproject.toml index a14c8fb4..8c2978a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ pycairo = "^1.21.0" PyGObject = "^3.42.2" Cython = "^0.29.32" setuptools = "^65.4.1" +pytest-timeout = "^2.1.0" [tool.semantic_release] branch = "main" diff --git a/src/dbus_fast/aio/message_bus.py b/src/dbus_fast/aio/message_bus.py index 6406da07..8865e794 100644 --- a/src/dbus_fast/aio/message_bus.py +++ b/src/dbus_fast/aio/message_bus.py @@ -95,7 +95,7 @@ def write_callback(self, remove_writer: bool = True) -> None: def buffer_message(self, msg: Message, future=None) -> None: self.messages.append( ( - msg._marshall(negotiate_unix_fd=self.negotiate_unix_fd), + msg._marshall(self.negotiate_unix_fd), copy(msg.unix_fds), future, ) @@ -216,7 +216,7 @@ def on_hello(reply, err): ) self._method_return_handlers[hello_msg.serial] = on_hello - self._stream.write(hello_msg._marshall()) + self._stream.write(hello_msg._marshall(False)) self._stream.flush() return await future diff --git a/src/dbus_fast/glib/message_bus.py b/src/dbus_fast/glib/message_bus.py index 8ceb370f..ed7b6801 100644 --- a/src/dbus_fast/glib/message_bus.py +++ b/src/dbus_fast/glib/message_bus.py @@ -98,7 +98,7 @@ def dispatch(self, callback, user_data): return GLib.SOURCE_REMOVE else: message = self.bus._buffered_messages.pop(0) - self.message_stream = io.BytesIO(message._marshall()) + self.message_stream = io.BytesIO(message._marshall(False)) return GLib.SOURCE_CONTINUE except BlockingIOError: return GLib.SOURCE_CONTINUE @@ -233,7 +233,7 @@ def on_hello(reply, err): ) self._method_return_handlers[hello_msg.serial] = on_hello - self._stream.write(hello_msg._marshall()) + self._stream.write(hello_msg._marshall(False)) self._stream.flush() self._authenticate(authenticate_notify) diff --git a/src/dbus_fast/message.pxd b/src/dbus_fast/message.pxd index cd7d9a80..27067339 100644 --- a/src/dbus_fast/message.pxd +++ b/src/dbus_fast/message.pxd @@ -20,4 +20,4 @@ cdef class Message: cdef public list body cdef public unsigned int serial - cpdef _marshall(self, negotiate_unix_fd: bint) + cpdef _marshall(self, bint negotiate_unix_fd) diff --git a/src/dbus_fast/message.py b/src/dbus_fast/message.py index 7b6626cb..453933e0 100644 --- a/src/dbus_fast/message.py +++ b/src/dbus_fast/message.py @@ -1,4 +1,4 @@ -from typing import Any, List, Union +from typing import Any, List, Optional, Union from ._private.constants import LITTLE_ENDIAN, PROTOCOL_VERSION, HeaderField from ._private.marshaller import Marshaller @@ -95,17 +95,17 @@ class Message: def __init__( self, - destination: str = None, - path: str = None, - interface: str = None, - member: str = None, + destination: Optional[str] = None, + path: Optional[str] = None, + interface: Optional[str] = None, + member: Optional[str] = None, message_type: MessageType = MessageType.METHOD_CALL, flags: MessageFlag = MessageFlag.NONE, - error_name: str = None, - reply_serial: int = None, - sender: str = None, + error_name: Optional[Union[str, ErrorType]] = None, + reply_serial=0, + sender: Optional[str] = None, unix_fds: List[int] = [], - signature: Union[str, SignatureTree] = "", + signature: Optional[Union[SignatureTree, str]] = None, body: List[Any] = [], serial: int = 0, validate: bool = True, @@ -119,7 +119,7 @@ def __init__( flags if type(flags) is MessageFlag else MessageFlag(bytes([flags])) ) self.error_name = ( - error_name if type(error_name) is not ErrorType else error_name.value + str(error_name.value) if type(error_name) is ErrorType else error_name ) self.reply_serial = reply_serial or 0 self.sender = sender @@ -128,8 +128,8 @@ def __init__( self.signature = signature.signature self.signature_tree = signature else: - self.signature = signature - self.signature_tree = get_signature_tree(signature) + self.signature = signature or "" + self.signature_tree = get_signature_tree(signature or "") self.body = body self.serial = serial or 0 @@ -154,7 +154,9 @@ def __init__( raise InvalidMessageError(f"missing required field: {field}") @staticmethod - def new_error(msg: "Message", error_name: str, error_text: str) -> "Message": + def new_error( + msg: "Message", error_name: Union[str, ErrorType], error_text: str + ) -> "Message": """A convenience constructor to create an error message in reply to the given message. :param msg: The message this error is in reply to. @@ -255,7 +257,8 @@ def new_signal( unix_fds=unix_fds, ) - def _marshall(self, negotiate_unix_fd=False): + def _marshall(self, negotiate_unix_fd: bool) -> bytearray: + """Marshall this message into a byte array.""" # TODO maximum message size is 134217728 (128 MiB) body_block = Marshaller(self.signature, self.body) body_block.marshall() diff --git a/src/dbus_fast/validators.py b/src/dbus_fast/validators.py index 176f79cb..b17e6621 100644 --- a/src/dbus_fast/validators.py +++ b/src/dbus_fast/validators.py @@ -135,7 +135,7 @@ def is_member_name_valid(member: str) -> bool: return True -def assert_bus_name_valid(name: str): +def assert_bus_name_valid(name: str) -> None: """Raise an error if this is not a valid bus name. .. seealso:: https://dbus.freedesktop.org/doc/dbus-specification.html#message-protocol-names-bus @@ -150,7 +150,7 @@ def assert_bus_name_valid(name: str): raise InvalidBusNameError(name) -def assert_object_path_valid(path: str): +def assert_object_path_valid(path: str) -> None: """Raise an error if this is not a valid object path. .. seealso:: https://dbus.freedesktop.org/doc/dbus-specification.html#message-protocol-marshaling-object-path @@ -165,7 +165,7 @@ def assert_object_path_valid(path: str): raise InvalidObjectPathError(path) -def assert_interface_name_valid(name: str): +def assert_interface_name_valid(name: str) -> None: """Raise an error if this is not a valid interface name. .. seealso:: https://dbus.freedesktop.org/doc/dbus-specification.html#message-protocol-names-interface @@ -180,7 +180,7 @@ def assert_interface_name_valid(name: str): raise InvalidInterfaceNameError(name) -def assert_member_name_valid(member): +def assert_member_name_valid(member) -> None: """Raise an error if this is not a valid member name. .. seealso:: https://dbus.freedesktop.org/doc/dbus-specification.html#message-protocol-names-member diff --git a/tests/test_marshaller.py b/tests/test_marshaller.py index be6d1b6f..f0809f01 100644 --- a/tests/test_marshaller.py +++ b/tests/test_marshaller.py @@ -78,7 +78,7 @@ def test_marshalling_with_table(): body.append(replace_variants(type_, message.body[i])) message.body = body - buf = message._marshall() + buf = message._marshall(False) data = bytes.fromhex(item["data"]) if buf != data: @@ -173,6 +173,6 @@ def read(self, n) -> bytes: def test_ay_buffer(): body = [bytes(10000)] msg = Message(path="/test", member="test", signature="ay", body=body) - marshalled = msg._marshall() + marshalled = msg._marshall(False) unmarshalled_msg = Unmarshaller(io.BytesIO(marshalled)).unmarshall() assert unmarshalled_msg.body[0] == body[0]