Skip to content

Commit

Permalink
fix: incorrect pxd typing for for _marshall (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Oct 6, 2022
1 parent 23903c3 commit cf1f012
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 27 deletions.
2 changes: 1 addition & 1 deletion bench/marshall.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


def marhsall_bluez_get_managed_objects_message():
message._marshall()
message._marshall(False)


count = 1000000
Expand Down
17 changes: 16 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ pycairo = "^1.21.0"
PyGObject = "^3.42.2"
Cython = "^0.29.32"
setuptools = "^65.4.1"
pytest-timeout = "^2.1.0"

[tool.semantic_release]
branch = "main"
Expand Down
4 changes: 2 additions & 2 deletions src/dbus_fast/aio/message_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def write_callback(self, remove_writer: bool = True) -> None:
def buffer_message(self, msg: Message, future=None) -> None:
self.messages.append(
(
msg._marshall(negotiate_unix_fd=self.negotiate_unix_fd),
msg._marshall(self.negotiate_unix_fd),
copy(msg.unix_fds),
future,
)
Expand Down Expand Up @@ -216,7 +216,7 @@ def on_hello(reply, err):
)

self._method_return_handlers[hello_msg.serial] = on_hello
self._stream.write(hello_msg._marshall())
self._stream.write(hello_msg._marshall(False))
self._stream.flush()

return await future
Expand Down
4 changes: 2 additions & 2 deletions src/dbus_fast/glib/message_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def dispatch(self, callback, user_data):
return GLib.SOURCE_REMOVE
else:
message = self.bus._buffered_messages.pop(0)
self.message_stream = io.BytesIO(message._marshall())
self.message_stream = io.BytesIO(message._marshall(False))
return GLib.SOURCE_CONTINUE
except BlockingIOError:
return GLib.SOURCE_CONTINUE
Expand Down Expand Up @@ -233,7 +233,7 @@ def on_hello(reply, err):
)

self._method_return_handlers[hello_msg.serial] = on_hello
self._stream.write(hello_msg._marshall())
self._stream.write(hello_msg._marshall(False))
self._stream.flush()

self._authenticate(authenticate_notify)
Expand Down
2 changes: 1 addition & 1 deletion src/dbus_fast/message.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ cdef class Message:
cdef public list body
cdef public unsigned int serial

cpdef _marshall(self, negotiate_unix_fd: bint)
cpdef _marshall(self, bint negotiate_unix_fd)
31 changes: 17 additions & 14 deletions src/dbus_fast/message.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Union
from typing import Any, List, Optional, Union

from ._private.constants import LITTLE_ENDIAN, PROTOCOL_VERSION, HeaderField
from ._private.marshaller import Marshaller
Expand Down Expand Up @@ -95,17 +95,17 @@ class Message:

def __init__(
self,
destination: str = None,
path: str = None,
interface: str = None,
member: str = None,
destination: Optional[str] = None,
path: Optional[str] = None,
interface: Optional[str] = None,
member: Optional[str] = None,
message_type: MessageType = MessageType.METHOD_CALL,
flags: MessageFlag = MessageFlag.NONE,
error_name: str = None,
reply_serial: int = None,
sender: str = None,
error_name: Optional[Union[str, ErrorType]] = None,
reply_serial=0,
sender: Optional[str] = None,
unix_fds: List[int] = [],
signature: Union[str, SignatureTree] = "",
signature: Optional[Union[SignatureTree, str]] = None,
body: List[Any] = [],
serial: int = 0,
validate: bool = True,
Expand All @@ -119,7 +119,7 @@ def __init__(
flags if type(flags) is MessageFlag else MessageFlag(bytes([flags]))
)
self.error_name = (
error_name if type(error_name) is not ErrorType else error_name.value
str(error_name.value) if type(error_name) is ErrorType else error_name
)
self.reply_serial = reply_serial or 0
self.sender = sender
Expand All @@ -128,8 +128,8 @@ def __init__(
self.signature = signature.signature
self.signature_tree = signature
else:
self.signature = signature
self.signature_tree = get_signature_tree(signature)
self.signature = signature or ""
self.signature_tree = get_signature_tree(signature or "")
self.body = body
self.serial = serial or 0

Expand All @@ -154,7 +154,9 @@ def __init__(
raise InvalidMessageError(f"missing required field: {field}")

@staticmethod
def new_error(msg: "Message", error_name: str, error_text: str) -> "Message":
def new_error(
msg: "Message", error_name: Union[str, ErrorType], error_text: str
) -> "Message":
"""A convenience constructor to create an error message in reply to the given message.
:param msg: The message this error is in reply to.
Expand Down Expand Up @@ -255,7 +257,8 @@ def new_signal(
unix_fds=unix_fds,
)

def _marshall(self, negotiate_unix_fd=False):
def _marshall(self, negotiate_unix_fd: bool) -> bytearray:
"""Marshall this message into a byte array."""
# TODO maximum message size is 134217728 (128 MiB)
body_block = Marshaller(self.signature, self.body)
body_block.marshall()
Expand Down
8 changes: 4 additions & 4 deletions src/dbus_fast/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def is_member_name_valid(member: str) -> bool:
return True


def assert_bus_name_valid(name: str):
def assert_bus_name_valid(name: str) -> None:
"""Raise an error if this is not a valid bus name.
.. seealso:: https://dbus.freedesktop.org/doc/dbus-specification.html#message-protocol-names-bus
Expand All @@ -150,7 +150,7 @@ def assert_bus_name_valid(name: str):
raise InvalidBusNameError(name)


def assert_object_path_valid(path: str):
def assert_object_path_valid(path: str) -> None:
"""Raise an error if this is not a valid object path.
.. seealso:: https://dbus.freedesktop.org/doc/dbus-specification.html#message-protocol-marshaling-object-path
Expand All @@ -165,7 +165,7 @@ def assert_object_path_valid(path: str):
raise InvalidObjectPathError(path)


def assert_interface_name_valid(name: str):
def assert_interface_name_valid(name: str) -> None:
"""Raise an error if this is not a valid interface name.
.. seealso:: https://dbus.freedesktop.org/doc/dbus-specification.html#message-protocol-names-interface
Expand All @@ -180,7 +180,7 @@ def assert_interface_name_valid(name: str):
raise InvalidInterfaceNameError(name)


def assert_member_name_valid(member):
def assert_member_name_valid(member) -> None:
"""Raise an error if this is not a valid member name.
.. seealso:: https://dbus.freedesktop.org/doc/dbus-specification.html#message-protocol-names-member
Expand Down
4 changes: 2 additions & 2 deletions tests/test_marshaller.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_marshalling_with_table():
body.append(replace_variants(type_, message.body[i]))
message.body = body

buf = message._marshall()
buf = message._marshall(False)
data = bytes.fromhex(item["data"])

if buf != data:
Expand Down Expand Up @@ -173,6 +173,6 @@ def read(self, n) -> bytes:
def test_ay_buffer():
body = [bytes(10000)]
msg = Message(path="/test", member="test", signature="ay", body=body)
marshalled = msg._marshall()
marshalled = msg._marshall(False)
unmarshalled_msg = Unmarshaller(io.BytesIO(marshalled)).unmarshall()
assert unmarshalled_msg.body[0] == body[0]

0 comments on commit cf1f012

Please sign in to comment.