diff --git a/homeassistant/helpers/dispatcher.py b/homeassistant/helpers/dispatcher.py index e416d939914b6..07112226ecfd8 100644 --- a/homeassistant/helpers/dispatcher.py +++ b/homeassistant/helpers/dispatcher.py @@ -90,20 +90,22 @@ def dispatcher_send(hass: HomeAssistant, signal: str, *args: Any) -> None: hass.loop.call_soon_threadsafe(async_dispatcher_send, hass, signal, *args) +def _format_err(signal: str, target: Callable[..., Any], *args: Any) -> str: + """Format error message.""" + return "Exception in {} when dispatching '{}': {}".format( + # Functions wrapped in partial do not have a __name__ + getattr(target, "__name__", None) or str(target), + signal, + args, + ) + + def _generate_job( signal: str, target: Callable[..., Any] ) -> HassJob[..., None | Coroutine[Any, Any, None]]: """Generate a HassJob for a signal and target.""" return HassJob( - catch_log_exception( - target, - lambda *args: "Exception in {} when dispatching '{}': {}".format( - # Functions wrapped in partial do not have a __name__ - getattr(target, "__name__", None) or str(target), - signal, - args, - ), - ), + catch_log_exception(target, partial(_format_err, signal, target)), f"dispatcher {signal}", ) diff --git a/homeassistant/util/logging.py b/homeassistant/util/logging.py index 1328e8ded6014..07ff413a01671 100644 --- a/homeassistant/util/logging.py +++ b/homeassistant/util/logging.py @@ -101,6 +101,39 @@ def log_exception(format_err: Callable[..., Any], *args: Any) -> None: logging.getLogger(module_name).error("%s\n%s", friendly_msg, exc_msg) +async def _async_wrapper( + async_func: Callable[..., Coroutine[Any, Any, None]], + format_err: Callable[..., Any], + *args: Any, +) -> None: + """Catch and log exception.""" + try: + await async_func(*args) + except Exception: # pylint: disable=broad-except + log_exception(format_err, *args) + + +def _sync_wrapper( + func: Callable[..., Any], format_err: Callable[..., Any], *args: Any +) -> None: + """Catch and log exception.""" + try: + func(*args) + except Exception: # pylint: disable=broad-except + log_exception(format_err, *args) + + +@callback +def _callback_wrapper( + func: Callable[..., Any], format_err: Callable[..., Any], *args: Any +) -> None: + """Catch and log exception.""" + try: + func(*args) + except Exception: # pylint: disable=broad-except + log_exception(format_err, *args) + + @overload def catch_log_exception( func: Callable[..., Coroutine[Any, Any, Any]], format_err: Callable[..., Any] @@ -128,35 +161,14 @@ def catch_log_exception( while isinstance(check_func, partial): check_func = check_func.func - wrapper_func: Callable[..., None] | Callable[..., Coroutine[Any, Any, None]] if asyncio.iscoroutinefunction(check_func): async_func = cast(Callable[..., Coroutine[Any, Any, None]], func) + return wraps(async_func)(partial(_async_wrapper, async_func, format_err)) - @wraps(async_func) - async def async_wrapper(*args: Any) -> None: - """Catch and log exception.""" - try: - await async_func(*args) - except Exception: # pylint: disable=broad-except - log_exception(format_err, *args) - - wrapper_func = async_wrapper - - else: - - @wraps(func) - def wrapper(*args: Any) -> None: - """Catch and log exception.""" - try: - func(*args) - except Exception: # pylint: disable=broad-except - log_exception(format_err, *args) - - if is_callback(check_func): - wrapper = callback(wrapper) + if is_callback(check_func): + return wraps(func)(partial(_callback_wrapper, func, format_err)) - wrapper_func = wrapper - return wrapper_func + return wraps(func)(partial(_sync_wrapper, func, format_err)) def catch_log_coro_exception( diff --git a/tests/helpers/test_dispatcher.py b/tests/helpers/test_dispatcher.py index a251b20b0f41a..89d23fb4533cf 100644 --- a/tests/helpers/test_dispatcher.py +++ b/tests/helpers/test_dispatcher.py @@ -144,8 +144,6 @@ def bad_handler(*args): # wrap in partial to test message logging. async_dispatcher_connect(hass, "test", partial(bad_handler)) async_dispatcher_send(hass, "test", "bad") - await hass.async_block_till_done() - await hass.async_block_till_done() assert ( f"Exception in functools.partial({bad_handler}) when dispatching 'test': ('bad',)" @@ -153,6 +151,25 @@ def bad_handler(*args): ) +@pytest.mark.no_fail_on_log_exception +async def test_coro_exception_gets_logged( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture +) -> None: + """Test exception raised by signal handler.""" + + async def bad_async_handler(*args): + """Record calls.""" + raise Exception("This is a bad message in a coro") + + # wrap in partial to test message logging. + async_dispatcher_connect(hass, "test", bad_async_handler) + async_dispatcher_send(hass, "test", "bad") + await hass.async_block_till_done() + + assert "bad_async_handler" in caplog.text + assert "when dispatching 'test': ('bad',)" in caplog.text + + async def test_dispatcher_add_dispatcher(hass: HomeAssistant) -> None: """Test adding a dispatcher from a dispatcher.""" calls = [] diff --git a/tests/util/test_logging.py b/tests/util/test_logging.py index a08311cca4f3a..350baa9d4c264 100644 --- a/tests/util/test_logging.py +++ b/tests/util/test_logging.py @@ -7,7 +7,12 @@ import pytest -from homeassistant.core import HomeAssistant, callback, is_callback +from homeassistant.core import ( + HomeAssistant, + callback, + is_callback, + is_callback_check_partial, +) import homeassistant.util.logging as logging_util @@ -93,7 +98,7 @@ async def async_meth(): def callback_meth(): pass - assert is_callback( + assert is_callback_check_partial( logging_util.catch_log_exception(partial(callback_meth), lambda: None) ) @@ -104,3 +109,39 @@ def sync_meth(): assert not is_callback(wrapped) assert not asyncio.iscoroutinefunction(wrapped) + + +@pytest.mark.no_fail_on_log_exception +async def test_catch_log_exception_catches_and_logs() -> None: + """Test it is still a callback after wrapping including partial.""" + saved_args = [] + + def save_args(*args): + saved_args.append(args) + + async def async_meth(): + raise ValueError("failure async") + + func = logging_util.catch_log_exception(async_meth, save_args) + await func("failure async passed") + + assert saved_args == [("failure async passed",)] + saved_args.clear() + + @callback + def callback_meth(): + raise ValueError("failure callback") + + func = logging_util.catch_log_exception(callback_meth, save_args) + func("failure callback passed") + + assert saved_args == [("failure callback passed",)] + saved_args.clear() + + def sync_meth(): + raise ValueError("failure sync") + + func = logging_util.catch_log_exception(sync_meth, save_args) + func("failure sync passed") + + assert saved_args == [("failure sync passed",)]