Skip to content

Commit

Permalink
feat: implement validation logic
Browse files Browse the repository at this point in the history
  • Loading branch information
bramstroker committed Jan 1, 2025
1 parent c3845e4 commit d875291
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 7 deletions.
2 changes: 1 addition & 1 deletion custom_components/powercalc/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,7 +1034,7 @@ async def _create_schema(
) -> vol.Schema:
"""Create sub profile schema."""
library = await ProfileLibrary.factory(self.hass)
profile = await library.get_profile(model_info)
profile = await library.get_profile(model_info, process_variables=False)
sub_profiles = [selector.SelectOptionDict(value=sub_profile, label=sub_profile) for sub_profile in await profile.get_sub_profiles()]
return vol.Schema(
{
Expand Down
5 changes: 3 additions & 2 deletions custom_components/powercalc/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def get_power_profile_by_source_entity(hass: HomeAssistant, source_entity:
model_info = await discovery_manager.extract_model_info_from_device_info(source_entity.entity_entry)
if not model_info:
return None
profiles = await discovery_manager.discover_entity(source_entity, model_info)
profiles = await discovery_manager.find_power_profiles(model_info, source_entity, DiscoveryBy.ENTITY)
return profiles[0] if profiles else None


Expand Down Expand Up @@ -203,7 +203,7 @@ async def find_power_profiles(

power_profiles = []
for model_info in models:
profile = await get_power_profile(self.hass, {}, model_info=model_info)
profile = await get_power_profile(self.hass, {}, model_info=model_info, process_variables=False)
if not profile or profile.discovery_by != discovery_type: # pragma: no cover
continue
if discovery_type == DiscoveryBy.ENTITY and not await self.is_entity_supported(
Expand Down Expand Up @@ -266,6 +266,7 @@ async def is_entity_supported(
{},
model_info,
log_errors=log_profile_loading_errors,
process_variables=False,
)
except ModelNotSupportedError:
return False
Expand Down
2 changes: 2 additions & 0 deletions custom_components/powercalc/power_profile/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ async def get_power_profile(
config: dict,
model_info: ModelInfo | None = None,
log_errors: bool = True,
process_variables: bool = True,
) -> PowerProfile | None:
manufacturer = config.get(CONF_MANUFACTURER)
model = config.get(CONF_MODEL)
Expand Down Expand Up @@ -55,6 +56,7 @@ async def get_power_profile(
ModelInfo(manufacturer or "", model or "", model_id),
custom_model_directory,
config.get(CONF_CUSTOM_FIELDS),
process_variables,
)
except LibraryError as err:
if log_errors:
Expand Down
25 changes: 21 additions & 4 deletions custom_components/powercalc/power_profile/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import os
import re
from typing import NamedTuple, cast
from typing import Any, NamedTuple, cast

from homeassistant.core import HomeAssistant
from homeassistant.helpers.singleton import singleton
Expand Down Expand Up @@ -98,6 +98,7 @@ async def get_profile(
model_info: ModelInfo,
custom_directory: str | None = None,
variables: dict[str, str] | None = None,
process_variables: bool = True,
) -> PowerProfile:
"""Get a power profile for a given manufacturer and model."""
# Support multiple LUT in subdirectories
Expand All @@ -106,7 +107,7 @@ async def get_profile(
(model, sub_profile) = model_info.model.split("/", 1)
model_info = ModelInfo(model_info.manufacturer, model, model_info.model_id)

profile = await self.create_power_profile(model_info, custom_directory, variables)
profile = await self.create_power_profile(model_info, custom_directory, variables, process_variables)

if sub_profile:
await profile.select_sub_profile(sub_profile)
Expand All @@ -118,6 +119,7 @@ async def create_power_profile(
model_info: ModelInfo,
custom_directory: str | None = None,
variables: dict[str, str] | None = None,
process_variables: bool = True,
) -> PowerProfile:
"""Create a power profile object from the model JSON data."""

Expand All @@ -128,14 +130,29 @@ async def create_power_profile(
model_info = next(iter(models))

json_data, directory = await self._load_model_data(model_info.manufacturer, model_info.model, custom_directory)
if variables:
json_data = cast(dict, replace_placeholders(json_data, variables))
if json_data.get("fields") and process_variables:
self.validate_variables(json_data, variables or {})
json_data = cast(dict, replace_placeholders(json_data, variables or {}))
if linked_profile := json_data.get("linked_profile", json_data.get("linked_lut")):
linked_manufacturer, linked_model = linked_profile.split("/")
_, directory = await self._load_model_data(linked_manufacturer, linked_model, custom_directory)

return await self._create_power_profile_instance(model_info.manufacturer, model_info.model, directory, json_data)

@staticmethod
def validate_variables(json_data: dict[str, Any], variables: dict[str, str]) -> None:
fields = json_data.get("fields", {}).keys()

# Check if all variables are valid for the model
for variable in variables:
if variable not in fields:
raise LibraryError(f"Variable {variable} is not valid for this model")

# Check if all fields have corresponding variables
missing_fields = [field for field in fields if field not in variables]
if missing_fields:
raise LibraryError(f"Missing variables for fields: {', '.join(missing_fields)}")

async def find_manufacturers(self, manufacturer: str) -> set[str]:
"""Resolve the manufacturer, either from the model info or by loading it."""
return await self._loader.find_manufacturers(manufacturer)
Expand Down
2 changes: 2 additions & 0 deletions custom_components/powercalc/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
CONF_CREATE_ENERGY_SENSOR,
CONF_CREATE_GROUP,
CONF_CREATE_UTILITY_METERS,
CONF_CUSTOM_FIELDS,
CONF_CUSTOM_MODEL_DIRECTORY,
CONF_DAILY_FIXED_ENERGY,
CONF_DELAY,
Expand Down Expand Up @@ -218,6 +219,7 @@
),
vol.Optional(CONF_UNAVAILABLE_POWER): vol.Coerce(float),
vol.Optional(CONF_COMPOSITE): COMPOSITE_SCHEMA,
vol.Optional(CONF_CUSTOM_FIELDS): vol.Schema({cv.string: cv.string}),
}


Expand Down
33 changes: 33 additions & 0 deletions tests/power_profile/device_types/test_custom_fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import logging

import pytest
from homeassistant.const import CONF_ENTITY_ID, CONF_NAME, STATE_ON
from homeassistant.core import HomeAssistant

from custom_components.powercalc.const import CONF_CUSTOM_FIELDS, CONF_MANUFACTURER, CONF_MODEL, DUMMY_ENTITY_ID
from tests.common import get_test_config_dir, run_powercalc_setup


async def test_custom_field_variables_from_yaml_config(hass: HomeAssistant, caplog: pytest.LogCaptureFixture) -> None:
caplog.set_level(logging.ERROR)
hass.config.config_dir = get_test_config_dir()

hass.states.async_set("sensor.test", STATE_ON)
await hass.async_block_till_done()

await run_powercalc_setup(
hass,
{
CONF_ENTITY_ID: DUMMY_ENTITY_ID,
CONF_NAME: "Test",
CONF_MANUFACTURER: "test",
CONF_MODEL: "custom-fields",
CONF_CUSTOM_FIELDS: {
"some_entity": "sensor.test",
},
},
)

assert not caplog.records

assert hass.states.get("sensor.test_power").state == "20.00"

0 comments on commit d875291

Please sign in to comment.