diff --git a/src/dbus_fast/aio/proxy_object.py b/src/dbus_fast/aio/proxy_object.py index 43f1e837..d1542978 100644 --- a/src/dbus_fast/aio/proxy_object.py +++ b/src/dbus_fast/aio/proxy_object.py @@ -9,6 +9,7 @@ from ..message_bus import BaseMessageBus from ..proxy_object import BaseProxyInterface, BaseProxyObject from ..signature import Variant +from ..signature import unpack_variants as unpack class ProxyInterface(BaseProxyInterface): @@ -74,7 +75,9 @@ class ProxyInterface(BaseProxyInterface): """ def _add_method(self, intr_method): - async def method_fn(*args, flags=MessageFlag.NONE): + async def method_fn( + *args, flags=MessageFlag.NONE, unpack_variants: bool = False + ): input_body, unix_fds = replace_fds_with_idx( intr_method.in_signature, list(args) ) @@ -103,16 +106,24 @@ async def method_fn(*args, flags=MessageFlag.NONE): if not out_len: return None - elif out_len == 1: + + if unpack_variants: + body = unpack(body) + + if out_len == 1: return body[0] - else: - return body + return body method_name = f"call_{BaseProxyInterface._to_snake_case(intr_method.name)}" setattr(self, method_name, method_fn) - def _add_property(self, intr_property): - async def property_getter(): + def _add_property( + self, + intr_property, + ): + async def property_getter( + *, flags=MessageFlag.NONE, unpack_variants: bool = False + ): msg = await self.bus.call( Message( destination=self.bus_name, @@ -133,7 +144,11 @@ async def property_getter(): msg, ) - return replace_idx_with_fds("v", msg.body, msg.unix_fds)[0].value + body = replace_idx_with_fds("v", msg.body, msg.unix_fds)[0].value + + if unpack_variants: + return unpack(body) + return body async def property_setter(val): variant = Variant(intr_property.signature, val) diff --git a/src/dbus_fast/glib/proxy_object.py b/src/dbus_fast/glib/proxy_object.py index f7946642..a5a020a2 100644 --- a/src/dbus_fast/glib/proxy_object.py +++ b/src/dbus_fast/glib/proxy_object.py @@ -8,6 +8,7 @@ from ..message_bus import BaseMessageBus from ..proxy_object import BaseProxyInterface, BaseProxyObject from ..signature import Variant +from ..signature import unpack_variants as unpack # glib is optional try: @@ -113,7 +114,7 @@ def _add_method(self, intr_method): in_len = len(intr_method.in_args) out_len = len(intr_method.out_args) - def method_fn(*args): + def method_fn(*args, unpack_variants: bool = False): if len(args) != in_len + 1: raise TypeError( f"method {intr_method.name} expects {in_len} arguments and a callback (got {len(args)} args)" @@ -136,7 +137,10 @@ def call_notify(msg, err): except DBusError as e: err = e - callback(msg.body, err) + if unpack_variants: + callback(unpack(msg.body), err) + else: + callback(msg.body, err) self.bus.call( Message( @@ -150,7 +154,7 @@ def call_notify(msg, err): call_notify, ) - def method_fn_sync(*args): + def method_fn_sync(*args, unpack_variants: bool = False): main = GLib.MainLoop() call_error = None call_body = None @@ -171,10 +175,13 @@ def callback(body, err): if not out_len: return None - elif out_len == 1: + + if unpack_variants: + call_body = unpack(call_body) + + if out_len == 1: return call_body[0] - else: - return call_body + return call_body method_name = f"call_{BaseProxyInterface._to_snake_case(intr_method.name)}" method_name_sync = f"{method_name}_sync" @@ -183,7 +190,7 @@ def callback(body, err): setattr(self, method_name_sync, method_fn_sync) def _add_property(self, intr_property): - def property_getter(callback): + def property_getter(callback, *, unpack_variants: bool = False): def call_notify(msg, err): if err: callback(None, err) @@ -204,8 +211,10 @@ def call_notify(msg, err): ) callback(None, err) return - - callback(variant.value, None) + if unpack_variants: + callback(unpack(variant.value), None) + else: + callback(variant.value, None) self.bus.call( Message( @@ -219,7 +228,7 @@ def call_notify(msg, err): call_notify, ) - def property_getter_sync(): + def property_getter_sync(*, unpack_variants: bool = False): property_value = None reply_error = None @@ -236,6 +245,8 @@ def callback(value, err): main.run() if reply_error: raise reply_error + if unpack_variants: + return unpack(property_value) return property_value def property_setter(value, callback): diff --git a/src/dbus_fast/proxy_object.py b/src/dbus_fast/proxy_object.py index 4ec6b7d6..626d94da 100644 --- a/src/dbus_fast/proxy_object.py +++ b/src/dbus_fast/proxy_object.py @@ -3,7 +3,8 @@ import logging import re import xml.etree.ElementTree as ET -from typing import Coroutine, List, Type, Union +from dataclasses import dataclass +from typing import Callable, Coroutine, Dict, List, Type, Union from . import introspection as intr from . import message_bus @@ -11,9 +12,18 @@ from .constants import ErrorType, MessageType from .errors import DBusError, InterfaceNotFoundError from .message import Message +from .signature import unpack_variants as unpack from .validators import assert_bus_name_valid, assert_object_path_valid +@dataclass +class SignalHandler: + """Signal handler.""" + + fn: Callable + unpack_variants: bool + + class BaseProxyInterface: """An abstract class representing a proxy to an interface exported on the bus by another client. @@ -46,7 +56,7 @@ def __init__(self, bus_name, path, introspection, bus): self.path = path self.introspection = introspection self.bus = bus - self._signal_handlers = {} + self._signal_handlers: Dict[str, List[SignalHandler]] = {} self._signal_match_rule = f"type='signal',sender={bus_name},interface={introspection.name},path={path}" _underscorer1 = re.compile(r"(.)([A-Z][a-z]+)") @@ -110,13 +120,21 @@ def _message_handler(self, msg): return body = replace_idx_with_fds(msg.signature, msg.body, msg.unix_fds) + no_sig = None for handler in self._signal_handlers[msg.member]: - cb_result = handler(*body) + if handler.unpack_variants: + if not no_sig: + no_sig = unpack(body) + data = no_sig + else: + data = body + + cb_result = handler.fn(*data) if isinstance(cb_result, Coroutine): asyncio.create_task(cb_result) def _add_signal(self, intr_signal, interface): - def on_signal_fn(fn): + def on_signal_fn(fn, *, unpack_variants: bool = False): fn_signature = inspect.signature(fn) if len(fn_signature.parameters) != len(intr_signal.args) and ( inspect.Parameter.VAR_POSITIONAL @@ -134,11 +152,15 @@ def on_signal_fn(fn): if intr_signal.name not in self._signal_handlers: self._signal_handlers[intr_signal.name] = [] - self._signal_handlers[intr_signal.name].append(fn) + self._signal_handlers[intr_signal.name].append( + SignalHandler(fn, unpack_variants) + ) - def off_signal_fn(fn): + def off_signal_fn(fn, *, unpack_variants: bool = False): try: - i = self._signal_handlers[intr_signal.name].index(fn) + i = self._signal_handlers[intr_signal.name].index( + SignalHandler(fn, unpack_variants) + ) del self._signal_handlers[intr_signal.name][i] if not self._signal_handlers[intr_signal.name]: del self._signal_handlers[intr_signal.name] diff --git a/src/dbus_fast/signature.py b/src/dbus_fast/signature.py index da8ff715..16bd9e9b 100644 --- a/src/dbus_fast/signature.py +++ b/src/dbus_fast/signature.py @@ -5,6 +5,17 @@ from .validators import is_object_path_valid +def unpack_variants(data: Any): + """Unpack variants and remove signature info.""" + if isinstance(data, Variant): + return unpack_variants(data.value) + if isinstance(data, dict): + return {k: unpack_variants(v) for k, v in data.items()} + if isinstance(data, list): + return [unpack_variants(item) for item in data] + return data + + class SignatureType: """A class that represents a single complete type within a signature. diff --git a/tests/client/test_methods.py b/tests/client/test_methods.py index d4c73cf5..b2c30e66 100644 --- a/tests/client/test_methods.py +++ b/tests/client/test_methods.py @@ -4,6 +4,7 @@ from dbus_fast import DBusError, aio, glib from dbus_fast.message import MessageFlag from dbus_fast.service import ServiceInterface, method +from dbus_fast.signature import Variant from tests.util import check_gi_repository, skip_reason_no_gi has_gi = check_gi_repository() @@ -33,6 +34,11 @@ def ConcatStrings(self, what1: "s", what2: "s") -> "s": def EchoThree(self, what1: "s", what2: "s", what3: "s") -> "sss": return [what1, what2, what3] + @method() + def GetComplex(self) -> "a{sv}": + """Return complex output.""" + return {"hello": Variant("s", "world")} + @method() def ThrowsError(self): raise DBusError("test.error", "something went wrong") @@ -81,6 +87,12 @@ async def test_aio_proxy_object(): ) assert result is None + result = await interface.call_get_complex() + assert result == {"hello": Variant("s", "world")} + + result = await interface.call_get_complex(unpack_variants=True) + assert result == {"hello": "world"} + with pytest.raises(DBusError): try: await interface.call_throws_error() @@ -120,6 +132,12 @@ def test_glib_proxy_object(): result = interface.call_echo_three_sync("hello", "there", "world") assert result == ["hello", "there", "world"] + result = interface.call_get_complex_sync() + assert result == {"hello": Variant("s", "world")} + + result = interface.call_get_complex_sync(unpack_variants=True) + assert result == {"hello": "world"} + with pytest.raises(DBusError): try: result = interface.call_throws_error_sync() diff --git a/tests/client/test_properties.py b/tests/client/test_properties.py index 2f983f4b..f3422863 100644 --- a/tests/client/test_properties.py +++ b/tests/client/test_properties.py @@ -2,6 +2,7 @@ from dbus_fast import DBusError, Message, aio, glib from dbus_fast.service import PropertyAccess, ServiceInterface, dbus_property +from dbus_fast.signature import Variant from tests.util import check_gi_repository, skip_reason_no_gi has_gi = check_gi_repository() @@ -27,6 +28,11 @@ def SomeProperty(self, val: "s"): def Int64Property(self) -> "x": return self._int64_property + @dbus_property(access=PropertyAccess.READ) + def ComplexProperty(self) -> "a{sv}": + """Return complex output.""" + return {"hello": Variant("s", "world")} + @dbus_property() def ErrorThrowingProperty(self) -> "s": raise DBusError(self.error_name, self.error_text) @@ -59,6 +65,12 @@ async def test_aio_properties(): await interface.set_some_property("different") assert service_interface._some_property == "different" + prop = await interface.get_complex_property() + assert prop == {"hello": Variant("s", "world")} + + prop = await interface.get_complex_property(unpack_variants=True) + assert prop == {"hello": "world"} + with pytest.raises(DBusError): try: prop = await interface.get_error_throwing_property() @@ -102,6 +114,12 @@ def test_glib_properties(): interface.set_some_property_sync("different") assert service_interface._some_property == "different" + prop = interface.get_complex_property_sync() + assert prop == {"hello": Variant("s", "world")} + + prop = interface.get_complex_property_sync(unpack_variants=True) + assert prop == {"hello": "world"} + with pytest.raises(DBusError): try: prop = interface.get_error_throwing_property_sync() diff --git a/tests/client/test_signals.py b/tests/client/test_signals.py index 57c7559e..490f3d27 100644 --- a/tests/client/test_signals.py +++ b/tests/client/test_signals.py @@ -2,10 +2,10 @@ from dbus_fast import Message from dbus_fast.aio import MessageBus -from dbus_fast.aio.proxy_object import ProxyInterface from dbus_fast.constants import RequestNameReply from dbus_fast.introspection import Node from dbus_fast.service import ServiceInterface, signal +from dbus_fast.signature import Variant class ExampleInterface(ServiceInterface): @@ -20,6 +20,11 @@ def SomeSignal(self) -> "s": def SignalMultiple(self) -> "ss": return ["hello", "world"] + @signal() + def SignalComplex(self) -> "a{sv}": + """Broadcast a complex signal.""" + return {"hello": Variant("s", "world")} + @pytest.mark.asyncio async def test_signals(): @@ -159,6 +164,69 @@ def dummy_signal_handler(what): bus3.disconnect() +@pytest.mark.asyncio +async def test_complex_signals(): + """Test complex signals with and without signature removal.""" + bus1 = await MessageBus().connect() + bus2 = await MessageBus().connect() + + await bus1.request_name("test.signals.name") + service_interface = ExampleInterface() + bus1.export("/test/path", service_interface) + + obj = bus2.get_proxy_object( + "test.signals.name", "/test/path", bus1._introspect_export_path("/test/path") + ) + interface = obj.get_interface(service_interface.name) + + async def ping(): + await bus2.call( + Message( + destination=bus1.unique_name, + interface="org.freedesktop.DBus.Peer", + path="/test/path", + member="Ping", + ) + ) + + sig_handler_counter = 0 + sig_handler_err = None + no_sig_handler_counter = 0 + no_sig_handler_err = None + + def complex_handler_with_sig(value): + nonlocal sig_handler_counter + nonlocal sig_handler_err + try: + assert value == {"hello": Variant("s", "world")} + sig_handler_counter += 1 + except AssertionError as ex: + sig_handler_err = ex + + def complex_handler_no_sig(value): + nonlocal no_sig_handler_counter + nonlocal no_sig_handler_err + try: + assert value == {"hello": "world"} + no_sig_handler_counter += 1 + except AssertionError as ex: + no_sig_handler_err = ex + + interface.on_signal_complex(complex_handler_with_sig) + interface.on_signal_complex(complex_handler_no_sig, unpack_variants=True) + await ping() + + service_interface.SignalComplex() + await ping() + assert sig_handler_err is None + assert sig_handler_counter == 1 + assert no_sig_handler_err is None + assert no_sig_handler_counter == 1 + + bus1.disconnect() + bus2.disconnect() + + @pytest.mark.asyncio async def test_varargs_callback(): """Test varargs callback for signal.""" diff --git a/tests/test_unpack_variants.py b/tests/test_unpack_variants.py new file mode 100644 index 00000000..0c832990 --- /dev/null +++ b/tests/test_unpack_variants.py @@ -0,0 +1,56 @@ +"""Test unpack variants.""" +import pytest + +from dbus_fast.signature import Variant, unpack_variants + + +@pytest.mark.asyncio +async def test_dictionary(): + """Test variants unpacked from dictionary.""" + assert unpack_variants( + { + "string": Variant("s", "test"), + "boolean": Variant("b", True), + "int": Variant("u", 1), + "object": Variant("o", "/test/path"), + "array": Variant("as", ["test", "value"]), + "tuple": Variant("(su)", ["test", 1]), + "bytes": Variant("ay", b"\0x62\0x75\0x66"), + } + ) == { + "string": "test", + "boolean": True, + "int": 1, + "object": "/test/path", + "array": ["test", "value"], + "tuple": ["test", 1], + "bytes": b"\0x62\0x75\0x66", + } + + +@pytest.mark.asyncio +async def test_output_list(): + """Test variants unpacked from multiple outputs.""" + assert unpack_variants( + [{"hello": Variant("s", "world")}, {"boolean": Variant("b", True)}, 1] + ) == [{"hello": "world"}, {"boolean": True}, 1] + + +@pytest.mark.asyncio +async def test_nested_variants(): + """Test unpack variants handles nesting.""" + assert unpack_variants( + { + "dict": Variant("a{sv}", {"hello": Variant("s", "world")}), + "array": Variant( + "aa{sv}", + [ + {"hello": Variant("s", "world")}, + {"bytes": Variant("ay", b"\0x62\0x75\0x66")}, + ], + ), + } + ) == { + "dict": {"hello": "world"}, + "array": [{"hello": "world"}, {"bytes": b"\0x62\0x75\0x66"}], + }