Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up decrypting frames #944

Merged
merged 6 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions aioesphomeapi/_frame_helper/noise.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,21 @@ cdef unsigned int NOISE_STATE_READY
cdef unsigned int NOISE_STATE_CLOSED

cdef bytes NOISE_HELLO
cdef object PACK_NONCE

cdef class EncryptCipher:

cdef object _nonce
cdef object _encrypt

cdef bytes encrypt(self, object frame)

cdef class DecryptCipher:

cdef object _nonce
cdef object _decrypt

cdef bytes decrypt(self, object frame)

cdef class APINoiseFrameHelper(APIFrameHelper):

Expand All @@ -20,8 +35,8 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
cdef unsigned int _state
cdef object _server_name
cdef object _proto
cdef object _decrypt
cdef object _encrypt
cdef EncryptCipher _encrypt_cipher
cdef DecryptCipher _decrypt_cipher

@cython.locals(
header=bytes,
Expand Down Expand Up @@ -59,6 +74,7 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
@cython.locals(
type_="unsigned int",
data=bytes,
data_header=bytes,
packet=tuple,
data_len=cython.uint,
frame=bytes,
Expand Down
70 changes: 53 additions & 17 deletions aioesphomeapi/_frame_helper/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
from functools import partial
import logging
from struct import Struct
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any

from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable
from cryptography.exceptions import InvalidTag
from noise.backends.default import DefaultNoiseBackend # type: ignore[import-untyped]
from noise.backends.default.ciphers import ( # type: ignore[import-untyped]
ChaCha20Cipher,
CryptographyCipher,
)
from noise.connection import NoiseConnection # type: ignore[import-untyped]
from noise.state import CipherState # type: ignore[import-untyped]

from ..core import (
APIConnectionError,
Expand All @@ -30,6 +32,8 @@

PACK_NONCE = partial(Struct("<LQ").pack, 0)

_bytes = bytes


class ChaCha20CipherReuseable(ChaCha20Cipher): # type: ignore[misc]
"""ChaCha20 cipher that can be reused."""
Expand Down Expand Up @@ -68,6 +72,44 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
int_ = int


class EncryptCipher:
"""Wrapper around the ChaCha20Poly1305 cipher for encryption."""

__slots__ = ("_nonce", "_encrypt")

def __init__(self, cipher_state: CipherState) -> None:
"""Initialize the cipher wrapper."""
crypto_cipher: CryptographyCipher = cipher_state.cipher
cipher: ChaCha20Poly1305Reusable = crypto_cipher.cipher
self._nonce: int = cipher_state.n
self._encrypt = cipher.encrypt

def encrypt(self, data: _bytes) -> bytes:
"""Encrypt a frame."""
ciphertext = self._encrypt(PACK_NONCE(self._nonce), data, None)
self._nonce += 1
return ciphertext


class DecryptCipher:
"""Wrapper around the ChaCha20Poly1305 cipher for decryption."""

__slots__ = ("_nonce", "_decrypt")

def __init__(self, cipher_state: CipherState) -> None:
"""Initialize the cipher wrapper."""
crypto_cipher: CryptographyCipher = cipher_state.cipher
cipher: ChaCha20Poly1305Reusable = crypto_cipher.cipher
self._nonce: int = cipher_state.n
self._decrypt = cipher.decrypt

def decrypt(self, data: _bytes) -> bytes:
"""Decrypt a frame."""
plaintext = self._decrypt(PACK_NONCE(self._nonce), data, None)
self._nonce += 1
return plaintext


class APINoiseFrameHelper(APIFrameHelper):
"""Frame helper for noise encrypted connections."""

Expand All @@ -77,8 +119,8 @@ class APINoiseFrameHelper(APIFrameHelper):
"_state",
"_server_name",
"_proto",
"_decrypt",
"_encrypt",
"_encrypt_cipher",
"_decrypt_cipher",
)

def __init__(
Expand All @@ -95,8 +137,8 @@ def __init__(
self._expected_name = expected_name
self._state = NOISE_STATE_HELLO
self._server_name: str | None = None
self._decrypt: Callable[[bytes], bytes] | None = None
self._encrypt: Callable[[bytes], bytes] | None = None
self._encrypt_cipher: EncryptCipher | None = None
self._decrypt_cipher: DecryptCipher | None = None
self._setup_proto()

def close(self) -> None:
Expand Down Expand Up @@ -271,14 +313,8 @@ def _handle_handshake(self, msg: bytes) -> None:
self._proto.read_message(msg[1:])
self._state = NOISE_STATE_READY
noise_protocol = self._proto.noise_protocol
self._decrypt = partial(
noise_protocol.cipher_state_decrypt.decrypt_with_ad, # pylint: disable=no-member
None,
)
self._encrypt = partial(
noise_protocol.cipher_state_encrypt.encrypt_with_ad, # pylint: disable=no-member
None,
)
self._decrypt_cipher = DecryptCipher(noise_protocol.cipher_state_decrypt) # pylint: disable=no-member
self._encrypt_cipher = EncryptCipher(noise_protocol.cipher_state_encrypt) # pylint: disable=no-member
self.ready_future.set_result(None)

def write_packets(
Expand All @@ -289,7 +325,7 @@ def write_packets(
Packets are in the format of tuple[protobuf_type, protobuf_data]
"""
if TYPE_CHECKING:
assert self._encrypt is not None, "Handshake should be complete"
assert self._encrypt_cipher is not None, "Handshake should be complete"

out: list[bytes] = []
for packet in packets:
Expand All @@ -304,7 +340,7 @@ def write_packets(
data_len & 0xFF,
)
)
frame = self._encrypt(data_header + data)
frame = self._encrypt_cipher.encrypt(data_header + data)
frame_len = len(frame)
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
out.append(header)
Expand All @@ -315,8 +351,8 @@ def write_packets(
def _handle_frame(self, frame: bytes) -> None:
"""Handle an incoming frame."""
if TYPE_CHECKING:
assert self._decrypt is not None, "Handshake should be complete"
msg = self._decrypt(frame)
assert self._decrypt_cipher is not None, "Handshake should be complete"
msg = self._decrypt_cipher.decrypt(frame)
# Message layout is
# 2 bytes: message type
# 2 bytes: message length
Expand Down
Loading