From 62d29861b74bf0e3485785573247c976b60d2c4b Mon Sep 17 00:00:00 2001 From: Bram Date: Thu, 26 Dec 2024 13:58:46 +0100 Subject: [PATCH 1/9] feat: prepare library loading code for same alias for multiple manufacturers (i.e. Tuya) --- custom_components/powercalc/discovery.py | 13 +---- .../powercalc/power_profile/library.py | 57 +++++++++++-------- .../power_profile/loader/composite.py | 11 ++-- .../powercalc/power_profile/loader/local.py | 6 +- .../power_profile/loader/protocol.py | 2 +- .../powercalc/power_profile/loader/remote.py | 7 ++- tests/power_profile/loader/test_local.py | 10 ++-- tests/power_profile/test_library.py | 10 ++-- tests/test_discovery.py | 4 +- 9 files changed, 62 insertions(+), 58 deletions(-) diff --git a/custom_components/powercalc/discovery.py b/custom_components/powercalc/discovery.py index 882ed01c9..24b1c343a 100644 --- a/custom_components/powercalc/discovery.py +++ b/custom_components/powercalc/discovery.py @@ -125,15 +125,7 @@ async def discover_entity( await self.init_wled_flow(model_info, source_entity) return None - manufacturer = await library.find_manufacturer(model_info) - if not manufacturer: - _LOGGER.debug( - "%s: Manufacturer not found in library, skipping discovery", - source_entity.entity_entry.entity_id, - ) - return None - - models = await library.find_models(manufacturer, model_info) + models = await library.find_models(model_info) if not models: _LOGGER.debug( "%s: Model not found in library, skipping discovery", @@ -142,8 +134,7 @@ async def discover_entity( return None power_profiles = [] - for model in models: - model_info = ModelInfo(manufacturer, model) + for model_info in models: profile = await get_power_profile(self.hass, {}, model_info=model_info) if not profile: # pragma: no cover continue diff --git a/custom_components/powercalc/power_profile/library.py b/custom_components/powercalc/power_profile/library.py index 268369b3b..2eecbbac7 100644 --- a/custom_components/powercalc/power_profile/library.py +++ b/custom_components/powercalc/power_profile/library.py @@ -78,17 +78,22 @@ async def get_manufacturer_listing(self, entity_domain: str | None = None) -> li async def get_model_listing(self, manufacturer: str, entity_domain: str | None = None) -> list[str]: """Get listing of available models for a given manufacturer.""" - resolved_manufacturer = await self._loader.find_manufacturer(manufacturer) - if not resolved_manufacturer: + resolved_manufacturers = await self._loader.find_manufacturers(manufacturer) + if not resolved_manufacturers: return [] device_types = get_device_types_from_domain(entity_domain) if entity_domain else None - cache_key = f"{resolved_manufacturer}/{device_types}" - cached_models = self._manufacturer_models.get(cache_key) - if cached_models: - return cached_models - models = await self._loader.get_model_listing(resolved_manufacturer, device_types) - self._manufacturer_models[cache_key] = sorted(models) - return self._manufacturer_models[cache_key] + all_models: list[str] = [] + for manufacturer in resolved_manufacturers: + cache_key = f"{manufacturer}/{device_types}" + cached_models = self._manufacturer_models.get(cache_key) + if cached_models: + all_models.extend(cached_models) + continue + models = await self._loader.get_model_listing(manufacturer, device_types) + self._manufacturer_models[cache_key] = sorted(models) + all_models.extend(models) + + return all_models async def get_profile( self, @@ -116,30 +121,24 @@ async def create_power_profile( ) -> PowerProfile: """Create a power profile object from the model JSON data.""" - manufacturer = model_info.manufacturer - model = model_info.model if not custom_directory: - manufacturer = await self.find_manufacturer(model_info) # type: ignore - if manufacturer is None: - raise LibraryError(f"Manufacturer {model_info.manufacturer} not found") - - models = await self.find_models(manufacturer, model_info) + models = await self.find_models(model_info) if not models: - raise LibraryError(f"Model {manufacturer} {model} not found") - model = next(iter(models)) + raise LibraryError(f"Model {model_info.manufacturer} {model_info.model} not found") + model_info = next(iter(models)) - json_data, directory = await self._load_model_data(manufacturer, model, custom_directory) + json_data, directory = await self._load_model_data(model_info.manufacturer, model_info.model, custom_directory) if linked_profile := json_data.get("linked_lut"): linked_manufacturer, linked_model = linked_profile.split("/") _, directory = await self._load_model_data(linked_manufacturer, linked_model, custom_directory) - return await self._create_power_profile_instance(manufacturer, model, directory, json_data) + return await self._create_power_profile_instance(model_info.manufacturer, model_info.model, directory, json_data) - async def find_manufacturer(self, model_info: ModelInfo) -> str | None: + async def find_manufacturers(self, model_info: ModelInfo) -> set[str]: """Resolve the manufacturer, either from the model info or by loading it.""" - return await self._loader.find_manufacturer(model_info.manufacturer) + return await self._loader.find_manufacturers(model_info.manufacturer) - async def find_models(self, manufacturer: str, model_info: ModelInfo) -> set[str]: + async def find_models(self, model_info: ModelInfo) -> set[ModelInfo]: """Resolve the model identifier, searching for it if no custom directory is provided.""" search: set[str] = set() for model_identifier in (model_info.model_id, model_info.model): @@ -155,7 +154,17 @@ async def find_models(self, manufacturer: str, model_info: ModelInfo) -> set[str if "/" in model_identifier: search.update(model_identifier.split("/")) - return set(await self._loader.find_model(manufacturer, search)) + manufacturers = await self._loader.find_manufacturers(model_info.manufacturer) + found_models: set[ModelInfo] = set() + if not manufacturers: + return found_models + + for manufacturer in manufacturers: + models = await self._loader.find_model(manufacturer, search) + if models: + found_models.update(ModelInfo(manufacturer, model) for model in models) + + return found_models async def _load_model_data(self, manufacturer: str, model: str, custom_directory: str | None) -> tuple[dict, str]: """Load the model data from the appropriate directory.""" diff --git a/custom_components/powercalc/power_profile/loader/composite.py b/custom_components/powercalc/power_profile/loader/composite.py index 3a2f5d72d..94d10ea71 100644 --- a/custom_components/powercalc/power_profile/loader/composite.py +++ b/custom_components/powercalc/power_profile/loader/composite.py @@ -18,16 +18,17 @@ async def get_manufacturer_listing(self, device_types: set[DeviceType] | None) - return {manufacturer for loader in self.loaders for manufacturer in await loader.get_manufacturer_listing(device_types)} - async def find_manufacturer(self, search: str) -> str | None: + async def find_manufacturers(self, search: str) -> set[str]: """Check if a manufacturer is available. Also must check aliases.""" search = search.lower() + found_manufacturers = set() for loader in self.loaders: - manufacturer = await loader.find_manufacturer(search) - if manufacturer: - return manufacturer + manufacturers = await loader.find_manufacturers(search) + if manufacturers: + found_manufacturers.update(manufacturers) - return None + return found_manufacturers async def get_model_listing(self, manufacturer: str, device_types: set[DeviceType] | None) -> set[str]: """Get listing of available models for a given manufacturer.""" diff --git a/custom_components/powercalc/power_profile/loader/local.py b/custom_components/powercalc/power_profile/loader/local.py index ddf02ed0e..6ff8010a3 100644 --- a/custom_components/powercalc/power_profile/loader/local.py +++ b/custom_components/powercalc/power_profile/loader/local.py @@ -39,15 +39,15 @@ async def get_manufacturer_listing(self, device_types: set[DeviceType] | None) - return manufacturers - async def find_manufacturer(self, search: str) -> str | None: + async def find_manufacturers(self, search: str) -> set[str]: """Check if a manufacturer is available.""" _search = search.lower() manufacturer_list = self._manufacturer_model_listing.keys() if _search in manufacturer_list: - return _search + return {_search} - return None + return set() async def get_model_listing(self, manufacturer: str, device_types: set[DeviceType] | None) -> set[str]: """Get listing of available models for a given manufacturer. diff --git a/custom_components/powercalc/power_profile/loader/protocol.py b/custom_components/powercalc/power_profile/loader/protocol.py index d570dac19..36d577896 100644 --- a/custom_components/powercalc/power_profile/loader/protocol.py +++ b/custom_components/powercalc/power_profile/loader/protocol.py @@ -10,7 +10,7 @@ async def initialize(self) -> None: async def get_manufacturer_listing(self, device_types: set[DeviceType] | None) -> set[str]: """Get listing of possible manufacturers.""" - async def find_manufacturer(self, search: str) -> str | None: + async def find_manufacturers(self, search: str) -> set[str]: """Check if a manufacturer is available. Also must check aliases.""" async def get_model_listing(self, manufacturer: str, device_types: set[DeviceType] | None) -> set[str]: diff --git a/custom_components/powercalc/power_profile/loader/remote.py b/custom_components/powercalc/power_profile/loader/remote.py index e46ac0d5f..a90da4b8e 100644 --- a/custom_components/powercalc/power_profile/loader/remote.py +++ b/custom_components/powercalc/power_profile/loader/remote.py @@ -107,10 +107,13 @@ async def get_manufacturer_listing(self, device_types: set[DeviceType] | None) - if not device_types or any(device_type in manufacturer.get("device_types", []) for device_type in device_types) } - async def find_manufacturer(self, search: str) -> str | None: + async def find_manufacturers(self, search: str) -> set[str]: """Find the manufacturer in the library.""" - return self.manufacturer_aliases.get(search, None) + manufacturer = self.manufacturer_aliases.get(search, None) + if manufacturer: + return {manufacturer} + return set() async def get_model_listing(self, manufacturer: str, device_types: set[DeviceType] | None) -> set[str]: """Get listing of available models for a given manufacturer.""" diff --git a/tests/power_profile/loader/test_local.py b/tests/power_profile/loader/test_local.py index 2ddec1a56..1113998f9 100644 --- a/tests/power_profile/loader/test_local.py +++ b/tests/power_profile/loader/test_local.py @@ -79,14 +79,14 @@ async def test_load_model_returns_none_when_model_not_found(hass: HomeAssistant) @pytest.mark.parametrize( "manufacturer,expected", [ - ["tp-link", "tp-link"], - ["TP-Link", "tp-link"], - ["foo", None], + ["tp-link", {"tp-link"}], + ["TP-Link", {"tp-link"}], + ["foo", set()], ], ) -async def test_find_manufacturer(hass: HomeAssistant, manufacturer: str, expected: str | None) -> None: +async def test_find_manufacturers(hass: HomeAssistant, manufacturer: str, expected: str | None) -> None: loader = await _create_loader(hass) - assert expected == await loader.find_manufacturer(manufacturer) + assert expected == await loader.find_manufacturers(manufacturer) async def test_get_manufacturer_listing(hass: HomeAssistant) -> None: diff --git a/tests/power_profile/test_library.py b/tests/power_profile/test_library.py index e8782acb7..6b76590a8 100644 --- a/tests/power_profile/test_library.py +++ b/tests/power_profile/test_library.py @@ -47,8 +47,8 @@ async def test_model_listing(hass: HomeAssistant, manufacturer: str, expected_mo ) async def test_find_models(hass: HomeAssistant, model_info: ModelInfo, expected_models: set[str]) -> None: library = await ProfileLibrary.factory(hass) - models = await library.find_models(model_info.manufacturer, model_info) - assert models == expected_models + models = await library.find_models(model_info) + assert {model.model for model in models} == expected_models async def test_get_subprofile_listing(hass: HomeAssistant) -> None: @@ -130,7 +130,7 @@ async def test_create_power_profile_raises_library_error(hass: HomeAssistant) -> """When no loader is able to load the model, a LibraryError should be raised.""" mock_loader = LocalLoader(hass, "") mock_loader.load_model = AsyncMock(return_value=None) - mock_loader.find_manufacturer = AsyncMock(return_value="signify") + mock_loader.find_manufacturers = AsyncMock(return_value="signify") mock_loader.find_model = AsyncMock(return_value=ModelInfo("signify", "LCT010")) library = ProfileLibrary(hass, loader=mock_loader) await library.initialize() @@ -142,7 +142,7 @@ async def test_create_power_raise_library_error_when_model_not_found(hass: HomeA """When model is not found in library a LibraryError should be raised""" mock_loader = LocalLoader(hass, "") mock_loader.load_model = AsyncMock(return_value=None) - mock_loader.find_manufacturer = AsyncMock(return_value="signify") + mock_loader.find_manufacturers = AsyncMock(return_value="signify") mock_loader.find_model = AsyncMock(return_value=[]) library = ProfileLibrary(hass, loader=mock_loader) await library.initialize() @@ -154,7 +154,7 @@ async def test_create_power_raise_library_error_when_manufacturer_not_found(hass """When model is not found in library a LibraryError should be raised""" mock_loader = LocalLoader(hass, "") mock_loader.load_model = AsyncMock(return_value=None) - mock_loader.find_manufacturer = AsyncMock(return_value=None) + mock_loader.find_manufacturers = AsyncMock(return_value=None) library = ProfileLibrary(hass, loader=mock_loader) await library.initialize() with pytest.raises(LibraryError): diff --git a/tests/test_discovery.py b/tests/test_discovery.py index 1663f92af..0df3004f1 100644 --- a/tests/test_discovery.py +++ b/tests/test_discovery.py @@ -326,8 +326,8 @@ async def test_autodiscover_continues_when_one_entity_fails( ), }, ) - with patch("custom_components.powercalc.power_profile.library.ProfileLibrary.find_manufacturer", new_callable=AsyncMock) as mock_find_models: - mock_find_models.side_effect = [Exception("Test exception"), "signify"] + with patch("custom_components.powercalc.power_profile.library.ProfileLibrary.find_models", new_callable=AsyncMock) as mock_find_models: + mock_find_models.side_effect = [Exception("Test exception"), {ModelInfo("signify", "LCT010")}] await run_powercalc_setup(hass, {}) assert "Error during auto discovery" in caplog.text From a6a44abd24a60ff613ff34cd8c7b4cef2193af22 Mon Sep 17 00:00:00 2001 From: Bram Date: Thu, 26 Dec 2024 14:45:21 +0100 Subject: [PATCH 2/9] feat: implement caching decorator --- custom_components/powercalc/helpers.py | 46 +++++++++++++++++++ .../powercalc/power_profile/loader/remote.py | 6 +++ 2 files changed, 52 insertions(+) diff --git a/custom_components/powercalc/helpers.py b/custom_components/powercalc/helpers.py index 1895e7732..e6621ca64 100644 --- a/custom_components/powercalc/helpers.py +++ b/custom_components/powercalc/helpers.py @@ -2,7 +2,10 @@ import logging import os.path import uuid +from collections.abc import Callable, Coroutine from decimal import Decimal +from functools import wraps +from typing import Any, TypeVar from homeassistant.const import CONF_UNIQUE_ID from homeassistant.helpers.template import Template @@ -69,3 +72,46 @@ def get_or_create_unique_id( return f"pc_{source_unique_id}" return str(uuid.uuid4()) + + +P = TypeVar("P") # Used for positional and keyword argument types +R = TypeVar("R") # Used for return type + + +def make_hashable(arg: Any) -> Any: # noqa: ANN401 + """Convert unhashable arguments to hashable equivalents.""" + if isinstance(arg, set): + return frozenset(arg) + if isinstance(arg, list): + return tuple(arg) + if isinstance(arg, dict): + return frozenset((key, make_hashable(value)) for key, value in arg.items()) + return arg + + +def async_cache(func: Callable[..., Coroutine[Any, Any, R]]) -> Callable[..., Coroutine[Any, Any, R]]: + """ + A decorator to cache results of an async function based on its arguments. + + Args: + func: The asynchronous function to decorate. + + Returns: + A decorated asynchronous function with caching. + """ + cache: dict[tuple[tuple[Any, ...], frozenset], R] = {} + + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> R: # noqa: ANN401 + # Make arguments hashable + hashable_args = tuple(make_hashable(arg) for arg in args) + hashable_kwargs = frozenset((key, make_hashable(value)) for key, value in kwargs.items()) + cache_key = (hashable_args, hashable_kwargs) + + if cache_key in cache: + return cache[cache_key] + result = await func(*args, **kwargs) + cache[cache_key] = result + return result + + return wrapper diff --git a/custom_components/powercalc/power_profile/loader/remote.py b/custom_components/powercalc/power_profile/loader/remote.py index a90da4b8e..9f102b2f9 100644 --- a/custom_components/powercalc/power_profile/loader/remote.py +++ b/custom_components/powercalc/power_profile/loader/remote.py @@ -15,6 +15,7 @@ from homeassistant.core import HomeAssistant from homeassistant.helpers.storage import STORAGE_DIR +from custom_components.powercalc.helpers import async_cache from custom_components.powercalc.power_profile.error import LibraryLoadingError, ProfileDownloadError from custom_components.powercalc.power_profile.loader.protocol import Loader from custom_components.powercalc.power_profile.power_profile import DeviceType @@ -98,6 +99,7 @@ def _save_to_local_storage(data: bytes) -> None: _LOGGER.debug("Failed to download library.json, falling back to local copy") return await self.hass.async_add_executor_job(_load_local_library_json) # type: ignore + @async_cache async def get_manufacturer_listing(self, device_types: set[DeviceType] | None) -> set[str]: """Get listing of available manufacturers.""" @@ -107,6 +109,7 @@ async def get_manufacturer_listing(self, device_types: set[DeviceType] | None) - if not device_types or any(device_type in manufacturer.get("device_types", []) for device_type in device_types) } + @async_cache async def find_manufacturers(self, search: str) -> set[str]: """Find the manufacturer in the library.""" @@ -115,6 +118,7 @@ async def find_manufacturers(self, search: str) -> set[str]: return {manufacturer} return set() + @async_cache async def get_model_listing(self, manufacturer: str, device_types: set[DeviceType] | None) -> set[str]: """Get listing of available models for a given manufacturer.""" @@ -124,6 +128,7 @@ async def get_model_listing(self, manufacturer: str, device_types: set[DeviceTyp if not device_types or any(device_type in model.get("device_type", [DeviceType.LIGHT]) for device_type in device_types) } + @async_cache async def load_model( self, manufacturer: str, @@ -226,6 +231,7 @@ def _write() -> None: return await self.hass.async_add_executor_job(_write) # type: ignore + @async_cache async def find_model(self, manufacturer: str, search: set[str]) -> list[str]: """Find the model in the library.""" From 8bcdf487405966bdb122a60a3574618fd60225b2 Mon Sep 17 00:00:00 2001 From: Bram Date: Thu, 26 Dec 2024 18:54:33 +0100 Subject: [PATCH 3/9] feat: change from list to set --- custom_components/powercalc/discovery.py | 2 +- .../power_profile/loader/composite.py | 6 +-- .../powercalc/power_profile/loader/local.py | 6 +-- .../power_profile/loader/protocol.py | 2 +- .../powercalc/power_profile/loader/remote.py | 38 +++++++++---------- tests/power_profile/loader/test_local.py | 28 +++++++------- tests/power_profile/loader/test_remote.py | 8 ++-- 7 files changed, 43 insertions(+), 47 deletions(-) diff --git a/custom_components/powercalc/discovery.py b/custom_components/powercalc/discovery.py index 24b1c343a..faddfaede 100644 --- a/custom_components/powercalc/discovery.py +++ b/custom_components/powercalc/discovery.py @@ -117,7 +117,6 @@ async def discover_entity( ) -> list[PowerProfile] | None: """Discover a single entity in Powercalc library and start the discovery flow if supported.""" - library = await self._get_library() if source_entity.entity_entry is None: # pragma: no cover return None @@ -125,6 +124,7 @@ async def discover_entity( await self.init_wled_flow(model_info, source_entity) return None + library = await self._get_library() models = await library.find_models(model_info) if not models: _LOGGER.debug( diff --git a/custom_components/powercalc/power_profile/loader/composite.py b/custom_components/powercalc/power_profile/loader/composite.py index 94d10ea71..79db6a173 100644 --- a/custom_components/powercalc/power_profile/loader/composite.py +++ b/custom_components/powercalc/power_profile/loader/composite.py @@ -43,11 +43,11 @@ async def load_model(self, manufacturer: str, model: str) -> tuple[dict, str] | return None - async def find_model(self, manufacturer: str, search: set[str]) -> list[str]: + async def find_model(self, manufacturer: str, search: set[str]) -> set[str]: """Find the model in the library.""" - models = [] + models = set() for loader in self.loaders: - models.extend(await loader.find_model(manufacturer, search)) + models.update(await loader.find_model(manufacturer, search)) return models diff --git a/custom_components/powercalc/power_profile/loader/local.py b/custom_components/powercalc/power_profile/loader/local.py index 6ff8010a3..dfc60b140 100644 --- a/custom_components/powercalc/power_profile/loader/local.py +++ b/custom_components/powercalc/power_profile/loader/local.py @@ -104,19 +104,19 @@ async def load_model(self, manufacturer: str, model: str) -> tuple[dict, str] | model_json = lib_model.json_data return model_json, model_path - async def find_model(self, manufacturer: str, search: set[str]) -> list[str]: + async def find_model(self, manufacturer: str, search: set[str]) -> set[str]: """Find a model for a given manufacturer. Also must check aliases.""" _manufacturer = manufacturer.lower() models = self._manufacturer_model_listing.get(_manufacturer) if not models: _LOGGER.info("Manufacturer does not exist in custom library: %s", _manufacturer) - return [] + return set() search_lower = {phrase.lower() for phrase in search} profile = next((models[model] for model in models if model.lower() in search_lower), None) - return [profile.model] if profile else [] + return {profile.model} if profile else set() def _load_custom_library(self) -> None: """Loading custom models and aliases from file system. diff --git a/custom_components/powercalc/power_profile/loader/protocol.py b/custom_components/powercalc/power_profile/loader/protocol.py index 36d577896..9b575c03c 100644 --- a/custom_components/powercalc/power_profile/loader/protocol.py +++ b/custom_components/powercalc/power_profile/loader/protocol.py @@ -19,5 +19,5 @@ async def get_model_listing(self, manufacturer: str, device_types: set[DeviceTyp async def load_model(self, manufacturer: str, model: str) -> tuple[dict, str] | None: """Load and optionally download a model profile.""" - async def find_model(self, manufacturer: str, search: set[str]) -> list[str]: + async def find_model(self, manufacturer: str, search: set[str]) -> set[str]: """Check if a model is available. Also must check aliases.""" diff --git a/custom_components/powercalc/power_profile/loader/remote.py b/custom_components/powercalc/power_profile/loader/remote.py index 9f102b2f9..43390def9 100644 --- a/custom_components/powercalc/power_profile/loader/remote.py +++ b/custom_components/powercalc/power_profile/loader/remote.py @@ -112,11 +112,7 @@ async def get_manufacturer_listing(self, device_types: set[DeviceType] | None) - @async_cache async def find_manufacturers(self, search: str) -> set[str]: """Find the manufacturer in the library.""" - - manufacturer = self.manufacturer_aliases.get(search, None) - if manufacturer: - return {manufacturer} - return set() + return {self.manufacturer_aliases[search]} if search in self.manufacturer_aliases else set() @async_cache async def get_model_listing(self, manufacturer: str, device_types: set[DeviceType] | None) -> set[str]: @@ -128,6 +124,19 @@ async def get_model_listing(self, manufacturer: str, device_types: set[DeviceTyp if not device_types or any(device_type in model.get("device_type", [DeviceType.LIGHT]) for device_type in device_types) } + @async_cache + async def find_model(self, manufacturer: str, search: set[str]) -> set[str]: + """Find the model in the library.""" + + models = self.manufacturer_models.get(manufacturer, []) + result = set() + for model in models: + model_id = model.get("id") + if model_id and (model_id in search or any(alias in search for alias in model.get("aliases", []))): + result.add(model_id) + + return result + @async_cache async def load_model( self, @@ -137,12 +146,12 @@ async def load_model( retry_count: int = 0, ) -> tuple[dict, str] | None: """Load a model, downloading it if necessary, with retry logic.""" - model_info = self._get_model_info(manufacturer.lower(), model) - storage_path = self.get_storage_path(manufacturer.lower(), model) + model_info = self._get_model_info(manufacturer, model) + storage_path = self.get_storage_path(manufacturer, model) model_path = os.path.join(storage_path, "model.json") if await self._needs_update(model_info, model_path, force_update): - await self._download_profile_with_retry(manufacturer.lower(), model, storage_path, model_path) + await self._download_profile_with_retry(manufacturer, model, storage_path, model_path) try: json_data = await self._load_model_json(model_path) @@ -231,19 +240,6 @@ def _write() -> None: return await self.hass.async_add_executor_job(_write) # type: ignore - @async_cache - async def find_model(self, manufacturer: str, search: set[str]) -> list[str]: - """Find the model in the library.""" - - models = self.manufacturer_models.get(manufacturer, []) - result = [] - for model in models: - model_id = model.get("id") - if model_id and (model_id in search or any(alias in search for alias in model.get("aliases", []))): - result.append(model_id) - - return result - @staticmethod def _get_remote_modification_time(model_info: dict) -> float: """Get the remote modification time of the model""" diff --git a/tests/power_profile/loader/test_local.py b/tests/power_profile/loader/test_local.py index 1113998f9..0b0fcb675 100644 --- a/tests/power_profile/loader/test_local.py +++ b/tests/power_profile/loader/test_local.py @@ -35,20 +35,20 @@ async def test_broken_lib_by_missing_model_json(hass: HomeAssistant, caplog: pyt @pytest.mark.parametrize( "manufacturer,search,expected", [ - ["tp-link", {"HS300"}, ["HS300"]], - ["TP-link", {"HS300"}, ["HS300"]], - ["tp-link", {"hs300"}, ["HS300"]], - ["TP-link", {"hs300"}, ["HS300"]], - ["tp-link", {"HS400"}, ["HS400"]], # alias - ["tp-link", {"hs400"}, ["HS400"]], # alias - ["tp-link", {"Hs500"}, ["hs500"]], # alias - ["tp-link", {"bla"}, []], - ["foo", {"bar"}, []], - ["casing", {"CaSinG- Test"}, ["CaSinG- Test"]], - ["casing", {"CasinG- test"}, ["CaSinG- Test"]], - ["casing", {"CASING- TEST"}, ["CaSinG- Test"]], - ["hidden-directories", {".test"}, []], - ["hidden-directories", {".hidden_model"}, []], + ["tp-link", {"HS300"}, {"HS300"}], + ["TP-link", {"HS300"}, {"HS300"}], + ["tp-link", {"hs300"}, {"HS300"}], + ["TP-link", {"hs300"}, {"HS300"}], + ["tp-link", {"HS400"}, {"HS400"}], # alias + ["tp-link", {"hs400"}, {"HS400"}], # alias + ["tp-link", {"Hs500"}, {"hs500"}], # alias + ["tp-link", {"bla"}, set()], + ["foo", {"bar"}, set()], + ["casing", {"CaSinG- Test"}, {"CaSinG- Test"}], + ["casing", {"CasinG- test"}, {"CaSinG- Test"}], + ["casing", {"CASING- TEST"}, {"CaSinG- Test"}], + ["hidden-directories", {".test"}, set()], + ["hidden-directories", {".hidden_model"}, set()], ], ) async def test_find_model(hass: HomeAssistant, manufacturer: str, search: set[str], expected: str | None) -> None: diff --git a/tests/power_profile/loader/test_remote.py b/tests/power_profile/loader/test_remote.py index a69166e71..56a9d8886 100644 --- a/tests/power_profile/loader/test_remote.py +++ b/tests/power_profile/loader/test_remote.py @@ -475,10 +475,10 @@ async def test_profile_redownloaded_when_model_json_corrupt_retry_limit( @pytest.mark.parametrize( "manufacturer,phrases,expected_models,library_dir", [ - ("apple", {"HomePod (gen 2)"}, ["MQJ83"], None), - ("apple", {"Non existing model"}, [], None), - ("signify", {"LCA001", "LCT010"}, ["LCT010", "LCA001"], None), - ("test_manu", {"CCT Light"}, ["model1", "model2"], "multi-profile"), + ("apple", {"HomePod (gen 2)"}, {"MQJ83"}, None), + ("apple", {"Non existing model"}, set(), None), + ("signify", {"LCA001", "LCT010"}, {"LCT010", "LCA001"}, None), + ("test_manu", {"CCT Light"}, {"model1", "model2"}, "multi-profile"), ], ) @pytest.mark.skip_remote_loader_mocking From 82c9ad2257e79817fc99acb26fd06e4ba0244a54 Mon Sep 17 00:00:00 2001 From: Bram Date: Thu, 26 Dec 2024 22:09:57 +0100 Subject: [PATCH 4/9] feat: use singleton pattern --- .../powercalc/power_profile/library.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/custom_components/powercalc/power_profile/library.py b/custom_components/powercalc/power_profile/library.py index 2eecbbac7..9be170988 100644 --- a/custom_components/powercalc/power_profile/library.py +++ b/custom_components/powercalc/power_profile/library.py @@ -6,8 +6,9 @@ from typing import NamedTuple from homeassistant.core import HomeAssistant +from homeassistant.helpers.singleton import singleton -from custom_components.powercalc.const import CONF_DISABLE_LIBRARY_DOWNLOAD, DATA_PROFILE_LIBRARY, DOMAIN, DOMAIN_CONFIG +from custom_components.powercalc.const import CONF_DISABLE_LIBRARY_DOWNLOAD, DOMAIN, DOMAIN_CONFIG from .error import LibraryError from .loader.composite import CompositeLoader @@ -34,19 +35,13 @@ async def initialize(self) -> None: await self._loader.initialize() @staticmethod + @singleton("powercalc_library") async def factory(hass: HomeAssistant) -> ProfileLibrary: """Creates and loads the profile library Makes sure it is only loaded once and instance is saved in hass data registry. """ - if DOMAIN not in hass.data: - hass.data[DOMAIN] = {} - - if DATA_PROFILE_LIBRARY in hass.data[DOMAIN]: - return hass.data[DOMAIN][DATA_PROFILE_LIBRARY] # type: ignore - library = ProfileLibrary(hass, ProfileLibrary.create_loader(hass)) await library.initialize() - hass.data[DOMAIN][DATA_PROFILE_LIBRARY] = library return library @staticmethod @@ -61,7 +56,8 @@ def create_loader(hass: HomeAssistant) -> Loader: if os.path.exists(data_dir) ] - global_config = hass.data[DOMAIN].get(DOMAIN_CONFIG, {}) + domain_config = hass.data.get(DOMAIN, {}) + global_config = domain_config.get(DOMAIN_CONFIG, {}) disable_library_download: bool = bool(global_config.get(CONF_DISABLE_LIBRARY_DOWNLOAD, False)) if not disable_library_download: loaders.append(RemoteLoader(hass)) From 14122dd22f81a6f3033792139c356835df9d32e5 Mon Sep 17 00:00:00 2001 From: Bram Date: Thu, 26 Dec 2024 22:59:21 +0100 Subject: [PATCH 5/9] chore: cleanup --- custom_components/powercalc/const.py | 1 - custom_components/powercalc/power_profile/library.py | 5 +++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/custom_components/powercalc/const.py b/custom_components/powercalc/const.py index 98bca705d..3e2866b58 100644 --- a/custom_components/powercalc/const.py +++ b/custom_components/powercalc/const.py @@ -29,7 +29,6 @@ DATA_ENTITIES = "entities" DATA_GROUP_ENTITIES = "group_entities" DATA_USED_UNIQUE_IDS = "used_unique_ids" -DATA_PROFILE_LIBRARY = "profile_library" DATA_STANDBY_POWER_SENSORS = "standby_power_sensors" ENTRY_DATA_ENERGY_ENTITY = "_energy_entity" diff --git a/custom_components/powercalc/power_profile/library.py b/custom_components/powercalc/power_profile/library.py index 9be170988..0ce530d2d 100644 --- a/custom_components/powercalc/power_profile/library.py +++ b/custom_components/powercalc/power_profile/library.py @@ -37,8 +37,9 @@ async def initialize(self) -> None: @staticmethod @singleton("powercalc_library") async def factory(hass: HomeAssistant) -> ProfileLibrary: - """Creates and loads the profile library - Makes sure it is only loaded once and instance is saved in hass data registry. + """ + Creates and loads the profile library. + Make sure we have a single instance throughout the application. """ library = ProfileLibrary(hass, ProfileLibrary.create_loader(hass)) await library.initialize() From 35f947f19908cb7bd51e5fae20db089c3b94f4d3 Mon Sep 17 00:00:00 2001 From: Bram Date: Fri, 27 Dec 2024 11:10:49 +0100 Subject: [PATCH 6/9] feat: add test for multi manufacturer --- tests/power_profile/loader/test_remote.py | 45 +++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/power_profile/loader/test_remote.py b/tests/power_profile/loader/test_remote.py index 56a9d8886..1be025e13 100644 --- a/tests/power_profile/loader/test_remote.py +++ b/tests/power_profile/loader/test_remote.py @@ -16,6 +16,7 @@ from custom_components.powercalc.helpers import get_library_json_path, get_library_path from custom_components.powercalc.power_profile.error import LibraryLoadingError, ProfileDownloadError +from custom_components.powercalc.power_profile.library import ModelInfo, ProfileLibrary from custom_components.powercalc.power_profile.loader.remote import ENDPOINT_DOWNLOAD, ENDPOINT_LIBRARY, RemoteLoader from custom_components.powercalc.power_profile.power_profile import DeviceType from tests.common import get_test_config_dir, get_test_profile_dir @@ -508,3 +509,47 @@ def clear_storage_dir(storage_path: str) -> None: if not os.path.exists(storage_path): return shutil.rmtree(storage_path, ignore_errors=True) + + +async def test_multiple_manufacturer_aliases(hass: HomeAssistant, mock_aioresponse: aioresponses) -> None: + mock_aioresponse.get( + ENDPOINT_LIBRARY, + status=200, + payload={ + "manufacturers": [ + { + "name": "manufacturer1", + "aliases": ["my-alias"], + "models": [ + { + "id": "model1", + "device_type": "light", + "updated_at": "2021-01-01T00:00:00", + }, + ], + }, + { + "name": "manufacturer2", + "aliases": ["my-alias"], + "models": [ + { + "id": "model1", + "device_type": "light", + "updated_at": "2021-01-01T00:00:00", + }, + ], + }, + ], + }, + ) + + library = await ProfileLibrary.factory(hass) + + manufacturers = await library.find_manufacturers("my-alias") + assert manufacturers == {"manufacturer1", "manufacturer2"} + + model_listing = await library.get_model_listing("my-alias", "light") + assert len(model_listing) == 2 + + models = await library.find_models(ModelInfo("my-alias", "model1")) + assert models == {ModelInfo("manufacturer1", "model1"), ModelInfo("manufacturer2", "model1")} \ No newline at end of file From 63c5bef10da5b64071f0603b5a1d19618f093dea Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Dec 2024 10:11:01 +0000 Subject: [PATCH 7/9] chore(lint): [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/power_profile/loader/test_remote.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/power_profile/loader/test_remote.py b/tests/power_profile/loader/test_remote.py index 1be025e13..55f388556 100644 --- a/tests/power_profile/loader/test_remote.py +++ b/tests/power_profile/loader/test_remote.py @@ -552,4 +552,4 @@ async def test_multiple_manufacturer_aliases(hass: HomeAssistant, mock_aiorespon assert len(model_listing) == 2 models = await library.find_models(ModelInfo("my-alias", "model1")) - assert models == {ModelInfo("manufacturer1", "model1"), ModelInfo("manufacturer2", "model1")} \ No newline at end of file + assert models == {ModelInfo("manufacturer1", "model1"), ModelInfo("manufacturer2", "model1")} From 4c609c82135fd64606f529c2dd205832b85c10b7 Mon Sep 17 00:00:00 2001 From: Bram Date: Fri, 27 Dec 2024 11:28:11 +0100 Subject: [PATCH 8/9] fix: implementation multiple manufacturers --- .../powercalc/power_profile/library.py | 4 +-- .../powercalc/power_profile/loader/remote.py | 26 +++++++++---------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/custom_components/powercalc/power_profile/library.py b/custom_components/powercalc/power_profile/library.py index 0ce530d2d..fc74dc963 100644 --- a/custom_components/powercalc/power_profile/library.py +++ b/custom_components/powercalc/power_profile/library.py @@ -131,9 +131,9 @@ async def create_power_profile( return await self._create_power_profile_instance(model_info.manufacturer, model_info.model, directory, json_data) - async def find_manufacturers(self, model_info: ModelInfo) -> set[str]: + async def find_manufacturers(self, manufacturer: str) -> set[str]: """Resolve the manufacturer, either from the model info or by loading it.""" - return await self._loader.find_manufacturers(model_info.manufacturer) + return await self._loader.find_manufacturers(manufacturer) async def find_models(self, model_info: ModelInfo) -> set[ModelInfo]: """Resolve the model identifier, searching for it if no custom directory is provided.""" diff --git a/custom_components/powercalc/power_profile/loader/remote.py b/custom_components/powercalc/power_profile/loader/remote.py index 43390def9..a39073c64 100644 --- a/custom_components/powercalc/power_profile/loader/remote.py +++ b/custom_components/powercalc/power_profile/loader/remote.py @@ -35,7 +35,7 @@ def __init__(self, hass: HomeAssistant) -> None: self.library_contents: dict = {} self.model_infos: dict[str, dict] = {} self.manufacturer_models: dict[str, list[dict]] = {} - self.manufacturer_aliases: dict[str, str] = {} + self.manufacturer_aliases: dict[str, set[str]] = {} self.last_update_time: float | None = None async def initialize(self) -> None: @@ -44,20 +44,20 @@ async def initialize(self) -> None: self.last_update_time = await self.hass.async_add_executor_job(self.get_last_update_time) # type: ignore # Load contents of library JSON into memory - manufacturers: list[dict] = self.library_contents.get("manufacturers", []) + manufacturers = self.library_contents.get("manufacturers", []) + for manufacturer in manufacturers: - models: list[dict] = manufacturer.get("models", []) manufacturer_name = str(manufacturer.get("name")) - for model in models: - model_id = str(model.get("id")) - self.model_infos[f"{manufacturer_name}/{model_id}"] = model - if manufacturer_name not in self.manufacturer_models: - self.manufacturer_models[manufacturer_name] = [] - self.manufacturer_models[manufacturer_name].append(model) - - self.manufacturer_aliases[manufacturer_name.lower()] = manufacturer_name + models = manufacturer.get("models", []) + + # Store model info and group models by manufacturer + self.model_infos.update({f"{manufacturer_name}/{model.get('id')!s}": model for model in models}) + self.manufacturer_models.setdefault(manufacturer_name, []).extend(models) + + # Map manufacturer aliases + self.manufacturer_aliases[manufacturer_name.lower()] = {manufacturer_name} for alias in manufacturer.get("aliases", []): - self.manufacturer_aliases[alias.lower()] = manufacturer_name + self.manufacturer_aliases.setdefault(alias.lower(), set()).add(manufacturer_name) async def load_library_json(self) -> dict[str, Any]: """Load library.json file""" @@ -112,7 +112,7 @@ async def get_manufacturer_listing(self, device_types: set[DeviceType] | None) - @async_cache async def find_manufacturers(self, search: str) -> set[str]: """Find the manufacturer in the library.""" - return {self.manufacturer_aliases[search]} if search in self.manufacturer_aliases else set() + return self.manufacturer_aliases.get(search, set()) @async_cache async def get_model_listing(self, manufacturer: str, device_types: set[DeviceType] | None) -> set[str]: From 166d44ac16c3a84452dc840b518e3e01aac40592 Mon Sep 17 00:00:00 2001 From: Bram Date: Fri, 27 Dec 2024 11:41:15 +0100 Subject: [PATCH 9/9] fix: 100% coverage --- .../powercalc/sensors/group/tracked_untracked.py | 2 +- tests/test_helpers.py | 14 +++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/custom_components/powercalc/sensors/group/tracked_untracked.py b/custom_components/powercalc/sensors/group/tracked_untracked.py index b1f0615ce..f5edb6e63 100644 --- a/custom_components/powercalc/sensors/group/tracked_untracked.py +++ b/custom_components/powercalc/sensors/group/tracked_untracked.py @@ -129,7 +129,7 @@ async def _handle_entity_registry_updated( if action == "update" and "old_entity_id" in event.data: if event.data["old_entity_id"] in self.tracked_entities: return await self.reload() - return None + return None # pragma: no cover if action == "remove" and entity_id in self.tracked_entities: return await self.reload() diff --git a/tests/test_helpers.py b/tests/test_helpers.py index e9528f5fe..5420b9ee6 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -9,7 +9,7 @@ from custom_components.powercalc.common import SourceEntity from custom_components.powercalc.const import DUMMY_ENTITY_ID, CalculationStrategy -from custom_components.powercalc.helpers import evaluate_power, get_or_create_unique_id +from custom_components.powercalc.helpers import evaluate_power, get_or_create_unique_id, make_hashable @pytest.mark.parametrize( @@ -53,3 +53,15 @@ async def test_wled_unique_id() -> None: source_entity = SourceEntity("wled", "light.wled", "light", device_entry=device_entry) unique_id = get_or_create_unique_id({}, source_entity, mock_instance) assert unique_id == "pc_123456" + + +@pytest.mark.parametrize( + "value,output", + [ + ({"a", "b", "c"}, frozenset({"a", "b", "c"})), + (["a", "b", "c"], ("a", "b", "c")), + ({"a": 1, "b": 2}, frozenset([("a", 1), ("b", 2)])), + ], +) +async def test_make_hashable(value: set | list | dict, output: tuple | frozenset) -> None: + assert make_hashable(value) == output