Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Support firmware update from HA #249

Merged
merged 8 commits into from
Jan 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions custom_components/aquarea/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,14 @@ def read_stats_json(field_name: str, json_doc: str) -> Optional[float]:
return float(field_value)
return None

def read_board_type(json_doc: str) -> Optional[str]:
j = json.loads(json_doc)
if "board" in j:
return j["board"]
if "voltage" in j:
if j["voltage"] != "3.3":
return "ESP8266"
return None

def ms_to_secs(value: Optional[float]) -> Optional[float]:
if value:
Expand Down Expand Up @@ -1806,6 +1814,16 @@ def build_sensors(mqtt_prefix: str) -> list[HeishaMonSensorEntityDescription]:
entity_category=EntityCategory.DIAGNOSTIC,
state_class=SensorStateClass.MEASUREMENT,
),
HeishaMonSensorEntityDescription(
heishamon_topic_id="STAT1-board",
key=f"{mqtt_prefix}stats",
name="HeishaMon Board type",
state=read_board_type,
device=DeviceType.HEISHAMON,
state_class=SensorStateClass.MEASUREMENT,
device_class=SensorDeviceClass.ENUM,
entity_category=EntityCategory.DIAGNOSTIC,
),
HeishaMonSensorEntityDescription(
heishamon_topic_id="INFO_ip",
key=f"{mqtt_prefix}ip",
Expand Down
120 changes: 115 additions & 5 deletions custom_components/aquarea/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import logging
import json
import aiohttp
from typing import Optional
import asyncio
from typing import Optional, Any
from io import BufferedReader, BytesIO

from homeassistant.components import mqtt
from homeassistant.components.mqtt.client import async_publish
Expand All @@ -21,7 +23,7 @@

from . import build_device_info
from .const import DeviceType
from .definitions import HeishaMonEntityDescription, frozendataclass
from .definitions import HeishaMonEntityDescription, frozendataclass, read_board_type

_LOGGER = logging.getLogger(__name__)
HEISHAMON_REPOSITORY = "Egyras/HeishaMon"
Expand Down Expand Up @@ -87,14 +89,32 @@ def __init__(
self.stats_firmware_contain_version: Optional[bool] = None

self._attr_supported_features = (
UpdateEntityFeature.RELEASE_NOTES | UpdateEntityFeature.INSTALL
UpdateEntityFeature.RELEASE_NOTES | UpdateEntityFeature.INSTALL | UpdateEntityFeature.PROGRESS | UpdateEntityFeature.SPECIFIC_VERSION
)
self._attr_release_url = f"https://github.com/{HEISHAMON_REPOSITORY}/releases"
self._model_type = None
self._release_notes = None
self._attr_progress = False

self._ip_topic = f"{self.discovery_prefix}ip"
self._heishamon_ip = None

self._stats_topic = f"{self.discovery_prefix}stats"

async def async_added_to_hass(self) -> None:
"""Subscribe to MQTT events."""

@callback
def ip_received(message):
self._heishamon_ip = message.payload
await mqtt.async_subscribe(self.hass, self._ip_topic, ip_received, 1)

@callback
def read_model(message):
self._model_type = read_board_type(message.payload)
await mqtt.async_subscribe(self.hass, self._stats_topic, read_model, 1)


@callback
def message_received(message):
"""Handle new MQTT messages."""
Expand Down Expand Up @@ -165,6 +185,96 @@ async def _update_latest_release(self):
self._release_notes = last_release["body"]
self.async_write_ha_state()

@property
def model_to_file(self) -> str | None:
return {
"ESP32": "model-type-large",
"ESP8266": "model-type-small",
None: "UNKNOWN",
}.get(self._model_type, None)


def release_notes(self) -> str | None:
header = f"⚠ Update is not supported via HA. Update is done via heishamon webui\n\n\n"
return header + str(self._release_notes)
return f"⚠️ Automated upgrades will fetch `{self.model_to_file}` binaries.\n\nBeware!\n\n" + str(self._release_notes)

async def async_install(self, version: str | None, backup: bool, **kwargs: Any) -> None:
if self._model_type is None:
raise Exception("Impossible to update automatically because we don't know the board version")
if version is None:
version = self._attr_latest_version
_LOGGER.info(f"Will install latest version ({version}) of the firmware")
else:
_LOGGER.info(f"Will install version {version} of the firmware")
self._attr_progress = 0
async with aiohttp.ClientSession() as session:
resp = await session.get(
f"https://github.com/{HEISHAMON_REPOSITORY}/raw/master/binaries/{self.model_to_file}/HeishaMon.ino.d1-v{version}.bin"
)

if resp.status != 200:
_LOGGER.warn(
f"Impossible to download version {version} from heishamon repository {HEISHAMON_REPOSITORY}"
)
return

firmware_binary = await resp.read()
_LOGGER.info(f"Firmware is {len(firmware_binary)} bytes long")
self._attr_progress = 10
resp = await session.get(
f"https://github.com/{HEISHAMON_REPOSITORY}/raw/master/binaries/{self.model_to_file}/HeishaMon.ino.d1-v{version}.md5"
)

if resp.status != 200:
_LOGGER.warn(
f"Impossible to fetch checksum of version #{version} from heishamon repository {HEISHAMON_REPOSITORY}"
)
return
checksum = await resp.text()
self._attr_progress = 20
_LOGGER.info(f"Downloaded binary and checksum {checksum} of version {version}")

while self._heishamon_ip is None:
_LOGGER.warn("Waiting for an mqtt message to get the ip address of heishamon")
await asyncio.sleep(1)

def track_progress(current, total):
self._attr_progress = int(current / total * 100)
_LOGGER.info(f"Currently read {current} out of {total}: {self._attr_progress}%")


async with aiohttp.ClientSession() as session:
_LOGGER.info(f"Starting upgrade of firmware to version {version} on {self._heishamon_ip}")
to = aiohttp.ClientTimeout(total=300, connect=10)
try:
with ProgressReader(firmware_binary, track_progress) as reader:
resp = await session.post(
f"http://{self._heishamon_ip}/firmware",
data={
'md5': checksum,
# 'firmware': ('firmware.bin', firmware_binary, 'application/octet-stream')
'firmware': reader

},
timeout=to
)
except TimeoutError as e:
_LOGGER.error(f"Timeout while uploading new firmware")
raise e
if resp.status != 200:
_LOGGER.warn(f"Impossible to perform firmware update to version {version}")
return
_LOGGER.info(f"Finished uploading firmware. Heishamon should now be rebooting")

class ProgressReader(BufferedReader):
def __init__(self, binary_data, read_callback=None):
self._read_callback = read_callback
super().__init__(raw=BytesIO(binary_data))
self.length = len(binary_data)

def read(self, size=None):
computed_size = size
if not computed_size:
computed_size = self.length - self.tell()
if self._read_callback:
self._read_callback(self.tell(), self.length)
return super(ProgressReader, self).read(size)
Loading