From ac955b50ea2921b114f6a89c2e1d3fbf34698deb Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 27 Sep 2022 10:09:09 -1000 Subject: [PATCH] fix: improve typing on proxy_object (#41) --- src/dbus_fast/aio/proxy_object.py | 17 ++++++++++------ src/dbus_fast/proxy_object.py | 32 +++++++++++++++++++------------ 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/src/dbus_fast/aio/proxy_object.py b/src/dbus_fast/aio/proxy_object.py index d1542978..6bef63cd 100644 --- a/src/dbus_fast/aio/proxy_object.py +++ b/src/dbus_fast/aio/proxy_object.py @@ -1,5 +1,5 @@ import xml.etree.ElementTree as ET -from typing import List, Union +from typing import TYPE_CHECKING, Any, List, Union from .. import introspection as intr from .._private.util import replace_fds_with_idx, replace_idx_with_fds @@ -11,6 +11,9 @@ from ..signature import Variant from ..signature import unpack_variants as unpack +if TYPE_CHECKING: + from .message_bus import MessageBus as AioMessageBus + class ProxyInterface(BaseProxyInterface): """A class representing a proxy to an interface exported on the bus by @@ -74,7 +77,9 @@ class ProxyInterface(BaseProxyInterface): ` will be raised with information about the error. """ - def _add_method(self, intr_method): + bus: "AioMessageBus" + + def _add_method(self, intr_method: intr.Method) -> None: async def method_fn( *args, flags=MessageFlag.NONE, unpack_variants: bool = False ): @@ -119,8 +124,8 @@ async def method_fn( def _add_property( self, - intr_property, - ): + intr_property: intr.Property, + ) -> None: async def property_getter( *, flags=MessageFlag.NONE, unpack_variants: bool = False ): @@ -150,7 +155,7 @@ async def property_getter( return unpack(body) return body - async def property_setter(val): + async def property_setter(val: Any) -> None: variant = Variant(intr_property.signature, val) body, unix_fds = replace_fds_with_idx( @@ -188,7 +193,7 @@ def __init__( path: str, introspection: Union[intr.Node, str, ET.Element], bus: BaseMessageBus, - ): + ) -> None: super().__init__(bus_name, path, introspection, bus, ProxyInterface) def get_interface(self, name: str) -> ProxyInterface: diff --git a/src/dbus_fast/proxy_object.py b/src/dbus_fast/proxy_object.py index 16aadd4a..f0130d24 100644 --- a/src/dbus_fast/proxy_object.py +++ b/src/dbus_fast/proxy_object.py @@ -4,7 +4,8 @@ import re import xml.etree.ElementTree as ET from dataclasses import dataclass -from typing import Callable, Coroutine, Dict, List, Type, Union +from functools import lru_cache +from typing import Callable, Coroutine, Dict, List, Optional, Type, Union from . import introspection as intr from . import message_bus @@ -50,7 +51,13 @@ class BaseProxyInterface: :vartype bus: :class:`BaseMessageBus ` """ - def __init__(self, bus_name, path, introspection, bus): + def __init__( + self, + bus_name: str, + path: str, + introspection: intr.Interface, + bus: "message_bus.BaseMessageBus", + ) -> None: self.bus_name = bus_name self.path = path @@ -63,12 +70,13 @@ def __init__(self, bus_name, path, introspection, bus): _underscorer2 = re.compile(r"([a-z0-9])([A-Z])") @staticmethod - def _to_snake_case(member): + @lru_cache(maxsize=128) + def _to_snake_case(member: str) -> str: subbed = BaseProxyInterface._underscorer1.sub(r"\1_\2", member) return BaseProxyInterface._underscorer2.sub(r"\1_\2", subbed).lower() @staticmethod - def _check_method_return(msg, signature=None): + def _check_method_return(msg: Message, signature: Optional[str] = None): if msg.message_type == MessageType.ERROR: raise DBusError._from_message(msg) elif msg.message_type != MessageType.METHOD_RETURN: @@ -82,13 +90,13 @@ def _check_method_return(msg, signature=None): msg, ) - def _add_method(self, intr_method): + def _add_method(self, intr_method: intr.Method) -> None: raise NotImplementedError("this must be implemented in the inheriting class") - def _add_property(self, intr_property): + def _add_property(self, intr_property: intr.Property) -> None: raise NotImplementedError("this must be implemented in the inheriting class") - def _message_handler(self, msg): + def _message_handler(self, msg: Message) -> None: if ( not msg._matches( message_type=MessageType.SIGNAL, @@ -133,8 +141,8 @@ def _message_handler(self, msg): if isinstance(cb_result, Coroutine): asyncio.create_task(cb_result) - def _add_signal(self, intr_signal, interface): - def on_signal_fn(fn, *, unpack_variants: bool = False): + def _add_signal(self, intr_signal: intr.Signal, interface: intr.Interface) -> None: + def on_signal_fn(fn: Callable, *, unpack_variants: bool = False): fn_signature = inspect.signature(fn) if 0 < len( [ @@ -173,7 +181,7 @@ def on_signal_fn(fn, *, unpack_variants: bool = False): SignalHandler(fn, unpack_variants) ) - def off_signal_fn(fn, *, unpack_variants: bool = False): + def off_signal_fn(fn: Callable, *, unpack_variants: bool = False) -> None: try: i = self._signal_handlers[intr_signal.name].index( SignalHandler(fn, unpack_variants) @@ -235,7 +243,7 @@ def __init__( introspection: Union[intr.Node, str, ET.Element], bus: "message_bus.BaseMessageBus", ProxyInterface: Type[BaseProxyInterface], - ): + ) -> None: assert_object_path_valid(path) assert_bus_name_valid(bus_name) @@ -296,7 +304,7 @@ def get_interface(self, name: str) -> BaseProxyInterface: for intr_signal in intr_interface.signals: interface._add_signal(intr_signal, interface) - def get_owner_notify(msg, err): + def get_owner_notify(msg: Message, err: Optional[Exception]) -> None: if err: logging.error(f'getting name owner for "{name}" failed, {err}') return