Skip to content

Commit

Permalink
Reduce overhead to connect dispatcher (#105715)
Browse files Browse the repository at this point in the history
Reduce overhead connect dispatcher

- We tend to have 1000s (or 10000s) of connected dispatchers which
  makes these prime targets to reduce overhead/memory

- Instead of creating new functions to wrap log exceptions each time
  use partials which reuses the function body and only create new
  arguments

Previous optimizations #103307 #93602
  • Loading branch information
bdraco authored Dec 16, 2023
1 parent 1271f16 commit 47f8e08
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 38 deletions.
20 changes: 11 additions & 9 deletions homeassistant/helpers/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,20 +90,22 @@ def dispatcher_send(hass: HomeAssistant, signal: str, *args: Any) -> None:
hass.loop.call_soon_threadsafe(async_dispatcher_send, hass, signal, *args)


def _format_err(signal: str, target: Callable[..., Any], *args: Any) -> str:
"""Format error message."""
return "Exception in {} when dispatching '{}': {}".format(
# Functions wrapped in partial do not have a __name__
getattr(target, "__name__", None) or str(target),
signal,
args,
)


def _generate_job(
signal: str, target: Callable[..., Any]
) -> HassJob[..., None | Coroutine[Any, Any, None]]:
"""Generate a HassJob for a signal and target."""
return HassJob(
catch_log_exception(
target,
lambda *args: "Exception in {} when dispatching '{}': {}".format(
# Functions wrapped in partial do not have a __name__
getattr(target, "__name__", None) or str(target),
signal,
args,
),
),
catch_log_exception(target, partial(_format_err, signal, target)),
f"dispatcher {signal}",
)

Expand Down
62 changes: 37 additions & 25 deletions homeassistant/util/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,39 @@ def log_exception(format_err: Callable[..., Any], *args: Any) -> None:
logging.getLogger(module_name).error("%s\n%s", friendly_msg, exc_msg)


async def _async_wrapper(
async_func: Callable[..., Coroutine[Any, Any, None]],
format_err: Callable[..., Any],
*args: Any,
) -> None:
"""Catch and log exception."""
try:
await async_func(*args)
except Exception: # pylint: disable=broad-except
log_exception(format_err, *args)


def _sync_wrapper(
func: Callable[..., Any], format_err: Callable[..., Any], *args: Any
) -> None:
"""Catch and log exception."""
try:
func(*args)
except Exception: # pylint: disable=broad-except
log_exception(format_err, *args)


@callback
def _callback_wrapper(
func: Callable[..., Any], format_err: Callable[..., Any], *args: Any
) -> None:
"""Catch and log exception."""
try:
func(*args)
except Exception: # pylint: disable=broad-except
log_exception(format_err, *args)


@overload
def catch_log_exception(
func: Callable[..., Coroutine[Any, Any, Any]], format_err: Callable[..., Any]
Expand Down Expand Up @@ -128,35 +161,14 @@ def catch_log_exception(
while isinstance(check_func, partial):
check_func = check_func.func

wrapper_func: Callable[..., None] | Callable[..., Coroutine[Any, Any, None]]
if asyncio.iscoroutinefunction(check_func):
async_func = cast(Callable[..., Coroutine[Any, Any, None]], func)
return wraps(async_func)(partial(_async_wrapper, async_func, format_err))

@wraps(async_func)
async def async_wrapper(*args: Any) -> None:
"""Catch and log exception."""
try:
await async_func(*args)
except Exception: # pylint: disable=broad-except
log_exception(format_err, *args)

wrapper_func = async_wrapper

else:

@wraps(func)
def wrapper(*args: Any) -> None:
"""Catch and log exception."""
try:
func(*args)
except Exception: # pylint: disable=broad-except
log_exception(format_err, *args)

if is_callback(check_func):
wrapper = callback(wrapper)
if is_callback(check_func):
return wraps(func)(partial(_callback_wrapper, func, format_err))

wrapper_func = wrapper
return wrapper_func
return wraps(func)(partial(_sync_wrapper, func, format_err))


def catch_log_coro_exception(
Expand Down
21 changes: 19 additions & 2 deletions tests/helpers/test_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,15 +144,32 @@ def bad_handler(*args):
# wrap in partial to test message logging.
async_dispatcher_connect(hass, "test", partial(bad_handler))
async_dispatcher_send(hass, "test", "bad")
await hass.async_block_till_done()
await hass.async_block_till_done()

assert (
f"Exception in functools.partial({bad_handler}) when dispatching 'test': ('bad',)"
in caplog.text
)


@pytest.mark.no_fail_on_log_exception
async def test_coro_exception_gets_logged(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test exception raised by signal handler."""

async def bad_async_handler(*args):
"""Record calls."""
raise Exception("This is a bad message in a coro")

# wrap in partial to test message logging.
async_dispatcher_connect(hass, "test", bad_async_handler)
async_dispatcher_send(hass, "test", "bad")
await hass.async_block_till_done()

assert "bad_async_handler" in caplog.text
assert "when dispatching 'test': ('bad',)" in caplog.text


async def test_dispatcher_add_dispatcher(hass: HomeAssistant) -> None:
"""Test adding a dispatcher from a dispatcher."""
calls = []
Expand Down
45 changes: 43 additions & 2 deletions tests/util/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@

import pytest

from homeassistant.core import HomeAssistant, callback, is_callback
from homeassistant.core import (
HomeAssistant,
callback,
is_callback,
is_callback_check_partial,
)
import homeassistant.util.logging as logging_util


Expand Down Expand Up @@ -93,7 +98,7 @@ async def async_meth():
def callback_meth():
pass

assert is_callback(
assert is_callback_check_partial(
logging_util.catch_log_exception(partial(callback_meth), lambda: None)
)

Expand All @@ -104,3 +109,39 @@ def sync_meth():

assert not is_callback(wrapped)
assert not asyncio.iscoroutinefunction(wrapped)


@pytest.mark.no_fail_on_log_exception
async def test_catch_log_exception_catches_and_logs() -> None:
"""Test it is still a callback after wrapping including partial."""
saved_args = []

def save_args(*args):
saved_args.append(args)

async def async_meth():
raise ValueError("failure async")

func = logging_util.catch_log_exception(async_meth, save_args)
await func("failure async passed")

assert saved_args == [("failure async passed",)]
saved_args.clear()

@callback
def callback_meth():
raise ValueError("failure callback")

func = logging_util.catch_log_exception(callback_meth, save_args)
func("failure callback passed")

assert saved_args == [("failure callback passed",)]
saved_args.clear()

def sync_meth():
raise ValueError("failure sync")

func = logging_util.catch_log_exception(sync_meth, save_args)
func("failure sync passed")

assert saved_args == [("failure sync passed",)]

0 comments on commit 47f8e08

Please sign in to comment.