Skip to content

Commit

Permalink
Merge pull request #2887 from bramstroker/feat/optimize-discovery
Browse files Browse the repository at this point in the history
Optimize discovery and improve logging
  • Loading branch information
bramstroker authored Jan 5, 2025
2 parents c0c1ce8 + b102a7b commit 6682a7b
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 27 deletions.
43 changes: 25 additions & 18 deletions custom_components/powercalc/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@
MANUFACTURER_WLED,
CalculationStrategy,
)
from .group_include.filter import CategoryFilter, CompositeFilter, FilterOperator, LambdaFilter, NotFilter, get_filtered_entity_list
from .group_include.filter import CategoryFilter, CompositeFilter, DomainFilter, FilterOperator, LambdaFilter, NotFilter, get_filtered_entity_list
from .helpers import get_or_create_unique_id
from .power_profile.factory import get_power_profile
from .power_profile.library import ModelInfo, ProfileLibrary
from .power_profile.power_profile import DiscoveryBy, PowerProfile
from .power_profile.power_profile import DEVICE_TYPE_DOMAIN, DiscoveryBy, PowerProfile

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -94,7 +94,10 @@ async def start_discovery(self) -> None:

_LOGGER.debug("Start auto discovery")

_LOGGER.debug("Start entity discovery")
await self.perform_discovery(self.get_entities, self.create_entity_source, DiscoveryBy.ENTITY) # type: ignore[arg-type]

_LOGGER.debug("Start device discovery")
await self.perform_discovery(self.get_devices, self.create_device_source, DiscoveryBy.DEVICE) # type: ignore[arg-type]

_LOGGER.debug("Done auto discovery")
Expand Down Expand Up @@ -124,6 +127,7 @@ async def perform_discovery(
) -> None:
"""Generalized discovery procedure for entities and devices."""
for source in await source_provider():
log_identifier = source.entity_id if discovery_type == DiscoveryBy.ENTITY else source.id
try:
model_info = await self.extract_model_info_from_device_info(source)
if not model_info:
Expand All @@ -133,10 +137,7 @@ async def perform_discovery(

power_profiles = await self.discover_entity(source_entity, model_info, discovery_type)
if not power_profiles:
_LOGGER.debug(
"%s: Model not found in library, skipping discovery",
source_entity.entity_id,
)
_LOGGER.debug("%s: Model not found in library, skipping discovery", log_identifier)
continue

unique_id = self.create_unique_id(
Expand All @@ -146,18 +147,16 @@ async def perform_discovery(
)

if self._is_already_discovered(source_entity, unique_id):
_LOGGER.debug(
"%s: Already setup with discovery, skipping",
source_entity.entity_id,
)
_LOGGER.debug("%s: Already setup with discovery, skipping", log_identifier)
continue

self._init_entity_discovery(model_info, unique_id, source_entity, power_profiles, {})
except Exception: # noqa: BLE001
self._init_entity_discovery(model_info, unique_id, source_entity, log_identifier, power_profiles, {})
except Exception as err: # noqa: BLE001
_LOGGER.error(
"Error during %s discovery: %s",
"%s: Error during %s discovery: %s",
log_identifier,
discovery_type,
source,
err,
)

async def discover_entity(
Expand Down Expand Up @@ -238,6 +237,7 @@ async def init_wled_flow(self, model_info: ModelInfo, source_entity: SourceEntit
model_info,
unique_id,
source_entity,
source_entity.entity_id,
power_profiles=None,
extra_discovery_data={
CONF_MODE: CalculationStrategy.WLED,
Expand Down Expand Up @@ -278,6 +278,7 @@ def _check_already_configured(entity: er.RegistryEntry) -> bool:
LambdaFilter(lambda entity: entity.device_id is None),
LambdaFilter(lambda entity: entity.platform == "mqtt" and "segment" in entity.entity_id),
LambdaFilter(lambda entity: entity.platform == "powercalc"),
NotFilter(DomainFilter(DEVICE_TYPE_DOMAIN.values())),
],
FilterOperator.OR,
)
Expand All @@ -287,19 +288,24 @@ async def get_devices(self) -> list:
"""Fetch device entries."""
return list(dr.async_get(self.hass).devices.values())

async def extract_model_info_from_device_info(self, entry: er.RegistryEntry | dr.DeviceEntry | None) -> ModelInfo | None:
async def extract_model_info_from_device_info(
self,
entry: er.RegistryEntry | dr.DeviceEntry | None,
) -> ModelInfo | None:
"""Try to auto discover manufacturer and model from the known device information."""
if not entry:
return None

log_identifier = entry.entity_id if isinstance(entry, er.RegistryEntry) else entry.id

if isinstance(entry, er.RegistryEntry):
model_info = await self.get_model_information_from_entity(entry)
else:
model_info = await self.get_model_information_from_device(entry)
if not model_info:
_LOGGER.debug(
"%s: Cannot autodiscover model, manufacturer or model unknown from device registry",
entry.id,
log_identifier,
)
return None

Expand All @@ -315,7 +321,7 @@ async def extract_model_info_from_device_info(self, entry: er.RegistryEntry | dr

_LOGGER.debug(
"%s: Found model information on device (manufacturer=%s, model=%s, model_id=%s)",
entry.id,
log_identifier,
model_info.manufacturer,
model_info.model,
model_info.model_id,
Expand Down Expand Up @@ -354,6 +360,7 @@ def _init_entity_discovery(
model_info: ModelInfo,
unique_id: str,
source_entity: SourceEntity,
log_identifier: str,
power_profiles: list[PowerProfile] | None,
extra_discovery_data: dict | None,
) -> None:
Expand Down Expand Up @@ -384,7 +391,7 @@ def _init_entity_discovery(
if source_entity.entity_id != DUMMY_ENTITY_ID:
self.initialized_flows.add(source_entity.entity_id)

_LOGGER.debug("%s: Initiating discovery flow, unique_id=%s", source_entity.entity_id, unique_id)
_LOGGER.debug("%s: Initiating discovery flow, unique_id=%s", log_identifier, unique_id)

discovery_flow.async_create_flow(
self.hass,
Expand Down
8 changes: 4 additions & 4 deletions custom_components/powercalc/group_include/filter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import re
from collections.abc import Callable
from collections.abc import Callable, Iterable
from enum import StrEnum
from typing import Protocol, cast

Expand Down Expand Up @@ -109,11 +109,11 @@ def is_valid(self, entity: RegistryEntry) -> bool:


class DomainFilter(EntityFilter):
def __init__(self, domain: str | list) -> None:
self.domain = domain
def __init__(self, domain: str | Iterable[str]) -> None:
self.domain = domain if isinstance(domain, str) else set(domain)

def is_valid(self, entity: RegistryEntry) -> bool:
if isinstance(self.domain, list):
if isinstance(self.domain, set):
return entity.domain in self.domain
return entity.domain == self.domain

Expand Down
24 changes: 19 additions & 5 deletions custom_components/powercalc/group_include/include.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
DOMAIN,
)
from custom_components.powercalc.discovery import get_power_profile_by_source_entity
from custom_components.powercalc.power_profile.power_profile import DEVICE_TYPE_DOMAIN
from custom_components.powercalc.sensors.energy import RealEnergySensor
from custom_components.powercalc.sensors.power import RealPowerSensor

from .filter import EntityFilter, NullFilter, get_filtered_entity_list
from .filter import CompositeFilter, DomainFilter, EntityFilter, LambdaFilter, get_filtered_entity_list

_LOGGER = logging.getLogger(__name__)

Expand All @@ -32,12 +33,12 @@ async def find_entities(

resolved_entities: list[Entity] = []
discoverable_entities: list[str] = []
source_entities = await get_filtered_entity_list(hass, entity_filter or NullFilter())
source_entities = await get_filtered_entity_list(hass, _build_filter(entity_filter))
if _LOGGER.isEnabledFor(logging.DEBUG): # pragma: no cover
_LOGGER.debug("Found possible include entities: %s", [entity.entity_id for entity in source_entities])

source_entity_powercalc_entity_map: dict[str, list] = domain_data[DATA_CONFIGURED_ENTITIES]
powercalc_entities: dict[str, Entity] = domain_data[DATA_ENTITIES]
source_entity_powercalc_entity_map: dict[str, list] = domain_data.get(DATA_CONFIGURED_ENTITIES, {})
powercalc_entities: dict[str, Entity] = domain_data.get(DATA_ENTITIES, {})
for source_entity in source_entities:
if source_entity.entity_id in source_entity_powercalc_entity_map:
resolved_entities.extend(source_entity_powercalc_entity_map[source_entity.entity_id])
Expand All @@ -51,7 +52,7 @@ async def find_entities(
device_class = source_entity.device_class or source_entity.original_device_class
if device_class == SensorDeviceClass.POWER:
resolved_entities.append(RealPowerSensor(source_entity.entity_id, source_entity.unit_of_measurement))
elif device_class == SensorDeviceClass.ENERGY and source_entity.platform != "utility_meter":
elif device_class == SensorDeviceClass.ENERGY:
resolved_entities.append(RealEnergySensor(source_entity.entity_id))

power_profile = await get_power_profile_by_source_entity(
Expand All @@ -62,3 +63,16 @@ async def find_entities(
discoverable_entities.append(source_entity.entity_id)

return resolved_entities, discoverable_entities


def _build_filter(entity_filter: EntityFilter | None) -> EntityFilter:
base_filter = CompositeFilter(
[
DomainFilter(DEVICE_TYPE_DOMAIN.values()),
LambdaFilter(lambda entity: entity.platform != "utility_meter"),
],
)
if not entity_filter:
return base_filter

return CompositeFilter([base_filter, entity_filter])
45 changes: 45 additions & 0 deletions tests/group_include/test_include.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
ENTRY_DATA_POWER_ENTITY,
SensorType,
)
from custom_components.powercalc.group_include.include import find_entities
from custom_components.test.light import MockLight
from tests.common import (
create_discoverable_light,
Expand Down Expand Up @@ -1287,6 +1288,50 @@ async def test_include_logs_warning(hass: HomeAssistant, caplog: pytest.LogCaptu
assert "Could not resolve any entities in group" in caplog.text


async def test_irrelevant_entity_domains_are_skipped(hass: HomeAssistant, caplog: pytest.LogCaptureFixture) -> None:
caplog.set_level(logging.DEBUG)

mock_device_registry(
hass,
{
"device-a": DeviceEntry(
id="device-a",
manufacturer="Signify",
model="LCT012",
),
},
)
mock_registry(
hass,
{
"light.test": RegistryEntry(
entity_id="light.test",
unique_id="2222",
platform="hue",
device_id="device-a",
),
"scene.test": RegistryEntry(
entity_id="scene.test",
unique_id="3333",
platform="hue",
device_id="device-a",
),
"event.test": RegistryEntry(
entity_id="event.test",
unique_id="4444",
platform="hue",
device_id="device-a",
),
},
)
_, discoverable_entities = await find_entities(hass)
assert len(discoverable_entities) == 1
assert "light.test" in discoverable_entities

assert "scene.test" not in caplog.text
assert "event.test" not in caplog.text


def _create_powercalc_config_entry(
hass: HomeAssistant,
source_entity_id: str,
Expand Down
53 changes: 53 additions & 0 deletions tests/test_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,3 +775,56 @@ async def test_powercalc_sensors_are_ignored_for_discovery(

mock_calls = mock_flow_init.mock_calls
assert len(mock_calls) == 1


@pytest.mark.parametrize(
"entity_entries,expected_entities",
[
(
[
RegistryEntry(
entity_id="switch.test",
unique_id="1111",
platform="hue",
device_id="hue-device",
),
],
["switch.test"],
),
# Entity domains that are not supported must be ignored
(
[
RegistryEntry(
entity_id="scene.test",
unique_id="1111",
platform="hue",
device_id="hue-device",
),
RegistryEntry(
entity_id="event.test",
unique_id="2222",
platform="hue",
device_id="hue-device",
),
],
[],
),
# Powercalc sensors should not be considered for discovery
(
[
RegistryEntry(
entity_id="sensor.test",
unique_id="1111",
platform="powercalc",
device_id="some-device",
),
],
[],
),
],
)
async def test_get_entities(hass: HomeAssistant, entity_entries: list[RegistryEntry], expected_entities: list[str]) -> None:
mock_registry(hass, {entity_entry.entity_id: entity_entry for entity_entry in entity_entries})
discovery_manager = DiscoveryManager(hass, {})
entity_ids = [entity.entity_id for entity in await discovery_manager.get_entities()]
assert entity_ids == expected_entities

0 comments on commit 6682a7b

Please sign in to comment.