Skip to content

Commit

Permalink
feat: improve Marshaller performance (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Sep 9, 2022
1 parent e386e22 commit a9e8866
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 92 deletions.
164 changes: 80 additions & 84 deletions src/dbus_fast/_private/marshaller.py
Original file line number Diff line number Diff line change
@@ -1,116 +1,59 @@
from struct import pack
from struct import Struct, error, pack
from typing import Any, Callable, Dict, List, Optional, Tuple

from ..signature import SignatureTree
from ..signature import SignatureTree, SignatureType, Variant

PACK_UINT32 = Struct("<I").pack


class Marshaller:
def __init__(self, signature, body):
def __init__(self, signature: str, body: Any) -> None:
self.signature_tree = SignatureTree._get(signature)
self.signature_tree.verify(body)
self.buffer = bytearray()
self.body = body

self.writers = {
"y": self.write_byte,
"b": self.write_boolean,
"n": self.write_int16,
"q": self.write_uint16,
"i": self.write_int32,
"u": self.write_uint32,
"x": self.write_int64,
"t": self.write_uint64,
"d": self.write_double,
"h": self.write_uint32,
"o": self.write_string,
"s": self.write_string,
"g": self.write_signature,
"a": self.write_array,
"(": self.write_struct,
"{": self.write_dict_entry,
"v": self.write_variant,
}

def align(self, n):
def align(self, n) -> int:
offset = n - len(self.buffer) % n
if offset == 0 or offset == n:
return 0
self.buffer.extend(bytes(offset))
return offset

def write_byte(self, byte, _=None):
self.buffer.append(byte)
return 1

def write_boolean(self, boolean, _=None):
if boolean:
return self.write_uint32(1)
else:
return self.write_uint32(0)

def write_int16(self, int16, _=None):
written = self.align(2)
self.buffer.extend(pack("<h", int16))
return written + 2

def write_uint16(self, uint16, _=None):
written = self.align(2)
self.buffer.extend(pack("<H", uint16))
return written + 2

def write_int32(self, int32, _):
written = self.align(4)
self.buffer.extend(pack("<i", int32))
return written + 4

def write_uint32(self, uint32, _=None):
written = self.align(4)
self.buffer.extend(pack("<I", uint32))
return written + 4

def write_int64(self, int64, _=None):
written = self.align(8)
self.buffer.extend(pack("<q", int64))
return written + 8
def write_boolean(self, boolean: bool, _=None) -> int:
self.buffer.extend(PACK_UINT32(int(boolean)))
return self.align(4) + 4

def write_uint64(self, uint64, _=None):
written = self.align(8)
self.buffer.extend(pack("<Q", uint64))
return written + 8

def write_double(self, double, _=None):
written = self.align(8)
self.buffer.extend(pack("<d", double))
return written + 8

def write_signature(self, signature, _=None):
def write_signature(self, signature: str, _=None) -> int:
signature = signature.encode()
signature_len = len(signature)
self.buffer.append(signature_len)
self.buffer.extend(signature)
self.buffer.append(0)
return signature_len + 2

def write_string(self, value, _=None):
def write_string(self, value: str, _=None) -> int:
value = value.encode()
value_len = len(value)
written = self.write_uint32(value_len)
written = self.align(4) + 4
self.buffer.extend(PACK_UINT32(value_len))
self.buffer.extend(value)
written += value_len
self.buffer.append(0)
written += 1
return written

def write_variant(self, variant, _=None):
def write_variant(self, variant: Variant, _=None) -> int:
written = self.write_signature(variant.signature)
written += self.write_single(variant.type, variant.value)
return written

def write_array(self, array, type_):
def write_array(self, array: Any, type_: SignatureType) -> int:
# TODO max array size is 64MiB (67108864 bytes)
written = self.align(4)
# length placeholder
offset = len(self.buffer)
written += self.write_uint32(0)
written += self.align(4) + 4
self.buffer.extend(PACK_UINT32(0))
child_type = type_.children[0]

if child_type.token in "xtd{(":
Expand All @@ -128,34 +71,87 @@ def write_array(self, array, type_):
for value in array:
array_len += self.write_single(child_type, value)

array_len_packed = pack("<I", array_len)
array_len_packed = PACK_UINT32(array_len)
for i in range(offset, offset + 4):
self.buffer[i] = array_len_packed[i - offset]

return written + array_len

def write_struct(self, array, type_):
def write_struct(self, array: List[Any], type_: SignatureType) -> int:
written = self.align(8)
for i, value in enumerate(array):
written += self.write_single(type_.children[i], value)
return written

def write_dict_entry(self, dict_entry, type_):
def write_dict_entry(self, dict_entry: List[Any], type_: SignatureType) -> int:
written = self.align(8)
written += self.write_single(type_.children[0], dict_entry[0])
written += self.write_single(type_.children[1], dict_entry[1])
return written

def write_single(self, type_, body):
def write_single(self, type_: SignatureType, body: Any) -> int:
t = type_.token

if t not in self.writers:
raise NotImplementedError(f'type isnt implemented yet: "{t}"')
if t not in self._writers:
raise NotImplementedError(f'type is not implemented yet: "{t}"')

return self.writers[t](body, type_)
writer, packer, size = self._writers[t]
if packer and size:
written = self.align(size)
self.buffer.extend(packer(body))
return written + size
return writer(self, body, type_)

def marshall(self):
"""Marshalls the body into a byte array"""
try:
self._construct_buffer()
except error:
self.signature_tree.verify(self.body)
return self.buffer

def _construct_buffer(self):
self.buffer.clear()
for i, type_ in enumerate(self.signature_tree.types):
self.write_single(type_, self.body[i])
return self.buffer
t = type_.token
if t not in self._writers:
raise NotImplementedError(f'type is not implemented yet: "{t}"')

writer, packer, size = self._writers[t]
if packer and size:

# In-line align
offset = size - len(self.buffer) % size
if offset != 0 and offset != size:
self.buffer.extend(bytes(offset))

self.buffer.extend(packer(self.body[i]))
else:
writer(self, self.body[i], type_)

_writers: Dict[
str,
Tuple[
Optional[Callable[[Any, Any], int]],
Optional[Callable[[Any], bytes]],
Optional[int],
],
] = {
"y": (None, Struct("<B").pack, 1),
"b": (write_boolean, None, None),
"n": (None, Struct("<h").pack, 2),
"q": (None, Struct("<H").pack, 2),
"i": (None, Struct("<i").pack, 4),
"u": (None, PACK_UINT32, 4),
"x": (None, Struct("<q").pack, 8),
"t": (None, Struct("<Q").pack, 8),
"d": (None, Struct("<d").pack, 8),
"h": (None, Struct("<I").pack, 4),
"o": (write_string, None, None),
"s": (write_string, None, None),
"g": (write_signature, None, None),
"a": (write_array, None, None),
"(": (write_struct, None, None),
"{": (write_dict_entry, None, None),
"v": (write_variant, None, None),
}
34 changes: 26 additions & 8 deletions src/dbus_fast/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@
MessageType.METHOD_RETURN: ("reply_serial",),
}

HEADER_PATH = HeaderField.PATH.value
HEADER_INTERFACE = HeaderField.INTERFACE.value
HEADER_MEMBER = HeaderField.MEMBER.value
HEADER_ERROR_NAME = HeaderField.ERROR_NAME.value
HEADER_REPLY_SERIAL = HeaderField.REPLY_SERIAL.value
HEADER_DESTINATION = HeaderField.DESTINATION.value
HEADER_SIGNATURE = HeaderField.SIGNATURE.value
HEADER_UNIX_FDS = HeaderField.UNIX_FDS.value


class Message:
"""A class for sending and receiving messages through the
Expand Down Expand Up @@ -242,27 +251,36 @@ def _marshall(self, negotiate_unix_fd=False):

fields = []

# No verify here since the marshaller will raise an exception if the
# Variant is invalid.

if self.path:
fields.append([HeaderField.PATH.value, Variant("o", self.path)])
fields.append([HEADER_PATH, Variant("o", self.path, verify=False)])
if self.interface:
fields.append([HeaderField.INTERFACE.value, Variant("s", self.interface)])
fields.append(
[HEADER_INTERFACE, Variant("s", self.interface, verify=False)]
)
if self.member:
fields.append([HeaderField.MEMBER.value, Variant("s", self.member)])
fields.append([HEADER_MEMBER, Variant("s", self.member, verify=False)])
if self.error_name:
fields.append([HeaderField.ERROR_NAME.value, Variant("s", self.error_name)])
fields.append(
[HEADER_ERROR_NAME, Variant("s", self.error_name, verify=False)]
)
if self.reply_serial:
fields.append(
[HeaderField.REPLY_SERIAL.value, Variant("u", self.reply_serial)]
[HEADER_REPLY_SERIAL, Variant("u", self.reply_serial, verify=False)]
)
if self.destination:
fields.append(
[HeaderField.DESTINATION.value, Variant("s", self.destination)]
[HEADER_DESTINATION, Variant("s", self.destination, verify=False)]
)
if self.signature:
fields.append([HeaderField.SIGNATURE.value, Variant("g", self.signature)])
fields.append(
[HEADER_SIGNATURE, Variant("g", self.signature, verify=False)]
)
if self.unix_fds and negotiate_unix_fd:
fields.append(
[HeaderField.UNIX_FDS.value, Variant("u", len(self.unix_fds))]
[HEADER_UNIX_FDS, Variant("u", len(self.unix_fds), verify=False)]
)

header_body = [
Expand Down

0 comments on commit a9e8866

Please sign in to comment.