diff --git a/.idea/codeStyles/codeStyleConfig.xml b/.idea/codeStyles/codeStyleConfig.xml
new file mode 100644
index 0000000..79ee123
--- /dev/null
+++ b/.idea/codeStyles/codeStyleConfig.xml
@@ -0,0 +1,5 @@
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/material_theme_project_new.xml b/.idea/material_theme_project_new.xml
new file mode 100644
index 0000000..374ec3d
--- /dev/null
+++ b/.idea/material_theme_project_new.xml
@@ -0,0 +1,17 @@
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/ruff.xml b/.idea/ruff.xml
new file mode 100644
index 0000000..045d9eb
--- /dev/null
+++ b/.idea/ruff.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/runConfigurations/pytest_in__.xml b/.idea/runConfigurations/pytest_in__.xml
new file mode 100644
index 0000000..316c583
--- /dev/null
+++ b/.idea/runConfigurations/pytest_in__.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/emulator.py b/emulator.py
index a3ac6b5..11d0b09 100644
--- a/emulator.py
+++ b/emulator.py
@@ -6,7 +6,6 @@
import socket
import time
-import machine
import network
import ubinascii
from ucryptolib import aes
diff --git a/greeclimate/cipher.py b/greeclimate/cipher.py
new file mode 100644
index 0000000..be8aabd
--- /dev/null
+++ b/greeclimate/cipher.py
@@ -0,0 +1,88 @@
+import base64
+import json
+import logging
+from typing import Union, Tuple
+
+from Crypto.Cipher import AES
+
+_logger = logging.getLogger(__name__)
+
+class CipherBase:
+ def __init__(self, key: bytes) -> None:
+ self._key: bytes = key
+
+ @property
+ def key(self) -> str:
+ return self._key.decode()
+
+ @key.setter
+ def key(self, value: str) -> None:
+ self._key = value.encode()
+
+ def encrypt(self, data) -> Tuple[str, Union[str, None]]:
+ raise NotImplementedError
+
+ def decrypt(self, data) -> dict:
+ raise NotImplementedError
+
+
+class CipherV1(CipherBase):
+ def __init__(self, key: bytes = b'a3K8Bx%2r8Y7#xDh') -> None:
+ super().__init__(key)
+
+ def __create_cipher(self) -> AES:
+ return AES.new(self._key, AES.MODE_ECB)
+
+ def __pad(self, s) -> str:
+ return s + (16 - len(s) % 16) * chr(16 - len(s) % 16)
+
+ def encrypt(self, data) -> Tuple[str, Union[str, None]]:
+ _logger.debug("Encrypting data: %s", data)
+ cipher = self.__create_cipher()
+ padded = self.__pad(json.dumps(data)).encode()
+ encrypted = cipher.encrypt(padded)
+ encoded = base64.b64encode(encrypted).decode()
+ _logger.debug("Encrypted data: %s", encoded)
+ return encoded, None
+
+ def decrypt(self, data) -> dict:
+ _logger.debug("Decrypting data: %s", data)
+ cipher = self.__create_cipher()
+ decoded = base64.b64decode(data)
+ decrypted = cipher.decrypt(decoded).decode()
+ t = decrypted.replace(decrypted[decrypted.rindex('}') + 1:], '')
+ _logger.debug("Decrypted data: %s", t)
+ return json.loads(t)
+
+
+class CipherV2(CipherBase):
+ GCM_NONCE = b'\x54\x40\x78\x44\x49\x67\x5a\x51\x6c\x5e\x63\x13'
+ GCM_AEAD = b'qualcomm-test'
+
+ def __init__(self, key: bytes = b'{yxAHAY_Lm6pbC/<') -> None:
+ super().__init__(key)
+
+ def __create_cipher(self) -> AES:
+ cipher = AES.new(self._key, AES.MODE_GCM, nonce=self.GCM_NONCE)
+ cipher.update(self.GCM_AEAD)
+ return cipher
+
+ def encrypt(self, data) -> Tuple[str, str]:
+ _logger.debug("Encrypting data: %s", data)
+ cipher = self.__create_cipher()
+ encrypted, tag = cipher.encrypt_and_digest(json.dumps(data).encode())
+ encoded = base64.b64encode(encrypted).decode()
+ tag = base64.b64encode(tag).decode()
+ _logger.debug("Encrypted data: %s", encoded)
+ _logger.debug("Cipher digest: %s", tag)
+ return encoded, tag
+
+ def decrypt(self, data) -> dict:
+ _logger.info("Decrypting data: %s", data)
+ cipher = self.__create_cipher()
+ decoded = base64.b64decode(data)
+ decrypted = cipher.decrypt(decoded).decode()
+ t = decrypted.replace(decrypted[decrypted.rindex('}') + 1:], '')
+ _logger.debug("Decrypted data: %s", t)
+ return json.loads(t)
+
diff --git a/greeclimate/device.py b/greeclimate/device.py
index 93ed49e..486a0dc 100644
--- a/greeclimate/device.py
+++ b/greeclimate/device.py
@@ -4,12 +4,12 @@
import re
from asyncio import AbstractEventLoop
from enum import IntEnum, unique
-from typing import List
+from typing import Union
-import greeclimate.network as network
+from greeclimate.cipher import CipherV1, CipherV2
from greeclimate.deviceinfo import DeviceInfo
-from greeclimate.network import DeviceProtocol2, IPAddr
from greeclimate.exceptions import DeviceNotBoundError, DeviceTimeoutError
+from greeclimate.network import DeviceProtocol2
from greeclimate.taskable import Taskable
@@ -155,20 +155,23 @@ class Device(DeviceProtocol2, Taskable):
water_full: A bool to indicate the water tank is full
"""
- def __init__(self, device_info: DeviceInfo, timeout: int = 120, loop: AbstractEventLoop = None):
+ def __init__(self, device_info: DeviceInfo, timeout: int = 120, bind_timeout: int = 10, loop: AbstractEventLoop = None):
"""Initialize the device object
Args:
device_info (DeviceInfo): Information about the physical device
timeout (int): Timeout for device communication
+ bind_timeout (int): Timeout for binding to the device, keep this short to prevent delays determining the
+ correct device cipher to use
loop (AbstractEventLoop): The event loop to run the device operations on
"""
DeviceProtocol2.__init__(self, timeout)
Taskable.__init__(self, loop)
self._logger = logging.getLogger(__name__)
-
self.device_info: DeviceInfo = device_info
-
+
+ self._bind_timeout = bind_timeout
+
""" Device properties """
self.hid = None
self.version = None
@@ -176,7 +179,11 @@ def __init__(self, device_info: DeviceInfo, timeout: int = 120, loop: AbstractEv
self._properties = {}
self._dirty = []
- async def bind(self, key=None):
+ async def bind(
+ self,
+ key: str = None,
+ cipher: Union[CipherV1, CipherV2, None] = None,
+ ):
"""Run the binding procedure.
Binding is a finicky procedure, and happens in 1 of 2 ways:
@@ -187,14 +194,23 @@ async def bind(self, key=None):
Both approaches result in a device_key which is used as like a persistent session id.
Args:
+ cipher (CipherV1 | CipherV2): The cipher type to use for encryption, if None will attempt to detect the correct one
key (str): The device key, when provided binding is a NOOP, if None binding will
- attempt to negotiate the key with the device.
+ attempt to negotiate the key with the device. cipher must be provided.
Raises:
DeviceNotBoundError: If binding was unsuccessful and no key returned
DeviceTimeoutError: The device didn't respond
"""
+ if key:
+ if not cipher:
+ raise ValueError("cipher must be provided when key is provided")
+ else:
+ cipher.key = key
+ self.device_cipher = cipher
+ return
+
if not self.device_info:
raise DeviceNotBoundError
@@ -206,29 +222,38 @@ async def bind(self, key=None):
self._logger.info("Starting device binding to %s", str(self.device_info))
try:
- if key:
- self.device_key = key
+ if cipher is not None:
+ await self.__bind_internal(cipher)
else:
- await self.send(self.create_bind_message(self.device_info))
- # Special case, wait for binding to complete so we know that the device is ready
- task = asyncio.create_task(self.ready.wait())
- await asyncio.wait_for(task, timeout=self._timeout)
+ """ Try binding with CipherV1 first, if that fails try CipherV2"""
+ try:
+ self._logger.info("Attempting to bind to device using CipherV1")
+ await self.__bind_internal(CipherV1())
+ except asyncio.TimeoutError:
+ self._logger.info("Attempting to bind to device using CipherV2")
+ await self.__bind_internal(CipherV2())
except asyncio.TimeoutError:
raise DeviceTimeoutError
- if not self.device_key:
+ if not self.device_cipher:
raise DeviceNotBoundError
else:
- self._logger.info("Bound to device using key %s", self.device_key)
+ self._logger.info("Bound to device using key %s", self.device_cipher.key)
+
+ async def __bind_internal(self, cipher: Union[CipherV1, CipherV2]):
+ """Internal binding procedure, do not call directly"""
+ await self.send(self.create_bind_message(self.device_info), cipher=cipher)
+ task = asyncio.create_task(self.ready.wait())
+ await asyncio.wait_for(task, timeout=self._bind_timeout)
- def handle_device_bound(self, key) -> None:
+ def handle_device_bound(self, key: str) -> None:
"""Handle the device bound message from the device"""
- self.device_key = key
+ self.device_cipher.key = key
async def request_version(self) -> None:
"""Request the firmware version from the device."""
- if not self.device_key:
+ if not self.device_cipher:
await self.bind()
try:
@@ -243,7 +268,7 @@ async def update_state(self, wait_for: float = 30):
Args:
wait_for (object): How long to wait for an update from the device
"""
- if not self.device_key:
+ if not self.device_cipher:
await self.bind()
self._logger.debug("Updating device properties for (%s)", str(self.device_info))
@@ -288,7 +313,7 @@ async def push_state_update(self, wait_for: float = 30):
if not self._dirty:
return
- if not self.device_key:
+ if not self.device_cipher:
await self.bind()
self._logger.debug("Pushing state updates to (%s)", str(self.device_info))
@@ -316,7 +341,7 @@ def __eq__(self, other):
"""Compare two devices for equality based on their properties state and device info."""
return self.device_info == other.device_info \
and self.raw_properties == other.raw_properties \
- and self.device_key == other.device_key
+ and self.device_cipher.key == other.device_cipher.key
def __ne__(self, other):
return not self.__eq__(other)
diff --git a/greeclimate/discovery.py b/greeclimate/discovery.py
index 5e18f48..0eb7015 100644
--- a/greeclimate/discovery.py
+++ b/greeclimate/discovery.py
@@ -6,6 +6,7 @@
from asyncio.events import AbstractEventLoop
from ipaddress import IPv4Address
+from greeclimate.cipher import CipherV1
from greeclimate.device import DeviceInfo
from greeclimate.network import BroadcastListenerProtocol, IPAddr
from greeclimate.taskable import Taskable
@@ -45,6 +46,7 @@ def __init__(
"""
BroadcastListenerProtocol.__init__(self, timeout)
Taskable.__init__(self, loop)
+ self.device_cipher = CipherV1()
self._allow_loopback: bool = allow_loopback
self._device_infos: list[DeviceInfo] = []
self._listeners: list[Listener] = []
diff --git a/greeclimate/network.py b/greeclimate/network.py
index aabea91..b8e432b 100644
--- a/greeclimate/network.py
+++ b/greeclimate/network.py
@@ -1,24 +1,17 @@
import asyncio
-import base64
import json
import logging
import socket
from dataclasses import dataclass
from enum import Enum
-from typing import Any, Dict, Text, Tuple, Union
-
-from Crypto.Cipher import AES
+from typing import Any, Dict, Tuple, Union
+from greeclimate.cipher import CipherBase
from greeclimate.deviceinfo import DeviceInfo
NETWORK_TIMEOUT = 10
-GENERIC_KEY = ["a3K8Bx%2r8Y7#xDh"]
-GCM_NONCE = b'\x54\x40\x78\x44\x49\x67\x5a\x51\x6c\x5e\x63\x13'
-GCM_AEAD = b'qualcomm-test'
-
_LOGGER = logging.getLogger(__name__)
-
IPAddr = Tuple[str, int]
@@ -52,11 +45,13 @@ def __init__(self, timeout: int = 10, drained: asyncio.Event = None) -> None:
timeout (int): Packet send timeout
drained (asyncio.Event): Packet send drain event signal
"""
- self._timeout = timeout
- self._drained = drained or asyncio.Event()
+ self._timeout: int = timeout
+ self._drained: asyncio.Event = drained or asyncio.Event()
self._drained.set()
- self._transport = None
- self._key = None
+
+ self._transport: Union[asyncio.transports.DatagramTransport, None] = None
+ self._cipher: Union[CipherBase, None] = None
+
# This event need to be implemented to handle incoming requests
def packet_received(self, obj, addr: IPAddr) -> None:
@@ -69,14 +64,28 @@ def packet_received(self, obj, addr: IPAddr) -> None:
raise NotImplementedError("packet_received must be implemented in a subclass")
@property
- def device_key(self) -> str:
+ def device_cipher(self) -> CipherBase:
"""Sets the encryption key used for device data."""
- return self._key
+ return self._cipher
+
+ @device_cipher.setter
+ def device_cipher(self, value: CipherBase):
+ """Gets the encryption key used for device data."""
+ self._cipher = value
+
+ @property
+ def device_key(self) -> str:
+ """Gets the encryption key used for device data."""
+ if self._cipher is None:
+ raise ValueError("Cipher object not set")
+ return self._cipher.key
@device_key.setter
def device_key(self, value: str):
- """Gets the encryption key used for device data."""
- self._key = value
+ """Sets the encryption key used for device data."""
+ if self._cipher is None:
+ raise ValueError("Cipher object not set")
+ self._cipher.key = value
def close(self) -> None:
"""Close the UDP transport."""
@@ -116,27 +125,6 @@ def resume_writing(self) -> None:
self._drained.set()
super().resume_writing()
- @staticmethod
- def decrypt_payload(payload, key=GENERIC_KEY[0]):
- cipher = AES.new(key.encode(), AES.MODE_ECB)
- decoded = base64.b64decode(payload)
- decrypted = cipher.decrypt(decoded).decode()
- t = decrypted.replace(decrypted[decrypted.rindex("}") + 1 :], "")
- return json.loads(t)
-
- @staticmethod
- def encrypt_payload(payload, key=GENERIC_KEY[0]):
- def pad(s):
- bs = 16
- return s + (bs - len(s) % bs) * chr(bs - len(s) % bs)
-
- cipher = AES.new(key.encode(), AES.MODE_ECB)
- encrypted = cipher.encrypt(pad(json.dumps(payload)).encode())
- encoded = base64.b64encode(encrypted).decode()
-
- _LOGGER.debug(f"Encrypted payload with key [{key}]: {encoded}")
- return encoded
-
def datagram_received(self, data: bytes, addr: IPAddr) -> None:
"""Handle an incoming datagram."""
if len(data) == 0:
@@ -144,23 +132,30 @@ def datagram_received(self, data: bytes, addr: IPAddr) -> None:
obj = json.loads(data)
- # It could be either a v1 or v2 key
- key = GENERIC_KEY[0] if obj.get("i") == 1 else self._key
-
if obj.get("pack"):
- obj["pack"] = DeviceProtocolBase2.decrypt_payload(obj["pack"], key)
-
- _LOGGER.debug("Received packet from %s:\n%s", addr[0], json.dumps(obj))
+ obj["pack"] = self._cipher.decrypt(obj["pack"])
+ _LOGGER.debug("Received packet from %s:\n<- %s", addr[0], json.dumps(obj))
self.packet_received(obj, addr)
- async def send(self, obj, addr: IPAddr = None) -> None:
- """Send encode and send JSON command to the device."""
- _LOGGER.debug("Sending packet:\n%s", json.dumps(obj))
+ async def send(self, obj, addr: IPAddr = None, cipher: Union[CipherBase, None] = None) -> None:
+ """Send encode and send JSON command to the device.
+
+ Args:
+ addr (object): (Optional) Address to send the message
+ cipher (object): (Optional) Initial cipher to use for SCANNING and BINDING
+ """
+ _LOGGER.debug("Sending packet:\n-> %s", json.dumps(obj))
if obj.get("pack"):
- key = GENERIC_KEY[0] if obj.get("i") == 1 else self._key
- obj["pack"] = DeviceProtocolBase2.encrypt_payload(obj["pack"], key)
+ if obj.get("i") == 1:
+ if cipher is None:
+ raise ValueError("Cipher must be supplied for SCAN or BIND messages")
+ self._cipher = cipher
+
+ obj["pack"], tag = self._cipher.encrypt(obj["pack"])
+ if tag:
+ obj["tag"] = tag
data_bytes = json.dumps(obj).encode()
self._transport.sendto(data_bytes, addr)
@@ -169,7 +164,6 @@ async def send(self, obj, addr: IPAddr = None) -> None:
await asyncio.wait_for(task, self._timeout)
-
class BroadcastListenerProtocol(DeviceProtocolBase2):
"""Special protocol handler for when broadcast is needed."""
@@ -184,6 +178,7 @@ def connection_made(self, transport: asyncio.transports.DatagramTransport) -> No
class DeviceProtocol2(DeviceProtocolBase2):
"""Protocol handler for direct device communication."""
+ _handlers = {}
def __init__(self, timeout: int = 10, drained: asyncio.Event = None) -> None:
"""Initialize the device protocol object.
@@ -195,7 +190,6 @@ def __init__(self, timeout: int = 10, drained: asyncio.Event = None) -> None:
DeviceProtocolBase2.__init__(self, timeout, drained)
self._ready = asyncio.Event()
self._ready.clear()
- self._handlers = {}
@property
def ready(self) -> asyncio.Event:
@@ -235,12 +229,13 @@ def packet_received(self, obj, addr: IPAddr) -> None:
Response.DATA.value: lambda *args: self.__handle_state_update(*args),
Response.RESULT.value: lambda *args: self.__handle_state_update(*args),
}
- resp = obj.get("pack", {}).get("t")
- handler = handlers.get(resp, self.handle_unknown_packet)
- param = []
try:
+ resp = obj.get("pack", {}).get("t")
+ handler = handlers.get(resp, self.handle_unknown_packet)
param = params.get(resp, lambda o, a: (o, a))(obj, addr)
handler(*param)
+ except AttributeError as e:
+ _LOGGER.exception("Error while handling packet", exc_info=e)
except KeyError as e:
_LOGGER.exception("Error while handling packet", exc_info=e)
else:
@@ -290,6 +285,5 @@ def create_status_message(self, device_info: DeviceInfo, *args) -> Dict[str, Any
return self._generate_payload(Commands.STATUS, device_info, {"cols": list(args)})
def create_command_message(self, device_info: DeviceInfo, **kwargs) -> Dict[str, Any]:
- return self._generate_payload(Commands.CMD, device_info, {"opt": list(kwargs.keys()), "p": list(kwargs.values())})
-
-
+ return self._generate_payload(Commands.CMD, device_info,
+ {"opt": list(kwargs.keys()), "p": list(kwargs.values())})
diff --git a/tests/common.py b/tests/common.py
index 0ed9651..b4e966d 100644
--- a/tests/common.py
+++ b/tests/common.py
@@ -1,9 +1,9 @@
-import asyncio
import socket
from socket import SOCK_DGRAM
-from unittest.mock import Mock, create_autospec, patch
+from typing import Tuple, Union
+from unittest.mock import Mock
-from greeclimate.network import DeviceProtocolBase2
+from greeclimate.cipher import CipherV1, CipherBase
DEFAULT_TIMEOUT = 1
DISCOVERY_REQUEST = {"t": "scan"}
@@ -95,10 +95,24 @@ def get_mock_device_info():
def encrypt_payload(data):
"""Encrypt the payload of responses quickly."""
d = data.copy()
- d["pack"] = DeviceProtocolBase2.encrypt_payload(d["pack"])
+ cipher = CipherV1()
+ d["pack"], _ = cipher.encrypt(d["pack"])
return d
+class FakeCipher(CipherBase):
+ """Fake cipher object for testing."""
+
+ def __init__(self, key: bytes) -> None:
+ super().__init__(key)
+
+ def encrypt(self, data) -> Tuple[str, Union[str, None]]:
+ return data, None
+
+ def decrypt(self, data) -> dict:
+ return data
+
+
class Responder:
"""Context manage for easy raw socket responders."""
diff --git a/tests/conftest.py b/tests/conftest.py
index 9a64b00..aadb011 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -3,6 +3,9 @@
import pytest
+from greeclimate.device import Device
+from tests.common import FakeCipher
+
MOCK_INTERFACES = ["lo"]
MOCK_LO_IFACE = {
2: [{"addr": "10.0.0.1", "netmask": "255.0.0.0", "peer": "10.255.255.255"}]
@@ -13,6 +16,22 @@
def netifaces_fixture():
"""Patch netifaces interface discover."""
with patch("netifaces.interfaces", return_value=MOCK_INTERFACES), patch(
- "netifaces.ifaddresses", return_value=MOCK_LO_IFACE
+ "netifaces.ifaddresses", return_value=MOCK_LO_IFACE
) as ifaddr_mock:
yield ifaddr_mock
+
+
+@pytest.fixture(name="cipher")
+def cipher_fixture():
+ """Patch the cipher object."""
+ with patch("greeclimate.device.CipherV1") as mock1, patch("greeclimate.device.CipherV2") as mock2:
+ mock1.return_value = FakeCipher(b"1234567890123456")
+ mock2.return_value = FakeCipher(b"1234567890123456")
+ yield mock1, mock2
+
+
+@pytest.fixture(name="send")
+def network_fixture():
+ """Patch the device object."""
+ with patch.object(Device, "send") as mock:
+ yield mock
diff --git a/tests/test_cipher.py b/tests/test_cipher.py
new file mode 100644
index 0000000..ec92be5
--- /dev/null
+++ b/tests/test_cipher.py
@@ -0,0 +1,60 @@
+import base64
+
+import pytest
+
+from greeclimate.cipher import CipherV1, CipherV2, CipherBase
+
+
+@pytest.fixture
+def cipher_v1_key():
+ return b'ThisIsASecretKey'
+
+
+@pytest.fixture
+def cipher_v1(cipher_v1_key):
+ return CipherV1(cipher_v1_key)
+
+
+@pytest.fixture
+def cipher_v2_key():
+ return b'ThisIsASecretKey'
+
+
+@pytest.fixture
+def cipher_v2(cipher_v2_key):
+ return CipherV2(cipher_v2_key)
+
+
+@pytest.fixture
+def plain_text():
+ return {"message": "Hello, World!"}
+
+
+def test_encryption_then_decryption_yields_original(cipher_v1, plain_text):
+ encrypted, _ = cipher_v1.encrypt(plain_text)
+ decrypted = cipher_v1.decrypt(encrypted)
+ assert decrypted == plain_text
+
+
+def test_decryption_with_modified_data_raises_error(cipher_v1, plain_text):
+ _, _ = cipher_v1.encrypt(plain_text)
+ modified_data = base64.b64encode(b"modified data ").decode()
+ with pytest.raises(UnicodeDecodeError):
+ cipher_v1.decrypt(modified_data)
+
+
+def test_encryption_then_decryption_yields_original_with_tag(cipher_v2, plain_text):
+ encrypted, tag = cipher_v2.encrypt(plain_text)
+ decrypted = cipher_v2.decrypt(encrypted)
+ assert decrypted == plain_text
+
+
+def test_cipher_base_not_implemented():
+ fake_key = b'fake'
+
+ with pytest.raises(NotImplementedError):
+ CipherBase(fake_key).encrypt(None)
+
+ with pytest.raises(NotImplementedError):
+ CipherBase(fake_key).decrypt(None)
+
diff --git a/tests/test_device.py b/tests/test_device.py
index 33b8663..111953b 100644
--- a/tests/test_device.py
+++ b/tests/test_device.py
@@ -1,11 +1,10 @@
import asyncio
import enum
-from unittest.mock import patch, AsyncMock
import pytest
+from greeclimate.cipher import CipherV1
from greeclimate.device import Device, DeviceInfo, Props, TemperatureUnits
-from greeclimate.discovery import Discovery
from greeclimate.exceptions import DeviceNotBoundError, DeviceTimeoutError
@@ -158,12 +157,12 @@ def get_mock_state_0c_v3_temp():
async def generate_device_mock_async():
- d = Device(DeviceInfo("192.168.1.29", 7000, "f4911e7aca59", "1e7aca59"))
- await d.bind(key="St8Vw1Yz4Bc7Ef0H")
+ d = Device(DeviceInfo("1.1.1.1", 7000, "f4911e7aca59", "1e7aca59"))
+ await d.bind(key="St8Vw1Yz4Bc7Ef0H", cipher=CipherV1())
return d
-def test_device_info_equality():
+def test_device_info_equality(send):
"""The only way to get the key through binding is by scanning first"""
props = [
@@ -187,7 +186,7 @@ def test_device_info_equality():
@pytest.mark.asyncio
-async def test_get_device_info():
+async def test_get_device_info(cipher, send):
"""Initialize device, check properties."""
info = DeviceInfo(*get_mock_info())
@@ -196,111 +195,147 @@ async def test_get_device_info():
assert device.device_info == info
fake_key = "abcdefgh12345678"
- await device.bind(key=fake_key)
+ await device.bind(key=fake_key, cipher=CipherV1())
- assert device.device_key == fake_key
+ assert device.device_cipher is not None
+ assert device.device_cipher.key == fake_key
@pytest.mark.asyncio
-async def test_device_bind():
+async def test_device_bind(cipher, send):
"""Check that the device returns a device key when binding."""
info = DeviceInfo(*get_mock_info())
device = Device(info, timeout=1)
-
- assert device.device_info == info
-
fake_key = "abcdefgh12345678"
-
- def fake_send(*args):
+
+ def fake_send(*args, **kwargs):
"""Emulate a bind event"""
+ device.device_cipher = CipherV1(fake_key.encode())
device.ready.set()
device.handle_device_bound(fake_key)
+ send.side_effect = fake_send
- with patch.object(Device, "send", side_effect=fake_send) as mock:
- await device.bind()
- assert mock.call_count == 1
+ assert device.device_info == info
+ await device.bind()
+ assert send.call_count == 1
- assert device.device_key == fake_key
+ assert device.device_cipher is not None
+ assert device.device_cipher.key == fake_key
+
+ # Bind with cipher already set
+ await device.bind()
+ assert send.call_count == 2
@pytest.mark.asyncio
-async def test_device_bind_timeout():
+async def test_device_bind_timeout(cipher, send):
"""Check that the device handles timeout errors when binding."""
-
info = DeviceInfo(*get_mock_info())
device = Device(info, timeout=1)
- assert device.device_info == info
-
with pytest.raises(DeviceTimeoutError):
- with patch.object(Device, "send", return_value=None) as mock:
- await device.bind()
- assert mock.call_count == 1
+ await device.bind()
+ assert send.call_count == 1
- assert device.device_key is None
+ assert device.device_cipher is None
@pytest.mark.asyncio
-async def test_device_bind_none():
+async def test_device_bind_none(cipher, send):
"""Check that the device handles bad binding sequences."""
-
info = DeviceInfo(*get_mock_info())
device = Device(info)
- assert device.device_info == info
-
- def fake_send(*args):
+ def fake_send(*args, **kwargs):
device.ready.set()
+ send.side_effect = fake_send
- fake_key = None
with pytest.raises(DeviceNotBoundError):
- with patch.object(Device, "send", wraps=fake_send) as mock:
- await device.bind()
- assert mock.call_count == 1
+ await device.bind()
+ assert send.call_count == 1
- assert device.device_key is None
+ assert device.device_cipher is None
@pytest.mark.asyncio
-async def test_device_late_bind():
+async def test_device_late_bind_from_update(cipher, send):
"""Check that the device handles late binding sequences."""
info = DeviceInfo(*get_mock_info())
device = Device(info, timeout=1)
- assert device.device_info == info
+ fake_key = "abcdefgh12345678"
+
+ def fake_send(*args, **kwargs):
+ device.device_cipher = CipherV1(fake_key.encode())
+ device.handle_device_bound(fake_key)
+ device.ready.set()
+ send.side_effect = fake_send
+
+ await device.update_state()
+ assert send.call_count == 2
+ assert device.device_cipher.key == fake_key
+
+ device.power = True
+ send.side_effect = None
+ await device.push_state_update()
+
+ assert device.device_cipher is not None
+ assert device.device_cipher.key == fake_key
+
+
+@pytest.mark.asyncio
+async def test_device_late_bind_from_request_version(cipher, send):
+ """Check that the device handles late binding sequences."""
+ info = DeviceInfo(*get_mock_info())
+ device = Device(info, timeout=1)
fake_key = "abcdefgh12345678"
- def fake_send(*args):
+ def fake_send(*args, **kwargs):
+ device.device_cipher = CipherV1(fake_key.encode())
device.handle_device_bound(fake_key)
device.ready.set()
+ send.side_effect = fake_send
- with patch.object(Device, "send", wraps=fake_send) as mock:
- await device.update_state()
- assert mock.call_count == 2
- assert device.device_key == fake_key
+ await device.request_version()
+ assert send.call_count == 2
+ assert device.device_cipher.key == fake_key
+
+
+@pytest.mark.asyncio
+async def test_device_bind_no_cipher(cipher, send):
+ """Check that the device handles late binding sequences."""
+ info = DeviceInfo(*get_mock_info())
+ device = Device(info, timeout=1)
+ fake_key = "abcdefgh12345678"
+
+ with pytest.raises(ValueError):
+ await device.bind(fake_key)
- device.power = True
- with patch.object(Device, "send"):
- await device.push_state_update()
- assert device.device_key == fake_key
+@pytest.mark.asyncio
+async def test_device_bind_no_device_info(cipher, send):
+ """Check that the device handles late binding sequences."""
+ device = Device(None, timeout=1)
+
+ with pytest.raises(DeviceNotBoundError):
+ await device.bind()
@pytest.mark.asyncio
-async def test_update_properties():
+async def test_update_properties(cipher, send):
"""Check that properties can be updates."""
device = await generate_device_mock_async()
for p in Props:
assert device.get_property(p) is None
- def fake_send(*args):
+ def fake_send(*args, **kwargs):
state = get_mock_state()
device.handle_state_update(**state)
+ send.side_effect = fake_send
- with patch.object(Device, "send", side_effect=fake_send) as mock:
- await device.update_state()
+ await device.update_state()
for p in Props:
assert device.get_property(p) is not None
@@ -308,36 +343,29 @@ def fake_send(*args):
@pytest.mark.asyncio
-async def test_update_properties_timeout():
+async def test_update_properties_timeout(cipher, send):
"""Check that timeouts are handled when properties are updates."""
device = await generate_device_mock_async()
- for p in Props:
- assert device.get_property(p) is None
-
+ send.side_effect = asyncio.TimeoutError
with pytest.raises(DeviceTimeoutError):
- with patch.object(Device, "send", side_effect=asyncio.TimeoutError):
- await device.update_state()
+ await device.update_state()
@pytest.mark.asyncio
-async def test_set_properties_not_dirty():
+async def test_set_properties_not_dirty(cipher, send):
"""Check that the state isn't pushed when properties unchanged."""
device = await generate_device_mock_async()
- with patch.object(Device, "send") as mock_request:
- await device.push_state_update()
- assert mock_request.call_count == 0
+ await device.push_state_update()
+ assert send.call_count == 0
@pytest.mark.asyncio
-async def test_set_properties():
+async def test_set_properties(cipher, send):
"""Check that state is pushed when properties are updated."""
device = await generate_device_mock_async()
- for p in Props:
- assert device.get_property(p) is None
-
device.power = True
device.mode = 1
device.temperature_units = 1
@@ -355,9 +383,8 @@ async def test_set_properties():
device.power_save = True
device.target_humidity = 30
- with patch.object(Device, "send") as mock_request:
- await device.push_state_update()
- mock_request.assert_called_once()
+ await device.push_state_update()
+ send.assert_called_once()
for p in Props:
if p not in (
@@ -375,13 +402,10 @@ async def test_set_properties():
@pytest.mark.asyncio
-async def test_set_properties_timeout():
+async def test_set_properties_timeout(cipher, send):
"""Check timeout handling when pushing state changes."""
device = await generate_device_mock_async()
- for p in Props:
- assert device.get_property(p) is None
-
device.power = True
device.mode = 1
device.temperature_units = 1
@@ -397,20 +421,20 @@ async def test_set_properties_timeout():
device.turbo = True
device.steady_heat = True
device.power_save = True
+
+ assert len(device._dirty)
+ send.reset_mock()
+ send.side_effect = [asyncio.TimeoutError, asyncio.TimeoutError, asyncio.TimeoutError]
with pytest.raises(DeviceTimeoutError):
- with patch.object(Device, "send", side_effect=asyncio.TimeoutError):
- await device.push_state_update()
+ await device.push_state_update()
@pytest.mark.asyncio
-async def test_uninitialized_properties():
+async def test_uninitialized_properties(cipher, send):
"""Check uninitialized property handling."""
device = await generate_device_mock_async()
- for p in Props:
- assert device.get_property(p) is None
-
assert not device.power
assert device.mode is None
assert device.target_temperature is None
@@ -431,19 +455,16 @@ async def test_uninitialized_properties():
@pytest.mark.asyncio
-async def test_update_current_temp_unsupported():
+async def test_update_current_temp_unsupported(cipher, send):
"""Check that properties can be updates."""
device = await generate_device_mock_async()
- for p in Props:
- assert device.get_property(p) is None
-
- def fake_send(*args):
+ def fake_send(*args, **kwargs):
state = get_mock_state_no_temperature()
device.handle_state_update(**state)
+ send.side_effect = fake_send
- with patch.object(Device, "send", wraps=fake_send) as mock:
- await device.update_state()
+ await device.update_state()
assert device.get_property(Props.TEMP_SENSOR) is None
assert device.current_temperature == device.target_temperature
@@ -458,19 +479,16 @@ def fake_send(*args):
(62, "362001061147+U-ZX6045RV1.01.bin"),
],
)
-async def test_update_current_temp_v3(temsen, hid):
+async def test_update_current_temp_v3(temsen, hid, cipher, send):
"""Check that properties can be updates."""
device = await generate_device_mock_async()
- for p in Props:
- assert device.get_property(p) is None
-
- def fake_send(*args):
+ def fake_send(*args, **kwargs):
device.handle_state_update(TemSen=temsen, hid=hid)
+ send.side_effect = fake_send
- with patch.object(Device, "send", wraps=fake_send):
- await device.update_state()
-
+ await device.update_state()
+
assert device.get_property(Props.TEMP_SENSOR) is not None
assert device.current_temperature == temsen - 40
@@ -485,97 +503,80 @@ def fake_send(*args):
(23, "362001061217+U-W04NV7.bin"),
],
)
-async def test_update_current_temp_v4(temsen, hid):
+async def test_update_current_temp_v4(temsen, hid, cipher, send):
"""Check that properties can be updates."""
device = await generate_device_mock_async()
- for p in Props:
- assert device.get_property(p) is None
-
- def fake_send(*args):
+ def fake_send(*args, **kwargs):
device.handle_state_update(TemSen=temsen, hid=hid)
+ send.side_effect = fake_send
- with patch.object(Device, "send", wraps=fake_send):
- await device.update_state()
+ await device.update_state()
assert device.get_property(Props.TEMP_SENSOR) is not None
assert device.current_temperature == temsen
@pytest.mark.asyncio
-async def test_update_current_temp_bad():
+async def test_update_current_temp_bad(cipher, send):
"""Check that properties can be updates."""
device = await generate_device_mock_async()
- for p in Props:
- assert device.get_property(p) is None
-
- def fake_send(*args):
+ def fake_send(*args, **kwargs):
device.handle_state_update(**get_mock_state_bad_temp())
+ send.side_effect = fake_send
- with patch.object(Device, "send", wraps=fake_send):
- await device.update_state()
+ await device.update_state()
assert device.current_temperature == get_mock_state_bad_temp()["TemSen"] - 40
@pytest.mark.asyncio
-async def test_update_current_temp_0C_v4():
+async def test_update_current_temp_0C_v4(cipher, send):
"""Check that properties can be updates."""
device = await generate_device_mock_async()
- for p in Props:
- assert device.get_property(p) is None
-
- def fake_send(*args):
+ def fake_send(*args, **kwargs):
device.handle_state_update(**get_mock_state_0c_v4_temp())
+ send.side_effect = fake_send
- with patch.object(Device, "send", wraps=fake_send):
- await device.update_state()
+ await device.update_state()
assert device.current_temperature == get_mock_state_0c_v4_temp()["TemSen"]
@pytest.mark.asyncio
-async def test_update_current_temp_0C_v3():
+async def test_update_current_temp_0C_v3(cipher, send):
"""Check for devices without a temperature sensor."""
device = await generate_device_mock_async()
- for p in Props:
- assert device.get_property(p) is None
-
- def fake_send(*args):
+ def fake_send(*args, **kwargs):
device.handle_state_update(**get_mock_state_0c_v3_temp())
+ send.side_effect = fake_send
- with patch.object(Device, "send", wraps=fake_send):
- await device.update_state()
+ await device.update_state()
assert device.current_temperature == device.target_temperature
@pytest.mark.asyncio
@pytest.mark.parametrize("temperature", [18, 19, 20, 21, 22])
-async def test_send_temperature_celsius(temperature):
+async def test_send_temperature_celsius(temperature, cipher, send):
"""Check that temperature is set and read properly in C."""
state = get_mock_state()
state["TemSen"] = temperature + 40
device = await generate_device_mock_async()
-
- for p in Props:
- assert device.get_property(p) is None
-
device.temperature_units = TemperatureUnits.C
device.target_temperature = temperature
- with patch.object(Device, "send") as mock_push:
- await device.push_state_update()
- assert mock_push.call_count == 1
+ await device.push_state_update()
+ assert send.call_count == 1
- def fake_send(*args):
+ def fake_send(*args, **kwargs):
device.handle_state_update(**state)
+ send.side_effect = fake_send
- with patch.object(Device, "send", wraps=fake_send):
- await device.update_state()
+ await device.update_state()
assert device.current_temperature == temperature
@@ -584,7 +585,7 @@ def fake_send(*args):
@pytest.mark.parametrize(
"temperature", [60, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 86]
)
-async def test_send_temperature_farenheit(temperature):
+async def test_send_temperature_farenheit(temperature, cipher, send):
"""Check that temperature is set and read properly in F."""
temSet = round((temperature - 32.0) * 5.0 / 9.0)
temRec = (int)((((temperature - 32.0) * 5.0 / 9.0) - temSet) > 0)
@@ -595,34 +596,27 @@ async def test_send_temperature_farenheit(temperature):
state["TemUn"] = 1
device = await generate_device_mock_async()
- for p in Props:
- assert device.get_property(p) is None
-
device.temperature_units = TemperatureUnits.F
device.target_temperature = temperature
- with patch.object(Device, "send") as mock_push:
- await device.push_state_update()
- assert mock_push.call_count == 1
+ await device.push_state_update()
+ assert send.call_count == 1
- def fake_send(*args):
+ def fake_send(*args, **kwargs):
device.handle_state_update(**state)
+ send.side_effect = fake_send
- with patch.object(Device, "send", wraps=fake_send):
- await device.update_state()
+ await device.update_state()
assert device.current_temperature == temperature
@pytest.mark.asyncio
@pytest.mark.parametrize("temperature", [-270, -61, 61, 100])
-async def test_send_temperature_out_of_range_celsius(temperature):
+async def test_send_temperature_out_of_range_celsius(temperature, cipher, send):
"""Check that bad temperatures raise the appropriate error."""
device = await generate_device_mock_async()
- for p in Props:
- assert device.get_property(p) is None
-
device.temperature_units = TemperatureUnits.C
with pytest.raises(ValueError):
device.target_temperature = temperature
@@ -630,7 +624,7 @@ async def test_send_temperature_out_of_range_celsius(temperature):
@pytest.mark.asyncio
@pytest.mark.parametrize("temperature", [-270, -61, 141])
-async def test_send_temperature_out_of_range_farenheit_set(temperature):
+async def test_send_temperature_out_of_range_farenheit_set(temperature, cipher, send):
"""Check that bad temperatures raise the appropriate error."""
device = await generate_device_mock_async()
@@ -644,13 +638,10 @@ async def test_send_temperature_out_of_range_farenheit_set(temperature):
@pytest.mark.asyncio
@pytest.mark.parametrize("temperature", [-270, 150])
-async def test_send_temperature_out_of_range_farenheit_get(temperature):
+async def test_send_temperature_out_of_range_farenheit_get(temperature, cipher, send):
"""Check that bad temperatures raise the appropriate error."""
device = await generate_device_mock_async()
- for p in Props:
- assert device.get_property(p) is None
-
device.set_property(Props.TEMP_SET, 20)
device.set_property(Props.TEMP_SENSOR, temperature)
device.set_property(Props.TEMP_BIT, 0)
@@ -661,25 +652,20 @@ async def test_send_temperature_out_of_range_farenheit_get(temperature):
@pytest.mark.asyncio
-async def test_enable_disable_sleep_mode():
+async def test_enable_disable_sleep_mode(cipher, send):
"""Check that properties can be updates."""
device = await generate_device_mock_async()
- for p in Props:
- assert device.get_property(p) is None
-
device.sleep = True
- with patch.object(Device, "send") as mock_push:
- await device.push_state_update()
- assert mock_push.call_count == 1
+ await device.push_state_update()
+ assert send.call_count == 1
assert device.get_property(Props.SLEEP) == 1
assert device.get_property(Props.SLEEP_MODE) == 1
device.sleep = False
- with patch.object(Device, "send") as mock_push:
- await device.push_state_update()
- assert mock_push.call_count == 1
+ await device.push_state_update()
+ assert send.call_count == 2
assert device.get_property(Props.SLEEP) == 0
assert device.get_property(Props.SLEEP_MODE) == 0
@@ -689,7 +675,7 @@ async def test_enable_disable_sleep_mode():
@pytest.mark.parametrize(
"temperature", [59, 77, 86]
)
-async def test_mismatch_temrec_farenheit(temperature):
+async def test_mismatch_temrec_farenheit(temperature, cipher, send):
"""Check that temperature is set and read properly in F."""
temSet = round((temperature - 32.0) * 5.0 / 9.0)
temRec = (int)((((temperature - 32.0) * 5.0 / 9.0) - temSet) > 0)
@@ -700,47 +686,51 @@ async def test_mismatch_temrec_farenheit(temperature):
state["TemRec"] = (temRec + 1) % 2
state["TemUn"] = 1
device = await generate_device_mock_async()
-
- for p in Props:
- assert device.get_property(p) is None
-
device.temperature_units = TemperatureUnits.F
device.target_temperature = temperature
- with patch.object(Device, "send") as mock_push:
- await device.push_state_update()
- assert mock_push.call_count == 1
+
+ await device.push_state_update()
+ assert send.call_count == 1
- def fake_send(*args):
+ def fake_send(*args, **kwargs):
device.handle_state_update(**state)
+ send.side_effect = None
- with patch.object(Device, "send", wraps=fake_send):
- await device.update_state()
+ await device.update_state()
assert device.current_temperature == temperature
@pytest.mark.asyncio
-async def test_device_equality():
+async def test_device_equality(send):
"""Check that two devices with the same info and key are equal."""
info1 = DeviceInfo(*get_mock_info())
device1 = Device(info1)
- await device1.bind(key="fake_key")
+ await device1.bind(key="fake_key", cipher=CipherV1())
info2 = DeviceInfo(*get_mock_info())
device2 = Device(info2)
- await device2.bind(key="fake_key")
+ await device2.bind(key="fake_key", cipher=CipherV1())
assert device1 == device2
# Change the key of the second device
- await device2.bind(key="another_fake_key")
+ await device2.bind(key="another_fake_key", cipher=CipherV1())
assert device1 != device2
# Change the info of the second device
info2 = DeviceInfo(*get_mock_info())
device2 = Device(info2)
device2.power = True
- await device2.bind(key="fake_key")
+ await device2.bind(key="fake_key", cipher=CipherV1())
assert device1 != device2
+
+def test_device_key_set_get():
+ """Check that the device key can be set and retrieved."""
+ device = Device(DeviceInfo(*get_mock_info()))
+ device.device_cipher = CipherV1()
+ device.device_key = "fake_key"
+ assert device.device_key == "fake_key"
+
\ No newline at end of file
diff --git a/tests/test_discovery.py b/tests/test_discovery.py
index 69784fa..d341ff1 100644
--- a/tests/test_discovery.py
+++ b/tests/test_discovery.py
@@ -1,24 +1,18 @@
import asyncio
import json
import socket
-from asyncio.tasks import wait_for
from threading import Thread
-from unittest.mock import MagicMock, PropertyMock, create_autospec, patch
+from unittest.mock import MagicMock, PropertyMock, patch
import pytest
-from greeclimate.device import DeviceInfo
from greeclimate.discovery import Discovery, Listener
-from greeclimate.network import DeviceProtocolBase2
-
from .common import (
DEFAULT_TIMEOUT,
DISCOVERY_REQUEST,
DISCOVERY_RESPONSE,
- DISCOVERY_RESPONSE_NO_CID,
Responder,
- encrypt_payload,
- get_mock_device_info,
+ get_mock_device_info, encrypt_payload,
)
diff --git a/tests/test_issues.py b/tests/test_issues.py
index 7c29438..ba08eb0 100644
--- a/tests/test_issues.py
+++ b/tests/test_issues.py
@@ -15,7 +15,7 @@ async def test_issue_69_TemSen_40_should_not_set_firmware_v4():
for p in Props:
assert device.get_property(p) is None
- def fake_send(*args):
+ def fake_send(*args, **kwargs):
device.handle_state_update(**mock_v3_state)
with patch.object(Device, "send", wraps=fake_send()):
@@ -35,7 +35,7 @@ async def test_issue_87_quiet_should_set_2():
assert device.get_property(Props.QUIET) is None
device.quiet = True
- def fake_send(*args):
+ def fake_send(*args, **kwargs):
device.handle_state_update(**mock_v3_state)
with patch.object(Device, "send", wraps=fake_send()) as mock:
diff --git a/tests/test_network.py b/tests/test_network.py
index 0e0dc0f..18c9ffa 100644
--- a/tests/test_network.py
+++ b/tests/test_network.py
@@ -2,8 +2,7 @@
import json
import socket
from threading import Thread
-from typing import Any
-from unittest.mock import create_autospec, patch, MagicMock
+from unittest.mock import patch, MagicMock
import pytest
@@ -14,15 +13,13 @@
IPAddr,
DeviceProtocol2, Commands, Response,
)
-
from .common import (
DEFAULT_RESPONSE,
DEFAULT_TIMEOUT,
DISCOVERY_REQUEST,
DISCOVERY_RESPONSE,
Responder,
- encrypt_payload,
- get_mock_device_info, DEFAULT_REQUEST, generate_response,
+ DEFAULT_REQUEST, generate_response, FakeCipher,
)
from .test_device import get_mock_info
@@ -44,6 +41,7 @@ class FakeDeviceProtocol(DeviceProtocol2):
def __init__(self, drained: asyncio.Event = None):
super().__init__(timeout=1, drained=drained)
self.packets = asyncio.Queue()
+ self.device_cipher = FakeCipher(b"1234567890123456")
def packet_received(self, obj, addr: IPAddr) -> None:
self.packets.put_nowait(obj)
@@ -90,8 +88,8 @@ async def test_set_get_key():
"""Test the encryption key property."""
key = "faketestkey"
dp2 = DeviceProtocolBase2()
- dp2.device_key = key
- assert dp2.device_key == key
+ dp2.device_cipher = FakeCipher(key.encode())
+ assert dp2.device_cipher.key == key
@pytest.mark.asyncio
@@ -151,7 +149,7 @@ def responder(s):
p = json.loads(d)
assert p == DISCOVERY_REQUEST
- p = json.dumps(encrypt_payload(DISCOVERY_RESPONSE))
+ p = json.dumps(DISCOVERY_RESPONSE)
s.sendto(p.encode(), addr)
serv = Thread(target=responder, args=(sock,))
@@ -164,6 +162,7 @@ def responder(s):
local_addr = (addr[0], 0)
dp2 = FakeDiscoveryProtocol()
+ dp2.device_cipher = FakeCipher(b"1234567890123456")
await loop.create_datagram_endpoint(
lambda: dp2,
local_addr=local_addr,
@@ -240,44 +239,30 @@ def responder(s):
lambda: FakeDeviceProtocol(drained=drained), remote_addr=remote_addr
)
- with patch("greeclimate.network.DeviceProtocolBase2.decrypt_payload", new_callable=MagicMock) as mock:
- mock.side_effect = lambda x, y: x
-
- # Send the scan command
- await protocol.send(DEFAULT_REQUEST, None)
+ # Send the scan command
+ await protocol.send(DEFAULT_REQUEST, None, FakeCipher(b"1234567890123456"))
- # Wait on the scan response
- task = asyncio.create_task(protocol.packets.get())
- await asyncio.wait_for(task, DEFAULT_TIMEOUT)
- response = task.result()
+ # Wait on the scan response
+ task = asyncio.create_task(protocol.packets.get())
+ await asyncio.wait_for(task, DEFAULT_TIMEOUT)
+ response = task.result()
- assert response == DEFAULT_RESPONSE
+ assert response == DEFAULT_RESPONSE
sock.close()
serv.join(timeout=DEFAULT_TIMEOUT)
-def test_encrypt_decrypt_payload():
- test_object = {"fake-key": "fake-value"}
-
- encrypted = DeviceProtocolBase2.encrypt_payload(test_object)
- assert encrypted != test_object
-
- decrypted = DeviceProtocolBase2.decrypt_payload(encrypted)
- assert decrypted == test_object
-
-
-@pytest.mark.asyncio
def test_bindok_handling():
"""Test the bindok response."""
response = generate_response({"t": "bindok", "key": "fake-key"})
protocol = DeviceProtocol2(timeout=DEFAULT_TIMEOUT)
- with patch("greeclimate.network.DeviceProtocolBase2.decrypt_payload", new_callable=MagicMock) as mock_decrypt:
- mock_decrypt.side_effect = lambda x, y: x
- with patch.object(DeviceProtocol2, "handle_device_bound") as mock:
- protocol.datagram_received(json.dumps(response).encode(), ("0.0.0.0", 0))
- assert mock.call_count == 1
- assert mock.call_args[0][0] == "fake-key"
+ protocol.device_cipher = FakeCipher(b"1234567890123456")
+
+ with patch.object(DeviceProtocol2, "handle_device_bound") as mock:
+ protocol.datagram_received(json.dumps(response).encode(), ("0.0.0.0", 0))
+ assert mock.call_count == 1
+ assert mock.call_args[0][0] == "fake-key"
def test_create_bind_message():
@@ -512,3 +497,78 @@ async def test_add_and_remove_handler(event_name, data):
# Check that the callback was not called this time
callback.assert_not_called()
+
+def test_packet_received_not_implemented():
+ # Arrange
+ protocol = DeviceProtocolBase2()
+
+ # Act
+ with pytest.raises(NotImplementedError):
+ protocol.packet_received({}, ("0.0.0.0", 0))
+
+
+def test_packet_received_invalid_data():
+ # Arrange
+ protocol = DeviceProtocol2()
+
+ # Act
+ protocol.packet_received(None, ("0.0.0.0", 0))
+ protocol.packet_received({}, ("0.0.0.0", 0))
+ protocol.packet_received({"pack"}, ("0.0.0.0", 0))
+
+ with patch.object(protocol, "handle_unknown_packet") as mock:
+ protocol.packet_received({"pack": {"t": "unknown"}}, ("0.0.0.0", 0))
+ mock.assert_called_once()
+
+
+def test_set_get_cipher():
+ # Arrange
+ protocol = DeviceProtocolBase2()
+ cipher = FakeCipher(b"1234567890123456")
+
+ # Act
+ protocol.device_cipher = cipher
+
+ # Assert
+ assert protocol.device_cipher == cipher
+
+
+def test_cipher_is_not_set():
+ # Arrange
+ protocol = DeviceProtocolBase2()
+
+ # Act
+ key = None
+ with pytest.raises(ValueError):
+ key = protocol.device_key
+
+ assert key is None
+
+ with pytest.raises(ValueError):
+ protocol.device_key = "fake-key"
+
+
+def test_add_invalid_handler():
+ # Arrange
+ protocol = DeviceProtocol2()
+ callback = MagicMock()
+
+ # Act
+ with pytest.raises(ValueError):
+ protocol.add_handler(Response("invalid"), callback)
+
+ with pytest.raises(ValueError):
+ protocol.add_handler(Response("invalid"), callback)
+
+
+def test_device_key_get_set():
+ # Arrange
+ protocol = DeviceProtocolBase2
+ key = "fake-key"
+
+ # Act
+ protocol.device_key = key
+
+ # Assert
+ assert protocol.device_key == key
+
\ No newline at end of file