From 7122cdd82597af82109a17bc32dcbfb97c78073c Mon Sep 17 00:00:00 2001 From: Clifford Roche <1007595+cmroche@users.noreply.github.com> Date: Mon, 5 Aug 2024 17:24:46 -0400 Subject: [PATCH] feat: Support GCM encryption for Gree devices (#92) --------- Co-authored-by: Rami Mosleh Co-authored-by: Rami Mousleh --- .idea/codeStyles/codeStyleConfig.xml | 5 + .idea/material_theme_project_new.xml | 17 ++ .idea/ruff.xml | 7 + .idea/runConfigurations/pytest_in__.xml | 20 ++ emulator.py | 1 - greeclimate/cipher.py | 88 ++++++ greeclimate/device.py | 69 +++-- greeclimate/discovery.py | 2 + greeclimate/network.py | 108 ++++---- tests/common.py | 22 +- tests/conftest.py | 21 +- tests/test_cipher.py | 60 +++++ tests/test_device.py | 344 ++++++++++++------------ tests/test_discovery.py | 10 +- tests/test_issues.py | 4 +- tests/test_network.py | 130 ++++++--- 16 files changed, 601 insertions(+), 307 deletions(-) create mode 100644 .idea/codeStyles/codeStyleConfig.xml create mode 100644 .idea/material_theme_project_new.xml create mode 100644 .idea/ruff.xml create mode 100644 .idea/runConfigurations/pytest_in__.xml create mode 100644 greeclimate/cipher.py create mode 100644 tests/test_cipher.py diff --git a/.idea/codeStyles/codeStyleConfig.xml b/.idea/codeStyles/codeStyleConfig.xml new file mode 100644 index 0000000..79ee123 --- /dev/null +++ b/.idea/codeStyles/codeStyleConfig.xml @@ -0,0 +1,5 @@ + + + + \ No newline at end of file diff --git a/.idea/material_theme_project_new.xml b/.idea/material_theme_project_new.xml new file mode 100644 index 0000000..374ec3d --- /dev/null +++ b/.idea/material_theme_project_new.xml @@ -0,0 +1,17 @@ + + + + + + + \ No newline at end of file diff --git a/.idea/ruff.xml b/.idea/ruff.xml new file mode 100644 index 0000000..045d9eb --- /dev/null +++ b/.idea/ruff.xml @@ -0,0 +1,7 @@ + + + + + \ No newline at end of file diff --git a/.idea/runConfigurations/pytest_in__.xml b/.idea/runConfigurations/pytest_in__.xml new file mode 100644 index 0000000..316c583 --- /dev/null +++ b/.idea/runConfigurations/pytest_in__.xml @@ -0,0 +1,20 @@ + + + + + \ No newline at end of file diff --git a/emulator.py b/emulator.py index a3ac6b5..11d0b09 100644 --- a/emulator.py +++ b/emulator.py @@ -6,7 +6,6 @@ import socket import time -import machine import network import ubinascii from ucryptolib import aes diff --git a/greeclimate/cipher.py b/greeclimate/cipher.py new file mode 100644 index 0000000..be8aabd --- /dev/null +++ b/greeclimate/cipher.py @@ -0,0 +1,88 @@ +import base64 +import json +import logging +from typing import Union, Tuple + +from Crypto.Cipher import AES + +_logger = logging.getLogger(__name__) + +class CipherBase: + def __init__(self, key: bytes) -> None: + self._key: bytes = key + + @property + def key(self) -> str: + return self._key.decode() + + @key.setter + def key(self, value: str) -> None: + self._key = value.encode() + + def encrypt(self, data) -> Tuple[str, Union[str, None]]: + raise NotImplementedError + + def decrypt(self, data) -> dict: + raise NotImplementedError + + +class CipherV1(CipherBase): + def __init__(self, key: bytes = b'a3K8Bx%2r8Y7#xDh') -> None: + super().__init__(key) + + def __create_cipher(self) -> AES: + return AES.new(self._key, AES.MODE_ECB) + + def __pad(self, s) -> str: + return s + (16 - len(s) % 16) * chr(16 - len(s) % 16) + + def encrypt(self, data) -> Tuple[str, Union[str, None]]: + _logger.debug("Encrypting data: %s", data) + cipher = self.__create_cipher() + padded = self.__pad(json.dumps(data)).encode() + encrypted = cipher.encrypt(padded) + encoded = base64.b64encode(encrypted).decode() + _logger.debug("Encrypted data: %s", encoded) + return encoded, None + + def decrypt(self, data) -> dict: + _logger.debug("Decrypting data: %s", data) + cipher = self.__create_cipher() + decoded = base64.b64decode(data) + decrypted = cipher.decrypt(decoded).decode() + t = decrypted.replace(decrypted[decrypted.rindex('}') + 1:], '') + _logger.debug("Decrypted data: %s", t) + return json.loads(t) + + +class CipherV2(CipherBase): + GCM_NONCE = b'\x54\x40\x78\x44\x49\x67\x5a\x51\x6c\x5e\x63\x13' + GCM_AEAD = b'qualcomm-test' + + def __init__(self, key: bytes = b'{yxAHAY_Lm6pbC/<') -> None: + super().__init__(key) + + def __create_cipher(self) -> AES: + cipher = AES.new(self._key, AES.MODE_GCM, nonce=self.GCM_NONCE) + cipher.update(self.GCM_AEAD) + return cipher + + def encrypt(self, data) -> Tuple[str, str]: + _logger.debug("Encrypting data: %s", data) + cipher = self.__create_cipher() + encrypted, tag = cipher.encrypt_and_digest(json.dumps(data).encode()) + encoded = base64.b64encode(encrypted).decode() + tag = base64.b64encode(tag).decode() + _logger.debug("Encrypted data: %s", encoded) + _logger.debug("Cipher digest: %s", tag) + return encoded, tag + + def decrypt(self, data) -> dict: + _logger.info("Decrypting data: %s", data) + cipher = self.__create_cipher() + decoded = base64.b64decode(data) + decrypted = cipher.decrypt(decoded).decode() + t = decrypted.replace(decrypted[decrypted.rindex('}') + 1:], '') + _logger.debug("Decrypted data: %s", t) + return json.loads(t) + diff --git a/greeclimate/device.py b/greeclimate/device.py index 93ed49e..486a0dc 100644 --- a/greeclimate/device.py +++ b/greeclimate/device.py @@ -4,12 +4,12 @@ import re from asyncio import AbstractEventLoop from enum import IntEnum, unique -from typing import List +from typing import Union -import greeclimate.network as network +from greeclimate.cipher import CipherV1, CipherV2 from greeclimate.deviceinfo import DeviceInfo -from greeclimate.network import DeviceProtocol2, IPAddr from greeclimate.exceptions import DeviceNotBoundError, DeviceTimeoutError +from greeclimate.network import DeviceProtocol2 from greeclimate.taskable import Taskable @@ -155,20 +155,23 @@ class Device(DeviceProtocol2, Taskable): water_full: A bool to indicate the water tank is full """ - def __init__(self, device_info: DeviceInfo, timeout: int = 120, loop: AbstractEventLoop = None): + def __init__(self, device_info: DeviceInfo, timeout: int = 120, bind_timeout: int = 10, loop: AbstractEventLoop = None): """Initialize the device object Args: device_info (DeviceInfo): Information about the physical device timeout (int): Timeout for device communication + bind_timeout (int): Timeout for binding to the device, keep this short to prevent delays determining the + correct device cipher to use loop (AbstractEventLoop): The event loop to run the device operations on """ DeviceProtocol2.__init__(self, timeout) Taskable.__init__(self, loop) self._logger = logging.getLogger(__name__) - self.device_info: DeviceInfo = device_info - + + self._bind_timeout = bind_timeout + """ Device properties """ self.hid = None self.version = None @@ -176,7 +179,11 @@ def __init__(self, device_info: DeviceInfo, timeout: int = 120, loop: AbstractEv self._properties = {} self._dirty = [] - async def bind(self, key=None): + async def bind( + self, + key: str = None, + cipher: Union[CipherV1, CipherV2, None] = None, + ): """Run the binding procedure. Binding is a finicky procedure, and happens in 1 of 2 ways: @@ -187,14 +194,23 @@ async def bind(self, key=None): Both approaches result in a device_key which is used as like a persistent session id. Args: + cipher (CipherV1 | CipherV2): The cipher type to use for encryption, if None will attempt to detect the correct one key (str): The device key, when provided binding is a NOOP, if None binding will - attempt to negotiate the key with the device. + attempt to negotiate the key with the device. cipher must be provided. Raises: DeviceNotBoundError: If binding was unsuccessful and no key returned DeviceTimeoutError: The device didn't respond """ + if key: + if not cipher: + raise ValueError("cipher must be provided when key is provided") + else: + cipher.key = key + self.device_cipher = cipher + return + if not self.device_info: raise DeviceNotBoundError @@ -206,29 +222,38 @@ async def bind(self, key=None): self._logger.info("Starting device binding to %s", str(self.device_info)) try: - if key: - self.device_key = key + if cipher is not None: + await self.__bind_internal(cipher) else: - await self.send(self.create_bind_message(self.device_info)) - # Special case, wait for binding to complete so we know that the device is ready - task = asyncio.create_task(self.ready.wait()) - await asyncio.wait_for(task, timeout=self._timeout) + """ Try binding with CipherV1 first, if that fails try CipherV2""" + try: + self._logger.info("Attempting to bind to device using CipherV1") + await self.__bind_internal(CipherV1()) + except asyncio.TimeoutError: + self._logger.info("Attempting to bind to device using CipherV2") + await self.__bind_internal(CipherV2()) except asyncio.TimeoutError: raise DeviceTimeoutError - if not self.device_key: + if not self.device_cipher: raise DeviceNotBoundError else: - self._logger.info("Bound to device using key %s", self.device_key) + self._logger.info("Bound to device using key %s", self.device_cipher.key) + + async def __bind_internal(self, cipher: Union[CipherV1, CipherV2]): + """Internal binding procedure, do not call directly""" + await self.send(self.create_bind_message(self.device_info), cipher=cipher) + task = asyncio.create_task(self.ready.wait()) + await asyncio.wait_for(task, timeout=self._bind_timeout) - def handle_device_bound(self, key) -> None: + def handle_device_bound(self, key: str) -> None: """Handle the device bound message from the device""" - self.device_key = key + self.device_cipher.key = key async def request_version(self) -> None: """Request the firmware version from the device.""" - if not self.device_key: + if not self.device_cipher: await self.bind() try: @@ -243,7 +268,7 @@ async def update_state(self, wait_for: float = 30): Args: wait_for (object): How long to wait for an update from the device """ - if not self.device_key: + if not self.device_cipher: await self.bind() self._logger.debug("Updating device properties for (%s)", str(self.device_info)) @@ -288,7 +313,7 @@ async def push_state_update(self, wait_for: float = 30): if not self._dirty: return - if not self.device_key: + if not self.device_cipher: await self.bind() self._logger.debug("Pushing state updates to (%s)", str(self.device_info)) @@ -316,7 +341,7 @@ def __eq__(self, other): """Compare two devices for equality based on their properties state and device info.""" return self.device_info == other.device_info \ and self.raw_properties == other.raw_properties \ - and self.device_key == other.device_key + and self.device_cipher.key == other.device_cipher.key def __ne__(self, other): return not self.__eq__(other) diff --git a/greeclimate/discovery.py b/greeclimate/discovery.py index 5e18f48..0eb7015 100644 --- a/greeclimate/discovery.py +++ b/greeclimate/discovery.py @@ -6,6 +6,7 @@ from asyncio.events import AbstractEventLoop from ipaddress import IPv4Address +from greeclimate.cipher import CipherV1 from greeclimate.device import DeviceInfo from greeclimate.network import BroadcastListenerProtocol, IPAddr from greeclimate.taskable import Taskable @@ -45,6 +46,7 @@ def __init__( """ BroadcastListenerProtocol.__init__(self, timeout) Taskable.__init__(self, loop) + self.device_cipher = CipherV1() self._allow_loopback: bool = allow_loopback self._device_infos: list[DeviceInfo] = [] self._listeners: list[Listener] = [] diff --git a/greeclimate/network.py b/greeclimate/network.py index aabea91..b8e432b 100644 --- a/greeclimate/network.py +++ b/greeclimate/network.py @@ -1,24 +1,17 @@ import asyncio -import base64 import json import logging import socket from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, Text, Tuple, Union - -from Crypto.Cipher import AES +from typing import Any, Dict, Tuple, Union +from greeclimate.cipher import CipherBase from greeclimate.deviceinfo import DeviceInfo NETWORK_TIMEOUT = 10 -GENERIC_KEY = ["a3K8Bx%2r8Y7#xDh"] -GCM_NONCE = b'\x54\x40\x78\x44\x49\x67\x5a\x51\x6c\x5e\x63\x13' -GCM_AEAD = b'qualcomm-test' - _LOGGER = logging.getLogger(__name__) - IPAddr = Tuple[str, int] @@ -52,11 +45,13 @@ def __init__(self, timeout: int = 10, drained: asyncio.Event = None) -> None: timeout (int): Packet send timeout drained (asyncio.Event): Packet send drain event signal """ - self._timeout = timeout - self._drained = drained or asyncio.Event() + self._timeout: int = timeout + self._drained: asyncio.Event = drained or asyncio.Event() self._drained.set() - self._transport = None - self._key = None + + self._transport: Union[asyncio.transports.DatagramTransport, None] = None + self._cipher: Union[CipherBase, None] = None + # This event need to be implemented to handle incoming requests def packet_received(self, obj, addr: IPAddr) -> None: @@ -69,14 +64,28 @@ def packet_received(self, obj, addr: IPAddr) -> None: raise NotImplementedError("packet_received must be implemented in a subclass") @property - def device_key(self) -> str: + def device_cipher(self) -> CipherBase: """Sets the encryption key used for device data.""" - return self._key + return self._cipher + + @device_cipher.setter + def device_cipher(self, value: CipherBase): + """Gets the encryption key used for device data.""" + self._cipher = value + + @property + def device_key(self) -> str: + """Gets the encryption key used for device data.""" + if self._cipher is None: + raise ValueError("Cipher object not set") + return self._cipher.key @device_key.setter def device_key(self, value: str): - """Gets the encryption key used for device data.""" - self._key = value + """Sets the encryption key used for device data.""" + if self._cipher is None: + raise ValueError("Cipher object not set") + self._cipher.key = value def close(self) -> None: """Close the UDP transport.""" @@ -116,27 +125,6 @@ def resume_writing(self) -> None: self._drained.set() super().resume_writing() - @staticmethod - def decrypt_payload(payload, key=GENERIC_KEY[0]): - cipher = AES.new(key.encode(), AES.MODE_ECB) - decoded = base64.b64decode(payload) - decrypted = cipher.decrypt(decoded).decode() - t = decrypted.replace(decrypted[decrypted.rindex("}") + 1 :], "") - return json.loads(t) - - @staticmethod - def encrypt_payload(payload, key=GENERIC_KEY[0]): - def pad(s): - bs = 16 - return s + (bs - len(s) % bs) * chr(bs - len(s) % bs) - - cipher = AES.new(key.encode(), AES.MODE_ECB) - encrypted = cipher.encrypt(pad(json.dumps(payload)).encode()) - encoded = base64.b64encode(encrypted).decode() - - _LOGGER.debug(f"Encrypted payload with key [{key}]: {encoded}") - return encoded - def datagram_received(self, data: bytes, addr: IPAddr) -> None: """Handle an incoming datagram.""" if len(data) == 0: @@ -144,23 +132,30 @@ def datagram_received(self, data: bytes, addr: IPAddr) -> None: obj = json.loads(data) - # It could be either a v1 or v2 key - key = GENERIC_KEY[0] if obj.get("i") == 1 else self._key - if obj.get("pack"): - obj["pack"] = DeviceProtocolBase2.decrypt_payload(obj["pack"], key) - - _LOGGER.debug("Received packet from %s:\n%s", addr[0], json.dumps(obj)) + obj["pack"] = self._cipher.decrypt(obj["pack"]) + _LOGGER.debug("Received packet from %s:\n<- %s", addr[0], json.dumps(obj)) self.packet_received(obj, addr) - async def send(self, obj, addr: IPAddr = None) -> None: - """Send encode and send JSON command to the device.""" - _LOGGER.debug("Sending packet:\n%s", json.dumps(obj)) + async def send(self, obj, addr: IPAddr = None, cipher: Union[CipherBase, None] = None) -> None: + """Send encode and send JSON command to the device. + + Args: + addr (object): (Optional) Address to send the message + cipher (object): (Optional) Initial cipher to use for SCANNING and BINDING + """ + _LOGGER.debug("Sending packet:\n-> %s", json.dumps(obj)) if obj.get("pack"): - key = GENERIC_KEY[0] if obj.get("i") == 1 else self._key - obj["pack"] = DeviceProtocolBase2.encrypt_payload(obj["pack"], key) + if obj.get("i") == 1: + if cipher is None: + raise ValueError("Cipher must be supplied for SCAN or BIND messages") + self._cipher = cipher + + obj["pack"], tag = self._cipher.encrypt(obj["pack"]) + if tag: + obj["tag"] = tag data_bytes = json.dumps(obj).encode() self._transport.sendto(data_bytes, addr) @@ -169,7 +164,6 @@ async def send(self, obj, addr: IPAddr = None) -> None: await asyncio.wait_for(task, self._timeout) - class BroadcastListenerProtocol(DeviceProtocolBase2): """Special protocol handler for when broadcast is needed.""" @@ -184,6 +178,7 @@ def connection_made(self, transport: asyncio.transports.DatagramTransport) -> No class DeviceProtocol2(DeviceProtocolBase2): """Protocol handler for direct device communication.""" + _handlers = {} def __init__(self, timeout: int = 10, drained: asyncio.Event = None) -> None: """Initialize the device protocol object. @@ -195,7 +190,6 @@ def __init__(self, timeout: int = 10, drained: asyncio.Event = None) -> None: DeviceProtocolBase2.__init__(self, timeout, drained) self._ready = asyncio.Event() self._ready.clear() - self._handlers = {} @property def ready(self) -> asyncio.Event: @@ -235,12 +229,13 @@ def packet_received(self, obj, addr: IPAddr) -> None: Response.DATA.value: lambda *args: self.__handle_state_update(*args), Response.RESULT.value: lambda *args: self.__handle_state_update(*args), } - resp = obj.get("pack", {}).get("t") - handler = handlers.get(resp, self.handle_unknown_packet) - param = [] try: + resp = obj.get("pack", {}).get("t") + handler = handlers.get(resp, self.handle_unknown_packet) param = params.get(resp, lambda o, a: (o, a))(obj, addr) handler(*param) + except AttributeError as e: + _LOGGER.exception("Error while handling packet", exc_info=e) except KeyError as e: _LOGGER.exception("Error while handling packet", exc_info=e) else: @@ -290,6 +285,5 @@ def create_status_message(self, device_info: DeviceInfo, *args) -> Dict[str, Any return self._generate_payload(Commands.STATUS, device_info, {"cols": list(args)}) def create_command_message(self, device_info: DeviceInfo, **kwargs) -> Dict[str, Any]: - return self._generate_payload(Commands.CMD, device_info, {"opt": list(kwargs.keys()), "p": list(kwargs.values())}) - - + return self._generate_payload(Commands.CMD, device_info, + {"opt": list(kwargs.keys()), "p": list(kwargs.values())}) diff --git a/tests/common.py b/tests/common.py index 0ed9651..b4e966d 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,9 +1,9 @@ -import asyncio import socket from socket import SOCK_DGRAM -from unittest.mock import Mock, create_autospec, patch +from typing import Tuple, Union +from unittest.mock import Mock -from greeclimate.network import DeviceProtocolBase2 +from greeclimate.cipher import CipherV1, CipherBase DEFAULT_TIMEOUT = 1 DISCOVERY_REQUEST = {"t": "scan"} @@ -95,10 +95,24 @@ def get_mock_device_info(): def encrypt_payload(data): """Encrypt the payload of responses quickly.""" d = data.copy() - d["pack"] = DeviceProtocolBase2.encrypt_payload(d["pack"]) + cipher = CipherV1() + d["pack"], _ = cipher.encrypt(d["pack"]) return d +class FakeCipher(CipherBase): + """Fake cipher object for testing.""" + + def __init__(self, key: bytes) -> None: + super().__init__(key) + + def encrypt(self, data) -> Tuple[str, Union[str, None]]: + return data, None + + def decrypt(self, data) -> dict: + return data + + class Responder: """Context manage for easy raw socket responders.""" diff --git a/tests/conftest.py b/tests/conftest.py index 9a64b00..aadb011 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,9 @@ import pytest +from greeclimate.device import Device +from tests.common import FakeCipher + MOCK_INTERFACES = ["lo"] MOCK_LO_IFACE = { 2: [{"addr": "10.0.0.1", "netmask": "255.0.0.0", "peer": "10.255.255.255"}] @@ -13,6 +16,22 @@ def netifaces_fixture(): """Patch netifaces interface discover.""" with patch("netifaces.interfaces", return_value=MOCK_INTERFACES), patch( - "netifaces.ifaddresses", return_value=MOCK_LO_IFACE + "netifaces.ifaddresses", return_value=MOCK_LO_IFACE ) as ifaddr_mock: yield ifaddr_mock + + +@pytest.fixture(name="cipher") +def cipher_fixture(): + """Patch the cipher object.""" + with patch("greeclimate.device.CipherV1") as mock1, patch("greeclimate.device.CipherV2") as mock2: + mock1.return_value = FakeCipher(b"1234567890123456") + mock2.return_value = FakeCipher(b"1234567890123456") + yield mock1, mock2 + + +@pytest.fixture(name="send") +def network_fixture(): + """Patch the device object.""" + with patch.object(Device, "send") as mock: + yield mock diff --git a/tests/test_cipher.py b/tests/test_cipher.py new file mode 100644 index 0000000..ec92be5 --- /dev/null +++ b/tests/test_cipher.py @@ -0,0 +1,60 @@ +import base64 + +import pytest + +from greeclimate.cipher import CipherV1, CipherV2, CipherBase + + +@pytest.fixture +def cipher_v1_key(): + return b'ThisIsASecretKey' + + +@pytest.fixture +def cipher_v1(cipher_v1_key): + return CipherV1(cipher_v1_key) + + +@pytest.fixture +def cipher_v2_key(): + return b'ThisIsASecretKey' + + +@pytest.fixture +def cipher_v2(cipher_v2_key): + return CipherV2(cipher_v2_key) + + +@pytest.fixture +def plain_text(): + return {"message": "Hello, World!"} + + +def test_encryption_then_decryption_yields_original(cipher_v1, plain_text): + encrypted, _ = cipher_v1.encrypt(plain_text) + decrypted = cipher_v1.decrypt(encrypted) + assert decrypted == plain_text + + +def test_decryption_with_modified_data_raises_error(cipher_v1, plain_text): + _, _ = cipher_v1.encrypt(plain_text) + modified_data = base64.b64encode(b"modified data ").decode() + with pytest.raises(UnicodeDecodeError): + cipher_v1.decrypt(modified_data) + + +def test_encryption_then_decryption_yields_original_with_tag(cipher_v2, plain_text): + encrypted, tag = cipher_v2.encrypt(plain_text) + decrypted = cipher_v2.decrypt(encrypted) + assert decrypted == plain_text + + +def test_cipher_base_not_implemented(): + fake_key = b'fake' + + with pytest.raises(NotImplementedError): + CipherBase(fake_key).encrypt(None) + + with pytest.raises(NotImplementedError): + CipherBase(fake_key).decrypt(None) + diff --git a/tests/test_device.py b/tests/test_device.py index 33b8663..111953b 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -1,11 +1,10 @@ import asyncio import enum -from unittest.mock import patch, AsyncMock import pytest +from greeclimate.cipher import CipherV1 from greeclimate.device import Device, DeviceInfo, Props, TemperatureUnits -from greeclimate.discovery import Discovery from greeclimate.exceptions import DeviceNotBoundError, DeviceTimeoutError @@ -158,12 +157,12 @@ def get_mock_state_0c_v3_temp(): async def generate_device_mock_async(): - d = Device(DeviceInfo("192.168.1.29", 7000, "f4911e7aca59", "1e7aca59")) - await d.bind(key="St8Vw1Yz4Bc7Ef0H") + d = Device(DeviceInfo("1.1.1.1", 7000, "f4911e7aca59", "1e7aca59")) + await d.bind(key="St8Vw1Yz4Bc7Ef0H", cipher=CipherV1()) return d -def test_device_info_equality(): +def test_device_info_equality(send): """The only way to get the key through binding is by scanning first""" props = [ @@ -187,7 +186,7 @@ def test_device_info_equality(): @pytest.mark.asyncio -async def test_get_device_info(): +async def test_get_device_info(cipher, send): """Initialize device, check properties.""" info = DeviceInfo(*get_mock_info()) @@ -196,111 +195,147 @@ async def test_get_device_info(): assert device.device_info == info fake_key = "abcdefgh12345678" - await device.bind(key=fake_key) + await device.bind(key=fake_key, cipher=CipherV1()) - assert device.device_key == fake_key + assert device.device_cipher is not None + assert device.device_cipher.key == fake_key @pytest.mark.asyncio -async def test_device_bind(): +async def test_device_bind(cipher, send): """Check that the device returns a device key when binding.""" info = DeviceInfo(*get_mock_info()) device = Device(info, timeout=1) - - assert device.device_info == info - fake_key = "abcdefgh12345678" - - def fake_send(*args): + + def fake_send(*args, **kwargs): """Emulate a bind event""" + device.device_cipher = CipherV1(fake_key.encode()) device.ready.set() device.handle_device_bound(fake_key) + send.side_effect = fake_send - with patch.object(Device, "send", side_effect=fake_send) as mock: - await device.bind() - assert mock.call_count == 1 + assert device.device_info == info + await device.bind() + assert send.call_count == 1 - assert device.device_key == fake_key + assert device.device_cipher is not None + assert device.device_cipher.key == fake_key + + # Bind with cipher already set + await device.bind() + assert send.call_count == 2 @pytest.mark.asyncio -async def test_device_bind_timeout(): +async def test_device_bind_timeout(cipher, send): """Check that the device handles timeout errors when binding.""" - info = DeviceInfo(*get_mock_info()) device = Device(info, timeout=1) - assert device.device_info == info - with pytest.raises(DeviceTimeoutError): - with patch.object(Device, "send", return_value=None) as mock: - await device.bind() - assert mock.call_count == 1 + await device.bind() + assert send.call_count == 1 - assert device.device_key is None + assert device.device_cipher is None @pytest.mark.asyncio -async def test_device_bind_none(): +async def test_device_bind_none(cipher, send): """Check that the device handles bad binding sequences.""" - info = DeviceInfo(*get_mock_info()) device = Device(info) - assert device.device_info == info - - def fake_send(*args): + def fake_send(*args, **kwargs): device.ready.set() + send.side_effect = fake_send - fake_key = None with pytest.raises(DeviceNotBoundError): - with patch.object(Device, "send", wraps=fake_send) as mock: - await device.bind() - assert mock.call_count == 1 + await device.bind() + assert send.call_count == 1 - assert device.device_key is None + assert device.device_cipher is None @pytest.mark.asyncio -async def test_device_late_bind(): +async def test_device_late_bind_from_update(cipher, send): """Check that the device handles late binding sequences.""" info = DeviceInfo(*get_mock_info()) device = Device(info, timeout=1) - assert device.device_info == info + fake_key = "abcdefgh12345678" + + def fake_send(*args, **kwargs): + device.device_cipher = CipherV1(fake_key.encode()) + device.handle_device_bound(fake_key) + device.ready.set() + send.side_effect = fake_send + + await device.update_state() + assert send.call_count == 2 + assert device.device_cipher.key == fake_key + + device.power = True + send.side_effect = None + await device.push_state_update() + + assert device.device_cipher is not None + assert device.device_cipher.key == fake_key + + +@pytest.mark.asyncio +async def test_device_late_bind_from_request_version(cipher, send): + """Check that the device handles late binding sequences.""" + info = DeviceInfo(*get_mock_info()) + device = Device(info, timeout=1) fake_key = "abcdefgh12345678" - def fake_send(*args): + def fake_send(*args, **kwargs): + device.device_cipher = CipherV1(fake_key.encode()) device.handle_device_bound(fake_key) device.ready.set() + send.side_effect = fake_send - with patch.object(Device, "send", wraps=fake_send) as mock: - await device.update_state() - assert mock.call_count == 2 - assert device.device_key == fake_key + await device.request_version() + assert send.call_count == 2 + assert device.device_cipher.key == fake_key + + +@pytest.mark.asyncio +async def test_device_bind_no_cipher(cipher, send): + """Check that the device handles late binding sequences.""" + info = DeviceInfo(*get_mock_info()) + device = Device(info, timeout=1) + fake_key = "abcdefgh12345678" + + with pytest.raises(ValueError): + await device.bind(fake_key) - device.power = True - with patch.object(Device, "send"): - await device.push_state_update() - assert device.device_key == fake_key +@pytest.mark.asyncio +async def test_device_bind_no_device_info(cipher, send): + """Check that the device handles late binding sequences.""" + device = Device(None, timeout=1) + + with pytest.raises(DeviceNotBoundError): + await device.bind() @pytest.mark.asyncio -async def test_update_properties(): +async def test_update_properties(cipher, send): """Check that properties can be updates.""" device = await generate_device_mock_async() for p in Props: assert device.get_property(p) is None - def fake_send(*args): + def fake_send(*args, **kwargs): state = get_mock_state() device.handle_state_update(**state) + send.side_effect = fake_send - with patch.object(Device, "send", side_effect=fake_send) as mock: - await device.update_state() + await device.update_state() for p in Props: assert device.get_property(p) is not None @@ -308,36 +343,29 @@ def fake_send(*args): @pytest.mark.asyncio -async def test_update_properties_timeout(): +async def test_update_properties_timeout(cipher, send): """Check that timeouts are handled when properties are updates.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - + send.side_effect = asyncio.TimeoutError with pytest.raises(DeviceTimeoutError): - with patch.object(Device, "send", side_effect=asyncio.TimeoutError): - await device.update_state() + await device.update_state() @pytest.mark.asyncio -async def test_set_properties_not_dirty(): +async def test_set_properties_not_dirty(cipher, send): """Check that the state isn't pushed when properties unchanged.""" device = await generate_device_mock_async() - with patch.object(Device, "send") as mock_request: - await device.push_state_update() - assert mock_request.call_count == 0 + await device.push_state_update() + assert send.call_count == 0 @pytest.mark.asyncio -async def test_set_properties(): +async def test_set_properties(cipher, send): """Check that state is pushed when properties are updated.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - device.power = True device.mode = 1 device.temperature_units = 1 @@ -355,9 +383,8 @@ async def test_set_properties(): device.power_save = True device.target_humidity = 30 - with patch.object(Device, "send") as mock_request: - await device.push_state_update() - mock_request.assert_called_once() + await device.push_state_update() + send.assert_called_once() for p in Props: if p not in ( @@ -375,13 +402,10 @@ async def test_set_properties(): @pytest.mark.asyncio -async def test_set_properties_timeout(): +async def test_set_properties_timeout(cipher, send): """Check timeout handling when pushing state changes.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - device.power = True device.mode = 1 device.temperature_units = 1 @@ -397,20 +421,20 @@ async def test_set_properties_timeout(): device.turbo = True device.steady_heat = True device.power_save = True + + assert len(device._dirty) + send.reset_mock() + send.side_effect = [asyncio.TimeoutError, asyncio.TimeoutError, asyncio.TimeoutError] with pytest.raises(DeviceTimeoutError): - with patch.object(Device, "send", side_effect=asyncio.TimeoutError): - await device.push_state_update() + await device.push_state_update() @pytest.mark.asyncio -async def test_uninitialized_properties(): +async def test_uninitialized_properties(cipher, send): """Check uninitialized property handling.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - assert not device.power assert device.mode is None assert device.target_temperature is None @@ -431,19 +455,16 @@ async def test_uninitialized_properties(): @pytest.mark.asyncio -async def test_update_current_temp_unsupported(): +async def test_update_current_temp_unsupported(cipher, send): """Check that properties can be updates.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - - def fake_send(*args): + def fake_send(*args, **kwargs): state = get_mock_state_no_temperature() device.handle_state_update(**state) + send.side_effect = fake_send - with patch.object(Device, "send", wraps=fake_send) as mock: - await device.update_state() + await device.update_state() assert device.get_property(Props.TEMP_SENSOR) is None assert device.current_temperature == device.target_temperature @@ -458,19 +479,16 @@ def fake_send(*args): (62, "362001061147+U-ZX6045RV1.01.bin"), ], ) -async def test_update_current_temp_v3(temsen, hid): +async def test_update_current_temp_v3(temsen, hid, cipher, send): """Check that properties can be updates.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - - def fake_send(*args): + def fake_send(*args, **kwargs): device.handle_state_update(TemSen=temsen, hid=hid) + send.side_effect = fake_send - with patch.object(Device, "send", wraps=fake_send): - await device.update_state() - + await device.update_state() + assert device.get_property(Props.TEMP_SENSOR) is not None assert device.current_temperature == temsen - 40 @@ -485,97 +503,80 @@ def fake_send(*args): (23, "362001061217+U-W04NV7.bin"), ], ) -async def test_update_current_temp_v4(temsen, hid): +async def test_update_current_temp_v4(temsen, hid, cipher, send): """Check that properties can be updates.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - - def fake_send(*args): + def fake_send(*args, **kwargs): device.handle_state_update(TemSen=temsen, hid=hid) + send.side_effect = fake_send - with patch.object(Device, "send", wraps=fake_send): - await device.update_state() + await device.update_state() assert device.get_property(Props.TEMP_SENSOR) is not None assert device.current_temperature == temsen @pytest.mark.asyncio -async def test_update_current_temp_bad(): +async def test_update_current_temp_bad(cipher, send): """Check that properties can be updates.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - - def fake_send(*args): + def fake_send(*args, **kwargs): device.handle_state_update(**get_mock_state_bad_temp()) + send.side_effect = fake_send - with patch.object(Device, "send", wraps=fake_send): - await device.update_state() + await device.update_state() assert device.current_temperature == get_mock_state_bad_temp()["TemSen"] - 40 @pytest.mark.asyncio -async def test_update_current_temp_0C_v4(): +async def test_update_current_temp_0C_v4(cipher, send): """Check that properties can be updates.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - - def fake_send(*args): + def fake_send(*args, **kwargs): device.handle_state_update(**get_mock_state_0c_v4_temp()) + send.side_effect = fake_send - with patch.object(Device, "send", wraps=fake_send): - await device.update_state() + await device.update_state() assert device.current_temperature == get_mock_state_0c_v4_temp()["TemSen"] @pytest.mark.asyncio -async def test_update_current_temp_0C_v3(): +async def test_update_current_temp_0C_v3(cipher, send): """Check for devices without a temperature sensor.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - - def fake_send(*args): + def fake_send(*args, **kwargs): device.handle_state_update(**get_mock_state_0c_v3_temp()) + send.side_effect = fake_send - with patch.object(Device, "send", wraps=fake_send): - await device.update_state() + await device.update_state() assert device.current_temperature == device.target_temperature @pytest.mark.asyncio @pytest.mark.parametrize("temperature", [18, 19, 20, 21, 22]) -async def test_send_temperature_celsius(temperature): +async def test_send_temperature_celsius(temperature, cipher, send): """Check that temperature is set and read properly in C.""" state = get_mock_state() state["TemSen"] = temperature + 40 device = await generate_device_mock_async() - - for p in Props: - assert device.get_property(p) is None - device.temperature_units = TemperatureUnits.C device.target_temperature = temperature - with patch.object(Device, "send") as mock_push: - await device.push_state_update() - assert mock_push.call_count == 1 + await device.push_state_update() + assert send.call_count == 1 - def fake_send(*args): + def fake_send(*args, **kwargs): device.handle_state_update(**state) + send.side_effect = fake_send - with patch.object(Device, "send", wraps=fake_send): - await device.update_state() + await device.update_state() assert device.current_temperature == temperature @@ -584,7 +585,7 @@ def fake_send(*args): @pytest.mark.parametrize( "temperature", [60, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 86] ) -async def test_send_temperature_farenheit(temperature): +async def test_send_temperature_farenheit(temperature, cipher, send): """Check that temperature is set and read properly in F.""" temSet = round((temperature - 32.0) * 5.0 / 9.0) temRec = (int)((((temperature - 32.0) * 5.0 / 9.0) - temSet) > 0) @@ -595,34 +596,27 @@ async def test_send_temperature_farenheit(temperature): state["TemUn"] = 1 device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - device.temperature_units = TemperatureUnits.F device.target_temperature = temperature - with patch.object(Device, "send") as mock_push: - await device.push_state_update() - assert mock_push.call_count == 1 + await device.push_state_update() + assert send.call_count == 1 - def fake_send(*args): + def fake_send(*args, **kwargs): device.handle_state_update(**state) + send.side_effect = fake_send - with patch.object(Device, "send", wraps=fake_send): - await device.update_state() + await device.update_state() assert device.current_temperature == temperature @pytest.mark.asyncio @pytest.mark.parametrize("temperature", [-270, -61, 61, 100]) -async def test_send_temperature_out_of_range_celsius(temperature): +async def test_send_temperature_out_of_range_celsius(temperature, cipher, send): """Check that bad temperatures raise the appropriate error.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - device.temperature_units = TemperatureUnits.C with pytest.raises(ValueError): device.target_temperature = temperature @@ -630,7 +624,7 @@ async def test_send_temperature_out_of_range_celsius(temperature): @pytest.mark.asyncio @pytest.mark.parametrize("temperature", [-270, -61, 141]) -async def test_send_temperature_out_of_range_farenheit_set(temperature): +async def test_send_temperature_out_of_range_farenheit_set(temperature, cipher, send): """Check that bad temperatures raise the appropriate error.""" device = await generate_device_mock_async() @@ -644,13 +638,10 @@ async def test_send_temperature_out_of_range_farenheit_set(temperature): @pytest.mark.asyncio @pytest.mark.parametrize("temperature", [-270, 150]) -async def test_send_temperature_out_of_range_farenheit_get(temperature): +async def test_send_temperature_out_of_range_farenheit_get(temperature, cipher, send): """Check that bad temperatures raise the appropriate error.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - device.set_property(Props.TEMP_SET, 20) device.set_property(Props.TEMP_SENSOR, temperature) device.set_property(Props.TEMP_BIT, 0) @@ -661,25 +652,20 @@ async def test_send_temperature_out_of_range_farenheit_get(temperature): @pytest.mark.asyncio -async def test_enable_disable_sleep_mode(): +async def test_enable_disable_sleep_mode(cipher, send): """Check that properties can be updates.""" device = await generate_device_mock_async() - for p in Props: - assert device.get_property(p) is None - device.sleep = True - with patch.object(Device, "send") as mock_push: - await device.push_state_update() - assert mock_push.call_count == 1 + await device.push_state_update() + assert send.call_count == 1 assert device.get_property(Props.SLEEP) == 1 assert device.get_property(Props.SLEEP_MODE) == 1 device.sleep = False - with patch.object(Device, "send") as mock_push: - await device.push_state_update() - assert mock_push.call_count == 1 + await device.push_state_update() + assert send.call_count == 2 assert device.get_property(Props.SLEEP) == 0 assert device.get_property(Props.SLEEP_MODE) == 0 @@ -689,7 +675,7 @@ async def test_enable_disable_sleep_mode(): @pytest.mark.parametrize( "temperature", [59, 77, 86] ) -async def test_mismatch_temrec_farenheit(temperature): +async def test_mismatch_temrec_farenheit(temperature, cipher, send): """Check that temperature is set and read properly in F.""" temSet = round((temperature - 32.0) * 5.0 / 9.0) temRec = (int)((((temperature - 32.0) * 5.0 / 9.0) - temSet) > 0) @@ -700,47 +686,51 @@ async def test_mismatch_temrec_farenheit(temperature): state["TemRec"] = (temRec + 1) % 2 state["TemUn"] = 1 device = await generate_device_mock_async() - - for p in Props: - assert device.get_property(p) is None - device.temperature_units = TemperatureUnits.F device.target_temperature = temperature - with patch.object(Device, "send") as mock_push: - await device.push_state_update() - assert mock_push.call_count == 1 + + await device.push_state_update() + assert send.call_count == 1 - def fake_send(*args): + def fake_send(*args, **kwargs): device.handle_state_update(**state) + send.side_effect = None - with patch.object(Device, "send", wraps=fake_send): - await device.update_state() + await device.update_state() assert device.current_temperature == temperature @pytest.mark.asyncio -async def test_device_equality(): +async def test_device_equality(send): """Check that two devices with the same info and key are equal.""" info1 = DeviceInfo(*get_mock_info()) device1 = Device(info1) - await device1.bind(key="fake_key") + await device1.bind(key="fake_key", cipher=CipherV1()) info2 = DeviceInfo(*get_mock_info()) device2 = Device(info2) - await device2.bind(key="fake_key") + await device2.bind(key="fake_key", cipher=CipherV1()) assert device1 == device2 # Change the key of the second device - await device2.bind(key="another_fake_key") + await device2.bind(key="another_fake_key", cipher=CipherV1()) assert device1 != device2 # Change the info of the second device info2 = DeviceInfo(*get_mock_info()) device2 = Device(info2) device2.power = True - await device2.bind(key="fake_key") + await device2.bind(key="fake_key", cipher=CipherV1()) assert device1 != device2 + +def test_device_key_set_get(): + """Check that the device key can be set and retrieved.""" + device = Device(DeviceInfo(*get_mock_info())) + device.device_cipher = CipherV1() + device.device_key = "fake_key" + assert device.device_key == "fake_key" + \ No newline at end of file diff --git a/tests/test_discovery.py b/tests/test_discovery.py index 69784fa..d341ff1 100644 --- a/tests/test_discovery.py +++ b/tests/test_discovery.py @@ -1,24 +1,18 @@ import asyncio import json import socket -from asyncio.tasks import wait_for from threading import Thread -from unittest.mock import MagicMock, PropertyMock, create_autospec, patch +from unittest.mock import MagicMock, PropertyMock, patch import pytest -from greeclimate.device import DeviceInfo from greeclimate.discovery import Discovery, Listener -from greeclimate.network import DeviceProtocolBase2 - from .common import ( DEFAULT_TIMEOUT, DISCOVERY_REQUEST, DISCOVERY_RESPONSE, - DISCOVERY_RESPONSE_NO_CID, Responder, - encrypt_payload, - get_mock_device_info, + get_mock_device_info, encrypt_payload, ) diff --git a/tests/test_issues.py b/tests/test_issues.py index 7c29438..ba08eb0 100644 --- a/tests/test_issues.py +++ b/tests/test_issues.py @@ -15,7 +15,7 @@ async def test_issue_69_TemSen_40_should_not_set_firmware_v4(): for p in Props: assert device.get_property(p) is None - def fake_send(*args): + def fake_send(*args, **kwargs): device.handle_state_update(**mock_v3_state) with patch.object(Device, "send", wraps=fake_send()): @@ -35,7 +35,7 @@ async def test_issue_87_quiet_should_set_2(): assert device.get_property(Props.QUIET) is None device.quiet = True - def fake_send(*args): + def fake_send(*args, **kwargs): device.handle_state_update(**mock_v3_state) with patch.object(Device, "send", wraps=fake_send()) as mock: diff --git a/tests/test_network.py b/tests/test_network.py index 0e0dc0f..18c9ffa 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -2,8 +2,7 @@ import json import socket from threading import Thread -from typing import Any -from unittest.mock import create_autospec, patch, MagicMock +from unittest.mock import patch, MagicMock import pytest @@ -14,15 +13,13 @@ IPAddr, DeviceProtocol2, Commands, Response, ) - from .common import ( DEFAULT_RESPONSE, DEFAULT_TIMEOUT, DISCOVERY_REQUEST, DISCOVERY_RESPONSE, Responder, - encrypt_payload, - get_mock_device_info, DEFAULT_REQUEST, generate_response, + DEFAULT_REQUEST, generate_response, FakeCipher, ) from .test_device import get_mock_info @@ -44,6 +41,7 @@ class FakeDeviceProtocol(DeviceProtocol2): def __init__(self, drained: asyncio.Event = None): super().__init__(timeout=1, drained=drained) self.packets = asyncio.Queue() + self.device_cipher = FakeCipher(b"1234567890123456") def packet_received(self, obj, addr: IPAddr) -> None: self.packets.put_nowait(obj) @@ -90,8 +88,8 @@ async def test_set_get_key(): """Test the encryption key property.""" key = "faketestkey" dp2 = DeviceProtocolBase2() - dp2.device_key = key - assert dp2.device_key == key + dp2.device_cipher = FakeCipher(key.encode()) + assert dp2.device_cipher.key == key @pytest.mark.asyncio @@ -151,7 +149,7 @@ def responder(s): p = json.loads(d) assert p == DISCOVERY_REQUEST - p = json.dumps(encrypt_payload(DISCOVERY_RESPONSE)) + p = json.dumps(DISCOVERY_RESPONSE) s.sendto(p.encode(), addr) serv = Thread(target=responder, args=(sock,)) @@ -164,6 +162,7 @@ def responder(s): local_addr = (addr[0], 0) dp2 = FakeDiscoveryProtocol() + dp2.device_cipher = FakeCipher(b"1234567890123456") await loop.create_datagram_endpoint( lambda: dp2, local_addr=local_addr, @@ -240,44 +239,30 @@ def responder(s): lambda: FakeDeviceProtocol(drained=drained), remote_addr=remote_addr ) - with patch("greeclimate.network.DeviceProtocolBase2.decrypt_payload", new_callable=MagicMock) as mock: - mock.side_effect = lambda x, y: x - - # Send the scan command - await protocol.send(DEFAULT_REQUEST, None) + # Send the scan command + await protocol.send(DEFAULT_REQUEST, None, FakeCipher(b"1234567890123456")) - # Wait on the scan response - task = asyncio.create_task(protocol.packets.get()) - await asyncio.wait_for(task, DEFAULT_TIMEOUT) - response = task.result() + # Wait on the scan response + task = asyncio.create_task(protocol.packets.get()) + await asyncio.wait_for(task, DEFAULT_TIMEOUT) + response = task.result() - assert response == DEFAULT_RESPONSE + assert response == DEFAULT_RESPONSE sock.close() serv.join(timeout=DEFAULT_TIMEOUT) -def test_encrypt_decrypt_payload(): - test_object = {"fake-key": "fake-value"} - - encrypted = DeviceProtocolBase2.encrypt_payload(test_object) - assert encrypted != test_object - - decrypted = DeviceProtocolBase2.decrypt_payload(encrypted) - assert decrypted == test_object - - -@pytest.mark.asyncio def test_bindok_handling(): """Test the bindok response.""" response = generate_response({"t": "bindok", "key": "fake-key"}) protocol = DeviceProtocol2(timeout=DEFAULT_TIMEOUT) - with patch("greeclimate.network.DeviceProtocolBase2.decrypt_payload", new_callable=MagicMock) as mock_decrypt: - mock_decrypt.side_effect = lambda x, y: x - with patch.object(DeviceProtocol2, "handle_device_bound") as mock: - protocol.datagram_received(json.dumps(response).encode(), ("0.0.0.0", 0)) - assert mock.call_count == 1 - assert mock.call_args[0][0] == "fake-key" + protocol.device_cipher = FakeCipher(b"1234567890123456") + + with patch.object(DeviceProtocol2, "handle_device_bound") as mock: + protocol.datagram_received(json.dumps(response).encode(), ("0.0.0.0", 0)) + assert mock.call_count == 1 + assert mock.call_args[0][0] == "fake-key" def test_create_bind_message(): @@ -512,3 +497,78 @@ async def test_add_and_remove_handler(event_name, data): # Check that the callback was not called this time callback.assert_not_called() + +def test_packet_received_not_implemented(): + # Arrange + protocol = DeviceProtocolBase2() + + # Act + with pytest.raises(NotImplementedError): + protocol.packet_received({}, ("0.0.0.0", 0)) + + +def test_packet_received_invalid_data(): + # Arrange + protocol = DeviceProtocol2() + + # Act + protocol.packet_received(None, ("0.0.0.0", 0)) + protocol.packet_received({}, ("0.0.0.0", 0)) + protocol.packet_received({"pack"}, ("0.0.0.0", 0)) + + with patch.object(protocol, "handle_unknown_packet") as mock: + protocol.packet_received({"pack": {"t": "unknown"}}, ("0.0.0.0", 0)) + mock.assert_called_once() + + +def test_set_get_cipher(): + # Arrange + protocol = DeviceProtocolBase2() + cipher = FakeCipher(b"1234567890123456") + + # Act + protocol.device_cipher = cipher + + # Assert + assert protocol.device_cipher == cipher + + +def test_cipher_is_not_set(): + # Arrange + protocol = DeviceProtocolBase2() + + # Act + key = None + with pytest.raises(ValueError): + key = protocol.device_key + + assert key is None + + with pytest.raises(ValueError): + protocol.device_key = "fake-key" + + +def test_add_invalid_handler(): + # Arrange + protocol = DeviceProtocol2() + callback = MagicMock() + + # Act + with pytest.raises(ValueError): + protocol.add_handler(Response("invalid"), callback) + + with pytest.raises(ValueError): + protocol.add_handler(Response("invalid"), callback) + + +def test_device_key_get_set(): + # Arrange + protocol = DeviceProtocolBase2 + key = "fake-key" + + # Act + protocol.device_key = key + + # Assert + assert protocol.device_key == key + \ No newline at end of file