diff --git a/custom_components/powercalc/group_include/filter.py b/custom_components/powercalc/group_include/filter.py index 890b09c16..f0b5db0e8 100644 --- a/custom_components/powercalc/group_include/filter.py +++ b/custom_components/powercalc/group_include/filter.py @@ -5,6 +5,8 @@ from awesomeversion.awesomeversion import AwesomeVersion from homeassistant.const import __version__ as HA_VERSION # noqa +from homeassistant.core import HomeAssistant +from homeassistant.helpers.template import Template if AwesomeVersion(HA_VERSION) >= AwesomeVersion("2023.8.0"): from enum import StrEnum @@ -67,6 +69,16 @@ def create_regex(pattern: str) -> str: pattern = pattern.replace("?", ".") return pattern.replace("*", ".*") +class TemplateFilter(IncludeEntityFilter): + def __init__(self, hass: HomeAssistant, template: str|Template) -> None: + if not isinstance(template, Template): + template = Template(template) + template.hass = hass + self.entity_ids = template.async_render() + + def is_valid(self, entity: RegistryEntry) -> bool: + return entity.entity_id in self.entity_ids + class CompositeFilter(IncludeEntityFilter): diff --git a/custom_components/powercalc/group_include/include.py b/custom_components/powercalc/group_include/include.py index 0cfe0af55..d6030f1e4 100644 --- a/custom_components/powercalc/group_include/include.py +++ b/custom_components/powercalc/group_include/include.py @@ -11,7 +11,6 @@ from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_platform import split_entity_id -from homeassistant.helpers.template import Template from custom_components.powercalc.const import ( CONF_AREA, @@ -28,7 +27,7 @@ from custom_components.powercalc.sensors.energy import RealEnergySensor from custom_components.powercalc.sensors.power import RealPowerSensor -from .filter import CompositeFilter, FilterOperator, WildcardFilter, create_filter +from .filter import CompositeFilter, FilterOperator, TemplateFilter, WildcardFilter, create_filter _LOGGER = logging.getLogger(__name__) @@ -104,22 +103,13 @@ def resolve_include_source_entities( _LOGGER.debug("Including entities from group: %s", group_id) entities = entities | resolve_include_groups(hass, group_id) - # Include entities by evaluating a template - if CONF_TEMPLATE in include_config: - template: Template = include_config.get(CONF_TEMPLATE) # type: ignore - template.hass = hass - - _LOGGER.debug("Including entities from template") - entity_ids = template.async_render() - entities = entities | { - entity_id: entity_reg.async_get(entity_id) for entity_id in entity_ids - } - base_filters = [] if CONF_WILDCARD in include_config: base_filters.append(WildcardFilter(include_config.get(CONF_WILDCARD))) if CONF_DOMAIN in include_config: base_filters.append(WildcardFilter(include_config.get(CONF_DOMAIN))) + if CONF_TEMPLATE in include_config: + base_filters.append(TemplateFilter(hass, include_config.get(CONF_TEMPLATE))) if base_filters: base_filter = CompositeFilter(base_filters, FilterOperator.OR)