From eca1d317818d2b938ec3ed3172b1be76a44a93a4 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 9 Sep 2022 09:58:12 -0500 Subject: [PATCH] feat: speed up unmarshaller (#1) --- bench/unmarshall.py | 22 + src/dbus_fast/_private/constants.py | 3 + src/dbus_fast/_private/unmarshaller.py | 547 +++++++++++++------------ src/dbus_fast/constants.py | 6 + src/dbus_fast/message.py | 26 +- src/dbus_fast/signature.py | 85 ++-- src/dbus_fast/validators.py | 5 + tests/test_marshaller.py | 57 ++- tests/test_validators.py | 5 +- 9 files changed, 430 insertions(+), 326 deletions(-) create mode 100644 bench/unmarshall.py diff --git a/bench/unmarshall.py b/bench/unmarshall.py new file mode 100644 index 00000000..aaaf4298 --- /dev/null +++ b/bench/unmarshall.py @@ -0,0 +1,22 @@ +import io +import timeit + +from dbus_fast._private.unmarshaller import Unmarshaller + +bluez_rssi_message = ( + "6c04010134000000e25389019500000001016f00250000002f6f72672f626c75657a2f686369302f6465" + "765f30385f33415f46325f31455f32425f3631000000020173001f0000006f72672e667265656465736b" + "746f702e444275732e50726f7065727469657300030173001100000050726f706572746965734368616e" + "67656400000000000000080167000873617b73767d617300000007017300040000003a312e3400000000" + "110000006f72672e626c75657a2e446576696365310000000e0000000000000004000000525353490001" + "6e00a7ff000000000000" +) + + +def unmarhsall_bluez_rssi_message(): + Unmarshaller(io.BytesIO(bytes.fromhex(bluez_rssi_message))).unmarshall() + + +count = 1000000 +time = timeit.Timer(unmarhsall_bluez_rssi_message).timeit(count) +print(f"Unmarshalling {count} bluetooth rssi messages took {time} seconds") diff --git a/src/dbus_fast/_private/constants.py b/src/dbus_fast/_private/constants.py index 605c3cf0..5ab7aefd 100644 --- a/src/dbus_fast/_private/constants.py +++ b/src/dbus_fast/_private/constants.py @@ -16,3 +16,6 @@ class HeaderField(Enum): SENDER = 7 SIGNATURE = 8 UNIX_FDS = 9 + + +HEADER_NAME_MAP = {field.value: field.name for field in HeaderField} diff --git a/src/dbus_fast/_private/unmarshaller.py b/src/dbus_fast/_private/unmarshaller.py index 9bcb011c..1c31367a 100644 --- a/src/dbus_fast/_private/unmarshaller.py +++ b/src/dbus_fast/_private/unmarshaller.py @@ -1,315 +1,348 @@ import array +import io import socket -from codecs import decode -from struct import unpack_from +import sys +from struct import Struct +from typing import Any, Callable, Dict, List, Optional, Tuple -from ..constants import MessageFlag, MessageType +from ..constants import MESSAGE_FLAG_MAP, MESSAGE_TYPE_MAP, MessageFlag, MessageType from ..errors import InvalidMessageError from ..message import Message -from ..signature import SignatureTree, Variant -from .constants import BIG_ENDIAN, LITTLE_ENDIAN, PROTOCOL_VERSION, HeaderField +from ..signature import SignatureTree, SignatureType, Variant +from .constants import ( + BIG_ENDIAN, + HEADER_NAME_MAP, + LITTLE_ENDIAN, + PROTOCOL_VERSION, + HeaderField, +) MAX_UNIX_FDS = 16 +UNPACK_SYMBOL = {LITTLE_ENDIAN: "<", BIG_ENDIAN: ">"} +UNPACK_LENGTHS = {BIG_ENDIAN: Struct(">III"), LITTLE_ENDIAN: Struct(" bytes: + """reads from the socket, storing any fds sent and handling errors + from the read itself""" + unix_fd_list = array.array("i") - self.readers = { - "y": self.read_byte, - "b": self.read_boolean, - "n": self.read_int16, - "q": self.read_uint16, - "i": self.read_int32, - "u": self.read_uint32, - "x": self.read_int64, - "t": self.read_uint64, - "d": self.read_double, - "h": self.read_uint32, - "o": self.read_string, - "s": self.read_string, - "g": self.read_signature, - "a": self.read_array, - "(": self.read_struct, - "{": self.read_dict_entry, - "v": self.read_variant, - } - - def read(self, n, prefetch=False): - """ - Read from underlying socket into buffer and advance offset accordingly. - - :arg n: - Number of bytes to read. If not enough bytes are available in the - buffer, read more from it. - :arg prefetch: - Do not update current offset after reading. + try: + msg, ancdata, *_ = self.sock.recvmsg( + length, socket.CMSG_LEN(MAX_UNIX_FDS * unix_fd_list.itemsize) + ) + except BlockingIOError: + raise MarshallerStreamEndError() + + for level, type_, data in ancdata: + if not (level == socket.SOL_SOCKET and type_ == socket.SCM_RIGHTS): + continue + unix_fd_list.frombytes( + data[: len(data) - (len(data) % unix_fd_list.itemsize)] + ) + self.unix_fds.extend(list(unix_fd_list)) - :returns: - Previous offset (before reading). To get the actual read bytes, - use the returned value and self.buf. - """ + return msg - def read_sock(length): - """reads from the socket, storing any fds sent and handling errors - from the read itself""" - if self.sock is not None: - unix_fd_list = array.array("i") - - try: - msg, ancdata, *_ = self.sock.recvmsg( - length, socket.CMSG_LEN(MAX_UNIX_FDS * unix_fd_list.itemsize) - ) - except BlockingIOError: - raise MarshallerStreamEndError() - - for level, type_, data in ancdata: - if not (level == socket.SOL_SOCKET and type_ == socket.SCM_RIGHTS): - continue - unix_fd_list.frombytes( - data[: len(data) - (len(data) % unix_fd_list.itemsize)] - ) - self.unix_fds.extend(list(unix_fd_list)) - - return msg - else: - return self.stream.read(length) - - # store previously read data in a buffer so we can resume on socket - # interruptions - missing_bytes = n - (len(self.buf) - self.offset) - if missing_bytes > 0: - data = read_sock(missing_bytes) - if data == b"": - raise EOFError() - elif data is None: - raise MarshallerStreamEndError() - self.buf.extend(data) - if len(data) != missing_bytes: - raise MarshallerStreamEndError() - prev = self.offset - if not prefetch: - self.offset += n - return prev - - @staticmethod - def _padding(offset, align): + def read_to_offset(self, offset: int) -> None: """ - Get padding bytes to get to the next align bytes mark. - - For any align value, the correct padding formula is: - - (align - (offset % align)) % align - - However, if align is a power of 2 (always the case here), the slow MOD - operator can be replaced by a bitwise AND: + Read from underlying socket into buffer. - (align - (offset & (align - 1))) & (align - 1) + Raises MarshallerStreamEndError if there is not enough data to be read. - Which can be simplified to: + :arg offset: + The offset to read to. If not enough bytes are available in the + buffer, read more from it. - (-offset) & (align - 1) + :returns: + None """ - return (-offset) & (align - 1) - - def align(self, n): - padding = self._padding(self.offset, n) - if padding > 0: - self.read(padding) - - def read_byte(self, _=None): - return self.buf[self.read(1)] - - def read_boolean(self, _=None): - data = self.read_uint32() - if data: - return True + start_len = len(self.buf) + missing_bytes = offset - (start_len - self.offset) + if self.sock is None: + data = self.stream.read(missing_bytes) else: - return False - - def read_int16(self, _=None): - return self.read_ctype("h", 2) - - def read_uint16(self, _=None): - return self.read_ctype("H", 2) + data = self.read_sock(missing_bytes) + if data == b"": + raise EOFError() + if data is None: + raise MarshallerStreamEndError() + self.buf.extend(data) + if len(data) + start_len != offset: + raise MarshallerStreamEndError() - def read_int32(self, _=None): - return self.read_ctype("i", 4) - - def read_uint32(self, _=None): - return self.read_ctype("I", 4) - - def read_int64(self, _=None): - return self.read_ctype("q", 8) - - def read_uint64(self, _=None): - return self.read_ctype("Q", 8) - - def read_double(self, _=None): - return self.read_ctype("d", 8) - - def read_ctype(self, fmt, size): - self.align(size) - if self.endian == LITTLE_ENDIAN: - fmt = "<" + fmt - else: - fmt = ">" + fmt - o = self.read(size) - return unpack_from(fmt, self.buf, o)[0] + def read_boolean(self, _=None): + return bool(self.read_argument(UINT32_SIGNATURE)) def read_string(self, _=None): - str_length = self.read_uint32() - o = self.read(str_length + 1) # read terminating '\0' byte as well - # avoid buffer copies when slicing - str_mem_slice = memoryview(self.buf)[o : o + str_length] - return decode(str_mem_slice) + str_length = self.read_argument(UINT32_SIGNATURE) + str_start = self.offset + # read terminating '\0' byte as well (str_length + 1) + self.offset += str_length + 1 + return self.buf[str_start : str_start + str_length].decode() def read_signature(self, _=None): - signature_len = self.read_byte() - o = self.read(signature_len + 1) # read terminating '\0' byte as well - # avoid buffer copies when slicing - sig_mem_slice = memoryview(self.buf)[o : o + signature_len] - return decode(sig_mem_slice) + signature_len = self.view[self.offset] # byte + o = self.offset + 1 + # read terminating '\0' byte as well (str_length + 1) + self.offset = o + signature_len + 1 + return self.buf[o : o + signature_len].decode() def read_variant(self, _=None): - signature = self.read_signature() - signature_tree = SignatureTree._get(signature) - value = self.read_argument(signature_tree.types[0]) - return Variant(signature_tree, value) - - def read_struct(self, type_): - self.align(8) - - result = [] - for child_type in type_.children: - result.append(self.read_argument(child_type)) - - return result - - def read_dict_entry(self, type_): - self.align(8) - - key = self.read_argument(type_.children[0]) - value = self.read_argument(type_.children[1]) - - return key, value + tree = SignatureTree._get(self.read_signature()) + # verify in Variant is only useful on construction not unmarshalling + return Variant(tree, self.read_argument(tree.types[0]), verify=False) + + def read_struct(self, type_: SignatureType): + self.offset += -self.offset & 7 # align 8 + return [self.read_argument(child_type) for child_type in type_.children] + + def read_dict_entry(self, type_: SignatureType): + self.offset += -self.offset & 7 # align 8 + return self.read_argument(type_.children[0]), self.read_argument( + type_.children[1] + ) - def read_array(self, type_): - self.align(4) - array_length = self.read_uint32() + def read_array(self, type_: SignatureType): + self.offset += -self.offset & 3 # align 4 for the array + array_length = self.read_argument(UINT32_SIGNATURE) child_type = type_.children[0] if child_type.token in "xtd{(": # the first alignment is not included in the array size - self.align(8) + self.offset += -self.offset & 7 # align 8 + + if child_type.token == "y": + self.offset += array_length + return self.buf[self.offset - array_length : self.offset] beginning_offset = self.offset - result = None if child_type.token == "{": - result = {} + result_dict = {} while self.offset - beginning_offset < array_length: key, value = self.read_dict_entry(child_type) - result[key] = value - elif child_type.token == "y": - o = self.read(array_length) - # avoid buffer copies when slicing - array_mem_slice = memoryview(self.buf)[o : o + array_length] - result = array_mem_slice.tobytes() - else: - result = [] - while self.offset - beginning_offset < array_length: - result.append(self.read_argument(child_type)) - - return result - - def read_argument(self, type_): - t = type_.token - - if t not in self.readers: - raise Exception(f'dont know how to read yet: "{t}"') - - return self.readers[t](type_) - - def _unmarshall(self): - self.offset = 0 - self.read(16, prefetch=True) - self.endian = self.read_byte() - if self.endian != LITTLE_ENDIAN and self.endian != BIG_ENDIAN: - raise InvalidMessageError("Expecting endianness as the first byte") - message_type = MessageType(self.read_byte()) - flags = MessageFlag(self.read_byte()) - - protocol_version = self.read_byte() - + result_dict[key] = value + return result_dict + + result_list = [] + while self.offset - beginning_offset < array_length: + result_list.append(self.read_argument(child_type)) + return result_list + + def read_argument(self, type_: SignatureType) -> Any: + """Dispatch to an argument reader or cast/unpack a C type.""" + token = type_.token + reader, ctype, size, struct = self.readers[token] + if reader: # complex type + return reader(self, type_) + self.offset += size + (-self.offset & (size - 1)) # align + if self.can_cast: + return self.view[self.offset - size : self.offset].cast(ctype)[0] + return struct.unpack_from(self.view, self.offset - size)[0] + + def header_fields(self, header_length): + """Header fields are always a(yv).""" + beginning_offset = self.offset + headers = {} + while self.offset - beginning_offset < header_length: + # Now read the y (byte) of struct (yv) + self.offset += (-self.offset & 7) + 1 # align 8 + 1 for 'y' byte + field_0 = self.view[self.offset - 1] + + # Now read the v (variant) of struct (yv) + signature_len = self.view[self.offset] # byte + o = self.offset + 1 + self.offset += signature_len + 2 # one for the byte, one for the '\0' + tree = SignatureTree._get(self.buf[o : o + signature_len].decode()) + headers[HEADER_NAME_MAP[field_0]] = self.read_argument(tree.types[0]) + return headers + + def _read_header(self): + """Read the header of the message.""" + # Signature is of the header is + # BYTE, BYTE, BYTE, BYTE, UINT32, UINT32, ARRAY of STRUCT of (BYTE,VARIANT) + self.read_to_offset(HEADER_SIGNATURE_SIZE) + buffer = self.buf + endian = buffer[0] + self.message_type = MESSAGE_TYPE_MAP[buffer[1]] + self.flag = MESSAGE_FLAG_MAP[buffer[2]] + protocol_version = buffer[3] + + if endian != LITTLE_ENDIAN and endian != BIG_ENDIAN: + raise InvalidMessageError( + f"Expecting endianness as the first byte, got {endian} from {buffer}" + ) if protocol_version != PROTOCOL_VERSION: raise InvalidMessageError( f"got unknown protocol version: {protocol_version}" ) - body_len = self.read_uint32() - serial = self.read_uint32() - - header_len = self.read_uint32() - msg_len = header_len + self._padding(header_len, 8) + body_len - self.read(msg_len, prefetch=True) - # backtrack offset since header array length needs to be read again - self.offset -= 4 - - header_fields = {} - for field_struct in self.read_argument(SignatureTree._get("a(yv)").types[0]): - field = HeaderField(field_struct[0]) - header_fields[field.name] = field_struct[1].value - - self.align(8) - - path = header_fields.get(HeaderField.PATH.name) - interface = header_fields.get(HeaderField.INTERFACE.name) - member = header_fields.get(HeaderField.MEMBER.name) - error_name = header_fields.get(HeaderField.ERROR_NAME.name) - reply_serial = header_fields.get(HeaderField.REPLY_SERIAL.name) - destination = header_fields.get(HeaderField.DESTINATION.name) - sender = header_fields.get(HeaderField.SENDER.name) - signature = header_fields.get(HeaderField.SIGNATURE.name, "") - signature_tree = SignatureTree._get(signature) - # unix_fds = header_fields.get(HeaderField.UNIX_FDS.name, 0) - - body = [] - - if body_len: - for type_ in signature_tree.types: - body.append(self.read_argument(type_)) - + self.body_len, self.serial, self.header_len = UNPACK_LENGTHS[ + endian + ].unpack_from(buffer, 4) + self.msg_len = ( + self.header_len + (-self.header_len & 7) + self.body_len + ) # align 8 + if (sys.byteorder == "little" and endian == LITTLE_ENDIAN) or ( + sys.byteorder == "big" and endian == BIG_ENDIAN + ): + self.can_cast = True + self.readers = self._readers_by_type[endian] + + def _read_body(self): + """Read the body of the message.""" + self.read_to_offset(HEADER_SIGNATURE_SIZE + self.msg_len) + self.view = memoryview(self.buf) + self.offset = HEADER_ARRAY_OF_STRUCT_SIGNATURE_POSITION + header_fields = self.header_fields(self.header_len) + self.offset += -self.offset & 7 # align 8 + tree = SignatureTree._get(header_fields.get(HeaderField.SIGNATURE.name, "")) self.message = Message( - destination=destination, - path=path, - interface=interface, - member=member, - message_type=message_type, - flags=flags, - error_name=error_name, - reply_serial=reply_serial, - sender=sender, + destination=header_fields.get(HEADER_DESTINATION), + path=header_fields.get(HEADER_PATH), + interface=header_fields.get(HEADER_INTERFACE), + member=header_fields.get(HEADER_MEMBER), + message_type=self.message_type, + flags=self.flag, + error_name=header_fields.get(HEADER_ERROR_NAME), + reply_serial=header_fields.get(HEADER_REPLY_SERIAL), + sender=header_fields.get(HEADER_SENDER), unix_fds=self.unix_fds, - signature=signature_tree, - body=body, - serial=serial, + signature=tree.signature, + body=[self.read_argument(t) for t in tree.types] if self.body_len else [], + serial=self.serial, ) def unmarshall(self): + """Unmarshall the message. + + The underlying read function will raise MarshallerStreamEndError + if there are not enough bytes in the buffer. This allows unmarshall + to be resumed when more data comes in over the wire. + """ try: - self._unmarshall() - return self.message + if not self.message_type: + self._read_header() + self._read_body() except MarshallerStreamEndError: return None + return self.message + + _complex_parsers: Dict[ + str, Tuple[Callable[["Unmarshaller", SignatureType], Any], None, None, None] + ] = { + "b": (read_boolean, None, None, None), + "o": (read_string, None, None, None), + "s": (read_string, None, None, None), + "g": (read_signature, None, None, None), + "a": (read_array, None, None, None), + "(": (read_struct, None, None, None), + "{": (read_dict_entry, None, None, None), + "v": (read_variant, None, None, None), + } + + _ctype_by_endian: Dict[int, Dict[str, Tuple[None, str, int, Struct]]] = { + endian: { + dbus_type: ( + None, + *ctype_size, + Struct(f"{UNPACK_SYMBOL[endian]}{ctype_size[0]}"), + ) + for dbus_type, ctype_size in DBUS_TO_CTYPE.items() + } + for endian in (BIG_ENDIAN, LITTLE_ENDIAN) + } + + _readers_by_type: Dict[int, READER_TYPE] = { + BIG_ENDIAN: {**_ctype_by_endian[BIG_ENDIAN], **_complex_parsers}, + LITTLE_ENDIAN: {**_ctype_by_endian[LITTLE_ENDIAN], **_complex_parsers}, + } diff --git a/src/dbus_fast/constants.py b/src/dbus_fast/constants.py index a8869b00..9aa62fdc 100644 --- a/src/dbus_fast/constants.py +++ b/src/dbus_fast/constants.py @@ -19,6 +19,9 @@ class MessageType(Enum): SIGNAL = 4 #: A broadcast signal to subscribed connections +MESSAGE_TYPE_MAP = {field.value: field for field in MessageType} + + class MessageFlag(IntFlag): """Flags that affect the behavior of sent and received messages""" @@ -28,6 +31,9 @@ class MessageFlag(IntFlag): ALLOW_INTERACTIVE_AUTHORIZATION = 4 +MESSAGE_FLAG_MAP = {field.value: field for field in MessageFlag} + + class NameFlag(IntFlag): """A flag that affects the behavior of a name request.""" diff --git a/src/dbus_fast/message.py b/src/dbus_fast/message.py index 6413d7e0..7c05f9c9 100644 --- a/src/dbus_fast/message.py +++ b/src/dbus_fast/message.py @@ -12,6 +12,13 @@ assert_object_path_valid, ) +REQUIRED_FIELDS = { + MessageType.METHOD_CALL: ("path", "member"), + MessageType.SIGNAL: ("path", "member", "interface"), + MessageType.ERROR: ("error_name", "reply_serial"), + MessageType.METHOD_RETURN: ("reply_serial",), +} + class Message: """A class for sending and receiving messages through the @@ -112,21 +119,12 @@ def __init__( if self.error_name is not None: assert_interface_name_valid(self.error_name) - def require_fields(*fields): - for field in fields: - if not getattr(self, field): - raise InvalidMessageError(f"missing required field: {field}") - - if self.message_type == MessageType.METHOD_CALL: - require_fields("path", "member") - elif self.message_type == MessageType.SIGNAL: - require_fields("path", "member", "interface") - elif self.message_type == MessageType.ERROR: - require_fields("error_name", "reply_serial") - elif self.message_type == MessageType.METHOD_RETURN: - require_fields("reply_serial") - else: + required_fields = REQUIRED_FIELDS.get(self.message_type) + if not required_fields: raise InvalidMessageError(f"got unknown message type: {self.message_type}") + for field in required_fields: + if not getattr(self, field): + raise InvalidMessageError(f"missing required field: {field}") @staticmethod def new_error(msg: "Message", error_name: str, error_text: str) -> "Message": diff --git a/src/dbus_fast/signature.py b/src/dbus_fast/signature.py index aea38ea9..1fd7b006 100644 --- a/src/dbus_fast/signature.py +++ b/src/dbus_fast/signature.py @@ -1,3 +1,4 @@ +from functools import lru_cache from typing import Any, List, Union from .errors import InvalidSignatureError, SignatureBodyMismatchError @@ -21,9 +22,9 @@ class to parse signatures. _tokens = "ybnqiuxtdsogavh({" - def __init__(self, token): + def __init__(self, token: str) -> None: self.token = token - self.children = [] + self.children: List[SignatureType] = [] self._signature = None def __eq__(self, other): @@ -240,7 +241,7 @@ def _verify_array(self, body): child_type.children[0].verify(key) child_type.children[1].verify(value) elif child_type.token == "y": - if not isinstance(body, bytes): + if not isinstance(body, (bytearray, bytes)): raise SignatureBodyMismatchError( f'DBus ARRAY type "a" with BYTE child must be Python type "bytes", got {type(body)}' ) @@ -284,43 +285,33 @@ def verify(self, body: Any) -> bool: """ if body is None: raise SignatureBodyMismatchError('Cannot serialize Python type "None"') - elif self.token == "y": - self._verify_byte(body) - elif self.token == "b": - self._verify_boolean(body) - elif self.token == "n": - self._verify_int16(body) - elif self.token == "q": - self._verify_uint16(body) - elif self.token == "i": - self._verify_int32(body) - elif self.token == "u": - self._verify_uint32(body) - elif self.token == "x": - self._verify_int64(body) - elif self.token == "t": - self._verify_uint64(body) - elif self.token == "d": - self._verify_double(body) - elif self.token == "h": - self._verify_unix_fd(body) - elif self.token == "o": - self._verify_object_path(body) - elif self.token == "s": - self._verify_string(body) - elif self.token == "g": - self._verify_signature(body) - elif self.token == "a": - self._verify_array(body) - elif self.token == "(": - self._verify_struct(body) - elif self.token == "v": - self._verify_variant(body) + validator = self.validators.get(self.token) + if validator: + validator(self, body) else: raise Exception(f"cannot verify type with token {self.token}") return True + validators = { + "y": _verify_byte, + "b": _verify_boolean, + "n": _verify_int16, + "q": _verify_uint16, + "i": _verify_int32, + "u": _verify_uint32, + "x": _verify_int64, + "t": _verify_uint64, + "d": _verify_double, + "h": _verify_uint32, + "o": _verify_string, + "s": _verify_string, + "g": _verify_signature, + "a": _verify_array, + "(": _verify_struct, + "v": _verify_variant, + } + class SignatureTree: """A class that represents a signature as a tree structure for conveniently @@ -338,19 +329,15 @@ class SignatureTree: :class:`InvalidSignatureError` if the given signature is not valid. """ - _cache = {} - @staticmethod - def _get(signature: str = ""): - if signature in SignatureTree._cache: - return SignatureTree._cache[signature] - SignatureTree._cache[signature] = SignatureTree(signature) - return SignatureTree._cache[signature] + @lru_cache(maxsize=None) + def _get(signature: str = "") -> "SignatureTree": + return SignatureTree(signature) def __init__(self, signature: str = ""): self.signature = signature - self.types = [] + self.types: List[SignatureType] = [] if len(signature) > 0xFF: raise InvalidSignatureError("A signature must be less than 256 characters") @@ -411,7 +398,12 @@ class Variant: :class:`SignatureBodyMismatchError` if the signature does not match the body. """ - def __init__(self, signature: Union[str, SignatureTree, SignatureType], value: Any): + def __init__( + self, + signature: Union[str, SignatureTree, SignatureType], + value: Any, + verify: bool = True, + ): signature_str = "" signature_tree = None signature_type = None @@ -429,14 +421,15 @@ def __init__(self, signature: Union[str, SignatureTree, SignatureType], value: A ) if signature_tree: - if len(signature_tree.types) != 1: + if verify and len(signature_tree.types) != 1: raise ValueError( "variants must have a signature for a single complete type" ) signature_str = signature_tree.signature signature_type = signature_tree.types[0] - signature_type.verify(value) + if verify: + signature_type.verify(value) self.type = signature_type self.signature = signature_str diff --git a/src/dbus_fast/validators.py b/src/dbus_fast/validators.py index d0c76ec7..6ff7aae2 100644 --- a/src/dbus_fast/validators.py +++ b/src/dbus_fast/validators.py @@ -1,4 +1,5 @@ import re +from functools import lru_cache from .errors import ( InvalidBusNameError, @@ -13,6 +14,7 @@ _member_re = re.compile(r"^[A-Za-z_][A-Za-z0-9_-]*$") +@lru_cache(maxsize=32) def is_bus_name_valid(name: str) -> bool: """Whether this is a valid bus name. @@ -47,6 +49,7 @@ def is_bus_name_valid(name: str) -> bool: return True +@lru_cache(maxsize=512) def is_object_path_valid(path: str) -> bool: """Whether this is a valid object path. @@ -77,6 +80,7 @@ def is_object_path_valid(path: str) -> bool: return True +@lru_cache(maxsize=32) def is_interface_name_valid(name: str) -> bool: """Whether this is a valid interface name. @@ -107,6 +111,7 @@ def is_interface_name_valid(name: str) -> bool: return True +@lru_cache(maxsize=512) def is_member_name_valid(member: str) -> bool: """Whether this is a valid member name. diff --git a/tests/test_marshaller.py b/tests/test_marshaller.py index 14e3bef6..ab3116f5 100644 --- a/tests/test_marshaller.py +++ b/tests/test_marshaller.py @@ -1,8 +1,11 @@ import io import json import os +from typing import Any, Dict -from dbus_fast import Message, SignatureTree, Variant +import pytest + +from dbus_fast import Message, MessageFlag, MessageType, SignatureTree, Variant from dbus_fast._private.unmarshaller import Unmarshaller @@ -20,6 +23,16 @@ def print_buf(buf): table = json.load(open(os.path.dirname(__file__) + "/data/messages.json")) +def json_to_message(message: Dict[str, Any]) -> Message: + copy = dict(message) + if "message_type" in copy: + copy["message_type"] = MessageType(copy["message_type"]) + if "flags" in copy: + copy["flags"] = MessageFlag(copy["flags"]) + + return Message(**copy) + + # variants are an object in the json def replace_variants(type_, item): if type_.token == "v" and type(item) is not Variant: @@ -56,7 +69,7 @@ def dumper(obj): def test_marshalling_with_table(): for item in table: - message = Message(**item["message"]) + message = json_to_message(item["message"]) body = [] for i, type_ in enumerate(message.signature_tree.types): @@ -79,8 +92,9 @@ def test_marshalling_with_table(): assert buf == data -def test_unmarshalling_with_table(): - for item in table: +@pytest.mark.parametrize("unmarshall_table", (table,)) +def test_unmarshalling_with_table(unmarshall_table): + for item in unmarshall_table: stream = io.BytesIO(bytes.fromhex(item["data"])) unmarshaller = Unmarshaller(stream) @@ -91,7 +105,7 @@ def test_unmarshalling_with_table(): print(json_dump(item["message"])) raise e - message = Message(**item["message"]) + message = json_to_message(item["message"]) body = [] for i, type_ in enumerate(message.signature_tree.types): @@ -114,6 +128,39 @@ def test_unmarshalling_with_table(): ), f"attr doesnt match: {attr}" +def test_unmarshall_can_resume(): + """Verify resume works.""" + bluez_rssi_message = ( + "6c04010134000000e25389019500000001016f00250000002f6f72672f626c75657a2f686369302f6465" + "765f30385f33415f46325f31455f32425f3631000000020173001f0000006f72672e667265656465736b" + "746f702e444275732e50726f7065727469657300030173001100000050726f706572746965734368616e" + "67656400000000000000080167000873617b73767d617300000007017300040000003a312e3400000000" + "110000006f72672e626c75657a2e446576696365310000000e0000000000000004000000525353490001" + "6e00a7ff000000000000" + ) + message_bytes = bytes.fromhex(bluez_rssi_message) + + class SlowStream(io.IOBase): + """A fake stream that will only give us one byte at a time.""" + + def __init__(self): + self.data = message_bytes + self.pos = 0 + + def read(self, n) -> bytes: + data = self.data[self.pos : self.pos + 1] + self.pos += 1 + return data + + stream = SlowStream() + unmarshaller = Unmarshaller(stream) + + for _ in range(len(bluez_rssi_message)): + if unmarshaller.unmarshall(): + break + assert unmarshaller.message is not None + + def test_ay_buffer(): body = [bytes(10000)] msg = Message(path="/test", member="test", signature="ay", body=body) diff --git a/tests/test_validators.py b/tests/test_validators.py index e184c09d..f7bfbbf8 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -10,7 +10,6 @@ def test_object_path_validator(): valid_paths = ["/", "/foo", "/foo/bar", "/foo/bar/bat"] invalid_paths = [ None, - {}, "", "foo", "foo/bar", @@ -37,7 +36,6 @@ def test_bus_name_validator(): ] invalid_names = [ None, - {}, "", "5foo.bar", "foo.6bar", @@ -57,7 +55,6 @@ def test_interface_name_validator(): valid_names = ["foo.bar", "foo.bar.bat", "_foo._bar", "foo.bar69"] invalid_names = [ None, - {}, "", "5foo.bar", "foo.6bar", @@ -80,7 +77,7 @@ def test_interface_name_validator(): def test_member_name_validator(): valid_members = ["foo", "FooBar", "Bat_Baz69", "foo-bar"] - invalid_members = [None, {}, "", "foo.bar", "5foo", "foo$bar"] + invalid_members = [None, "", "foo.bar", "5foo", "foo$bar"] for member in valid_members: assert is_member_name_valid(member), f'member name should be valid: "{member}"'