Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ak/automation copilot #50

Merged
merged 4 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/integration-addon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ jobs:
- name: Lint with Ruff
run: |
ruff check .
- name: install playwright deps
run: |
playwright install --with-deps chromium
- name: Copy python source + services into build context because builder action doesn't support --build-context
id: copy
run: bash scripts/copy_content_to_addon_context.sh
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/integration-local.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ jobs:
- name: Lint with Ruff
run: |
ruff check .
- name: install playwright deps
run: |
playwright install --with-deps chromium
- name: Test with pytest
run: |
pytest -v -s -rA -c tests/pytest.ini --deploy-mode=local --replay-mode=replay
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# LLMOps for Home Assistant
# Home ~~Automation~~ Intelligence

Manage your home automation LLM prompts, available LLMs and evaluate changes with MLflow

Expand Down
Empty file.
229 changes: 229 additions & 0 deletions python/src/mindctrl/homeassistant/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
import asyncio
import logging
from typing import Union
import time
import httpx
from httpx_ws import aconnect_ws
from pydantic import ValidationError

from .messages import (
AreasResult,
Auth,
AuthChallenge,
AuthOk,
Automation,
Command,
CreateAutomation,
CreateLabel,
Error,
Label,
LabelsResult,
ListAreas,
ListEntities,
ListLabels,
ManyResponsesWrapper,
Result,
SingleResponseWrapper,
UpdateEntityLabels,
)

_logger = logging.getLogger(__name__)


class HassClientError(Exception):
pass


class HassClient(object):
def __init__(self, id: str, hass_url: httpx.URL, token: str):
self.hass_url = hass_url
self._token = token
if not (str(self.hass_url).endswith("api")):
raise ValueError(
f"hass_url must end with 'api'. For example, 'http://homeassistant.local:8123/api'. Got: {self.hass_url}"
)
self.headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
self.websocket_message_id = 0
self._client = httpx.AsyncClient(base_url=self.hass_url, headers=self.headers)
self._ws_session = aconnect_ws(
f"{self._client.base_url}websocket", client=self._client
)
self._authenticated_session = None

# is there a better way to process these?
self._command_results: dict[int, Union[Result, Error]] = {}

@property
def authenticated_session(self):
if not self._authenticated_session:
raise ValueError("Session not authenticated")
return self._authenticated_session

async def __aenter__(self):
session = await self._ws_session.__aenter__()

auth_required_msg = AuthChallenge.model_validate(await session.receive_json())
_logger.info(auth_required_msg)

await session.send_json(Auth(access_token=self._token).model_dump())

auth_ok_msg = AuthOk.model_validate(await session.receive_json())
_logger.info(auth_ok_msg)

self._authenticated_session = session
return self

async def __aexit__(self, exc_type, exc, tb):
if len(self._command_results.keys()) > 0:
_logger.warning(
f"Exiting with {len(self._command_results.keys())} unprocessed messages"
)
await self._ws_session.__aexit__(exc_type, exc, tb)
await self._client.aclose()

async def _receive_message(self, message_id: int) -> Union[Result, None]:
if message_id in self._command_results.keys():
val = self._command_results.pop(message_id)
if not val.success:
assert isinstance(val, Error)
raise HassClientError(f"Error: {val.code} - {val.message}")
assert isinstance(val, Result)
return val

response_json = await self.authenticated_session.receive_json()
responses: list[Union[Result, Error]]
try:
response = SingleResponseWrapper.model_validate(
{"response": response_json}, strict=False
)
responses = [response.response]
except ValidationError as ve:
_logger.info(f"receive: {ve}")
response = ManyResponsesWrapper.model_validate(
{"responses": response_json}, strict=False
)
responses = response.responses

return_response: Union[Result, Error, None] = None
for response in responses:
if response.id != message_id:
self._command_results[response.id] = response
else:
return_response = response

if isinstance(return_response, Error):
raise HassClientError(
f"Error: {return_response.code} - {return_response.message}"
)

return return_response

async def _send_message(self, message: Command) -> Result:
# The atomicity of this is sketchy - add tests
self.websocket_message_id += 1
message.id = self.websocket_message_id
await self.authenticated_session.send_json(message.model_dump())
result = None
while result is None:
result = await self._receive_message(message.id)
if result is None:
_logger.info("Recent receive batch didn't have the response")
await asyncio.sleep(0.1)
return result

@staticmethod
def _current_milli_time():
return round(time.time() * 1000)

async def list_entities(self):
entities = await self._send_message(ListEntities(id=-1))
if entities.result is None:
_logger.warning(f"Unexpected null entities result: {entities}")
return []
return entities.result

# TODO: bad api call pattern, need to revisit
async def list_automations(self) -> list[Automation]:
entities = await self.list_entities()
_logger.debug(entities)
automation_entities = [
entity
for entity in entities
if entity["platform"] == "automation"
]

_logger.info(f"Fetching {len(automation_entities)} automations")
fetch_automation_tasks = []
for entity in automation_entities:
fetch_automation_tasks.append(self.get_automation(entity["unique_id"]))
return await asyncio.gather(*fetch_automation_tasks)

async def list_labels(self):
any_result = await self._send_message(ListLabels(id=-1))
labels = LabelsResult.model_validate_json(any_result.model_dump_json())
return labels.result

async def list_areas(self):
any_result = await self._send_message(ListAreas(id=-1))
areas = AreasResult.model_validate_json(any_result.model_dump_json())
return areas.result

async def create_label(self, label: Label):
await self._send_message(
CreateLabel(
color=label.color,
description=label.description,
icon=label.icon,
name=label.name,
id=-1,
)
)
return

async def add_labels(self, entity_id: str, labels: list[str]):
await self._send_message(
UpdateEntityLabels(
entity_id=entity_id,
labels=labels,
id=-1,
)
)

async def get_automation(self, unique_id: str):
# Creating an automation is REST POST to a unix timestamp
# http://hass-dev.ak:8123/api/config/automation/config/1713577351529
# Request Method:
# POST
# Status Code:
# 200 OK
path = f"config/automation/config/{unique_id}"
get_response = await self._client.get(path)
get_response.raise_for_status()
return Automation.model_validate(get_response.json())

async def create_automation(self, name: str, description: str):
# Creating an automation is REST POST to a unix timestamp
# http://hass-dev.ak:8123/api/config/automation/config/1713577351529
# Request Method:
# POST
# Status Code:
# 200 OK
current_milli_time = HassClient._current_milli_time()
path = f"config/automation/config/{current_milli_time}"
response = await self._client.post(
path,
json=CreateAutomation(
alias=name,
description=description,
mode="single",
action=[],
condition=[],
trigger=[],
).model_dump(),
)
response.raise_for_status()
# response.json() is just {'result': 'ok'}, need to do a get (why?)
return await self.get_automation(str(current_milli_time))
117 changes: 117 additions & 0 deletions python/src/mindctrl/homeassistant/messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from typing import Any, Optional, Union
from pydantic import BaseModel


class Message(BaseModel):
type: str


class AuthChallenge(Message):
type: str = "auth_required"
ha_version: str


class AuthOk(Message):
type: str = "auth_ok"
ha_version: str


class Auth(Message):
type: str = "auth"
access_token: str


class Command(Message):
id: int


class CommandResponse(Command):
type: str = "result"
success: bool


class Error(CommandResponse):
success: bool = False
code: str
message: str


class Result(CommandResponse):
success: bool = True
result: Optional[Any]


class ManyResponsesWrapper(BaseModel):
responses: list[Union[Error, Result]]


class SingleResponseWrapper(BaseModel):
response: Union[Error, Result]


class ListEntities(Command):
type: str = "config/entity_registry/list"


class ListLabels(Command):
type: str = "config/label_registry/list"


class ListAreas(Command):
type: str = "config/area_registry/list"


# {"color":"indigo","description":null,"icon":"mdi:account","label_id":"test","name":"test"}
class Label(BaseModel):
color: str
description: Optional[str]
icon: str
label_id: str
name: str


class LabelsResult(Result):
result: list[Label]


class Area(BaseModel):
area_id: str
name: str
aliases: list[str]
floor_id: Optional[str]
icon: Optional[str]
labels: list[str]
picture: Optional[str]


class AreasResult(Result):
result: list[Area]


# {"type":"config/label_registry/create","name":"test","icon":"mdi:account","color":"indigo","id":62}
class CreateLabel(Command):
type: str = "config/label_registry/create"
description: Optional[str]
name: str
icon: str
color: str


# {"type":"config/entity_registry/update","entity_id":"automation.test_zone_automation","labels":["test"],"id":51}
class UpdateEntityLabels(Command):
type: str = "config/entity_registry/update"
entity_id: str
labels: list[str]


class CreateAutomation(BaseModel):
action: list[Any]
alias: str
condition: list[Any]
description: str
mode: str
trigger: list[Any]


class Automation(CreateAutomation):
id: str
Loading