From c8e5bc3403a15c7bd9944ddb998a9d634b747a06 Mon Sep 17 00:00:00 2001 From: Clifford Roche <1007595+cmroche@users.noreply.github.com> Date: Sat, 6 Jul 2024 17:29:54 -0400 Subject: [PATCH 01/15] Starting work on adding new ciphers --- .idea/codeStyles/codeStyleConfig.xml | 5 ++ .idea/material_theme_project_new.xml | 17 ++++++ .idea/ruff.xml | 8 +++ emulator.py | 1 - greeclimate/cipher.py | 87 ++++++++++++++++++++++++++++ greeclimate/device.py | 23 ++++---- greeclimate/network.py | 72 ++++++++--------------- tests/common.py | 23 ++++++-- tests/conftest.py | 14 ++++- tests/test_cipher.py | 48 +++++++++++++++ tests/test_device.py | 21 ++++--- tests/test_discovery.py | 10 +--- tests/test_network.py | 49 ++++++---------- 13 files changed, 262 insertions(+), 116 deletions(-) create mode 100644 .idea/codeStyles/codeStyleConfig.xml create mode 100644 .idea/material_theme_project_new.xml create mode 100644 .idea/ruff.xml create mode 100644 greeclimate/cipher.py create mode 100644 tests/test_cipher.py diff --git a/.idea/codeStyles/codeStyleConfig.xml b/.idea/codeStyles/codeStyleConfig.xml new file mode 100644 index 0000000..79ee123 --- /dev/null +++ b/.idea/codeStyles/codeStyleConfig.xml @@ -0,0 +1,5 @@ + + + + \ No newline at end of file diff --git a/.idea/material_theme_project_new.xml b/.idea/material_theme_project_new.xml new file mode 100644 index 0000000..374ec3d --- /dev/null +++ b/.idea/material_theme_project_new.xml @@ -0,0 +1,17 @@ + + + + + + + \ No newline at end of file diff --git a/.idea/ruff.xml b/.idea/ruff.xml new file mode 100644 index 0000000..bb95353 --- /dev/null +++ b/.idea/ruff.xml @@ -0,0 +1,8 @@ + + + + + \ No newline at end of file diff --git a/emulator.py b/emulator.py index a3ac6b5..11d0b09 100644 --- a/emulator.py +++ b/emulator.py @@ -6,7 +6,6 @@ import socket import time -import machine import network import ubinascii from ucryptolib import aes diff --git a/greeclimate/cipher.py b/greeclimate/cipher.py new file mode 100644 index 0000000..b021889 --- /dev/null +++ b/greeclimate/cipher.py @@ -0,0 +1,87 @@ +import base64 +import json +import logging +from typing import Union, Tuple + +from Crypto.Cipher import AES + + +class CipherBase: + def __init__(self, key: bytes) -> None: + self._key: bytes = key + + @property + def key(self) -> str: + return self._key.decode() + + @key.setter + def key(self, value: str) -> None: + self._key = value.encode() + + def encrypt(self, data) -> Tuple[str, Union[str, None]]: + raise NotImplementedError + + def decrypt(self, data) -> dict: + raise NotImplementedError + + +class CipherV1(CipherBase): + def __init__(self, key: bytes) -> None: + super().__init__(key) + + def __create_cipher(self) -> AES: + return AES.new(self._key, AES.MODE_ECB) + + def __pad(self, s) -> str: + return s + (16 - len(s) % 16) * chr(16 - len(s) % 16) + + def encrypt(self, data) -> Tuple[str, Union[str, None]]: + logging.debug("Encrypting data: %s", data) + cipher = self.__create_cipher() + padded = self.__pad(json.dumps(data)).encode() + encrypted = cipher.encrypt(padded) + encoded = base64.b64encode(encrypted).decode() + logging.info("Encrypted data: %s", encoded) + return encoded, None + + def decrypt(self, data) -> dict: + logging.info("Decrypting data: %s", data) + cipher = self.__create_cipher() + decoded = base64.b64decode(data) + decrypted = cipher.decrypt(decoded).decode() + t = decrypted.replace(decrypted[decrypted.rindex('}') + 1:], '') + logging.debug("Decrypted data: %s", t) + return json.loads(t) + + +class CipherV2(CipherBase): + GCM_NONCE = b'\x54\x40\x78\x44\x49\x67\x5a\x51\x6c\x5e\x63\x13' + GCM_AEAD = b'qualcomm-test' + + def __init__(self, key: bytes) -> None: + super().__init__(key) + + def __create_cipher(self) -> AES: + cipher = AES.new(self._key, AES.MODE_GCM, nonce=self.GCM_NONCE) + cipher.update(self.GCM_AEAD) + return cipher + + def encrypt(self, data) -> Tuple[str, str]: + logging.debug("Encrypting data: %s", data) + cipher = self.__create_cipher() + encrypted, tag = cipher.encrypt_and_digest(json.dumps(data).encode()) + encoded = base64.b64encode(encrypted).decode() + tag = base64.b64encode(tag).decode() + logging.debug("Encrypted data: %s", encoded) + logging.debug("Cipher digest: %s", tag) + return encoded, tag + + def decrypt(self, data) -> dict: + logging.info("Decrypting data: %s", data) + cipher = self.__create_cipher() + decoded = base64.b64decode(data) + decrypted = cipher.decrypt(decoded).decode() + t = decrypted.replace(decrypted[decrypted.rindex('}') + 1:], '') + logging.debug("Decrypted data: %s", t) + return json.loads(t) + diff --git a/greeclimate/device.py b/greeclimate/device.py index 93ed49e..ebd674f 100644 --- a/greeclimate/device.py +++ b/greeclimate/device.py @@ -4,11 +4,10 @@ import re from asyncio import AbstractEventLoop from enum import IntEnum, unique -from typing import List -import greeclimate.network as network +from greeclimate.cipher import CipherV1 from greeclimate.deviceinfo import DeviceInfo -from greeclimate.network import DeviceProtocol2, IPAddr +from greeclimate.network import DeviceProtocol2 from greeclimate.exceptions import DeviceNotBoundError, DeviceTimeoutError from greeclimate.taskable import Taskable @@ -207,7 +206,7 @@ async def bind(self, key=None): try: if key: - self.device_key = key + self.device_cipher = CipherV1(key.encode()) else: await self.send(self.create_bind_message(self.device_info)) # Special case, wait for binding to complete so we know that the device is ready @@ -217,18 +216,18 @@ async def bind(self, key=None): except asyncio.TimeoutError: raise DeviceTimeoutError - if not self.device_key: + if not self.device_cipher: raise DeviceNotBoundError else: - self._logger.info("Bound to device using key %s", self.device_key) + self._logger.info("Bound to device using key %s", self.device_cipher.key) - def handle_device_bound(self, key) -> None: + def handle_device_bound(self, key: str) -> None: """Handle the device bound message from the device""" - self.device_key = key + self.device_cipher = CipherV1(key.encode()) async def request_version(self) -> None: """Request the firmware version from the device.""" - if not self.device_key: + if not self.device_cipher: await self.bind() try: @@ -243,7 +242,7 @@ async def update_state(self, wait_for: float = 30): Args: wait_for (object): How long to wait for an update from the device """ - if not self.device_key: + if not self.device_cipher: await self.bind() self._logger.debug("Updating device properties for (%s)", str(self.device_info)) @@ -288,7 +287,7 @@ async def push_state_update(self, wait_for: float = 30): if not self._dirty: return - if not self.device_key: + if not self.device_cipher: await self.bind() self._logger.debug("Pushing state updates to (%s)", str(self.device_info)) @@ -316,7 +315,7 @@ def __eq__(self, other): """Compare two devices for equality based on their properties state and device info.""" return self.device_info == other.device_info \ and self.raw_properties == other.raw_properties \ - and self.device_key == other.device_key + and self.device_cipher.key == other.device_cipher.key def __ne__(self, other): return not self.__eq__(other) diff --git a/greeclimate/network.py b/greeclimate/network.py index aabea91..083cd46 100644 --- a/greeclimate/network.py +++ b/greeclimate/network.py @@ -1,24 +1,22 @@ import asyncio -import base64 import json import logging import socket from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, Text, Tuple, Union - -from Crypto.Cipher import AES +from typing import Any, Dict, Tuple, Union +from greeclimate.cipher import CipherBase, CipherV1 from greeclimate.deviceinfo import DeviceInfo NETWORK_TIMEOUT = 10 -GENERIC_KEY = ["a3K8Bx%2r8Y7#xDh"] -GCM_NONCE = b'\x54\x40\x78\x44\x49\x67\x5a\x51\x6c\x5e\x63\x13' -GCM_AEAD = b'qualcomm-test' +GENERIC_CIPHERS_KEYS = [ + b'a3K8Bx%2r8Y7#xDh', + b'{yxAHAY_Lm6pbC/<' +] _LOGGER = logging.getLogger(__name__) - IPAddr = Tuple[str, int] @@ -52,11 +50,11 @@ def __init__(self, timeout: int = 10, drained: asyncio.Event = None) -> None: timeout (int): Packet send timeout drained (asyncio.Event): Packet send drain event signal """ - self._timeout = timeout - self._drained = drained or asyncio.Event() + self._timeout: int = timeout + self._drained: asyncio.Event = drained or asyncio.Event() self._drained.set() - self._transport = None - self._key = None + self._transport: Union[asyncio.transports.DatagramTransport, None] = None + self._cipher: Union[CipherBase, None] = None # This event need to be implemented to handle incoming requests def packet_received(self, obj, addr: IPAddr) -> None: @@ -69,14 +67,14 @@ def packet_received(self, obj, addr: IPAddr) -> None: raise NotImplementedError("packet_received must be implemented in a subclass") @property - def device_key(self) -> str: + def device_cipher(self) -> CipherBase: """Sets the encryption key used for device data.""" - return self._key + return self._cipher - @device_key.setter - def device_key(self, value: str): + @device_cipher.setter + def device_cipher(self, value: CipherBase): """Gets the encryption key used for device data.""" - self._key = value + self._cipher = value def close(self) -> None: """Close the UDP transport.""" @@ -116,27 +114,6 @@ def resume_writing(self) -> None: self._drained.set() super().resume_writing() - @staticmethod - def decrypt_payload(payload, key=GENERIC_KEY[0]): - cipher = AES.new(key.encode(), AES.MODE_ECB) - decoded = base64.b64decode(payload) - decrypted = cipher.decrypt(decoded).decode() - t = decrypted.replace(decrypted[decrypted.rindex("}") + 1 :], "") - return json.loads(t) - - @staticmethod - def encrypt_payload(payload, key=GENERIC_KEY[0]): - def pad(s): - bs = 16 - return s + (bs - len(s) % bs) * chr(bs - len(s) % bs) - - cipher = AES.new(key.encode(), AES.MODE_ECB) - encrypted = cipher.encrypt(pad(json.dumps(payload)).encode()) - encoded = base64.b64encode(encrypted).decode() - - _LOGGER.debug(f"Encrypted payload with key [{key}]: {encoded}") - return encoded - def datagram_received(self, data: bytes, addr: IPAddr) -> None: """Handle an incoming datagram.""" if len(data) == 0: @@ -145,22 +122,21 @@ def datagram_received(self, data: bytes, addr: IPAddr) -> None: obj = json.loads(data) # It could be either a v1 or v2 key - key = GENERIC_KEY[0] if obj.get("i") == 1 else self._key + cipher = CipherV1(GENERIC_CIPHERS_KEYS[0]) if obj.get("i") == 1 else self._cipher if obj.get("pack"): - obj["pack"] = DeviceProtocolBase2.decrypt_payload(obj["pack"], key) - - _LOGGER.debug("Received packet from %s:\n%s", addr[0], json.dumps(obj)) + obj["pack"] = cipher.decrypt(obj["pack"]) + _LOGGER.debug("Received packet from %s:\n<- %s", addr[0], json.dumps(obj)) self.packet_received(obj, addr) async def send(self, obj, addr: IPAddr = None) -> None: """Send encode and send JSON command to the device.""" - _LOGGER.debug("Sending packet:\n%s", json.dumps(obj)) + _LOGGER.debug("Sending packet:\n-> %s", json.dumps(obj)) if obj.get("pack"): - key = GENERIC_KEY[0] if obj.get("i") == 1 else self._key - obj["pack"] = DeviceProtocolBase2.encrypt_payload(obj["pack"], key) + cipher = CipherV1(GENERIC_CIPHERS_KEYS[0]) if obj.get("i") == 1 else self._cipher + obj["pack"], _ = cipher.encrypt(obj["pack"]) data_bytes = json.dumps(obj).encode() self._transport.sendto(data_bytes, addr) @@ -169,7 +145,6 @@ async def send(self, obj, addr: IPAddr = None) -> None: await asyncio.wait_for(task, self._timeout) - class BroadcastListenerProtocol(DeviceProtocolBase2): """Special protocol handler for when broadcast is needed.""" @@ -290,6 +265,5 @@ def create_status_message(self, device_info: DeviceInfo, *args) -> Dict[str, Any return self._generate_payload(Commands.STATUS, device_info, {"cols": list(args)}) def create_command_message(self, device_info: DeviceInfo, **kwargs) -> Dict[str, Any]: - return self._generate_payload(Commands.CMD, device_info, {"opt": list(kwargs.keys()), "p": list(kwargs.values())}) - - + return self._generate_payload(Commands.CMD, device_info, + {"opt": list(kwargs.keys()), "p": list(kwargs.values())}) diff --git a/tests/common.py b/tests/common.py index 0ed9651..ff72c4b 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,9 +1,10 @@ -import asyncio import socket from socket import SOCK_DGRAM -from unittest.mock import Mock, create_autospec, patch +from typing import Tuple, Union +from unittest.mock import Mock -from greeclimate.network import DeviceProtocolBase2 +from greeclimate.cipher import CipherV1, CipherBase +from greeclimate.network import GENERIC_CIPHERS_KEYS DEFAULT_TIMEOUT = 1 DISCOVERY_REQUEST = {"t": "scan"} @@ -95,10 +96,24 @@ def get_mock_device_info(): def encrypt_payload(data): """Encrypt the payload of responses quickly.""" d = data.copy() - d["pack"] = DeviceProtocolBase2.encrypt_payload(d["pack"]) + cipher = CipherV1(GENERIC_CIPHERS_KEYS[0]) + d["pack"], _ = cipher.encrypt(d["pack"]) return d +class FakeCipher(CipherBase): + """Fake cipher object for testing.""" + + def __init__(self, key: bytes) -> None: + super().__init__(key) + + def encrypt(self, data) -> Tuple[str, Union[str, None]]: + return data, None + + def decrypt(self, data) -> dict: + return data + + class Responder: """Context manage for easy raw socket responders.""" diff --git a/tests/conftest.py b/tests/conftest.py index 9a64b00..3081be2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,8 @@ import pytest +from tests.common import FakeCipher + MOCK_INTERFACES = ["lo"] MOCK_LO_IFACE = { 2: [{"addr": "10.0.0.1", "netmask": "255.0.0.0", "peer": "10.255.255.255"}] @@ -13,6 +15,16 @@ def netifaces_fixture(): """Patch netifaces interface discover.""" with patch("netifaces.interfaces", return_value=MOCK_INTERFACES), patch( - "netifaces.ifaddresses", return_value=MOCK_LO_IFACE + "netifaces.ifaddresses", return_value=MOCK_LO_IFACE ) as ifaddr_mock: yield ifaddr_mock + + + +@pytest.fixture(name="cipher") +def cipher_fixture(): + """Patch the cipher object.""" + with patch("greeclimate.network.CipherV1") as mock1, patch("greeclimate.cipher.CipherV2") as mock2: + mock1.return_value = FakeCipher(b"1234567890123456") + mock2.return_value = FakeCipher(b"1234567890123456") + yield diff --git a/tests/test_cipher.py b/tests/test_cipher.py new file mode 100644 index 0000000..baaaabb --- /dev/null +++ b/tests/test_cipher.py @@ -0,0 +1,48 @@ +import base64 +import pytest +from greeclimate.cipher import CipherV1, CipherV2 + + +@pytest.fixture +def cipher_v1_key(): + return b'ThisIsASecretKey' + + +@pytest.fixture +def cipher_v1(cipher_v1_key): + return CipherV1(cipher_v1_key) + + +@pytest.fixture +def cipher_v2_key(): + return b'ThisIsASecretKey' + + +@pytest.fixture +def cipher_v2(cipher_v2_key): + return CipherV2(cipher_v2_key) + + +@pytest.fixture +def plain_text(): + return {"message": "Hello, World!"} + + +def test_encryption_then_decryption_yields_original(cipher_v1, plain_text): + encrypted, _ = cipher_v1.encrypt(plain_text) + decrypted = cipher_v1.decrypt(encrypted) + assert decrypted == plain_text + + +def test_decryption_with_modified_data_raises_error(cipher_v1, plain_text): + _, _ = cipher_v1.encrypt(plain_text) + modified_data = base64.b64encode(b"modified data ").decode() + with pytest.raises(UnicodeDecodeError): + cipher_v1.decrypt(modified_data) + + +def test_encryption_then_decryption_yields_original_with_tag(cipher_v2, plain_text): + encrypted, tag = cipher_v2.encrypt(plain_text) + decrypted = cipher_v2.decrypt(encrypted) + assert decrypted == plain_text + diff --git a/tests/test_device.py b/tests/test_device.py index 33b8663..d7d0162 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -1,11 +1,10 @@ import asyncio import enum -from unittest.mock import patch, AsyncMock +from unittest.mock import patch import pytest from greeclimate.device import Device, DeviceInfo, Props, TemperatureUnits -from greeclimate.discovery import Discovery from greeclimate.exceptions import DeviceNotBoundError, DeviceTimeoutError @@ -187,7 +186,7 @@ def test_device_info_equality(): @pytest.mark.asyncio -async def test_get_device_info(): +async def test_get_device_info(cipher): """Initialize device, check properties.""" info = DeviceInfo(*get_mock_info()) @@ -198,7 +197,8 @@ async def test_get_device_info(): fake_key = "abcdefgh12345678" await device.bind(key=fake_key) - assert device.device_key == fake_key + assert device.device_cipher is not None + assert device.device_cipher.key == fake_key @pytest.mark.asyncio @@ -221,7 +221,8 @@ def fake_send(*args): await device.bind() assert mock.call_count == 1 - assert device.device_key == fake_key + assert device.device_cipher is not None + assert device.device_cipher.key == fake_key @pytest.mark.asyncio @@ -238,7 +239,7 @@ async def test_device_bind_timeout(): await device.bind() assert mock.call_count == 1 - assert device.device_key is None + assert device.device_cipher is None @pytest.mark.asyncio @@ -259,7 +260,7 @@ def fake_send(*args): await device.bind() assert mock.call_count == 1 - assert device.device_key is None + assert device.device_cipher is None @pytest.mark.asyncio @@ -278,13 +279,15 @@ def fake_send(*args): with patch.object(Device, "send", wraps=fake_send) as mock: await device.update_state() assert mock.call_count == 2 - assert device.device_key == fake_key + assert device.device_cipher.key == fake_key device.power = True with patch.object(Device, "send"): await device.push_state_update() - assert device.device_key == fake_key + + assert device.device_cipher is not None + assert device.device_cipher.key == fake_key @pytest.mark.asyncio diff --git a/tests/test_discovery.py b/tests/test_discovery.py index 69784fa..d341ff1 100644 --- a/tests/test_discovery.py +++ b/tests/test_discovery.py @@ -1,24 +1,18 @@ import asyncio import json import socket -from asyncio.tasks import wait_for from threading import Thread -from unittest.mock import MagicMock, PropertyMock, create_autospec, patch +from unittest.mock import MagicMock, PropertyMock, patch import pytest -from greeclimate.device import DeviceInfo from greeclimate.discovery import Discovery, Listener -from greeclimate.network import DeviceProtocolBase2 - from .common import ( DEFAULT_TIMEOUT, DISCOVERY_REQUEST, DISCOVERY_RESPONSE, - DISCOVERY_RESPONSE_NO_CID, Responder, - encrypt_payload, - get_mock_device_info, + get_mock_device_info, encrypt_payload, ) diff --git a/tests/test_network.py b/tests/test_network.py index 0e0dc0f..8dbae35 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -2,8 +2,7 @@ import json import socket from threading import Thread -from typing import Any -from unittest.mock import create_autospec, patch, MagicMock +from unittest.mock import patch, MagicMock import pytest @@ -14,7 +13,6 @@ IPAddr, DeviceProtocol2, Commands, Response, ) - from .common import ( DEFAULT_RESPONSE, DEFAULT_TIMEOUT, @@ -22,7 +20,7 @@ DISCOVERY_RESPONSE, Responder, encrypt_payload, - get_mock_device_info, DEFAULT_REQUEST, generate_response, + DEFAULT_REQUEST, generate_response, FakeCipher, ) from .test_device import get_mock_info @@ -44,6 +42,7 @@ class FakeDeviceProtocol(DeviceProtocol2): def __init__(self, drained: asyncio.Event = None): super().__init__(timeout=1, drained=drained) self.packets = asyncio.Queue() + self.device_cipher = FakeCipher(b"1234567890123456") def packet_received(self, obj, addr: IPAddr) -> None: self.packets.put_nowait(obj) @@ -219,7 +218,7 @@ async def test_broadcast_timeout(addr, family): @pytest.mark.asyncio @pytest.mark.parametrize("addr,family", [(("127.0.0.1", 7000), socket.AF_INET)]) -async def test_datagram_connect(addr, family): +async def test_datagram_connect(addr, family, cipher): """Create a socket responder, an async connection, test send and recv.""" with Responder(family, addr[1], bcast=False) as sock: @@ -240,44 +239,30 @@ def responder(s): lambda: FakeDeviceProtocol(drained=drained), remote_addr=remote_addr ) - with patch("greeclimate.network.DeviceProtocolBase2.decrypt_payload", new_callable=MagicMock) as mock: - mock.side_effect = lambda x, y: x - - # Send the scan command - await protocol.send(DEFAULT_REQUEST, None) + # Send the scan command + await protocol.send(DEFAULT_REQUEST, None) - # Wait on the scan response - task = asyncio.create_task(protocol.packets.get()) - await asyncio.wait_for(task, DEFAULT_TIMEOUT) - response = task.result() + # Wait on the scan response + task = asyncio.create_task(protocol.packets.get()) + await asyncio.wait_for(task, DEFAULT_TIMEOUT) + response = task.result() - assert response == DEFAULT_RESPONSE + assert response == DEFAULT_RESPONSE sock.close() serv.join(timeout=DEFAULT_TIMEOUT) -def test_encrypt_decrypt_payload(): - test_object = {"fake-key": "fake-value"} - - encrypted = DeviceProtocolBase2.encrypt_payload(test_object) - assert encrypted != test_object - - decrypted = DeviceProtocolBase2.decrypt_payload(encrypted) - assert decrypted == test_object - - @pytest.mark.asyncio -def test_bindok_handling(): +def test_bindok_handling(cipher): """Test the bindok response.""" response = generate_response({"t": "bindok", "key": "fake-key"}) protocol = DeviceProtocol2(timeout=DEFAULT_TIMEOUT) - with patch("greeclimate.network.DeviceProtocolBase2.decrypt_payload", new_callable=MagicMock) as mock_decrypt: - mock_decrypt.side_effect = lambda x, y: x - with patch.object(DeviceProtocol2, "handle_device_bound") as mock: - protocol.datagram_received(json.dumps(response).encode(), ("0.0.0.0", 0)) - assert mock.call_count == 1 - assert mock.call_args[0][0] == "fake-key" + + with patch.object(DeviceProtocol2, "handle_device_bound") as mock: + protocol.datagram_received(json.dumps(response).encode(), ("0.0.0.0", 0)) + assert mock.call_count == 1 + assert mock.call_args[0][0] == "fake-key" def test_create_bind_message(): From 94b441263873d0815f7938607b9a04edd4b56fd7 Mon Sep 17 00:00:00 2001 From: Clifford Roche <1007595+cmroche@users.noreply.github.com> Date: Sat, 6 Jul 2024 17:53:32 -0400 Subject: [PATCH 02/15] Adding GCM digest to outgoing packets --- greeclimate/network.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/greeclimate/network.py b/greeclimate/network.py index 083cd46..ab4e1b3 100644 --- a/greeclimate/network.py +++ b/greeclimate/network.py @@ -136,7 +136,9 @@ async def send(self, obj, addr: IPAddr = None) -> None: if obj.get("pack"): cipher = CipherV1(GENERIC_CIPHERS_KEYS[0]) if obj.get("i") == 1 else self._cipher - obj["pack"], _ = cipher.encrypt(obj["pack"]) + obj["pack"], tag = cipher.encrypt(obj["pack"]) + if tag: + obj["tag"] = tag data_bytes = json.dumps(obj).encode() self._transport.sendto(data_bytes, addr) From 0319337e015b10e9b857329764dc9185495d93f6 Mon Sep 17 00:00:00 2001 From: Clifford Roche <1007595+cmroche@users.noreply.github.com> Date: Sun, 7 Jul 2024 18:50:20 -0400 Subject: [PATCH 03/15] Hookup new Cipher type with auto-detect Tests do not pass yet though... --- .idea/runConfigurations/pytest_in__.xml | 20 ++++++++ greeclimate/device.py | 62 ++++++++++++++++++------- greeclimate/network.py | 50 +++++++++++++------- tests/common.py | 4 +- tests/conftest.py | 2 +- tests/test_device.py | 32 +++++++------ tests/test_issues.py | 4 +- 7 files changed, 119 insertions(+), 55 deletions(-) create mode 100644 .idea/runConfigurations/pytest_in__.xml diff --git a/.idea/runConfigurations/pytest_in__.xml b/.idea/runConfigurations/pytest_in__.xml new file mode 100644 index 0000000..316c583 --- /dev/null +++ b/.idea/runConfigurations/pytest_in__.xml @@ -0,0 +1,20 @@ + + + + + \ No newline at end of file diff --git a/greeclimate/device.py b/greeclimate/device.py index ebd674f..ea569e2 100644 --- a/greeclimate/device.py +++ b/greeclimate/device.py @@ -4,11 +4,12 @@ import re from asyncio import AbstractEventLoop from enum import IntEnum, unique +from typing import Union -from greeclimate.cipher import CipherV1 +from greeclimate.cipher import CipherV1, CipherV2, CipherBase from greeclimate.deviceinfo import DeviceInfo -from greeclimate.network import DeviceProtocol2 from greeclimate.exceptions import DeviceNotBoundError, DeviceTimeoutError +from greeclimate.network import DeviceProtocol2 from greeclimate.taskable import Taskable @@ -43,6 +44,12 @@ class Props(enum.Enum): UNKNOWN_HEATCOOLTYPE = "HeatCoolType" +GENERIC_CIPHERS_KEYS = { + CipherV1: b'a3K8Bx%2r8Y7#xDh', + CipherV2: b'{yxAHAY_Lm6pbC/<' +} + + @unique class TemperatureUnits(IntEnum): C = 0 @@ -154,6 +161,13 @@ class Device(DeviceProtocol2, Taskable): water_full: A bool to indicate the water tank is full """ + """ Device properties """ + hid = None + version = None + check_version = True + _properties = {} + _dirty = [] + def __init__(self, device_info: DeviceInfo, timeout: int = 120, loop: AbstractEventLoop = None): """Initialize the device object @@ -165,17 +179,9 @@ def __init__(self, device_info: DeviceInfo, timeout: int = 120, loop: AbstractEv DeviceProtocol2.__init__(self, timeout) Taskable.__init__(self, loop) self._logger = logging.getLogger(__name__) - self.device_info: DeviceInfo = device_info - """ Device properties """ - self.hid = None - self.version = None - self.check_version = True - self._properties = {} - self._dirty = [] - - async def bind(self, key=None): + async def bind(self, key: str = None, cipher_type: Union[type[Union[CipherV1, CipherV2]], None] = None): """Run the binding procedure. Binding is a finicky procedure, and happens in 1 of 2 ways: @@ -186,14 +192,18 @@ async def bind(self, key=None): Both approaches result in a device_key which is used as like a persistent session id. Args: + cipher_type (type): The cipher type to use for encryption, if None will attempt to detect the correct one key (str): The device key, when provided binding is a NOOP, if None binding will - attempt to negotiate the key with the device. + attempt to negotiate the key with the device. cipher_type must be provided. Raises: DeviceNotBoundError: If binding was unsuccessful and no key returned DeviceTimeoutError: The device didn't respond """ + if key and not cipher_type: + raise ValueError("cipher_type must be provided when key is provided") + if not self.device_info: raise DeviceNotBoundError @@ -206,12 +216,18 @@ async def bind(self, key=None): try: if key: - self.device_cipher = CipherV1(key.encode()) + self.device_cipher = cipher_type(key.encode()) else: - await self.send(self.create_bind_message(self.device_info)) - # Special case, wait for binding to complete so we know that the device is ready - task = asyncio.create_task(self.ready.wait()) - await asyncio.wait_for(task, timeout=self._timeout) + if cipher_type is not None: + await self.__bind_internal(cipher_type) + else: + """ Try binding with CipherV1 first, if that fails try CipherV2""" + try: + self._logger.info("Attempting to bind to device using CipherV1") + await self.__bind_internal(CipherV1) + except asyncio.TimeoutError: + self._logger.info("Attempting to bind to device using CipherV2") + await self.__bind_internal(CipherV2) except asyncio.TimeoutError: raise DeviceTimeoutError @@ -221,9 +237,19 @@ async def bind(self, key=None): else: self._logger.info("Bound to device using key %s", self.device_cipher.key) + async def __bind_internal(self, cipher_type: type[Union[CipherV1, CipherV2]]): + """Internal binding procedure, do not call directly""" + default_key = GENERIC_CIPHERS_KEYS.get(cipher_type) + await self.send(self.create_bind_message(self.device_info), cipher=cipher_type(default_key)) + task = asyncio.create_task(self.ready.wait()) + await asyncio.wait_for(task, timeout=self._timeout) + def handle_device_bound(self, key: str) -> None: """Handle the device bound message from the device""" - self.device_cipher = CipherV1(key.encode()) + cipher_type = type(self.device_cipher) + if not issubclass(cipher_type, CipherBase): + raise ValueError(f"Invalid cipher type {cipher_type}") + self.device_cipher = cipher_type(key.encode()) async def request_version(self) -> None: """Request the firmware version from the device.""" diff --git a/greeclimate/network.py b/greeclimate/network.py index ab4e1b3..ab443ce 100644 --- a/greeclimate/network.py +++ b/greeclimate/network.py @@ -6,15 +6,10 @@ from enum import Enum from typing import Any, Dict, Tuple, Union -from greeclimate.cipher import CipherBase, CipherV1 +from greeclimate.cipher import CipherBase from greeclimate.deviceinfo import DeviceInfo NETWORK_TIMEOUT = 10 -GENERIC_CIPHERS_KEYS = [ - b'a3K8Bx%2r8Y7#xDh', - b'{yxAHAY_Lm6pbC/<' -] - _LOGGER = logging.getLogger(__name__) IPAddr = Tuple[str, int] @@ -43,6 +38,9 @@ class IPInterface: class DeviceProtocolBase2(asyncio.DatagramProtocol): """Event driven device protocol class.""" + _transport: Union[asyncio.transports.DatagramTransport, None] = None + _cipher: Union[CipherBase, None] = None + def __init__(self, timeout: int = 10, drained: asyncio.Event = None) -> None: """Initialize the device protocol object. @@ -53,8 +51,6 @@ def __init__(self, timeout: int = 10, drained: asyncio.Event = None) -> None: self._timeout: int = timeout self._drained: asyncio.Event = drained or asyncio.Event() self._drained.set() - self._transport: Union[asyncio.transports.DatagramTransport, None] = None - self._cipher: Union[CipherBase, None] = None # This event need to be implemented to handle incoming requests def packet_received(self, obj, addr: IPAddr) -> None: @@ -76,6 +72,20 @@ def device_cipher(self, value: CipherBase): """Gets the encryption key used for device data.""" self._cipher = value + @property + def device_key(self) -> str: + """Gets the encryption key used for device data.""" + if self._cipher is None: + raise ValueError("Cipher object not set") + return self._cipher.key + + @device_key.setter + def device_key(self, value: str): + """Sets the encryption key used for device data.""" + if self._cipher is None: + raise ValueError("Cipher object not set") + self._cipher.key = value + def close(self) -> None: """Close the UDP transport.""" try: @@ -121,22 +131,28 @@ def datagram_received(self, data: bytes, addr: IPAddr) -> None: obj = json.loads(data) - # It could be either a v1 or v2 key - cipher = CipherV1(GENERIC_CIPHERS_KEYS[0]) if obj.get("i") == 1 else self._cipher - if obj.get("pack"): - obj["pack"] = cipher.decrypt(obj["pack"]) + obj["pack"] = self._cipher.decrypt(obj["pack"]) _LOGGER.debug("Received packet from %s:\n<- %s", addr[0], json.dumps(obj)) self.packet_received(obj, addr) - async def send(self, obj, addr: IPAddr = None) -> None: - """Send encode and send JSON command to the device.""" + async def send(self, obj, addr: IPAddr = None, cipher: Union[CipherBase, None] = None) -> None: + """Send encode and send JSON command to the device. + + Args: + addr (object): (Optional) Address to send the message + cipher (object): (Optional) Initial cipher to use for SCANNING and BINDING + """ _LOGGER.debug("Sending packet:\n-> %s", json.dumps(obj)) if obj.get("pack"): - cipher = CipherV1(GENERIC_CIPHERS_KEYS[0]) if obj.get("i") == 1 else self._cipher - obj["pack"], tag = cipher.encrypt(obj["pack"]) + if obj.get("i") == 1: + if cipher is None: + raise ValueError("Cipher must be supplied for SCAN or BIND messages") + self._cipher = cipher + + obj["pack"], tag = self._cipher.encrypt(obj["pack"]) if tag: obj["tag"] = tag @@ -161,6 +177,7 @@ def connection_made(self, transport: asyncio.transports.DatagramTransport) -> No class DeviceProtocol2(DeviceProtocolBase2): """Protocol handler for direct device communication.""" + _handlers = {} def __init__(self, timeout: int = 10, drained: asyncio.Event = None) -> None: """Initialize the device protocol object. @@ -172,7 +189,6 @@ def __init__(self, timeout: int = 10, drained: asyncio.Event = None) -> None: DeviceProtocolBase2.__init__(self, timeout, drained) self._ready = asyncio.Event() self._ready.clear() - self._handlers = {} @property def ready(self) -> asyncio.Event: diff --git a/tests/common.py b/tests/common.py index ff72c4b..bb8e139 100644 --- a/tests/common.py +++ b/tests/common.py @@ -4,7 +4,7 @@ from unittest.mock import Mock from greeclimate.cipher import CipherV1, CipherBase -from greeclimate.network import GENERIC_CIPHERS_KEYS +from greeclimate.device import GENERIC_CIPHERS_KEYS DEFAULT_TIMEOUT = 1 DISCOVERY_REQUEST = {"t": "scan"} @@ -96,7 +96,7 @@ def get_mock_device_info(): def encrypt_payload(data): """Encrypt the payload of responses quickly.""" d = data.copy() - cipher = CipherV1(GENERIC_CIPHERS_KEYS[0]) + cipher = CipherV1(GENERIC_CIPHERS_KEYS[CipherV1]) d["pack"], _ = cipher.encrypt(d["pack"]) return d diff --git a/tests/conftest.py b/tests/conftest.py index 3081be2..5ff5938 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,7 +24,7 @@ def netifaces_fixture(): @pytest.fixture(name="cipher") def cipher_fixture(): """Patch the cipher object.""" - with patch("greeclimate.network.CipherV1") as mock1, patch("greeclimate.cipher.CipherV2") as mock2: + with patch("greeclimate.device.CipherV1") as mock1, patch("greeclimate.device.CipherV2") as mock2: mock1.return_value = FakeCipher(b"1234567890123456") mock2.return_value = FakeCipher(b"1234567890123456") yield diff --git a/tests/test_device.py b/tests/test_device.py index d7d0162..05b6195 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -4,6 +4,7 @@ import pytest +from greeclimate.cipher import CipherV1 from greeclimate.device import Device, DeviceInfo, Props, TemperatureUnits from greeclimate.exceptions import DeviceNotBoundError, DeviceTimeoutError @@ -158,7 +159,7 @@ def get_mock_state_0c_v3_temp(): async def generate_device_mock_async(): d = Device(DeviceInfo("192.168.1.29", 7000, "f4911e7aca59", "1e7aca59")) - await d.bind(key="St8Vw1Yz4Bc7Ef0H") + await d.bind(key="St8Vw1Yz4Bc7Ef0H", cipher_type=CipherV1) return d @@ -195,7 +196,7 @@ async def test_get_device_info(cipher): assert device.device_info == info fake_key = "abcdefgh12345678" - await device.bind(key=fake_key) + await device.bind(key=fake_key, cipher_type=CipherV1) assert device.device_cipher is not None assert device.device_cipher.key == fake_key @@ -212,8 +213,9 @@ async def test_device_bind(): fake_key = "abcdefgh12345678" - def fake_send(*args): + def fake_send(*args, **kwargs): """Emulate a bind event""" + device.device_cipher = CipherV1(fake_key.encode()) device.ready.set() device.handle_device_bound(fake_key) @@ -251,7 +253,7 @@ async def test_device_bind_none(): assert device.device_info == info - def fake_send(*args): + def fake_send(*args, **kwargs): device.ready.set() fake_key = None @@ -272,7 +274,7 @@ async def test_device_late_bind(): fake_key = "abcdefgh12345678" - def fake_send(*args): + def fake_send(*args, **kwargs): device.handle_device_bound(fake_key) device.ready.set() @@ -298,7 +300,7 @@ async def test_update_properties(): for p in Props: assert device.get_property(p) is None - def fake_send(*args): + def fake_send(*args, **kwargs): state = get_mock_state() device.handle_state_update(**state) @@ -441,7 +443,7 @@ async def test_update_current_temp_unsupported(): for p in Props: assert device.get_property(p) is None - def fake_send(*args): + def fake_send(*args, **kwargs): state = get_mock_state_no_temperature() device.handle_state_update(**state) @@ -468,7 +470,7 @@ async def test_update_current_temp_v3(temsen, hid): for p in Props: assert device.get_property(p) is None - def fake_send(*args): + def fake_send(*args, **kwargs): device.handle_state_update(TemSen=temsen, hid=hid) with patch.object(Device, "send", wraps=fake_send): @@ -495,7 +497,7 @@ async def test_update_current_temp_v4(temsen, hid): for p in Props: assert device.get_property(p) is None - def fake_send(*args): + def fake_send(*args, **kwargs): device.handle_state_update(TemSen=temsen, hid=hid) with patch.object(Device, "send", wraps=fake_send): @@ -513,7 +515,7 @@ async def test_update_current_temp_bad(): for p in Props: assert device.get_property(p) is None - def fake_send(*args): + def fake_send(*args, **kwargs): device.handle_state_update(**get_mock_state_bad_temp()) with patch.object(Device, "send", wraps=fake_send): @@ -530,7 +532,7 @@ async def test_update_current_temp_0C_v4(): for p in Props: assert device.get_property(p) is None - def fake_send(*args): + def fake_send(*args, **kwargs): device.handle_state_update(**get_mock_state_0c_v4_temp()) with patch.object(Device, "send", wraps=fake_send): @@ -547,7 +549,7 @@ async def test_update_current_temp_0C_v3(): for p in Props: assert device.get_property(p) is None - def fake_send(*args): + def fake_send(*args, **kwargs): device.handle_state_update(**get_mock_state_0c_v3_temp()) with patch.object(Device, "send", wraps=fake_send): @@ -574,7 +576,7 @@ async def test_send_temperature_celsius(temperature): await device.push_state_update() assert mock_push.call_count == 1 - def fake_send(*args): + def fake_send(*args, **kwargs): device.handle_state_update(**state) with patch.object(Device, "send", wraps=fake_send): @@ -608,7 +610,7 @@ async def test_send_temperature_farenheit(temperature): await device.push_state_update() assert mock_push.call_count == 1 - def fake_send(*args): + def fake_send(*args, **kwargs): device.handle_state_update(**state) with patch.object(Device, "send", wraps=fake_send): @@ -713,7 +715,7 @@ async def test_mismatch_temrec_farenheit(temperature): await device.push_state_update() assert mock_push.call_count == 1 - def fake_send(*args): + def fake_send(*args, **kwargs): device.handle_state_update(**state) with patch.object(Device, "send", wraps=fake_send): diff --git a/tests/test_issues.py b/tests/test_issues.py index 7c29438..ba08eb0 100644 --- a/tests/test_issues.py +++ b/tests/test_issues.py @@ -15,7 +15,7 @@ async def test_issue_69_TemSen_40_should_not_set_firmware_v4(): for p in Props: assert device.get_property(p) is None - def fake_send(*args): + def fake_send(*args, **kwargs): device.handle_state_update(**mock_v3_state) with patch.object(Device, "send", wraps=fake_send()): @@ -35,7 +35,7 @@ async def test_issue_87_quiet_should_set_2(): assert device.get_property(Props.QUIET) is None device.quiet = True - def fake_send(*args): + def fake_send(*args, **kwargs): device.handle_state_update(**mock_v3_state) with patch.object(Device, "send", wraps=fake_send()) as mock: From 68b442fa049a3051e823c4d4a16e3f6fcadbe159 Mon Sep 17 00:00:00 2001 From: Clifford Roche <1007595+cmroche@users.noreply.github.com> Date: Sat, 13 Jul 2024 16:33:37 -0400 Subject: [PATCH 04/15] Clean up device tests, working. --- greeclimate/device.py | 4 +- tests/conftest.py | 13 ++- tests/test_device.py | 227 ++++++++++++++++-------------------------- 3 files changed, 99 insertions(+), 145 deletions(-) diff --git a/greeclimate/device.py b/greeclimate/device.py index ea569e2..88bdfdb 100644 --- a/greeclimate/device.py +++ b/greeclimate/device.py @@ -6,7 +6,7 @@ from enum import IntEnum, unique from typing import Union -from greeclimate.cipher import CipherV1, CipherV2, CipherBase +from greeclimate.cipher import CipherV1, CipherV2 from greeclimate.deviceinfo import DeviceInfo from greeclimate.exceptions import DeviceNotBoundError, DeviceTimeoutError from greeclimate.network import DeviceProtocol2 @@ -247,8 +247,6 @@ async def __bind_internal(self, cipher_type: type[Union[CipherV1, CipherV2]]): def handle_device_bound(self, key: str) -> None: """Handle the device bound message from the device""" cipher_type = type(self.device_cipher) - if not issubclass(cipher_type, CipherBase): - raise ValueError(f"Invalid cipher type {cipher_type}") self.device_cipher = cipher_type(key.encode()) async def request_version(self) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index 5ff5938..e36f639 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ import pytest +from greeclimate.device import Device from tests.common import FakeCipher MOCK_INTERFACES = ["lo"] @@ -20,11 +21,17 @@ def netifaces_fixture(): yield ifaddr_mock - -@pytest.fixture(name="cipher") +@pytest.fixture(name="cipher", autouse=False, scope="function") def cipher_fixture(): """Patch the cipher object.""" with patch("greeclimate.device.CipherV1") as mock1, patch("greeclimate.device.CipherV2") as mock2: mock1.return_value = FakeCipher(b"1234567890123456") mock2.return_value = FakeCipher(b"1234567890123456") - yield + yield mock1, mock2 + + +@pytest.fixture(name="send", autouse=False, scope="function") +def network_fixture(): + """Patch the device object.""" + with patch.object(Device, "send") as mock: + yield mock diff --git a/tests/test_device.py b/tests/test_device.py index 05b6195..9abf4f9 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -158,12 +158,12 @@ def get_mock_state_0c_v3_temp(): async def generate_device_mock_async(): - d = Device(DeviceInfo("192.168.1.29", 7000, "f4911e7aca59", "1e7aca59")) + d = Device(DeviceInfo("1.1.1.1", 7000, "f4911e7aca59", "1e7aca59")) await d.bind(key="St8Vw1Yz4Bc7Ef0H", cipher_type=CipherV1) return d -def test_device_info_equality(): +def test_device_info_equality(send): """The only way to get the key through binding is by scanning first""" props = [ @@ -187,7 +187,7 @@ def test_device_info_equality(): @pytest.mark.asyncio -async def test_get_device_info(cipher): +async def test_get_device_info(cipher, send): """Initialize device, check properties.""" info = DeviceInfo(*get_mock_info()) @@ -203,97 +203,86 @@ async def test_get_device_info(cipher): @pytest.mark.asyncio -async def test_device_bind(): +async def test_device_bind(cipher, send): """Check that the device returns a device key when binding.""" info = DeviceInfo(*get_mock_info()) device = Device(info, timeout=1) - - assert device.device_info == info - fake_key = "abcdefgh12345678" - + def fake_send(*args, **kwargs): """Emulate a bind event""" device.device_cipher = CipherV1(fake_key.encode()) device.ready.set() device.handle_device_bound(fake_key) + send.side_effect = fake_send - with patch.object(Device, "send", side_effect=fake_send) as mock: - await device.bind() - assert mock.call_count == 1 + assert device.device_info == info + await device.bind() + assert send.call_count == 1 assert device.device_cipher is not None assert device.device_cipher.key == fake_key @pytest.mark.asyncio -async def test_device_bind_timeout(): +async def test_device_bind_timeout(cipher, send): """Check that the device handles timeout errors when binding.""" - info = DeviceInfo(*get_mock_info()) device = Device(info, timeout=1) - assert device.device_info == info - with pytest.raises(DeviceTimeoutError): - with patch.object(Device, "send", return_value=None) as mock: - await device.bind() - assert mock.call_count == 1 + await device.bind() + assert send.call_count == 1 assert device.device_cipher is None @pytest.mark.asyncio -async def test_device_bind_none(): +async def test_device_bind_none(cipher, send): """Check that the device handles bad binding sequences.""" - info = DeviceInfo(*get_mock_info()) device = Device(info) - assert device.device_info == info - def fake_send(*args, **kwargs): device.ready.set() + send.side_effect = fake_send - fake_key = None with pytest.raises(DeviceNotBoundError): - with patch.object(Device, "send", wraps=fake_send) as mock: - await device.bind() - assert mock.call_count == 1 + await device.bind() + assert send.call_count == 1 assert device.device_cipher is None @pytest.mark.asyncio -async def test_device_late_bind(): +async def test_device_late_bind(cipher, send): """Check that the device handles late binding sequences.""" info = DeviceInfo(*get_mock_info()) device = Device(info, timeout=1) - assert device.device_info == info - fake_key = "abcdefgh12345678" def fake_send(*args, **kwargs): + device.device_cipher = CipherV1(fake_key.encode()) device.handle_device_bound(fake_key) device.ready.set() + send.side_effect = fake_send - with patch.object(Device, "send", wraps=fake_send) as mock: - await device.update_state() - assert mock.call_count == 2 - assert device.device_cipher.key == fake_key + await device.update_state() + assert send.call_count == 2 + assert device.device_cipher.key == fake_key - device.power = True + device.power = True - with patch.object(Device, "send"): - await device.push_state_update() - - assert device.device_cipher is not None - assert device.device_cipher.key == fake_key + send.side_effect = None + await device.push_state_update() + + assert device.device_cipher is not None + assert device.device_cipher.key == fake_key @pytest.mark.asyncio -async def test_update_properties(): +async def test_update_properties(cipher, send): """Check that properties can be updates.""" device = await generate_device_mock_async() @@ -303,9 +292,9 @@ async def test_update_properties(): def fake_send(*args, **kwargs): state = get_mock_state() device.handle_state_update(**state) + send.side_effect = fake_send - with patch.object(Device, "send", side_effect=fake_send) as mock: - await device.update_state() + await device.update_state() for p in Props: assert device.get_property(p) is not None @@ -313,30 +302,29 @@ def fake_send(*args, **kwargs): @pytest.mark.asyncio -async def test_update_properties_timeout(): +async def test_update_properties_timeout(cipher, send): """Check that timeouts are handled when properties are updates.""" device = await generate_device_mock_async() for p in Props: assert device.get_property(p) is None + send.side_effect = asyncio.TimeoutError with pytest.raises(DeviceTimeoutError): - with patch.object(Device, "send", side_effect=asyncio.TimeoutError): - await device.update_state() + await device.update_state() @pytest.mark.asyncio -async def test_set_properties_not_dirty(): +async def test_set_properties_not_dirty(cipher, send): """Check that the state isn't pushed when properties unchanged.""" device = await generate_device_mock_async() - with patch.object(Device, "send") as mock_request: - await device.push_state_update() - assert mock_request.call_count == 0 + await device.push_state_update() + assert send.call_count == 0 @pytest.mark.asyncio -async def test_set_properties(): +async def test_set_properties(cipher, send): """Check that state is pushed when properties are updated.""" device = await generate_device_mock_async() @@ -380,7 +368,7 @@ async def test_set_properties(): @pytest.mark.asyncio -async def test_set_properties_timeout(): +async def test_set_properties_timeout(cipher, send): """Check timeout handling when pushing state changes.""" device = await generate_device_mock_async() @@ -409,7 +397,7 @@ async def test_set_properties_timeout(): @pytest.mark.asyncio -async def test_uninitialized_properties(): +async def test_uninitialized_properties(cipher, send): """Check uninitialized property handling.""" device = await generate_device_mock_async() @@ -436,7 +424,7 @@ async def test_uninitialized_properties(): @pytest.mark.asyncio -async def test_update_current_temp_unsupported(): +async def test_update_current_temp_unsupported(cipher, send): """Check that properties can be updates.""" device = await generate_device_mock_async() @@ -446,9 +434,9 @@ async def test_update_current_temp_unsupported(): def fake_send(*args, **kwargs): state = get_mock_state_no_temperature() device.handle_state_update(**state) + send.side_effect = fake_send - with patch.object(Device, "send", wraps=fake_send) as mock: - await device.update_state() + await device.update_state() assert device.get_property(Props.TEMP_SENSOR) is None assert device.current_temperature == device.target_temperature @@ -463,19 +451,16 @@ def fake_send(*args, **kwargs): (62, "362001061147+U-ZX6045RV1.01.bin"), ], ) -async def test_update_current_temp_v3(temsen, hid): +async def test_update_current_temp_v3(temsen, hid, cipher, send): """Check that properties can be updates.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - def fake_send(*args, **kwargs): device.handle_state_update(TemSen=temsen, hid=hid) + send.side_effect = fake_send - with patch.object(Device, "send", wraps=fake_send): - await device.update_state() - + await device.update_state() + assert device.get_property(Props.TEMP_SENSOR) is not None assert device.current_temperature == temsen - 40 @@ -490,97 +475,80 @@ def fake_send(*args, **kwargs): (23, "362001061217+U-W04NV7.bin"), ], ) -async def test_update_current_temp_v4(temsen, hid): +async def test_update_current_temp_v4(temsen, hid, cipher, send): """Check that properties can be updates.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - def fake_send(*args, **kwargs): device.handle_state_update(TemSen=temsen, hid=hid) + send.side_effect = fake_send - with patch.object(Device, "send", wraps=fake_send): - await device.update_state() + await device.update_state() assert device.get_property(Props.TEMP_SENSOR) is not None assert device.current_temperature == temsen @pytest.mark.asyncio -async def test_update_current_temp_bad(): +async def test_update_current_temp_bad(cipher, send): """Check that properties can be updates.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - def fake_send(*args, **kwargs): device.handle_state_update(**get_mock_state_bad_temp()) + send.side_effect = fake_send - with patch.object(Device, "send", wraps=fake_send): - await device.update_state() + await device.update_state() assert device.current_temperature == get_mock_state_bad_temp()["TemSen"] - 40 @pytest.mark.asyncio -async def test_update_current_temp_0C_v4(): +async def test_update_current_temp_0C_v4(cipher, send): """Check that properties can be updates.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - def fake_send(*args, **kwargs): device.handle_state_update(**get_mock_state_0c_v4_temp()) + send.side_effect = fake_send - with patch.object(Device, "send", wraps=fake_send): - await device.update_state() + await device.update_state() assert device.current_temperature == get_mock_state_0c_v4_temp()["TemSen"] @pytest.mark.asyncio -async def test_update_current_temp_0C_v3(): +async def test_update_current_temp_0C_v3(cipher, send): """Check for devices without a temperature sensor.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - def fake_send(*args, **kwargs): device.handle_state_update(**get_mock_state_0c_v3_temp()) + send.side_effect = fake_send - with patch.object(Device, "send", wraps=fake_send): - await device.update_state() + await device.update_state() assert device.current_temperature == device.target_temperature @pytest.mark.asyncio @pytest.mark.parametrize("temperature", [18, 19, 20, 21, 22]) -async def test_send_temperature_celsius(temperature): +async def test_send_temperature_celsius(temperature, cipher, send): """Check that temperature is set and read properly in C.""" state = get_mock_state() state["TemSen"] = temperature + 40 device = await generate_device_mock_async() - - for p in Props: - assert device.get_property(p) is None - device.temperature_units = TemperatureUnits.C device.target_temperature = temperature - with patch.object(Device, "send") as mock_push: - await device.push_state_update() - assert mock_push.call_count == 1 + await device.push_state_update() + assert send.call_count == 1 def fake_send(*args, **kwargs): device.handle_state_update(**state) + send.side_effect = fake_send - with patch.object(Device, "send", wraps=fake_send): - await device.update_state() + await device.update_state() assert device.current_temperature == temperature @@ -589,7 +557,7 @@ def fake_send(*args, **kwargs): @pytest.mark.parametrize( "temperature", [60, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 86] ) -async def test_send_temperature_farenheit(temperature): +async def test_send_temperature_farenheit(temperature, cipher, send): """Check that temperature is set and read properly in F.""" temSet = round((temperature - 32.0) * 5.0 / 9.0) temRec = (int)((((temperature - 32.0) * 5.0 / 9.0) - temSet) > 0) @@ -600,34 +568,27 @@ async def test_send_temperature_farenheit(temperature): state["TemUn"] = 1 device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - device.temperature_units = TemperatureUnits.F device.target_temperature = temperature - with patch.object(Device, "send") as mock_push: - await device.push_state_update() - assert mock_push.call_count == 1 + await device.push_state_update() + assert send.call_count == 1 def fake_send(*args, **kwargs): device.handle_state_update(**state) + send.side_effect = fake_send - with patch.object(Device, "send", wraps=fake_send): - await device.update_state() + await device.update_state() assert device.current_temperature == temperature @pytest.mark.asyncio @pytest.mark.parametrize("temperature", [-270, -61, 61, 100]) -async def test_send_temperature_out_of_range_celsius(temperature): +async def test_send_temperature_out_of_range_celsius(temperature, cipher, send): """Check that bad temperatures raise the appropriate error.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - device.temperature_units = TemperatureUnits.C with pytest.raises(ValueError): device.target_temperature = temperature @@ -635,7 +596,7 @@ async def test_send_temperature_out_of_range_celsius(temperature): @pytest.mark.asyncio @pytest.mark.parametrize("temperature", [-270, -61, 141]) -async def test_send_temperature_out_of_range_farenheit_set(temperature): +async def test_send_temperature_out_of_range_farenheit_set(temperature, cipher, send): """Check that bad temperatures raise the appropriate error.""" device = await generate_device_mock_async() @@ -649,13 +610,10 @@ async def test_send_temperature_out_of_range_farenheit_set(temperature): @pytest.mark.asyncio @pytest.mark.parametrize("temperature", [-270, 150]) -async def test_send_temperature_out_of_range_farenheit_get(temperature): +async def test_send_temperature_out_of_range_farenheit_get(temperature, cipher, send): """Check that bad temperatures raise the appropriate error.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - device.set_property(Props.TEMP_SET, 20) device.set_property(Props.TEMP_SENSOR, temperature) device.set_property(Props.TEMP_BIT, 0) @@ -666,25 +624,20 @@ async def test_send_temperature_out_of_range_farenheit_get(temperature): @pytest.mark.asyncio -async def test_enable_disable_sleep_mode(): +async def test_enable_disable_sleep_mode(cipher, send): """Check that properties can be updates.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - device.sleep = True - with patch.object(Device, "send") as mock_push: - await device.push_state_update() - assert mock_push.call_count == 1 + await device.push_state_update() + assert send.call_count == 1 assert device.get_property(Props.SLEEP) == 1 assert device.get_property(Props.SLEEP_MODE) == 1 device.sleep = False - with patch.object(Device, "send") as mock_push: - await device.push_state_update() - assert mock_push.call_count == 1 + await device.push_state_update() + assert send.call_count == 2 assert device.get_property(Props.SLEEP) == 0 assert device.get_property(Props.SLEEP_MODE) == 0 @@ -694,7 +647,7 @@ async def test_enable_disable_sleep_mode(): @pytest.mark.parametrize( "temperature", [59, 77, 86] ) -async def test_mismatch_temrec_farenheit(temperature): +async def test_mismatch_temrec_farenheit(temperature, cipher, send): """Check that temperature is set and read properly in F.""" temSet = round((temperature - 32.0) * 5.0 / 9.0) temRec = (int)((((temperature - 32.0) * 5.0 / 9.0) - temSet) > 0) @@ -705,47 +658,43 @@ async def test_mismatch_temrec_farenheit(temperature): state["TemRec"] = (temRec + 1) % 2 state["TemUn"] = 1 device = await generate_device_mock_async() - - for p in Props: - assert device.get_property(p) is None - device.temperature_units = TemperatureUnits.F device.target_temperature = temperature - with patch.object(Device, "send") as mock_push: - await device.push_state_update() - assert mock_push.call_count == 1 + + await device.push_state_update() + assert send.call_count == 1 def fake_send(*args, **kwargs): device.handle_state_update(**state) + send.side_effect = None - with patch.object(Device, "send", wraps=fake_send): - await device.update_state() + await device.update_state() assert device.current_temperature == temperature @pytest.mark.asyncio -async def test_device_equality(): +async def test_device_equality(send): """Check that two devices with the same info and key are equal.""" info1 = DeviceInfo(*get_mock_info()) device1 = Device(info1) - await device1.bind(key="fake_key") + await device1.bind(key="fake_key", cipher_type=CipherV1) info2 = DeviceInfo(*get_mock_info()) device2 = Device(info2) - await device2.bind(key="fake_key") + await device2.bind(key="fake_key", cipher_type=CipherV1) assert device1 == device2 # Change the key of the second device - await device2.bind(key="another_fake_key") + await device2.bind(key="another_fake_key", cipher_type=CipherV1) assert device1 != device2 # Change the info of the second device info2 = DeviceInfo(*get_mock_info()) device2 = Device(info2) device2.power = True - await device2.bind(key="fake_key") + await device2.bind(key="fake_key", cipher_type=CipherV1) assert device1 != device2 From 9e71503d26ceb9fa6550b59d2114bb1f78d84c73 Mon Sep 17 00:00:00 2001 From: Clifford Roche <1007595+cmroche@users.noreply.github.com> Date: Sat, 13 Jul 2024 17:13:43 -0400 Subject: [PATCH 05/15] Fix instancing bug. --- greeclimate/device.py | 15 ++++++++------- greeclimate/network.py | 7 ++++--- tests/conftest.py | 4 ++-- tests/test_device.py | 28 +++++++--------------------- 4 files changed, 21 insertions(+), 33 deletions(-) diff --git a/greeclimate/device.py b/greeclimate/device.py index 88bdfdb..e4a0515 100644 --- a/greeclimate/device.py +++ b/greeclimate/device.py @@ -161,13 +161,6 @@ class Device(DeviceProtocol2, Taskable): water_full: A bool to indicate the water tank is full """ - """ Device properties """ - hid = None - version = None - check_version = True - _properties = {} - _dirty = [] - def __init__(self, device_info: DeviceInfo, timeout: int = 120, loop: AbstractEventLoop = None): """Initialize the device object @@ -180,6 +173,14 @@ def __init__(self, device_info: DeviceInfo, timeout: int = 120, loop: AbstractEv Taskable.__init__(self, loop) self._logger = logging.getLogger(__name__) self.device_info: DeviceInfo = device_info + + """ Device properties """ + self.hid = None + self.version = None + self.check_version = True + self._properties = {} + self._dirty = [] + async def bind(self, key: str = None, cipher_type: Union[type[Union[CipherV1, CipherV2]], None] = None): """Run the binding procedure. diff --git a/greeclimate/network.py b/greeclimate/network.py index ab443ce..9608cee 100644 --- a/greeclimate/network.py +++ b/greeclimate/network.py @@ -38,9 +38,6 @@ class IPInterface: class DeviceProtocolBase2(asyncio.DatagramProtocol): """Event driven device protocol class.""" - _transport: Union[asyncio.transports.DatagramTransport, None] = None - _cipher: Union[CipherBase, None] = None - def __init__(self, timeout: int = 10, drained: asyncio.Event = None) -> None: """Initialize the device protocol object. @@ -52,6 +49,10 @@ def __init__(self, timeout: int = 10, drained: asyncio.Event = None) -> None: self._drained: asyncio.Event = drained or asyncio.Event() self._drained.set() + self._transport: Union[asyncio.transports.DatagramTransport, None] = None + self._cipher: Union[CipherBase, None] = None + + # This event need to be implemented to handle incoming requests def packet_received(self, obj, addr: IPAddr) -> None: """Event called when a packet is received and decoded. diff --git a/tests/conftest.py b/tests/conftest.py index e36f639..aadb011 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,7 +21,7 @@ def netifaces_fixture(): yield ifaddr_mock -@pytest.fixture(name="cipher", autouse=False, scope="function") +@pytest.fixture(name="cipher") def cipher_fixture(): """Patch the cipher object.""" with patch("greeclimate.device.CipherV1") as mock1, patch("greeclimate.device.CipherV2") as mock2: @@ -30,7 +30,7 @@ def cipher_fixture(): yield mock1, mock2 -@pytest.fixture(name="send", autouse=False, scope="function") +@pytest.fixture(name="send") def network_fixture(): """Patch the device object.""" with patch.object(Device, "send") as mock: diff --git a/tests/test_device.py b/tests/test_device.py index 9abf4f9..0230f4d 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -1,6 +1,5 @@ import asyncio import enum -from unittest.mock import patch import pytest @@ -306,9 +305,6 @@ async def test_update_properties_timeout(cipher, send): """Check that timeouts are handled when properties are updates.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - send.side_effect = asyncio.TimeoutError with pytest.raises(DeviceTimeoutError): await device.update_state() @@ -328,9 +324,6 @@ async def test_set_properties(cipher, send): """Check that state is pushed when properties are updated.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - device.power = True device.mode = 1 device.temperature_units = 1 @@ -348,9 +341,8 @@ async def test_set_properties(cipher, send): device.power_save = True device.target_humidity = 30 - with patch.object(Device, "send") as mock_request: - await device.push_state_update() - mock_request.assert_called_once() + await device.push_state_update() + send.assert_called_once() for p in Props: if p not in ( @@ -372,9 +364,6 @@ async def test_set_properties_timeout(cipher, send): """Check timeout handling when pushing state changes.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - device.power = True device.mode = 1 device.temperature_units = 1 @@ -390,10 +379,13 @@ async def test_set_properties_timeout(cipher, send): device.turbo = True device.steady_heat = True device.power_save = True + + assert len(device._dirty) + send.reset_mock() + send.side_effect = [asyncio.TimeoutError, asyncio.TimeoutError, asyncio.TimeoutError] with pytest.raises(DeviceTimeoutError): - with patch.object(Device, "send", side_effect=asyncio.TimeoutError): - await device.push_state_update() + await device.push_state_update() @pytest.mark.asyncio @@ -401,9 +393,6 @@ async def test_uninitialized_properties(cipher, send): """Check uninitialized property handling.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - assert not device.power assert device.mode is None assert device.target_temperature is None @@ -428,9 +417,6 @@ async def test_update_current_temp_unsupported(cipher, send): """Check that properties can be updates.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - def fake_send(*args, **kwargs): state = get_mock_state_no_temperature() device.handle_state_update(**state) From 2434100427e259c621fa6d5b8c2946a097429d11 Mon Sep 17 00:00:00 2001 From: Clifford Roche <1007595+cmroche@users.noreply.github.com> Date: Sat, 13 Jul 2024 17:41:08 -0400 Subject: [PATCH 06/15] Move default keys into cipher classes --- .idea/ruff.xml | 3 +-- greeclimate/cipher.py | 4 ++-- greeclimate/device.py | 9 +-------- tests/common.py | 3 +-- 4 files changed, 5 insertions(+), 14 deletions(-) diff --git a/.idea/ruff.xml b/.idea/ruff.xml index bb95353..045d9eb 100644 --- a/.idea/ruff.xml +++ b/.idea/ruff.xml @@ -1,8 +1,7 @@ - \ No newline at end of file diff --git a/greeclimate/cipher.py b/greeclimate/cipher.py index b021889..0084cbd 100644 --- a/greeclimate/cipher.py +++ b/greeclimate/cipher.py @@ -26,7 +26,7 @@ def decrypt(self, data) -> dict: class CipherV1(CipherBase): - def __init__(self, key: bytes) -> None: + def __init__(self, key: bytes = b'a3K8Bx%2r8Y7#xDh') -> None: super().__init__(key) def __create_cipher(self) -> AES: @@ -58,7 +58,7 @@ class CipherV2(CipherBase): GCM_NONCE = b'\x54\x40\x78\x44\x49\x67\x5a\x51\x6c\x5e\x63\x13' GCM_AEAD = b'qualcomm-test' - def __init__(self, key: bytes) -> None: + def __init__(self, key: bytes = b'{yxAHAY_Lm6pbC/<') -> None: super().__init__(key) def __create_cipher(self) -> AES: diff --git a/greeclimate/device.py b/greeclimate/device.py index e4a0515..9045a6a 100644 --- a/greeclimate/device.py +++ b/greeclimate/device.py @@ -44,12 +44,6 @@ class Props(enum.Enum): UNKNOWN_HEATCOOLTYPE = "HeatCoolType" -GENERIC_CIPHERS_KEYS = { - CipherV1: b'a3K8Bx%2r8Y7#xDh', - CipherV2: b'{yxAHAY_Lm6pbC/<' -} - - @unique class TemperatureUnits(IntEnum): C = 0 @@ -240,8 +234,7 @@ async def bind(self, key: str = None, cipher_type: Union[type[Union[CipherV1, Ci async def __bind_internal(self, cipher_type: type[Union[CipherV1, CipherV2]]): """Internal binding procedure, do not call directly""" - default_key = GENERIC_CIPHERS_KEYS.get(cipher_type) - await self.send(self.create_bind_message(self.device_info), cipher=cipher_type(default_key)) + await self.send(self.create_bind_message(self.device_info), cipher=cipher_type()) task = asyncio.create_task(self.ready.wait()) await asyncio.wait_for(task, timeout=self._timeout) diff --git a/tests/common.py b/tests/common.py index bb8e139..b4e966d 100644 --- a/tests/common.py +++ b/tests/common.py @@ -4,7 +4,6 @@ from unittest.mock import Mock from greeclimate.cipher import CipherV1, CipherBase -from greeclimate.device import GENERIC_CIPHERS_KEYS DEFAULT_TIMEOUT = 1 DISCOVERY_REQUEST = {"t": "scan"} @@ -96,7 +95,7 @@ def get_mock_device_info(): def encrypt_payload(data): """Encrypt the payload of responses quickly.""" d = data.copy() - cipher = CipherV1(GENERIC_CIPHERS_KEYS[CipherV1]) + cipher = CipherV1() d["pack"], _ = cipher.encrypt(d["pack"]) return d From 37eac7423ef8a14176bdb77e35472dd1ff12fead Mon Sep 17 00:00:00 2001 From: Clifford Roche <1007595+cmroche@users.noreply.github.com> Date: Sat, 13 Jul 2024 17:45:21 -0400 Subject: [PATCH 07/15] Fix missing cipher set in discovery --- greeclimate/discovery.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/greeclimate/discovery.py b/greeclimate/discovery.py index 5e18f48..0eb7015 100644 --- a/greeclimate/discovery.py +++ b/greeclimate/discovery.py @@ -6,6 +6,7 @@ from asyncio.events import AbstractEventLoop from ipaddress import IPv4Address +from greeclimate.cipher import CipherV1 from greeclimate.device import DeviceInfo from greeclimate.network import BroadcastListenerProtocol, IPAddr from greeclimate.taskable import Taskable @@ -45,6 +46,7 @@ def __init__( """ BroadcastListenerProtocol.__init__(self, timeout) Taskable.__init__(self, loop) + self.device_cipher = CipherV1() self._allow_loopback: bool = allow_loopback self._device_infos: list[DeviceInfo] = [] self._listeners: list[Listener] = [] From 5c2c08df10e904077720aec107753f52112c4ff1 Mon Sep 17 00:00:00 2001 From: Rami Mosleh Date: Mon, 29 Jul 2024 03:54:53 +0300 Subject: [PATCH 08/15] Update device.py and test_device.py (#95) Co-authored-by: Rami Mousleh --- greeclimate/device.py | 48 +++++++++++++++++++++++-------------------- tests/test_device.py | 12 +++++------ 2 files changed, 32 insertions(+), 28 deletions(-) diff --git a/greeclimate/device.py b/greeclimate/device.py index 9045a6a..3198a5c 100644 --- a/greeclimate/device.py +++ b/greeclimate/device.py @@ -175,8 +175,11 @@ def __init__(self, device_info: DeviceInfo, timeout: int = 120, loop: AbstractEv self._properties = {} self._dirty = [] - - async def bind(self, key: str = None, cipher_type: Union[type[Union[CipherV1, CipherV2]], None] = None): + async def bind( + self, + key: str = None, + cipher: Union[CipherV1, CipherV2, None] = None, + ): """Run the binding procedure. Binding is a finicky procedure, and happens in 1 of 2 ways: @@ -187,17 +190,22 @@ async def bind(self, key: str = None, cipher_type: Union[type[Union[CipherV1, Ci Both approaches result in a device_key which is used as like a persistent session id. Args: - cipher_type (type): The cipher type to use for encryption, if None will attempt to detect the correct one + cipher (CipherV1 | CipherV2): The cipher type to use for encryption, if None will attempt to detect the correct one key (str): The device key, when provided binding is a NOOP, if None binding will - attempt to negotiate the key with the device. cipher_type must be provided. + attempt to negotiate the key with the device. cipher must be provided. Raises: DeviceNotBoundError: If binding was unsuccessful and no key returned DeviceTimeoutError: The device didn't respond """ - if key and not cipher_type: - raise ValueError("cipher_type must be provided when key is provided") + if key: + if not cipher: + raise ValueError("cipher must be provided when key is provided") + else: + cipher.key = key + self.device_cipher = cipher + return if not self.device_info: raise DeviceNotBoundError @@ -210,19 +218,16 @@ async def bind(self, key: str = None, cipher_type: Union[type[Union[CipherV1, Ci self._logger.info("Starting device binding to %s", str(self.device_info)) try: - if key: - self.device_cipher = cipher_type(key.encode()) + if cipher is not None: + await self.__bind_internal(cipher) else: - if cipher_type is not None: - await self.__bind_internal(cipher_type) - else: - """ Try binding with CipherV1 first, if that fails try CipherV2""" - try: - self._logger.info("Attempting to bind to device using CipherV1") - await self.__bind_internal(CipherV1) - except asyncio.TimeoutError: - self._logger.info("Attempting to bind to device using CipherV2") - await self.__bind_internal(CipherV2) + """ Try binding with CipherV1 first, if that fails try CipherV2""" + try: + self._logger.info("Attempting to bind to device using CipherV1") + await self.__bind_internal(CipherV1()) + except asyncio.TimeoutError: + self._logger.info("Attempting to bind to device using CipherV2") + await self.__bind_internal(CipherV2()) except asyncio.TimeoutError: raise DeviceTimeoutError @@ -232,16 +237,15 @@ async def bind(self, key: str = None, cipher_type: Union[type[Union[CipherV1, Ci else: self._logger.info("Bound to device using key %s", self.device_cipher.key) - async def __bind_internal(self, cipher_type: type[Union[CipherV1, CipherV2]]): + async def __bind_internal(self, cipher: Union[CipherV1, CipherV2]): """Internal binding procedure, do not call directly""" - await self.send(self.create_bind_message(self.device_info), cipher=cipher_type()) + await self.send(self.create_bind_message(self.device_info), cipher=cipher) task = asyncio.create_task(self.ready.wait()) await asyncio.wait_for(task, timeout=self._timeout) def handle_device_bound(self, key: str) -> None: """Handle the device bound message from the device""" - cipher_type = type(self.device_cipher) - self.device_cipher = cipher_type(key.encode()) + self.device_cipher.key = key async def request_version(self) -> None: """Request the firmware version from the device.""" diff --git a/tests/test_device.py b/tests/test_device.py index 0230f4d..110757c 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -158,7 +158,7 @@ def get_mock_state_0c_v3_temp(): async def generate_device_mock_async(): d = Device(DeviceInfo("1.1.1.1", 7000, "f4911e7aca59", "1e7aca59")) - await d.bind(key="St8Vw1Yz4Bc7Ef0H", cipher_type=CipherV1) + await d.bind(key="St8Vw1Yz4Bc7Ef0H", cipher=CipherV1()) return d @@ -195,7 +195,7 @@ async def test_get_device_info(cipher, send): assert device.device_info == info fake_key = "abcdefgh12345678" - await device.bind(key=fake_key, cipher_type=CipherV1) + await device.bind(key=fake_key, cipher=CipherV1()) assert device.device_cipher is not None assert device.device_cipher.key == fake_key @@ -665,22 +665,22 @@ async def test_device_equality(send): info1 = DeviceInfo(*get_mock_info()) device1 = Device(info1) - await device1.bind(key="fake_key", cipher_type=CipherV1) + await device1.bind(key="fake_key", cipher=CipherV1()) info2 = DeviceInfo(*get_mock_info()) device2 = Device(info2) - await device2.bind(key="fake_key", cipher_type=CipherV1) + await device2.bind(key="fake_key", cipher=CipherV1()) assert device1 == device2 # Change the key of the second device - await device2.bind(key="another_fake_key", cipher_type=CipherV1) + await device2.bind(key="another_fake_key", cipher=CipherV1()) assert device1 != device2 # Change the info of the second device info2 = DeviceInfo(*get_mock_info()) device2 = Device(info2) device2.power = True - await device2.bind(key="fake_key", cipher_type=CipherV1) + await device2.bind(key="fake_key", cipher=CipherV1()) assert device1 != device2 From 80b3de4b4f4b76fed95b7ec23d6f6456bdb55a4f Mon Sep 17 00:00:00 2001 From: Clifford Roche <1007595+cmroche@users.noreply.github.com> Date: Sat, 3 Aug 2024 17:48:00 -0400 Subject: [PATCH 09/15] Fix network tests --- .idea/ruff.xml | 3 ++- tests/test_network.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/.idea/ruff.xml b/.idea/ruff.xml index 045d9eb..bb95353 100644 --- a/.idea/ruff.xml +++ b/.idea/ruff.xml @@ -1,7 +1,8 @@ - \ No newline at end of file diff --git a/tests/test_network.py b/tests/test_network.py index 8dbae35..ef54671 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -19,7 +19,6 @@ DISCOVERY_REQUEST, DISCOVERY_RESPONSE, Responder, - encrypt_payload, DEFAULT_REQUEST, generate_response, FakeCipher, ) from .test_device import get_mock_info @@ -89,8 +88,8 @@ async def test_set_get_key(): """Test the encryption key property.""" key = "faketestkey" dp2 = DeviceProtocolBase2() - dp2.device_key = key - assert dp2.device_key == key + dp2.device_cipher = FakeCipher(key.encode()) + assert dp2.device_cipher.key == key @pytest.mark.asyncio @@ -150,7 +149,7 @@ def responder(s): p = json.loads(d) assert p == DISCOVERY_REQUEST - p = json.dumps(encrypt_payload(DISCOVERY_RESPONSE)) + p = json.dumps(DISCOVERY_RESPONSE) s.sendto(p.encode(), addr) serv = Thread(target=responder, args=(sock,)) @@ -163,6 +162,7 @@ def responder(s): local_addr = (addr[0], 0) dp2 = FakeDiscoveryProtocol() + dp2.device_cipher = FakeCipher(b"1234567890123456") await loop.create_datagram_endpoint( lambda: dp2, local_addr=local_addr, @@ -218,7 +218,7 @@ async def test_broadcast_timeout(addr, family): @pytest.mark.asyncio @pytest.mark.parametrize("addr,family", [(("127.0.0.1", 7000), socket.AF_INET)]) -async def test_datagram_connect(addr, family, cipher): +async def test_datagram_connect(addr, family): """Create a socket responder, an async connection, test send and recv.""" with Responder(family, addr[1], bcast=False) as sock: @@ -240,7 +240,7 @@ def responder(s): ) # Send the scan command - await protocol.send(DEFAULT_REQUEST, None) + await protocol.send(DEFAULT_REQUEST, None, FakeCipher(b"1234567890123456")) # Wait on the scan response task = asyncio.create_task(protocol.packets.get()) @@ -253,11 +253,11 @@ def responder(s): serv.join(timeout=DEFAULT_TIMEOUT) -@pytest.mark.asyncio -def test_bindok_handling(cipher): +def test_bindok_handling(): """Test the bindok response.""" response = generate_response({"t": "bindok", "key": "fake-key"}) protocol = DeviceProtocol2(timeout=DEFAULT_TIMEOUT) + protocol.device_cipher = FakeCipher(b"1234567890123456") with patch.object(DeviceProtocol2, "handle_device_bound") as mock: protocol.datagram_received(json.dumps(response).encode(), ("0.0.0.0", 0)) From 667342c51fd43445d0fc4e205900e86a2e9b09b2 Mon Sep 17 00:00:00 2001 From: Clifford Roche <1007595+cmroche@users.noreply.github.com> Date: Sat, 3 Aug 2024 17:57:43 -0400 Subject: [PATCH 10/15] Use separate timeout for bind tests --- greeclimate/device.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/greeclimate/device.py b/greeclimate/device.py index 3198a5c..486a0dc 100644 --- a/greeclimate/device.py +++ b/greeclimate/device.py @@ -155,12 +155,14 @@ class Device(DeviceProtocol2, Taskable): water_full: A bool to indicate the water tank is full """ - def __init__(self, device_info: DeviceInfo, timeout: int = 120, loop: AbstractEventLoop = None): + def __init__(self, device_info: DeviceInfo, timeout: int = 120, bind_timeout: int = 10, loop: AbstractEventLoop = None): """Initialize the device object Args: device_info (DeviceInfo): Information about the physical device timeout (int): Timeout for device communication + bind_timeout (int): Timeout for binding to the device, keep this short to prevent delays determining the + correct device cipher to use loop (AbstractEventLoop): The event loop to run the device operations on """ DeviceProtocol2.__init__(self, timeout) @@ -168,6 +170,8 @@ def __init__(self, device_info: DeviceInfo, timeout: int = 120, loop: AbstractEv self._logger = logging.getLogger(__name__) self.device_info: DeviceInfo = device_info + self._bind_timeout = bind_timeout + """ Device properties """ self.hid = None self.version = None @@ -241,7 +245,7 @@ async def __bind_internal(self, cipher: Union[CipherV1, CipherV2]): """Internal binding procedure, do not call directly""" await self.send(self.create_bind_message(self.device_info), cipher=cipher) task = asyncio.create_task(self.ready.wait()) - await asyncio.wait_for(task, timeout=self._timeout) + await asyncio.wait_for(task, timeout=self._bind_timeout) def handle_device_bound(self, key: str) -> None: """Handle the device bound message from the device""" From dcf1fac53e31c5fbdcaf98719265f1517a441d10 Mon Sep 17 00:00:00 2001 From: Clifford Roche <1007595+cmroche@users.noreply.github.com> Date: Sat, 3 Aug 2024 18:22:57 -0400 Subject: [PATCH 11/15] Add coverage to network and device --- greeclimate/network.py | 7 +++--- tests/test_device.py | 20 +++++++++++++++++ tests/test_network.py | 49 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 3 deletions(-) diff --git a/greeclimate/network.py b/greeclimate/network.py index 9608cee..b8e432b 100644 --- a/greeclimate/network.py +++ b/greeclimate/network.py @@ -229,12 +229,13 @@ def packet_received(self, obj, addr: IPAddr) -> None: Response.DATA.value: lambda *args: self.__handle_state_update(*args), Response.RESULT.value: lambda *args: self.__handle_state_update(*args), } - resp = obj.get("pack", {}).get("t") - handler = handlers.get(resp, self.handle_unknown_packet) - param = [] try: + resp = obj.get("pack", {}).get("t") + handler = handlers.get(resp, self.handle_unknown_packet) param = params.get(resp, lambda o, a: (o, a))(obj, addr) handler(*param) + except AttributeError as e: + _LOGGER.exception("Error while handling packet", exc_info=e) except KeyError as e: _LOGGER.exception("Error while handling packet", exc_info=e) else: diff --git a/tests/test_device.py b/tests/test_device.py index 110757c..8ef0683 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -280,6 +280,26 @@ def fake_send(*args, **kwargs): assert device.device_cipher.key == fake_key +@pytest.mark.asyncio +async def test_device_bind_no_cipher(cipher, send): + """Check that the device handles late binding sequences.""" + info = DeviceInfo(*get_mock_info()) + device = Device(info, timeout=1) + fake_key = "abcdefgh12345678" + + with pytest.raises(ValueError): + await device.bind(fake_key) + + +@pytest.mark.asyncio +async def test_device_bind_no_device_info(cipher, send): + """Check that the device handles late binding sequences.""" + device = Device(None, timeout=1) + + with pytest.raises(DeviceNotBoundError): + await device.bind() + + @pytest.mark.asyncio async def test_update_properties(cipher, send): """Check that properties can be updates.""" diff --git a/tests/test_network.py b/tests/test_network.py index ef54671..5a6caad 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -497,3 +497,52 @@ async def test_add_and_remove_handler(event_name, data): # Check that the callback was not called this time callback.assert_not_called() + +def test_packet_received_not_implemented(): + # Arrange + protocol = DeviceProtocolBase2() + + # Act + with pytest.raises(NotImplementedError): + protocol.packet_received({}, ("0.0.0.0", 0)) + + +def test_packet_received_invalid_data(): + # Arrange + protocol = DeviceProtocol2() + + # Act + protocol.packet_received(None, ("0.0.0.0", 0)) + protocol.packet_received({"pack"}, ("0.0.0.0", 0)) + + with patch.object(protocol, "handle_unknown_packet") as mock: + protocol.packet_received({"pack": {"t": "unknown"}}, ("0.0.0.0", 0)) + mock.assert_called_once() + + +def test_cipher_is_not_set(): + # Arrange + protocol = DeviceProtocolBase2() + + # Act + key = None + with pytest.raises(ValueError): + key = protocol.device_key + + assert key is None + + with pytest.raises(ValueError): + protocol.device_key = "fake-key" + + +def test_add_invalid_handler(): + # Arrange + protocol = DeviceProtocol2() + callback = MagicMock() + + # Act + with pytest.raises(ValueError): + protocol.add_handler(Response("invalid"), callback) + + with pytest.raises(ValueError): + protocol.add_handler(Response("invalid"), callback) From c49be0ff8a22633fbd111a20fdef1f333b8edd7f Mon Sep 17 00:00:00 2001 From: Clifford Roche <1007595+cmroche@users.noreply.github.com> Date: Sat, 3 Aug 2024 18:27:43 -0400 Subject: [PATCH 12/15] Add coverage to network and device --- tests/test_network.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_network.py b/tests/test_network.py index 5a6caad..692aae8 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -513,6 +513,7 @@ def test_packet_received_invalid_data(): # Act protocol.packet_received(None, ("0.0.0.0", 0)) + protocol.packet_received({}, ("0.0.0.0", 0)) protocol.packet_received({"pack"}, ("0.0.0.0", 0)) with patch.object(protocol, "handle_unknown_packet") as mock: @@ -520,6 +521,18 @@ def test_packet_received_invalid_data(): mock.assert_called_once() +def test_set_get_cipher(): + # Arrange + protocol = DeviceProtocolBase2() + cipher = FakeCipher(b"1234567890123456") + + # Act + protocol.device_cipher = cipher + + # Assert + assert protocol.device_cipher == cipher + + def test_cipher_is_not_set(): # Arrange protocol = DeviceProtocolBase2() From cea40acb1de9586bf7a3375a9edeb59f5dc4f312 Mon Sep 17 00:00:00 2001 From: Clifford Roche <1007595+cmroche@users.noreply.github.com> Date: Sat, 3 Aug 2024 18:43:51 -0400 Subject: [PATCH 13/15] Add coverage to network and device --- .idea/ruff.xml | 3 +-- tests/test_cipher.py | 14 +++++++++++++- tests/test_network.py | 13 +++++++++++++ 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/.idea/ruff.xml b/.idea/ruff.xml index bb95353..045d9eb 100644 --- a/.idea/ruff.xml +++ b/.idea/ruff.xml @@ -1,8 +1,7 @@ - \ No newline at end of file diff --git a/tests/test_cipher.py b/tests/test_cipher.py index baaaabb..ec92be5 100644 --- a/tests/test_cipher.py +++ b/tests/test_cipher.py @@ -1,6 +1,8 @@ import base64 + import pytest -from greeclimate.cipher import CipherV1, CipherV2 + +from greeclimate.cipher import CipherV1, CipherV2, CipherBase @pytest.fixture @@ -46,3 +48,13 @@ def test_encryption_then_decryption_yields_original_with_tag(cipher_v2, plain_te decrypted = cipher_v2.decrypt(encrypted) assert decrypted == plain_text + +def test_cipher_base_not_implemented(): + fake_key = b'fake' + + with pytest.raises(NotImplementedError): + CipherBase(fake_key).encrypt(None) + + with pytest.raises(NotImplementedError): + CipherBase(fake_key).decrypt(None) + diff --git a/tests/test_network.py b/tests/test_network.py index 692aae8..18c9ffa 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -559,3 +559,16 @@ def test_add_invalid_handler(): with pytest.raises(ValueError): protocol.add_handler(Response("invalid"), callback) + + +def test_device_key_get_set(): + # Arrange + protocol = DeviceProtocolBase2 + key = "fake-key" + + # Act + protocol.device_key = key + + # Assert + assert protocol.device_key == key + \ No newline at end of file From f595eab152f7867774a131d09a6c6d63f880ccf1 Mon Sep 17 00:00:00 2001 From: Clifford Roche <1007595+cmroche@users.noreply.github.com> Date: Sat, 3 Aug 2024 19:02:52 -0400 Subject: [PATCH 14/15] Add coverage to device --- tests/test_device.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/tests/test_device.py b/tests/test_device.py index 8ef0683..111953b 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -222,6 +222,10 @@ def fake_send(*args, **kwargs): assert device.device_cipher is not None assert device.device_cipher.key == fake_key + + # Bind with cipher already set + await device.bind() + assert send.call_count == 2 @pytest.mark.asyncio @@ -255,7 +259,7 @@ def fake_send(*args, **kwargs): @pytest.mark.asyncio -async def test_device_late_bind(cipher, send): +async def test_device_late_bind_from_update(cipher, send): """Check that the device handles late binding sequences.""" info = DeviceInfo(*get_mock_info()) device = Device(info, timeout=1) @@ -280,6 +284,24 @@ def fake_send(*args, **kwargs): assert device.device_cipher.key == fake_key +@pytest.mark.asyncio +async def test_device_late_bind_from_request_version(cipher, send): + """Check that the device handles late binding sequences.""" + info = DeviceInfo(*get_mock_info()) + device = Device(info, timeout=1) + fake_key = "abcdefgh12345678" + + def fake_send(*args, **kwargs): + device.device_cipher = CipherV1(fake_key.encode()) + device.handle_device_bound(fake_key) + device.ready.set() + send.side_effect = fake_send + + await device.request_version() + assert send.call_count == 2 + assert device.device_cipher.key == fake_key + + @pytest.mark.asyncio async def test_device_bind_no_cipher(cipher, send): """Check that the device handles late binding sequences.""" @@ -704,3 +726,11 @@ async def test_device_equality(send): await device2.bind(key="fake_key", cipher=CipherV1()) assert device1 != device2 + +def test_device_key_set_get(): + """Check that the device key can be set and retrieved.""" + device = Device(DeviceInfo(*get_mock_info())) + device.device_cipher = CipherV1() + device.device_key = "fake_key" + assert device.device_key == "fake_key" + \ No newline at end of file From f8ba8bcea58af0e6670835d997822c87026c5005 Mon Sep 17 00:00:00 2001 From: Clifford Roche <1007595+cmroche@users.noreply.github.com> Date: Mon, 5 Aug 2024 17:22:01 -0400 Subject: [PATCH 15/15] Fix cipher logging --- greeclimate/cipher.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/greeclimate/cipher.py b/greeclimate/cipher.py index 0084cbd..be8aabd 100644 --- a/greeclimate/cipher.py +++ b/greeclimate/cipher.py @@ -5,6 +5,7 @@ from Crypto.Cipher import AES +_logger = logging.getLogger(__name__) class CipherBase: def __init__(self, key: bytes) -> None: @@ -36,21 +37,21 @@ def __pad(self, s) -> str: return s + (16 - len(s) % 16) * chr(16 - len(s) % 16) def encrypt(self, data) -> Tuple[str, Union[str, None]]: - logging.debug("Encrypting data: %s", data) + _logger.debug("Encrypting data: %s", data) cipher = self.__create_cipher() padded = self.__pad(json.dumps(data)).encode() encrypted = cipher.encrypt(padded) encoded = base64.b64encode(encrypted).decode() - logging.info("Encrypted data: %s", encoded) + _logger.debug("Encrypted data: %s", encoded) return encoded, None def decrypt(self, data) -> dict: - logging.info("Decrypting data: %s", data) + _logger.debug("Decrypting data: %s", data) cipher = self.__create_cipher() decoded = base64.b64decode(data) decrypted = cipher.decrypt(decoded).decode() t = decrypted.replace(decrypted[decrypted.rindex('}') + 1:], '') - logging.debug("Decrypted data: %s", t) + _logger.debug("Decrypted data: %s", t) return json.loads(t) @@ -67,21 +68,21 @@ def __create_cipher(self) -> AES: return cipher def encrypt(self, data) -> Tuple[str, str]: - logging.debug("Encrypting data: %s", data) + _logger.debug("Encrypting data: %s", data) cipher = self.__create_cipher() encrypted, tag = cipher.encrypt_and_digest(json.dumps(data).encode()) encoded = base64.b64encode(encrypted).decode() tag = base64.b64encode(tag).decode() - logging.debug("Encrypted data: %s", encoded) - logging.debug("Cipher digest: %s", tag) + _logger.debug("Encrypted data: %s", encoded) + _logger.debug("Cipher digest: %s", tag) return encoded, tag def decrypt(self, data) -> dict: - logging.info("Decrypting data: %s", data) + _logger.info("Decrypting data: %s", data) cipher = self.__create_cipher() decoded = base64.b64decode(data) decrypted = cipher.decrypt(decoded).decode() t = decrypted.replace(decrypted[decrypted.rindex('}') + 1:], '') - logging.debug("Decrypted data: %s", t) + _logger.debug("Decrypted data: %s", t) return json.loads(t)