Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support same alias for multiple manufacturers in library loading code #2852

Merged
merged 9 commits into from
Dec 27, 2024
1 change: 0 additions & 1 deletion custom_components/powercalc/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
DATA_ENTITIES = "entities"
DATA_GROUP_ENTITIES = "group_entities"
DATA_USED_UNIQUE_IDS = "used_unique_ids"
DATA_PROFILE_LIBRARY = "profile_library"
DATA_STANDBY_POWER_SENSORS = "standby_power_sensors"

ENTRY_DATA_ENERGY_ENTITY = "_energy_entity"
Expand Down
15 changes: 3 additions & 12 deletions custom_components/powercalc/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,23 +117,15 @@ async def discover_entity(
) -> list[PowerProfile] | None:
"""Discover a single entity in Powercalc library and start the discovery flow if supported."""

library = await self._get_library()
if source_entity.entity_entry is None: # pragma: no cover
return None

if self.is_wled_light(model_info, source_entity.entity_entry):
await self.init_wled_flow(model_info, source_entity)
return None

manufacturer = await library.find_manufacturer(model_info)
if not manufacturer:
_LOGGER.debug(
"%s: Manufacturer not found in library, skipping discovery",
source_entity.entity_entry.entity_id,
)
return None

models = await library.find_models(manufacturer, model_info)
library = await self._get_library()
models = await library.find_models(model_info)
if not models:
_LOGGER.debug(
"%s: Model not found in library, skipping discovery",
Expand All @@ -142,8 +134,7 @@ async def discover_entity(
return None

power_profiles = []
for model in models:
model_info = ModelInfo(manufacturer, model)
for model_info in models:
profile = await get_power_profile(self.hass, {}, model_info=model_info)
if not profile: # pragma: no cover
continue
Expand Down
46 changes: 46 additions & 0 deletions custom_components/powercalc/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
import logging
import os.path
import uuid
from collections.abc import Callable, Coroutine
from decimal import Decimal
from functools import wraps
from typing import Any, TypeVar

from homeassistant.const import CONF_UNIQUE_ID
from homeassistant.helpers.template import Template
Expand Down Expand Up @@ -69,3 +72,46 @@ def get_or_create_unique_id(
return f"pc_{source_unique_id}"

return str(uuid.uuid4())


P = TypeVar("P") # Used for positional and keyword argument types
R = TypeVar("R") # Used for return type


def make_hashable(arg: Any) -> Any: # noqa: ANN401
"""Convert unhashable arguments to hashable equivalents."""
if isinstance(arg, set):
return frozenset(arg)
if isinstance(arg, list):
return tuple(arg)
if isinstance(arg, dict):
return frozenset((key, make_hashable(value)) for key, value in arg.items())
return arg


def async_cache(func: Callable[..., Coroutine[Any, Any, R]]) -> Callable[..., Coroutine[Any, Any, R]]:
"""
A decorator to cache results of an async function based on its arguments.

Args:
func: The asynchronous function to decorate.

Returns:
A decorated asynchronous function with caching.
"""
cache: dict[tuple[tuple[Any, ...], frozenset], R] = {}

@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> R: # noqa: ANN401
# Make arguments hashable
hashable_args = tuple(make_hashable(arg) for arg in args)
hashable_kwargs = frozenset((key, make_hashable(value)) for key, value in kwargs.items())
cache_key = (hashable_args, hashable_kwargs)

if cache_key in cache:
return cache[cache_key]
result = await func(*args, **kwargs)
cache[cache_key] = result
return result

return wrapper
76 changes: 41 additions & 35 deletions custom_components/powercalc/power_profile/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from typing import NamedTuple

from homeassistant.core import HomeAssistant
from homeassistant.helpers.singleton import singleton

from custom_components.powercalc.const import CONF_DISABLE_LIBRARY_DOWNLOAD, DATA_PROFILE_LIBRARY, DOMAIN, DOMAIN_CONFIG
from custom_components.powercalc.const import CONF_DISABLE_LIBRARY_DOWNLOAD, DOMAIN, DOMAIN_CONFIG

from .error import LibraryError
from .loader.composite import CompositeLoader
Expand All @@ -34,19 +35,14 @@ async def initialize(self) -> None:
await self._loader.initialize()

@staticmethod
@singleton("powercalc_library")
async def factory(hass: HomeAssistant) -> ProfileLibrary:
"""Creates and loads the profile library
Makes sure it is only loaded once and instance is saved in hass data registry.
"""
if DOMAIN not in hass.data:
hass.data[DOMAIN] = {}

if DATA_PROFILE_LIBRARY in hass.data[DOMAIN]:
return hass.data[DOMAIN][DATA_PROFILE_LIBRARY] # type: ignore

Creates and loads the profile library.
Make sure we have a single instance throughout the application.
"""
library = ProfileLibrary(hass, ProfileLibrary.create_loader(hass))
await library.initialize()
hass.data[DOMAIN][DATA_PROFILE_LIBRARY] = library
return library

@staticmethod
Expand All @@ -61,7 +57,8 @@ def create_loader(hass: HomeAssistant) -> Loader:
if os.path.exists(data_dir)
]

global_config = hass.data[DOMAIN].get(DOMAIN_CONFIG, {})
domain_config = hass.data.get(DOMAIN, {})
global_config = domain_config.get(DOMAIN_CONFIG, {})
disable_library_download: bool = bool(global_config.get(CONF_DISABLE_LIBRARY_DOWNLOAD, False))
if not disable_library_download:
loaders.append(RemoteLoader(hass))
Expand All @@ -78,17 +75,22 @@ async def get_manufacturer_listing(self, entity_domain: str | None = None) -> li
async def get_model_listing(self, manufacturer: str, entity_domain: str | None = None) -> list[str]:
"""Get listing of available models for a given manufacturer."""

resolved_manufacturer = await self._loader.find_manufacturer(manufacturer)
if not resolved_manufacturer:
resolved_manufacturers = await self._loader.find_manufacturers(manufacturer)
if not resolved_manufacturers:
return []
device_types = get_device_types_from_domain(entity_domain) if entity_domain else None
cache_key = f"{resolved_manufacturer}/{device_types}"
cached_models = self._manufacturer_models.get(cache_key)
if cached_models:
return cached_models
models = await self._loader.get_model_listing(resolved_manufacturer, device_types)
self._manufacturer_models[cache_key] = sorted(models)
return self._manufacturer_models[cache_key]
all_models: list[str] = []
for manufacturer in resolved_manufacturers:
cache_key = f"{manufacturer}/{device_types}"
cached_models = self._manufacturer_models.get(cache_key)
if cached_models:
all_models.extend(cached_models)
continue
models = await self._loader.get_model_listing(manufacturer, device_types)
self._manufacturer_models[cache_key] = sorted(models)
all_models.extend(models)

return all_models

async def get_profile(
self,
Expand Down Expand Up @@ -116,30 +118,24 @@ async def create_power_profile(
) -> PowerProfile:
"""Create a power profile object from the model JSON data."""

manufacturer = model_info.manufacturer
model = model_info.model
if not custom_directory:
manufacturer = await self.find_manufacturer(model_info) # type: ignore
if manufacturer is None:
raise LibraryError(f"Manufacturer {model_info.manufacturer} not found")

models = await self.find_models(manufacturer, model_info)
models = await self.find_models(model_info)
if not models:
raise LibraryError(f"Model {manufacturer} {model} not found")
model = next(iter(models))
raise LibraryError(f"Model {model_info.manufacturer} {model_info.model} not found")
model_info = next(iter(models))

json_data, directory = await self._load_model_data(manufacturer, model, custom_directory)
json_data, directory = await self._load_model_data(model_info.manufacturer, model_info.model, custom_directory)
if linked_profile := json_data.get("linked_lut"):
linked_manufacturer, linked_model = linked_profile.split("/")
_, directory = await self._load_model_data(linked_manufacturer, linked_model, custom_directory)

return await self._create_power_profile_instance(manufacturer, model, directory, json_data)
return await self._create_power_profile_instance(model_info.manufacturer, model_info.model, directory, json_data)

async def find_manufacturer(self, model_info: ModelInfo) -> str | None:
async def find_manufacturers(self, manufacturer: str) -> set[str]:
"""Resolve the manufacturer, either from the model info or by loading it."""
return await self._loader.find_manufacturer(model_info.manufacturer)
return await self._loader.find_manufacturers(manufacturer)

async def find_models(self, manufacturer: str, model_info: ModelInfo) -> set[str]:
async def find_models(self, model_info: ModelInfo) -> set[ModelInfo]:
"""Resolve the model identifier, searching for it if no custom directory is provided."""
search: set[str] = set()
for model_identifier in (model_info.model_id, model_info.model):
Expand All @@ -155,7 +151,17 @@ async def find_models(self, manufacturer: str, model_info: ModelInfo) -> set[str
if "/" in model_identifier:
search.update(model_identifier.split("/"))

return set(await self._loader.find_model(manufacturer, search))
manufacturers = await self._loader.find_manufacturers(model_info.manufacturer)
found_models: set[ModelInfo] = set()
if not manufacturers:
return found_models

for manufacturer in manufacturers:
models = await self._loader.find_model(manufacturer, search)
if models:
found_models.update(ModelInfo(manufacturer, model) for model in models)

return found_models

async def _load_model_data(self, manufacturer: str, model: str, custom_directory: str | None) -> tuple[dict, str]:
"""Load the model data from the appropriate directory."""
Expand Down
17 changes: 9 additions & 8 deletions custom_components/powercalc/power_profile/loader/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@ async def get_manufacturer_listing(self, device_types: set[DeviceType] | None) -

return {manufacturer for loader in self.loaders for manufacturer in await loader.get_manufacturer_listing(device_types)}

async def find_manufacturer(self, search: str) -> str | None:
async def find_manufacturers(self, search: str) -> set[str]:
"""Check if a manufacturer is available. Also must check aliases."""

search = search.lower()
found_manufacturers = set()
for loader in self.loaders:
manufacturer = await loader.find_manufacturer(search)
if manufacturer:
return manufacturer
manufacturers = await loader.find_manufacturers(search)
if manufacturers:
found_manufacturers.update(manufacturers)

return None
return found_manufacturers

async def get_model_listing(self, manufacturer: str, device_types: set[DeviceType] | None) -> set[str]:
"""Get listing of available models for a given manufacturer."""
Expand All @@ -42,11 +43,11 @@ async def load_model(self, manufacturer: str, model: str) -> tuple[dict, str] |

return None

async def find_model(self, manufacturer: str, search: set[str]) -> list[str]:
async def find_model(self, manufacturer: str, search: set[str]) -> set[str]:
"""Find the model in the library."""

models = []
models = set()
for loader in self.loaders:
models.extend(await loader.find_model(manufacturer, search))
models.update(await loader.find_model(manufacturer, search))

return models
12 changes: 6 additions & 6 deletions custom_components/powercalc/power_profile/loader/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ async def get_manufacturer_listing(self, device_types: set[DeviceType] | None) -

return manufacturers

async def find_manufacturer(self, search: str) -> str | None:
async def find_manufacturers(self, search: str) -> set[str]:
"""Check if a manufacturer is available."""

_search = search.lower()
manufacturer_list = self._manufacturer_model_listing.keys()
if _search in manufacturer_list:
return _search
return {_search}

return None
return set()

async def get_model_listing(self, manufacturer: str, device_types: set[DeviceType] | None) -> set[str]:
"""Get listing of available models for a given manufacturer.
Expand Down Expand Up @@ -104,19 +104,19 @@ async def load_model(self, manufacturer: str, model: str) -> tuple[dict, str] |
model_json = lib_model.json_data
return model_json, model_path

async def find_model(self, manufacturer: str, search: set[str]) -> list[str]:
async def find_model(self, manufacturer: str, search: set[str]) -> set[str]:
"""Find a model for a given manufacturer. Also must check aliases."""
_manufacturer = manufacturer.lower()

models = self._manufacturer_model_listing.get(_manufacturer)
if not models:
_LOGGER.info("Manufacturer does not exist in custom library: %s", _manufacturer)
return []
return set()

search_lower = {phrase.lower() for phrase in search}

profile = next((models[model] for model in models if model.lower() in search_lower), None)
return [profile.model] if profile else []
return {profile.model} if profile else set()

def _load_custom_library(self) -> None:
"""Loading custom models and aliases from file system.
Expand Down
4 changes: 2 additions & 2 deletions custom_components/powercalc/power_profile/loader/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ async def initialize(self) -> None:
async def get_manufacturer_listing(self, device_types: set[DeviceType] | None) -> set[str]:
"""Get listing of possible manufacturers."""

async def find_manufacturer(self, search: str) -> str | None:
async def find_manufacturers(self, search: str) -> set[str]:
"""Check if a manufacturer is available. Also must check aliases."""

async def get_model_listing(self, manufacturer: str, device_types: set[DeviceType] | None) -> set[str]:
Expand All @@ -19,5 +19,5 @@ async def get_model_listing(self, manufacturer: str, device_types: set[DeviceTyp
async def load_model(self, manufacturer: str, model: str) -> tuple[dict, str] | None:
"""Load and optionally download a model profile."""

async def find_model(self, manufacturer: str, search: set[str]) -> list[str]:
async def find_model(self, manufacturer: str, search: set[str]) -> set[str]:
"""Check if a model is available. Also must check aliases."""
Loading
Loading