Skip to content

Commit

Permalink
Move all include logic to filters
Browse files Browse the repository at this point in the history
  • Loading branch information
bramstroker committed Dec 9, 2023
1 parent a6bbb24 commit d283369
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 95 deletions.
108 changes: 95 additions & 13 deletions custom_components/powercalc/group_include/filter.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
from __future__ import annotations

import re
from typing import Protocol
from typing import Protocol, cast

from awesomeversion.awesomeversion import AwesomeVersion
from homeassistant.components.group import DOMAIN as GROUP_DOMAIN
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
from homeassistant.const import ATTR_ENTITY_ID
from homeassistant.const import __version__ as HA_VERSION # noqa
from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, split_entity_id
from homeassistant.helpers import area_registry, device_registry, entity_registry
from homeassistant.helpers.area_registry import AreaEntry
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.template import Template

from custom_components.powercalc.errors import SensorConfigurationError

if AwesomeVersion(HA_VERSION) >= AwesomeVersion("2023.8.0"):
from enum import StrEnum
else:
Expand All @@ -27,15 +35,7 @@ def create_filter(filter_config: dict) -> IncludeEntityFilter:
filters: list[IncludeEntityFilter] = []
if CONF_DOMAIN in filter_config:
domain_config = filter_config.get(CONF_DOMAIN)
if isinstance(domain_config, list):
filters.append(
CompositeFilter(
[DomainFilter(domain) for domain in domain_config],
FilterOperator.OR,
),
)
elif isinstance(domain_config, str):
filters.append(DomainFilter(domain_config))
filters.append(DomainFilter(domain_config))

return CompositeFilter(filters, FilterOperator.AND)

Expand All @@ -46,12 +46,75 @@ def is_valid(self, entity: RegistryEntry) -> bool:


class DomainFilter(IncludeEntityFilter):
def __init__(self, domain: str) -> None:
self.domain: str = domain
def __init__(self, domain: str | list) -> None:
self.domain = domain

def is_valid(self, entity: RegistryEntry) -> bool:
if isinstance(self.domain, list):
return entity.domain in self.domain
return entity.domain == self.domain

class GroupFilter(IncludeEntityFilter):
def __init__(self, hass: HomeAssistant, group_id: str) -> None:
domain = split_entity_id(group_id)[0]
self.filter = LightGroupFilter(hass, group_id) if domain == LIGHT_DOMAIN else StandardGroupFilter(hass, group_id)

def is_valid(self, entity: RegistryEntry) -> bool:
return self.filter.is_valid(entity)

class StandardGroupFilter(IncludeEntityFilter):
def __init__(self, hass: HomeAssistant, group_id: str) -> None:
entity_reg = entity_registry.async_get(hass)
entity_reg.async_get(group_id)
group_state = hass.states.get(group_id)
if group_state is None:
raise SensorConfigurationError(f"Group state {group_id} not found")
self.entity_ids = group_state.attributes.get(ATTR_ENTITY_ID) or []

def is_valid(self, entity: RegistryEntry) -> bool:
return entity.entity_id in self.entity_ids


class LightGroupFilter(IncludeEntityFilter):
def __init__(self, hass: HomeAssistant, group_id: str) -> None:
light_component = cast(EntityComponent, hass.data.get(LIGHT_DOMAIN))
light_group = next(
filter(lambda entity: entity.entity_id == group_id, light_component.entities),
None,
)
if light_group is None or light_group.platform.platform_name != GROUP_DOMAIN:
raise SensorConfigurationError(f"Light group {group_id} not found")

self.entity_ids = self.find_all_entity_ids_recursively(hass, group_id, [])

def is_valid(self, entity: RegistryEntry) -> bool:
return entity.entity_id in self.entity_ids

def find_all_entity_ids_recursively(self, hass: HomeAssistant, group_entity_id: str, all_entity_ids: list[str]) -> list[str]:
entity_reg = entity_registry.async_get(hass)
light_component = cast(EntityComponent, hass.data.get(LIGHT_DOMAIN))
light_group = next(
filter(lambda entity: entity.entity_id == group_entity_id, light_component.entities),
None,
)

entity_ids = light_group.extra_state_attributes.get(ATTR_ENTITY_ID)
for entity_id in entity_ids:
registry_entry = entity_reg.async_get(entity_id)
if registry_entry is None:
continue

if registry_entry.platform == GROUP_DOMAIN:
self.find_all_entity_ids_recursively(
hass,
registry_entry.entity_id,
all_entity_ids,
)

all_entity_ids.append(entity_id)

return all_entity_ids


class NullFilter(IncludeEntityFilter):
def is_valid(self, entity: RegistryEntry) -> bool:
Expand Down Expand Up @@ -80,6 +143,25 @@ def is_valid(self, entity: RegistryEntry) -> bool:
return entity.entity_id in self.entity_ids


class AreaFilter(IncludeEntityFilter):
def __init__(self, hass: HomeAssistant, area_id_or_name: str) -> None:
area_reg = area_registry.async_get(hass)
area = area_reg.async_get_area(area_id_or_name)
if area is None:
area = area_reg.async_get_area_by_name(str(area_id_or_name))

if area is None or area.id is None:
raise SensorConfigurationError(
f"No area with id or name '{area_id_or_name}' found in your HA instance",
)

self.area: AreaEntry = area

device_reg = device_registry.async_get(hass)
self.area_devices = [device.id for device in device_registry.async_entries_for_area(device_reg, area.id)]

def is_valid(self, entity: RegistryEntry) -> bool:
return entity.area_id == self.area.id or entity.device_id in self.area_devices

class CompositeFilter(IncludeEntityFilter):
def __init__(
Expand Down
96 changes: 30 additions & 66 deletions custom_components/powercalc/group_include/include.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from homeassistant.components.sensor import SensorDeviceClass
from homeassistant.const import ATTR_ENTITY_ID, CONF_DOMAIN, CONF_ENTITY_ID
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import area_registry, device_registry, entity_registry
from homeassistant.helpers import entity_registry
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.entity_platform import split_entity_id
Expand All @@ -27,7 +27,17 @@
from custom_components.powercalc.sensors.energy import RealEnergySensor
from custom_components.powercalc.sensors.power import RealPowerSensor

from .filter import CompositeFilter, FilterOperator, TemplateFilter, WildcardFilter, create_filter
from .filter import (
AreaFilter,
CompositeFilter,
DomainFilter,
FilterOperator,
GroupFilter,
IncludeEntityFilter,
TemplateFilter,
WildcardFilter,
create_filter,
)

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -88,44 +98,35 @@ def resolve_include_source_entities(
hass: HomeAssistant,
include_config: dict,
) -> dict[str, entity_registry.RegistryEntry | None]:
entities: dict[str, entity_registry.RegistryEntry | None] = {}
entity_reg = entity_registry.async_get(hass)

# Include entities from a certain area
if CONF_AREA in include_config:
area_id = str(include_config.get(CONF_AREA))
_LOGGER.debug("Including entities from area: %s", area_id)
entities = entities | resolve_area_entities(hass, area_id)

# Include entities from a certain group
if CONF_GROUP in include_config:
group_id = str(include_config.get(CONF_GROUP))
_LOGGER.debug("Including entities from group: %s", group_id)
entities = entities | resolve_include_groups(hass, group_id)

base_filters = []
if CONF_GROUP in include_config:
base_filters.append(GroupFilter(hass, include_config.get(CONF_GROUP)))
if CONF_WILDCARD in include_config:
base_filters.append(WildcardFilter(include_config.get(CONF_WILDCARD)))
if CONF_DOMAIN in include_config:
base_filters.append(WildcardFilter(include_config.get(CONF_DOMAIN)))
base_filters.append(DomainFilter(include_config.get(CONF_DOMAIN)))
if CONF_TEMPLATE in include_config:
base_filters.append(TemplateFilter(hass, include_config.get(CONF_TEMPLATE)))
if CONF_AREA in include_config:
base_filters.append(AreaFilter(hass, include_config.get(CONF_AREA)))

entity_filter: IncludeEntityFilter | None = None
if base_filters:
base_filter = CompositeFilter(base_filters, FilterOperator.OR)
entities = entities | {
entry.entity_id: entry for entry in entity_reg.entities.values() if base_filter.is_valid(entry)
}
entity_filter = CompositeFilter(base_filters, FilterOperator.OR)

if CONF_FILTER in include_config:
entity_filter = create_filter(include_config.get(CONF_FILTER)) # type: ignore
entities = {
entity_id: entity
for entity_id, entity in entities.items()
if entity is not None and entity_filter.is_valid(entity)
}
if entity_filter:
entity_filter = CompositeFilter(
[entity_filter, create_filter(include_config.get(CONF_FILTER))],
FilterOperator.AND,
)
else:
entity_filter = create_filter(include_config.get(CONF_FILTER)) # type: ignore

return entities
entity_reg = entity_registry.async_get(hass)
return {
entry.entity_id: entry for entry in entity_reg.entities.values() if entity_filter.is_valid(entry)
}


@callback
Expand Down Expand Up @@ -185,40 +186,3 @@ def resolve_light_group_entities(

return resolved_entities


@callback
def resolve_area_entities(
hass: HomeAssistant,
area_id_or_name: str,
) -> dict[str, entity_registry.RegistryEntry]:
"""Get a listing of al entities in a given area."""
area_reg = area_registry.async_get(hass)
area = area_reg.async_get_area(area_id_or_name)
if area is None:
area = area_reg.async_get_area_by_name(str(area_id_or_name))

if area is None or area.id is None:
raise SensorConfigurationError(
f"No area with id or name '{area_id_or_name}' found in your HA instance",
)

area_id = area.id
entity_reg = entity_registry.async_get(hass)

entities = entity_registry.async_entries_for_area(entity_reg, area_id)

device_reg = device_registry.async_get(hass)
# We also need to add entities tied to a device in the area that don't themselves
# have an area specified since they inherit the area from the device.
entities.extend(
[
entity
for device in device_registry.async_entries_for_area(device_reg, area_id)
for entity in entity_registry.async_entries_for_device(
entity_reg,
device.id,
)
if entity.area_id is None
],
)
return {entity.entity_id: entity for entity in entities}
26 changes: 10 additions & 16 deletions tests/group_include/test_filter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from unittest.mock import MagicMock

import pytest
from homeassistant.const import CONF_DOMAIN
from homeassistant.helpers.entity_registry import RegistryEntry

from custom_components.powercalc.group_include.filter import (
Expand All @@ -10,7 +9,6 @@
FilterOperator,
NullFilter,
WildcardFilter,
create_filter,
)


Expand Down Expand Up @@ -41,20 +39,16 @@ async def test_composite_filter(
== expected_result
)


async def test_domain_filter() -> None:
registry_entry = _create_registry_entry()
entity_filter = DomainFilter("switch")
assert entity_filter.is_valid(registry_entry) is True

entity_filter = DomainFilter("light")
assert entity_filter.is_valid(registry_entry) is False


async def test_domain_filter_multiple() -> None:
entity_filter = create_filter({CONF_DOMAIN: ["switch", "light"]})
assert entity_filter.is_valid(_create_registry_entry()) is True

@pytest.mark.parametrize(
"domain,expected_result",
[
("switch", True),
("light", False),
(["switch", "light"], True),
],
)
async def test_domain_filter(domain: str | list, expected_result: bool) -> None:
assert DomainFilter(domain).is_valid(_create_registry_entry()) is expected_result

async def test_null_filter() -> None:
assert NullFilter().is_valid(_create_registry_entry()) is True
Expand Down
17 changes: 17 additions & 0 deletions tests/group_include/test_include.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,23 @@ async def test_include_template(hass: HomeAssistant) -> None:

async def test_include_group(hass: HomeAssistant) -> None:
hass.states.async_set("switch.tv", "on")

mock_registry(
hass,
{
"switch.tv": RegistryEntry(
entity_id="switch.tv",
unique_id="12345",
platform="switch",
),
"switch.soundbar": RegistryEntry(
entity_id="switch.soundbar",
unique_id="123456",
platform="switch",
),
},
)

await async_setup_component(
hass,
SWITCH_DOMAIN,
Expand Down

0 comments on commit d283369

Please sign in to comment.