diff --git a/src/dbus_fast/_private/unmarshaller.py b/src/dbus_fast/_private/unmarshaller.py index 48b7239c..1745c61c 100644 --- a/src/dbus_fast/_private/unmarshaller.py +++ b/src/dbus_fast/_private/unmarshaller.py @@ -158,6 +158,7 @@ def build_simple_parsers( except ImportError: from ._cython_compat import FAKE_CYTHON as cython + # # Alignment padding is handled with the following formula below # diff --git a/src/dbus_fast/aio/message_bus.py b/src/dbus_fast/aio/message_bus.py index 795c394b..03634da7 100644 --- a/src/dbus_fast/aio/message_bus.py +++ b/src/dbus_fast/aio/message_bus.py @@ -413,7 +413,7 @@ def _make_method_handler(self, interface, method): if not asyncio.iscoroutinefunction(method.fn): return super()._make_method_handler(interface, method) - def handler(msg, send_reply): + def _coro_method_handler(msg, send_reply): def done(fut): with send_reply: result = fut.result() @@ -430,7 +430,7 @@ def done(fut): fut = asyncio.ensure_future(method.fn(interface, *args)) fut.add_done_callback(done) - return handler + return _coro_method_handler async def _auth_readline(self) -> str: buf = b"" diff --git a/src/dbus_fast/message.py b/src/dbus_fast/message.py index ccb2cba1..12e56de3 100644 --- a/src/dbus_fast/message.py +++ b/src/dbus_fast/message.py @@ -13,10 +13,10 @@ ) REQUIRED_FIELDS = { - MessageType.METHOD_CALL: ("path", "member"), - MessageType.SIGNAL: ("path", "member", "interface"), - MessageType.ERROR: ("error_name", "reply_serial"), - MessageType.METHOD_RETURN: ("reply_serial",), + MessageType.METHOD_CALL.value: ("path", "member"), + MessageType.SIGNAL.value: ("path", "member", "interface"), + MessageType.ERROR.value: ("error_name", "reply_serial"), + MessageType.METHOD_RETURN.value: ("reply_serial",), } HEADER_PATH = HeaderField.PATH.value @@ -146,7 +146,7 @@ def __init__( if self.error_name is not None: assert_interface_name_valid(self.error_name) # type: ignore[arg-type] - required_fields = REQUIRED_FIELDS.get(self.message_type) + required_fields = REQUIRED_FIELDS.get(self.message_type.value) if not required_fields: raise InvalidMessageError(f"got unknown message type: {self.message_type}") for field in required_fields: diff --git a/src/dbus_fast/message_bus.py b/src/dbus_fast/message_bus.py index 1f648fde..b71f5b87 100644 --- a/src/dbus_fast/message_bus.py +++ b/src/dbus_fast/message_bus.py @@ -881,69 +881,75 @@ def _process_message(self, msg: _Message) -> None: def _make_method_handler( self, interface: ServiceInterface, method: _Method ) -> Callable[[Message, Callable[[Message], None]], None]: - def handler(msg: Message, send_reply: Callable[[Message], None]) -> None: - args = ServiceInterface._msg_body_to_args(msg) - result = method.fn(interface, *args) - body, fds = ServiceInterface._fn_result_to_body( - result, - signature_tree=method.out_signature_tree, - replace_fds=self._negotiate_unix_fd, + method_fn = method.fn + out_signature_tree = method.out_signature_tree + negotiate_unix_fd = self._negotiate_unix_fd + out_signature = method.out_signature + message_type_method_return = MessageType.METHOD_RETURN + msg_body_to_args = ServiceInterface._msg_body_to_args + fn_result_to_body = ServiceInterface._fn_result_to_body + + def _callback_method_handler( + msg: Message, send_reply: Callable[[Message], None] + ) -> None: + body, fds = fn_result_to_body( + method_fn(interface, *msg_body_to_args(msg)), + signature_tree=out_signature_tree, + replace_fds=negotiate_unix_fd, ) send_reply( Message( - message_type=MessageType.METHOD_RETURN, + message_type=message_type_method_return, reply_serial=msg.serial, destination=msg.sender, - signature=method.out_signature, + signature=out_signature, body=body, unix_fds=fds, ) ) - return handler + return _callback_method_handler def _find_message_handler( self, msg ) -> Optional[Callable[[Message, Callable], None]]: - handler: Optional[Callable[[Message, Callable], None]] = None - if ( msg.interface == "org.freedesktop.DBus.Introspectable" and msg.member == "Introspect" and msg.signature == "" ): - handler = self._default_introspect_handler + return self._default_introspect_handler - elif msg.interface == "org.freedesktop.DBus.Properties": - handler = self._default_properties_handler + if msg.interface == "org.freedesktop.DBus.Properties": + return self._default_properties_handler - elif msg.interface == "org.freedesktop.DBus.Peer": + if msg.interface == "org.freedesktop.DBus.Peer": if msg.member == "Ping" and msg.signature == "": - handler = self._default_ping_handler + return self._default_ping_handler elif msg.member == "GetMachineId" and msg.signature == "": - handler = self._default_get_machine_id_handler - elif ( + return self._default_get_machine_id_handler + + if ( msg.interface == "org.freedesktop.DBus.ObjectManager" and msg.member == "GetManagedObjects" ): - handler = self._default_get_managed_objects_handler + return self._default_get_managed_objects_handler - elif msg.path: - for interface in self._path_exports.get(msg.path, []): + msg_path = msg.path + if msg_path: + for interface in self._path_exports.get(msg_path, []): for method in ServiceInterface._get_methods(interface): if method.disabled: continue + if ( msg.interface == interface.name and msg.member == method.name and msg.signature == method.in_signature ): - handler = ServiceInterface._get_handler(interface, method, self) - break - if handler: - break + return ServiceInterface._get_handler(interface, method, self) - return handler + return None def _default_introspect_handler( self, msg: Message, send_reply: Callable[[Message], None] diff --git a/src/dbus_fast/proxy_object.py b/src/dbus_fast/proxy_object.py index 60dce504..0aef097b 100644 --- a/src/dbus_fast/proxy_object.py +++ b/src/dbus_fast/proxy_object.py @@ -58,7 +58,6 @@ def __init__( introspection: intr.Interface, bus: "message_bus.BaseMessageBus", ) -> None: - self.bus_name = bus_name self.path = path self.introspection = introspection diff --git a/src/dbus_fast/validators.py b/src/dbus_fast/validators.py index 594310e9..f35ccd4c 100644 --- a/src/dbus_fast/validators.py +++ b/src/dbus_fast/validators.py @@ -135,6 +135,7 @@ def is_member_name_valid(member: str) -> bool: return True +@lru_cache(maxsize=32) def assert_bus_name_valid(name: str) -> None: """Raise an error if this is not a valid bus name. @@ -150,6 +151,7 @@ def assert_bus_name_valid(name: str) -> None: raise InvalidBusNameError(name) +@lru_cache(maxsize=1024) def assert_object_path_valid(path: str) -> None: """Raise an error if this is not a valid object path. @@ -165,6 +167,7 @@ def assert_object_path_valid(path: str) -> None: raise InvalidObjectPathError(path) +@lru_cache(maxsize=32) def assert_interface_name_valid(name: str) -> None: """Raise an error if this is not a valid interface name. @@ -180,6 +183,7 @@ def assert_interface_name_valid(name: str) -> None: raise InvalidInterfaceNameError(name) +@lru_cache(maxsize=512) def assert_member_name_valid(member: str) -> None: """Raise an error if this is not a valid member name.