Skip to content

Commit

Permalink
Move template to filter architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
bramstroker committed Dec 9, 2023
1 parent b96b7b7 commit a6bbb24
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
12 changes: 12 additions & 0 deletions custom_components/powercalc/group_include/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from awesomeversion.awesomeversion import AwesomeVersion
from homeassistant.const import __version__ as HA_VERSION # noqa
from homeassistant.core import HomeAssistant
from homeassistant.helpers.template import Template

if AwesomeVersion(HA_VERSION) >= AwesomeVersion("2023.8.0"):
from enum import StrEnum
Expand Down Expand Up @@ -67,6 +69,16 @@ def create_regex(pattern: str) -> str:
pattern = pattern.replace("?", ".")
return pattern.replace("*", ".*")

class TemplateFilter(IncludeEntityFilter):
def __init__(self, hass: HomeAssistant, template: str|Template) -> None:
if not isinstance(template, Template):
template = Template(template)
template.hass = hass
self.entity_ids = template.async_render()

def is_valid(self, entity: RegistryEntry) -> bool:
return entity.entity_id in self.entity_ids



class CompositeFilter(IncludeEntityFilter):
Expand Down
16 changes: 3 additions & 13 deletions custom_components/powercalc/group_include/include.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.entity_platform import split_entity_id
from homeassistant.helpers.template import Template

from custom_components.powercalc.const import (
CONF_AREA,
Expand All @@ -28,7 +27,7 @@
from custom_components.powercalc.sensors.energy import RealEnergySensor
from custom_components.powercalc.sensors.power import RealPowerSensor

from .filter import CompositeFilter, FilterOperator, WildcardFilter, create_filter
from .filter import CompositeFilter, FilterOperator, TemplateFilter, WildcardFilter, create_filter

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -104,22 +103,13 @@ def resolve_include_source_entities(
_LOGGER.debug("Including entities from group: %s", group_id)
entities = entities | resolve_include_groups(hass, group_id)

# Include entities by evaluating a template
if CONF_TEMPLATE in include_config:
template: Template = include_config.get(CONF_TEMPLATE) # type: ignore
template.hass = hass

_LOGGER.debug("Including entities from template")
entity_ids = template.async_render()
entities = entities | {
entity_id: entity_reg.async_get(entity_id) for entity_id in entity_ids
}

base_filters = []
if CONF_WILDCARD in include_config:
base_filters.append(WildcardFilter(include_config.get(CONF_WILDCARD)))
if CONF_DOMAIN in include_config:
base_filters.append(WildcardFilter(include_config.get(CONF_DOMAIN)))
if CONF_TEMPLATE in include_config:
base_filters.append(TemplateFilter(hass, include_config.get(CONF_TEMPLATE)))

if base_filters:
base_filter = CompositeFilter(base_filters, FilterOperator.OR)
Expand Down

0 comments on commit a6bbb24

Please sign in to comment.