diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 98078da98bf235..3082d5080fed52 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -61,15 +61,15 @@ repos: name: mypy entry: script/run-in-env.sh mypy language: script - types: [python] + types_or: [python, pyi] require_serial: true files: ^(homeassistant|pylint)/.+\.(py|pyi)$ - id: pylint name: pylint entry: script/run-in-env.sh pylint -j 0 --ignore-missing-annotations=y language: script - types: [python] - files: ^homeassistant/.+\.py$ + types_or: [python, pyi] + files: ^homeassistant/.+\.(py|pyi)$ - id: gen_requirements_all name: gen_requirements_all entry: script/run-in-env.sh python3 -m script.gen_requirements_all diff --git a/homeassistant/util/signal_type.py b/homeassistant/util/signal_type.py index c9b74411ae0007..2552b3515fc95e 100644 --- a/homeassistant/util/signal_type.py +++ b/homeassistant/util/signal_type.py @@ -2,40 +2,20 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import Any - -@dataclass(frozen=True) -class _SignalTypeBase[*_Ts]: +class _SignalTypeBase[*_Ts](str): """Generic base class for SignalType.""" - name: str - - def __hash__(self) -> int: - """Return hash of name.""" - - return hash(self.name) + __slots__ = () - def __eq__(self, other: object) -> bool: - """Check equality for dict keys to be compatible with str.""" - if isinstance(other, str): - return self.name == other - if isinstance(other, SignalType): - return self.name == other.name - return False - - -@dataclass(frozen=True, eq=False) class SignalType[*_Ts](_SignalTypeBase[*_Ts]): """Generic string class for signal to improve typing.""" + __slots__ = () + -@dataclass(frozen=True, eq=False) class SignalTypeFormat[*_Ts](_SignalTypeBase[*_Ts]): """Generic string class for signal. Requires call to 'format' before use.""" - def format(self, *args: Any, **kwargs: Any) -> SignalType[*_Ts]: - """Format name and return new SignalType instance.""" - return SignalType(self.name.format(*args, **kwargs)) + __slots__ = () diff --git a/homeassistant/util/signal_type.pyi b/homeassistant/util/signal_type.pyi new file mode 100644 index 00000000000000..9987c3a0931799 --- /dev/null +++ b/homeassistant/util/signal_type.pyi @@ -0,0 +1,69 @@ +"""Stub file for signal_type. Provide overload for type checking.""" +# ruff: noqa: PYI021 # Allow docstring + +from typing import Any, assert_type + +__all__ = [ + "SignalType", + "SignalTypeFormat", +] + +class _SignalTypeBase[*_Ts]: + """Custom base class for SignalType. At runtime delegate to str. + + For type checkers pretend to be its own separate class. + """ + + def __init__(self, value: str, /) -> None: ... + def __hash__(self) -> int: ... + def __eq__(self, other: object, /) -> bool: ... + +class SignalType[*_Ts](_SignalTypeBase[*_Ts]): + """Generic string class for signal to improve typing.""" + +class SignalTypeFormat[*_Ts](_SignalTypeBase[*_Ts]): + """Generic string class for signal. Requires call to 'format' before use.""" + + def format(self, *args: Any, **kwargs: Any) -> SignalType[*_Ts]: ... + +def _test_signal_type_typing() -> None: # noqa: PYI048 + """Test SignalType and dispatcher overloads work as intended. + + This is tested during the mypy run. Do not move it to 'tests'! + """ + # pylint: disable=import-outside-toplevel + from homeassistant.core import HomeAssistant + from homeassistant.helpers.dispatcher import ( + async_dispatcher_connect, + async_dispatcher_send, + ) + + hass: HomeAssistant + def test_func(a: int) -> None: ... + def test_func_other(a: int, b: str) -> None: ... + + # No type validation for str signals + signal_str = "signal" + async_dispatcher_connect(hass, signal_str, test_func) + async_dispatcher_connect(hass, signal_str, test_func_other) + async_dispatcher_send(hass, signal_str, 2) + async_dispatcher_send(hass, signal_str, 2, "Hello World") + + # Using SignalType will perform type validation on target and args + signal_1: SignalType[int] = SignalType("signal") + assert_type(signal_1, SignalType[int]) + async_dispatcher_connect(hass, signal_1, test_func) + async_dispatcher_connect(hass, signal_1, test_func_other) # type: ignore[arg-type] + async_dispatcher_send(hass, signal_1, 2) + async_dispatcher_send(hass, signal_1, "Hello World") # type: ignore[misc] + + # SignalTypeFormat cannot be used for dispatcher_connect / dispatcher_send + # Call format() on it first to convert it to a SignalType + signal_format: SignalTypeFormat[int] = SignalTypeFormat("signal_") + signal_2 = signal_format.format("2") + assert_type(signal_format, SignalTypeFormat[int]) + assert_type(signal_2, SignalType[int]) + async_dispatcher_connect(hass, signal_format, test_func) # type: ignore[call-overload] + async_dispatcher_connect(hass, signal_2, test_func) + async_dispatcher_send(hass, signal_format, 2) # type: ignore[call-overload] + async_dispatcher_send(hass, signal_2, 2)