Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce overhead to connect dispatcher #105715

Merged
merged 1 commit into from
Dec 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions homeassistant/helpers/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,20 +90,22 @@ def dispatcher_send(hass: HomeAssistant, signal: str, *args: Any) -> None:
hass.loop.call_soon_threadsafe(async_dispatcher_send, hass, signal, *args)


def _format_err(signal: str, target: Callable[..., Any], *args: Any) -> str:
"""Format error message."""
return "Exception in {} when dispatching '{}': {}".format(
# Functions wrapped in partial do not have a __name__
getattr(target, "__name__", None) or str(target),
signal,
args,
)


def _generate_job(
signal: str, target: Callable[..., Any]
) -> HassJob[..., None | Coroutine[Any, Any, None]]:
"""Generate a HassJob for a signal and target."""
return HassJob(
catch_log_exception(
target,
lambda *args: "Exception in {} when dispatching '{}': {}".format(
# Functions wrapped in partial do not have a __name__
getattr(target, "__name__", None) or str(target),
signal,
args,
),
),
catch_log_exception(target, partial(_format_err, signal, target)),
f"dispatcher {signal}",
)

Expand Down
62 changes: 37 additions & 25 deletions homeassistant/util/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,39 @@ def log_exception(format_err: Callable[..., Any], *args: Any) -> None:
logging.getLogger(module_name).error("%s\n%s", friendly_msg, exc_msg)


async def _async_wrapper(
async_func: Callable[..., Coroutine[Any, Any, None]],
format_err: Callable[..., Any],
*args: Any,
) -> None:
"""Catch and log exception."""
try:
await async_func(*args)
except Exception: # pylint: disable=broad-except
log_exception(format_err, *args)


def _sync_wrapper(
func: Callable[..., Any], format_err: Callable[..., Any], *args: Any
) -> None:
"""Catch and log exception."""
try:
func(*args)
except Exception: # pylint: disable=broad-except
log_exception(format_err, *args)


@callback
def _callback_wrapper(
func: Callable[..., Any], format_err: Callable[..., Any], *args: Any
) -> None:
"""Catch and log exception."""
try:
func(*args)
except Exception: # pylint: disable=broad-except
log_exception(format_err, *args)


@overload
def catch_log_exception(
func: Callable[..., Coroutine[Any, Any, Any]], format_err: Callable[..., Any]
Expand Down Expand Up @@ -128,35 +161,14 @@ def catch_log_exception(
while isinstance(check_func, partial):
check_func = check_func.func

wrapper_func: Callable[..., None] | Callable[..., Coroutine[Any, Any, None]]
if asyncio.iscoroutinefunction(check_func):
async_func = cast(Callable[..., Coroutine[Any, Any, None]], func)
return wraps(async_func)(partial(_async_wrapper, async_func, format_err))

@wraps(async_func)
async def async_wrapper(*args: Any) -> None:
"""Catch and log exception."""
try:
await async_func(*args)
except Exception: # pylint: disable=broad-except
log_exception(format_err, *args)

wrapper_func = async_wrapper

else:

@wraps(func)
def wrapper(*args: Any) -> None:
"""Catch and log exception."""
try:
func(*args)
except Exception: # pylint: disable=broad-except
log_exception(format_err, *args)

if is_callback(check_func):
wrapper = callback(wrapper)
if is_callback(check_func):
return wraps(func)(partial(_callback_wrapper, func, format_err))

wrapper_func = wrapper
return wrapper_func
return wraps(func)(partial(_sync_wrapper, func, format_err))


def catch_log_coro_exception(
Expand Down
21 changes: 19 additions & 2 deletions tests/helpers/test_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,15 +144,32 @@ def bad_handler(*args):
# wrap in partial to test message logging.
async_dispatcher_connect(hass, "test", partial(bad_handler))
async_dispatcher_send(hass, "test", "bad")
await hass.async_block_till_done()
await hass.async_block_till_done()

assert (
f"Exception in functools.partial({bad_handler}) when dispatching 'test': ('bad',)"
in caplog.text
)


@pytest.mark.no_fail_on_log_exception
async def test_coro_exception_gets_logged(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test exception raised by signal handler."""

async def bad_async_handler(*args):
"""Record calls."""
raise Exception("This is a bad message in a coro")

# wrap in partial to test message logging.
async_dispatcher_connect(hass, "test", bad_async_handler)
async_dispatcher_send(hass, "test", "bad")
await hass.async_block_till_done()

assert "bad_async_handler" in caplog.text
assert "when dispatching 'test': ('bad',)" in caplog.text


async def test_dispatcher_add_dispatcher(hass: HomeAssistant) -> None:
"""Test adding a dispatcher from a dispatcher."""
calls = []
Expand Down
45 changes: 43 additions & 2 deletions tests/util/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@

import pytest

from homeassistant.core import HomeAssistant, callback, is_callback
from homeassistant.core import (
HomeAssistant,
callback,
is_callback,
is_callback_check_partial,
)
import homeassistant.util.logging as logging_util


Expand Down Expand Up @@ -93,7 +98,7 @@ async def async_meth():
def callback_meth():
pass

assert is_callback(
assert is_callback_check_partial(
logging_util.catch_log_exception(partial(callback_meth), lambda: None)
)

Expand All @@ -104,3 +109,39 @@ def sync_meth():

assert not is_callback(wrapped)
assert not asyncio.iscoroutinefunction(wrapped)


@pytest.mark.no_fail_on_log_exception
async def test_catch_log_exception_catches_and_logs() -> None:
"""Test it is still a callback after wrapping including partial."""
saved_args = []

def save_args(*args):
saved_args.append(args)

async def async_meth():
raise ValueError("failure async")

func = logging_util.catch_log_exception(async_meth, save_args)
await func("failure async passed")

assert saved_args == [("failure async passed",)]
saved_args.clear()

@callback
def callback_meth():
raise ValueError("failure callback")

func = logging_util.catch_log_exception(callback_meth, save_args)
func("failure callback passed")

assert saved_args == [("failure callback passed",)]
saved_args.clear()

def sync_meth():
raise ValueError("failure sync")

func = logging_util.catch_log_exception(sync_meth, save_args)
func("failure sync passed")

assert saved_args == [("failure sync passed",)]