Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: incorrect pxd typing for for _marshall #75

Merged
merged 3 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bench/marshall.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


def marhsall_bluez_get_managed_objects_message():
message._marshall()
message._marshall(False)


count = 1000000
Expand Down
17 changes: 16 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ pycairo = "^1.21.0"
PyGObject = "^3.42.2"
Cython = "^0.29.32"
setuptools = "^65.4.1"
pytest-timeout = "^2.1.0"

[tool.semantic_release]
branch = "main"
Expand Down
4 changes: 2 additions & 2 deletions src/dbus_fast/aio/message_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def write_callback(self, remove_writer: bool = True) -> None:
def buffer_message(self, msg: Message, future=None) -> None:
self.messages.append(
(
msg._marshall(negotiate_unix_fd=self.negotiate_unix_fd),
msg._marshall(self.negotiate_unix_fd),
copy(msg.unix_fds),
future,
)
Expand Down Expand Up @@ -216,7 +216,7 @@ def on_hello(reply, err):
)

self._method_return_handlers[hello_msg.serial] = on_hello
self._stream.write(hello_msg._marshall())
self._stream.write(hello_msg._marshall(False))
self._stream.flush()

return await future
Expand Down
4 changes: 2 additions & 2 deletions src/dbus_fast/glib/message_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def dispatch(self, callback, user_data):
return GLib.SOURCE_REMOVE
else:
message = self.bus._buffered_messages.pop(0)
self.message_stream = io.BytesIO(message._marshall())
self.message_stream = io.BytesIO(message._marshall(False))
return GLib.SOURCE_CONTINUE
except BlockingIOError:
return GLib.SOURCE_CONTINUE
Expand Down Expand Up @@ -233,7 +233,7 @@ def on_hello(reply, err):
)

self._method_return_handlers[hello_msg.serial] = on_hello
self._stream.write(hello_msg._marshall())
self._stream.write(hello_msg._marshall(False))
self._stream.flush()

self._authenticate(authenticate_notify)
Expand Down
2 changes: 1 addition & 1 deletion src/dbus_fast/message.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ cdef class Message:
cdef public list body
cdef public unsigned int serial

cpdef _marshall(self, negotiate_unix_fd: bint)
cpdef _marshall(self, bint negotiate_unix_fd)
31 changes: 17 additions & 14 deletions src/dbus_fast/message.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Union
from typing import Any, List, Optional, Union

from ._private.constants import LITTLE_ENDIAN, PROTOCOL_VERSION, HeaderField
from ._private.marshaller import Marshaller
Expand Down Expand Up @@ -95,17 +95,17 @@ class Message:

def __init__(
self,
destination: str = None,
path: str = None,
interface: str = None,
member: str = None,
destination: Optional[str] = None,
path: Optional[str] = None,
interface: Optional[str] = None,
member: Optional[str] = None,
message_type: MessageType = MessageType.METHOD_CALL,
flags: MessageFlag = MessageFlag.NONE,
error_name: str = None,
reply_serial: int = None,
sender: str = None,
error_name: Optional[Union[str, ErrorType]] = None,
reply_serial=0,
sender: Optional[str] = None,
unix_fds: List[int] = [],
signature: Union[str, SignatureTree] = "",
signature: Optional[Union[SignatureTree, str]] = None,
body: List[Any] = [],
serial: int = 0,
validate: bool = True,
Expand All @@ -119,7 +119,7 @@ def __init__(
flags if type(flags) is MessageFlag else MessageFlag(bytes([flags]))
)
self.error_name = (
error_name if type(error_name) is not ErrorType else error_name.value
str(error_name.value) if type(error_name) is ErrorType else error_name
)
self.reply_serial = reply_serial or 0
self.sender = sender
Expand All @@ -128,8 +128,8 @@ def __init__(
self.signature = signature.signature
self.signature_tree = signature
else:
self.signature = signature
self.signature_tree = get_signature_tree(signature)
self.signature = signature or ""
self.signature_tree = get_signature_tree(signature or "")
self.body = body
self.serial = serial or 0

Expand All @@ -154,7 +154,9 @@ def __init__(
raise InvalidMessageError(f"missing required field: {field}")

@staticmethod
def new_error(msg: "Message", error_name: str, error_text: str) -> "Message":
def new_error(
msg: "Message", error_name: Union[str, ErrorType], error_text: str
) -> "Message":
"""A convenience constructor to create an error message in reply to the given message.

:param msg: The message this error is in reply to.
Expand Down Expand Up @@ -255,7 +257,8 @@ def new_signal(
unix_fds=unix_fds,
)

def _marshall(self, negotiate_unix_fd=False):
def _marshall(self, negotiate_unix_fd: bool) -> bytearray:
"""Marshall this message into a byte array."""
# TODO maximum message size is 134217728 (128 MiB)
body_block = Marshaller(self.signature, self.body)
body_block.marshall()
Expand Down
8 changes: 4 additions & 4 deletions src/dbus_fast/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def is_member_name_valid(member: str) -> bool:
return True


def assert_bus_name_valid(name: str):
def assert_bus_name_valid(name: str) -> None:
"""Raise an error if this is not a valid bus name.

.. seealso:: https://dbus.freedesktop.org/doc/dbus-specification.html#message-protocol-names-bus
Expand All @@ -150,7 +150,7 @@ def assert_bus_name_valid(name: str):
raise InvalidBusNameError(name)


def assert_object_path_valid(path: str):
def assert_object_path_valid(path: str) -> None:
"""Raise an error if this is not a valid object path.

.. seealso:: https://dbus.freedesktop.org/doc/dbus-specification.html#message-protocol-marshaling-object-path
Expand All @@ -165,7 +165,7 @@ def assert_object_path_valid(path: str):
raise InvalidObjectPathError(path)


def assert_interface_name_valid(name: str):
def assert_interface_name_valid(name: str) -> None:
"""Raise an error if this is not a valid interface name.

.. seealso:: https://dbus.freedesktop.org/doc/dbus-specification.html#message-protocol-names-interface
Expand All @@ -180,7 +180,7 @@ def assert_interface_name_valid(name: str):
raise InvalidInterfaceNameError(name)


def assert_member_name_valid(member):
def assert_member_name_valid(member) -> None:
"""Raise an error if this is not a valid member name.

.. seealso:: https://dbus.freedesktop.org/doc/dbus-specification.html#message-protocol-names-member
Expand Down
4 changes: 2 additions & 2 deletions tests/test_marshaller.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_marshalling_with_table():
body.append(replace_variants(type_, message.body[i]))
message.body = body

buf = message._marshall()
buf = message._marshall(False)
data = bytes.fromhex(item["data"])

if buf != data:
Expand Down Expand Up @@ -173,6 +173,6 @@ def read(self, n) -> bytes:
def test_ay_buffer():
body = [bytes(10000)]
msg = Message(path="/test", member="test", signature="ay", body=body)
marshalled = msg._marshall()
marshalled = msg._marshall(False)
unmarshalled_msg = Unmarshaller(io.BytesIO(marshalled)).unmarshall()
assert unmarshalled_msg.body[0] == body[0]