Skip to content

Commit

Permalink
fix: improve typing on proxy_object (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Sep 27, 2022
1 parent acc26f4 commit ac955b5
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 18 deletions.
17 changes: 11 additions & 6 deletions src/dbus_fast/aio/proxy_object.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import xml.etree.ElementTree as ET
from typing import List, Union
from typing import TYPE_CHECKING, Any, List, Union

from .. import introspection as intr
from .._private.util import replace_fds_with_idx, replace_idx_with_fds
Expand All @@ -11,6 +11,9 @@
from ..signature import Variant
from ..signature import unpack_variants as unpack

if TYPE_CHECKING:
from .message_bus import MessageBus as AioMessageBus


class ProxyInterface(BaseProxyInterface):
"""A class representing a proxy to an interface exported on the bus by
Expand Down Expand Up @@ -74,7 +77,9 @@ class ProxyInterface(BaseProxyInterface):
<dbus_fast.DBusError>` will be raised with information about the error.
"""

def _add_method(self, intr_method):
bus: "AioMessageBus"

def _add_method(self, intr_method: intr.Method) -> None:
async def method_fn(
*args, flags=MessageFlag.NONE, unpack_variants: bool = False
):
Expand Down Expand Up @@ -119,8 +124,8 @@ async def method_fn(

def _add_property(
self,
intr_property,
):
intr_property: intr.Property,
) -> None:
async def property_getter(
*, flags=MessageFlag.NONE, unpack_variants: bool = False
):
Expand Down Expand Up @@ -150,7 +155,7 @@ async def property_getter(
return unpack(body)
return body

async def property_setter(val):
async def property_setter(val: Any) -> None:
variant = Variant(intr_property.signature, val)

body, unix_fds = replace_fds_with_idx(
Expand Down Expand Up @@ -188,7 +193,7 @@ def __init__(
path: str,
introspection: Union[intr.Node, str, ET.Element],
bus: BaseMessageBus,
):
) -> None:
super().__init__(bus_name, path, introspection, bus, ProxyInterface)

def get_interface(self, name: str) -> ProxyInterface:
Expand Down
32 changes: 20 additions & 12 deletions src/dbus_fast/proxy_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import re
import xml.etree.ElementTree as ET
from dataclasses import dataclass
from typing import Callable, Coroutine, Dict, List, Type, Union
from functools import lru_cache
from typing import Callable, Coroutine, Dict, List, Optional, Type, Union

from . import introspection as intr
from . import message_bus
Expand Down Expand Up @@ -50,7 +51,13 @@ class BaseProxyInterface:
:vartype bus: :class:`BaseMessageBus <dbus_fast.message_bus.BaseMessageBus>`
"""

def __init__(self, bus_name, path, introspection, bus):
def __init__(
self,
bus_name: str,
path: str,
introspection: intr.Interface,
bus: "message_bus.BaseMessageBus",
) -> None:

self.bus_name = bus_name
self.path = path
Expand All @@ -63,12 +70,13 @@ def __init__(self, bus_name, path, introspection, bus):
_underscorer2 = re.compile(r"([a-z0-9])([A-Z])")

@staticmethod
def _to_snake_case(member):
@lru_cache(maxsize=128)
def _to_snake_case(member: str) -> str:
subbed = BaseProxyInterface._underscorer1.sub(r"\1_\2", member)
return BaseProxyInterface._underscorer2.sub(r"\1_\2", subbed).lower()

@staticmethod
def _check_method_return(msg, signature=None):
def _check_method_return(msg: Message, signature: Optional[str] = None):
if msg.message_type == MessageType.ERROR:
raise DBusError._from_message(msg)
elif msg.message_type != MessageType.METHOD_RETURN:
Expand All @@ -82,13 +90,13 @@ def _check_method_return(msg, signature=None):
msg,
)

def _add_method(self, intr_method):
def _add_method(self, intr_method: intr.Method) -> None:
raise NotImplementedError("this must be implemented in the inheriting class")

def _add_property(self, intr_property):
def _add_property(self, intr_property: intr.Property) -> None:
raise NotImplementedError("this must be implemented in the inheriting class")

def _message_handler(self, msg):
def _message_handler(self, msg: Message) -> None:
if (
not msg._matches(
message_type=MessageType.SIGNAL,
Expand Down Expand Up @@ -133,8 +141,8 @@ def _message_handler(self, msg):
if isinstance(cb_result, Coroutine):
asyncio.create_task(cb_result)

def _add_signal(self, intr_signal, interface):
def on_signal_fn(fn, *, unpack_variants: bool = False):
def _add_signal(self, intr_signal: intr.Signal, interface: intr.Interface) -> None:
def on_signal_fn(fn: Callable, *, unpack_variants: bool = False):
fn_signature = inspect.signature(fn)
if 0 < len(
[
Expand Down Expand Up @@ -173,7 +181,7 @@ def on_signal_fn(fn, *, unpack_variants: bool = False):
SignalHandler(fn, unpack_variants)
)

def off_signal_fn(fn, *, unpack_variants: bool = False):
def off_signal_fn(fn: Callable, *, unpack_variants: bool = False) -> None:
try:
i = self._signal_handlers[intr_signal.name].index(
SignalHandler(fn, unpack_variants)
Expand Down Expand Up @@ -235,7 +243,7 @@ def __init__(
introspection: Union[intr.Node, str, ET.Element],
bus: "message_bus.BaseMessageBus",
ProxyInterface: Type[BaseProxyInterface],
):
) -> None:
assert_object_path_valid(path)
assert_bus_name_valid(bus_name)

Expand Down Expand Up @@ -296,7 +304,7 @@ def get_interface(self, name: str) -> BaseProxyInterface:
for intr_signal in intr_interface.signals:
interface._add_signal(intr_signal, interface)

def get_owner_notify(msg, err):
def get_owner_notify(msg: Message, err: Optional[Exception]) -> None:
if err:
logging.error(f'getting name owner for "{name}" failed, {err}')
return
Expand Down

0 comments on commit ac955b5

Please sign in to comment.