diff --git a/src/dbus_fast/proxy_object.py b/src/dbus_fast/proxy_object.py index 626d94da..16aadd4a 100644 --- a/src/dbus_fast/proxy_object.py +++ b/src/dbus_fast/proxy_object.py @@ -136,13 +136,30 @@ def _message_handler(self, msg): def _add_signal(self, intr_signal, interface): def on_signal_fn(fn, *, unpack_variants: bool = False): fn_signature = inspect.signature(fn) - if len(fn_signature.parameters) != len(intr_signal.args) and ( - inspect.Parameter.VAR_POSITIONAL - not in [par.kind for par in fn_signature.parameters.values()] - or len(fn_signature.parameters) - 1 > len(intr_signal.args) + if 0 < len( + [ + par + for par in fn_signature.parameters.values() + if par.kind == inspect.Parameter.KEYWORD_ONLY + and par.default == inspect.Parameter.empty + ] ): raise TypeError( - f"reply_notify must be a function with {len(intr_signal.args)} parameters" + "reply_notify cannot have required keyword only parameters" + ) + + positional_params = [ + par.kind + for par in fn_signature.parameters.values() + if par.kind + not in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.VAR_KEYWORD] + ] + if len(positional_params) != len(intr_signal.args) and ( + inspect.Parameter.VAR_POSITIONAL not in positional_params + or len(positional_params) - 1 > len(intr_signal.args) + ): + raise TypeError( + f"reply_notify must be a function with {len(intr_signal.args)} positional parameters" ) if not self._signal_handlers: diff --git a/tests/client/test_signals.py b/tests/client/test_signals.py index 490f3d27..d08473dd 100644 --- a/tests/client/test_signals.py +++ b/tests/client/test_signals.py @@ -290,6 +290,75 @@ def varargs_plus_handler(value, *_): bus2.disconnect() +@pytest.mark.asyncio +async def test_kwargs_callback(): + """Test callback for signal with kwargs.""" + bus1 = await MessageBus().connect() + bus2 = await MessageBus().connect() + + await bus1.request_name("test.signals.name") + service_interface = ExampleInterface() + bus1.export("/test/path", service_interface) + + obj = bus2.get_proxy_object( + "test.signals.name", "/test/path", bus1._introspect_export_path("/test/path") + ) + interface = obj.get_interface(service_interface.name) + + async def ping(): + await bus2.call( + Message( + destination=bus1.unique_name, + interface="org.freedesktop.DBus.Peer", + path="/test/path", + member="Ping", + ) + ) + + kwargs_handler_counter = 0 + kwargs_handler_err = None + kwarg_default_handler_counter = 0 + kwarg_default_handler_err = None + + def kwargs_handler(value, **_): + nonlocal kwargs_handler_counter + nonlocal kwargs_handler_err + try: + assert value == "hello" + kwargs_handler_counter += 1 + except AssertionError as ex: + kwargs_handler_err = ex + + def kwarg_default_handler(value, *, _=True): + nonlocal kwarg_default_handler_counter + nonlocal kwarg_default_handler_err + try: + assert value == "hello" + kwarg_default_handler_counter += 1 + except AssertionError as ex: + kwarg_default_handler_err = ex + + interface.on_some_signal(kwargs_handler) + interface.on_some_signal(kwarg_default_handler) + await ping() + + service_interface.SomeSignal() + await ping() + assert kwargs_handler_err is None + assert kwargs_handler_counter == 1 + assert kwarg_default_handler_err is None + assert kwarg_default_handler_counter == 1 + + def kwarg_bad_handler(value, *, bad_kwarg): + pass + + with pytest.raises(TypeError): + interface.on_some_signal(kwarg_bad_handler) + + bus1.disconnect() + bus2.disconnect() + + @pytest.mark.asyncio async def test_on_signal_type_error(): """Test on callback raises type errors for invalid callbacks."""