Skip to content

Commit

Permalink
Fix performance regression with SignalType (#117920)
Browse files Browse the repository at this point in the history
  • Loading branch information
cdce8p authored May 22, 2024
1 parent 5229f0d commit 5c9c71b
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 28 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ repos:
name: mypy
entry: script/run-in-env.sh mypy
language: script
types: [python]
types_or: [python, pyi]
require_serial: true
files: ^(homeassistant|pylint)/.+\.(py|pyi)$
- id: pylint
name: pylint
entry: script/run-in-env.sh pylint -j 0 --ignore-missing-annotations=y
language: script
types: [python]
files: ^homeassistant/.+\.py$
types_or: [python, pyi]
files: ^homeassistant/.+\.(py|pyi)$
- id: gen_requirements_all
name: gen_requirements_all
entry: script/run-in-env.sh python3 -m script.gen_requirements_all
Expand Down
30 changes: 5 additions & 25 deletions homeassistant/util/signal_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,20 @@

from __future__ import annotations

from dataclasses import dataclass
from typing import Any


@dataclass(frozen=True)
class _SignalTypeBase[*_Ts]:
class _SignalTypeBase[*_Ts](str):
"""Generic base class for SignalType."""

name: str

def __hash__(self) -> int:
"""Return hash of name."""

return hash(self.name)
__slots__ = ()

def __eq__(self, other: object) -> bool:
"""Check equality for dict keys to be compatible with str."""

if isinstance(other, str):
return self.name == other
if isinstance(other, SignalType):
return self.name == other.name
return False


@dataclass(frozen=True, eq=False)
class SignalType[*_Ts](_SignalTypeBase[*_Ts]):
"""Generic string class for signal to improve typing."""

__slots__ = ()


@dataclass(frozen=True, eq=False)
class SignalTypeFormat[*_Ts](_SignalTypeBase[*_Ts]):
"""Generic string class for signal. Requires call to 'format' before use."""

def format(self, *args: Any, **kwargs: Any) -> SignalType[*_Ts]:
"""Format name and return new SignalType instance."""
return SignalType(self.name.format(*args, **kwargs))
__slots__ = ()
69 changes: 69 additions & 0 deletions homeassistant/util/signal_type.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Stub file for signal_type. Provide overload for type checking."""
# ruff: noqa: PYI021 # Allow docstring

from typing import Any, assert_type

__all__ = [
"SignalType",
"SignalTypeFormat",
]

class _SignalTypeBase[*_Ts]:
"""Custom base class for SignalType. At runtime delegate to str.
For type checkers pretend to be its own separate class.
"""

def __init__(self, value: str, /) -> None: ...
def __hash__(self) -> int: ...
def __eq__(self, other: object, /) -> bool: ...

class SignalType[*_Ts](_SignalTypeBase[*_Ts]):
"""Generic string class for signal to improve typing."""

class SignalTypeFormat[*_Ts](_SignalTypeBase[*_Ts]):
"""Generic string class for signal. Requires call to 'format' before use."""

def format(self, *args: Any, **kwargs: Any) -> SignalType[*_Ts]: ...

def _test_signal_type_typing() -> None: # noqa: PYI048
"""Test SignalType and dispatcher overloads work as intended.
This is tested during the mypy run. Do not move it to 'tests'!
"""
# pylint: disable=import-outside-toplevel
from homeassistant.core import HomeAssistant
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
async_dispatcher_send,
)

hass: HomeAssistant
def test_func(a: int) -> None: ...
def test_func_other(a: int, b: str) -> None: ...

# No type validation for str signals
signal_str = "signal"
async_dispatcher_connect(hass, signal_str, test_func)
async_dispatcher_connect(hass, signal_str, test_func_other)
async_dispatcher_send(hass, signal_str, 2)
async_dispatcher_send(hass, signal_str, 2, "Hello World")

# Using SignalType will perform type validation on target and args
signal_1: SignalType[int] = SignalType("signal")
assert_type(signal_1, SignalType[int])
async_dispatcher_connect(hass, signal_1, test_func)
async_dispatcher_connect(hass, signal_1, test_func_other) # type: ignore[arg-type]
async_dispatcher_send(hass, signal_1, 2)
async_dispatcher_send(hass, signal_1, "Hello World") # type: ignore[misc]

# SignalTypeFormat cannot be used for dispatcher_connect / dispatcher_send
# Call format() on it first to convert it to a SignalType
signal_format: SignalTypeFormat[int] = SignalTypeFormat("signal_")
signal_2 = signal_format.format("2")
assert_type(signal_format, SignalTypeFormat[int])
assert_type(signal_2, SignalType[int])
async_dispatcher_connect(hass, signal_format, test_func) # type: ignore[call-overload]
async_dispatcher_connect(hass, signal_2, test_func)
async_dispatcher_send(hass, signal_format, 2) # type: ignore[call-overload]
async_dispatcher_send(hass, signal_2, 2)

0 comments on commit 5c9c71b

Please sign in to comment.