From fb2172b7f88e3c64fb54bd4dfe1a46c7eaad8c61 Mon Sep 17 00:00:00 2001 From: penguin_wwy <940375606@qq.com> Date: Mon, 18 Nov 2024 10:35:57 +0800 Subject: [PATCH] feat(python): Implement collection serialization protocol (#1942) ## What does this PR do? Implement a new format for collection serialization in pyfury. ## Related issues ## Does this PR introduce any user-facing change? - [ ] Does this PR introduce any public API change? - [ ] Does this PR introduce any binary protocol compatibility change? ## Benchmark ``` fury_tuple: Mean +- std dev: [base] 259 us +- 6 us -> [collection] 256 us +- 5 us: 1.01x faster fury_large_tuple: Mean +- std dev: [base] 92.7 ms +- 5.5 ms -> [collection] 63.7 ms +- 4.8 ms: 1.46x faster fury_list: Mean +- std dev: [base] 277 us +- 6 us -> [collection] 267 us +- 3 us: 1.04x faster fury_large_list: Mean +- std dev: [base] 92.8 ms +- 5.3 ms -> [collection] 66.5 ms +- 3.0 ms: 1.40x faster Geometric mean: 1.21x faster ``` --- python/pyfury/_serialization.pyx | 301 ++++++++++++++++++++++++++----- 1 file changed, 253 insertions(+), 48 deletions(-) diff --git a/python/pyfury/_serialization.pyx b/python/pyfury/_serialization.pyx index 1e2280acd1..0175f6151a 100644 --- a/python/pyfury/_serialization.pyx +++ b/python/pyfury/_serialization.pyx @@ -69,6 +69,14 @@ logger = logging.getLogger(__name__) ENABLE_FURY_CYTHON_SERIALIZATION = os.environ.get( "ENABLE_FURY_CYTHON_SERIALIZATION", "True").lower() in ("true", "1") +cdef extern from *: + """ + #define int2obj(obj_addr) ((PyObject *)(obj_addr)) + #define obj2int(obj_ref) (Py_INCREF(obj_ref), ((int64_t)(obj_ref))) + """ + object int2obj(int64_t obj_addr) + int64_t obj2int(object obj_ref) + cdef int8_t NULL_FLAG = -3 # This flag indicates that object is a not-null value. @@ -1630,6 +1638,18 @@ cdef class BytesSerializer(CrossLanguageCompatibleSerializer): return fury_buf.to_pybytes() +""" +Collection serialization format: +https://fury.apache.org/docs/specification/fury_xlang_serialization_spec/#list +Has the following changes: +* None has an independent NonType type, so COLLECTION_NOT_SAME_TYPE can also cover the concept of being nullable. +* No flag is needed to indicate that the element type is not the declared type. +""" +cdef int8_t COLLECTION_DEFAULT_FLAG = 0b0 +cdef int8_t COLLECTION_TRACKING_REF = 0b1 +cdef int8_t COLLECTION_NOT_SAME_TYPE = 0b1000 + + cdef class CollectionSerializer(Serializer): cdef ClassResolver class_resolver cdef MapRefResolver ref_resolver @@ -1644,29 +1664,143 @@ cdef class CollectionSerializer(Serializer): cpdef int16_t get_xtype_id(self): return -FuryType.LIST.value + cdef pair[int8_t, int64_t] write_header(self, Buffer buffer, value): + cdef int8_t collect_flag = COLLECTION_DEFAULT_FLAG + elem_type = type(next(iter(value))) + for s in value: + if type(s) is not elem_type: + collect_flag |= COLLECTION_NOT_SAME_TYPE + break + if self.fury.ref_tracking: + collect_flag |= COLLECTION_TRACKING_REF + buffer.write_varint64((len(value) << 4) | collect_flag) + return pair[int8_t, int64_t](collect_flag, obj2int(elem_type)) + cpdef write(self, Buffer buffer, value): - buffer.write_varint32(len(value)) + if len(value) == 0: + buffer.write_varint64(0) + return + cdef pair[int8_t, int64_t] header_pair = self.write_header(buffer, value) + cdef int8_t collect_flag = header_pair.first + cdef int64_t elem_type_ptr = header_pair.second + cdef elem_type = int2obj(elem_type_ptr) cdef MapRefResolver ref_resolver = self.ref_resolver cdef ClassResolver class_resolver = self.class_resolver + if (collect_flag & COLLECTION_NOT_SAME_TYPE) == 0: + if elem_type is str: + self._write_string(buffer, value) + elif elem_type is int: + self._write_int(buffer, value) + elif elem_type is bool: + self._write_bool(buffer, value) + elif elem_type is float: + self._write_float(buffer, value) + else: + if (collect_flag & COLLECTION_TRACKING_REF) == 0: + self._write_same_type_no_ref(buffer, value, elem_type) + else: + self._write_same_type_ref(buffer, value, elem_type) + else: + for s in value: + cls = type(s) + if cls is str: + buffer.write_int16(NOT_NULL_STRING_FLAG) + buffer.write_string(s) + elif cls is int: + buffer.write_int16(NOT_NULL_PYINT_FLAG) + buffer.write_varint64(s) + elif cls is bool: + buffer.write_int16(NOT_NULL_PYBOOL_FLAG) + buffer.write_bool(s) + elif cls is float: + buffer.write_int16(NOT_NULL_PYFLOAT_FLAG) + buffer.write_double(s) + else: + if not ref_resolver.write_ref_or_null(buffer, s): + classinfo = class_resolver.get_or_create_classinfo(cls) + class_resolver.write_classinfo(buffer, classinfo) + classinfo.serializer.write(buffer, s) + + cdef inline _write_string(self, Buffer buffer, value): + buffer.write_int16(NOT_NULL_STRING_FLAG) for s in value: - cls = type(s) - if cls is str: - buffer.write_int16(NOT_NULL_STRING_FLAG) - buffer.write_string(s) - elif cls is int: - buffer.write_int16(NOT_NULL_PYINT_FLAG) - buffer.write_varint64(s) - elif cls is bool: - buffer.write_int16(NOT_NULL_PYBOOL_FLAG) - buffer.write_bool(s) - elif cls is float: - buffer.write_int16(NOT_NULL_PYFLOAT_FLAG) - buffer.write_double(s) + buffer.write_string(s) + + cdef inline _read_string(self, Buffer buffer, int64_t len_, object collection_): + assert buffer.read_int16() == NOT_NULL_STRING_FLAG + for i in range(len_): + self._add_element(collection_, i, buffer.read_string()) + + cdef inline _write_int(self, Buffer buffer, value): + buffer.write_int16(NOT_NULL_PYINT_FLAG) + for s in value: + buffer.write_varint64(s) + + cdef inline _read_int(self, Buffer buffer, int64_t len_, object collection_): + assert buffer.read_int16() == NOT_NULL_PYINT_FLAG + for i in range(len_): + self._add_element(collection_, i, buffer.read_varint64()) + + cdef inline _write_bool(self, Buffer buffer, value): + buffer.write_int16(NOT_NULL_PYBOOL_FLAG) + for s in value: + buffer.write_bool(s) + + cdef inline _read_bool(self, Buffer buffer, int64_t len_, object collection_): + assert buffer.read_int16() == NOT_NULL_PYBOOL_FLAG + for i in range(len_): + self._add_element(collection_, i, buffer.read_bool()) + + cdef inline _write_float(self, Buffer buffer, value): + buffer.write_int16(NOT_NULL_PYFLOAT_FLAG) + for s in value: + buffer.write_double(s) + + cdef inline _read_float(self, Buffer buffer, int64_t len_, object collection_): + assert buffer.read_int16() == NOT_NULL_PYFLOAT_FLAG + for i in range(len_): + self._add_element(collection_, i, buffer.read_double()) + + cpdef _write_same_type_no_ref(self, Buffer buffer, value, elem_type): + cdef MapRefResolver ref_resolver = self.ref_resolver + cdef ClassResolver class_resolver = self.class_resolver + classinfo = class_resolver.get_or_create_classinfo(elem_type) + class_resolver.write_classinfo(buffer, classinfo) + for s in value: + classinfo.serializer.write(buffer, s) + + cpdef _read_same_type_no_ref(self, Buffer buffer, int64_t len_, object collection_): + cdef MapRefResolver ref_resolver = self.ref_resolver + cdef ClassResolver class_resolver = self.class_resolver + classinfo = class_resolver.read_classinfo(buffer) + for i in range(len_): + obj = classinfo.serializer.read(buffer) + self._add_element(collection_, i, obj) + + cpdef _write_same_type_ref(self, Buffer buffer, value, elem_type): + cdef MapRefResolver ref_resolver = self.ref_resolver + cdef ClassResolver class_resolver = self.class_resolver + classinfo = class_resolver.get_or_create_classinfo(elem_type) + class_resolver.write_classinfo(buffer, classinfo) + for s in value: + if not ref_resolver.write_ref_or_null(buffer, s): + classinfo.serializer.write(buffer, s) + + cpdef _read_same_type_ref(self, Buffer buffer, int64_t len_, object collection_): + cdef MapRefResolver ref_resolver = self.ref_resolver + cdef ClassResolver class_resolver = self.class_resolver + classinfo = class_resolver.read_classinfo(buffer) + for i in range(len_): + ref_id = ref_resolver.try_preserve_ref_id(buffer) + if ref_id < NOT_NULL_VALUE_FLAG: + obj = ref_resolver.get_read_object() else: - if not ref_resolver.write_ref_or_null(buffer, s): - classinfo = class_resolver.get_or_create_classinfo(cls) - class_resolver.write_classinfo(buffer, classinfo) - classinfo.serializer.write(buffer, s) + obj = classinfo.serializer.read(buffer) + ref_resolver.set_read_object(ref_id, obj) + self._add_element(collection_, i, obj) + + cpdef _add_element(self, object collection_, int64_t index, object element): + raise NotImplementedError cpdef xwrite(self, Buffer buffer, value): cdef int32_t len_ = 0 @@ -1690,15 +1824,39 @@ cdef class ListSerializer(CollectionSerializer): cpdef read(self, Buffer buffer): cdef MapRefResolver ref_resolver = self.fury.ref_resolver cdef ClassResolver class_resolver = self.fury.class_resolver - cdef int32_t len_ = buffer.read_varint32() + cdef int64_t len_and_flag = buffer.read_varint64() + cdef int64_t len_ = len_and_flag >> 4 + cdef int8_t collect_flag = (len_and_flag & 0xF) cdef list list_ = PyList_New(len_) ref_resolver.reference(list_) - for i in range(len_): - elem = get_next_elenment(buffer, ref_resolver, class_resolver) - Py_INCREF(elem) - PyList_SET_ITEM(list_, i, elem) + if len_ == 0: + return list_ + if (collect_flag & COLLECTION_NOT_SAME_TYPE) == 0: + type_flag = buffer.get_int16(buffer.reader_index) + if type_flag == NOT_NULL_STRING_FLAG: + self._read_string(buffer, len_, list_) + elif type_flag == NOT_NULL_PYINT_FLAG: + self._read_int(buffer, len_, list_) + elif type_flag == NOT_NULL_PYBOOL_FLAG: + self._read_bool(buffer, len_, list_) + elif type_flag == NOT_NULL_PYFLOAT_FLAG: + self._read_float(buffer, len_, list_) + else: + if (collect_flag & COLLECTION_TRACKING_REF) == 0: + self._read_same_type_no_ref(buffer, len_, list_) + else: + self._read_same_type_ref(buffer, len_, list_) + else: + for i in range(len_): + elem = get_next_elenment(buffer, ref_resolver, class_resolver) + Py_INCREF(elem) + PyList_SET_ITEM(list_, i, elem) return list_ + cpdef _add_element(self, object collection_, int64_t index, object element): + Py_INCREF(element) + PyList_SET_ITEM(collection_, index, element) + cpdef xread(self, Buffer buffer): cdef int32_t len_ = buffer.read_varint32() cdef list collection_ = PyList_New(len_) @@ -1746,14 +1904,38 @@ cdef class TupleSerializer(CollectionSerializer): cpdef inline read(self, Buffer buffer): cdef MapRefResolver ref_resolver = self.fury.ref_resolver cdef ClassResolver class_resolver = self.fury.class_resolver - cdef int32_t len_ = buffer.read_varint32() + cdef int64_t len_and_flag = buffer.read_varint64() + cdef int64_t len_ = len_and_flag >> 4 + cdef int8_t collect_flag = (len_and_flag & 0xF) cdef tuple tuple_ = PyTuple_New(len_) - for i in range(len_): - elem = get_next_elenment(buffer, ref_resolver, class_resolver) - Py_INCREF(elem) - PyTuple_SET_ITEM(tuple_, i, elem) + if len_ == 0: + return tuple_ + if (collect_flag & COLLECTION_NOT_SAME_TYPE) == 0: + type_flag = buffer.get_int16(buffer.reader_index) + if type_flag == NOT_NULL_STRING_FLAG: + self._read_string(buffer, len_, tuple_) + elif type_flag == NOT_NULL_PYINT_FLAG: + self._read_int(buffer, len_, tuple_) + elif type_flag == NOT_NULL_PYBOOL_FLAG: + self._read_bool(buffer, len_, tuple_) + elif type_flag == NOT_NULL_PYFLOAT_FLAG: + self._read_float(buffer, len_, tuple_) + else: + if (collect_flag & COLLECTION_TRACKING_REF) == 0: + self._read_same_type_no_ref(buffer, len_, tuple_) + else: + self._read_same_type_ref(buffer, len_, tuple_) + else: + for i in range(len_): + elem = get_next_elenment(buffer, ref_resolver, class_resolver) + Py_INCREF(elem) + PyTuple_SET_ITEM(tuple_, i, elem) return tuple_ + cpdef inline _add_element(self, object collection_, int64_t index, object element): + Py_INCREF(element) + PyTuple_SET_ITEM(collection_, index, element) + cpdef inline xread(self, Buffer buffer): cdef int32_t len_ = buffer.read_varint32() cdef tuple tuple_ = PyTuple_New(len_) @@ -1785,31 +1967,54 @@ cdef class SetSerializer(CollectionSerializer): cdef ClassResolver class_resolver = self.fury.class_resolver cdef set instance = set() ref_resolver.reference(instance) - cdef int32_t len_ = buffer.read_varint32() + cdef int64_t len_and_flag = buffer.read_varint64() + cdef int64_t len_ = len_and_flag >> 4 + cdef int8_t collect_flag = (len_and_flag & 0xF) cdef int32_t ref_id cdef ClassInfo classinfo - for i in range(len_): - ref_id = ref_resolver.try_preserve_ref_id(buffer) - if ref_id < NOT_NULL_VALUE_FLAG: - instance.add(ref_resolver.get_read_object()) - continue - # indicates that the object is first read. - classinfo = class_resolver.read_classinfo(buffer) - cls = classinfo.cls - if cls is str: - instance.add(buffer.read_string()) - elif cls is int: - instance.add(buffer.read_varint64()) - elif cls is bool: - instance.add(buffer.read_bool()) - elif cls is float: - instance.add(buffer.read_double()) + if len_ == 0: + return instance + if (collect_flag & COLLECTION_NOT_SAME_TYPE) == 0: + type_flag = buffer.get_int16(buffer.reader_index) + if type_flag == NOT_NULL_STRING_FLAG: + self._read_string(buffer, len_, instance) + elif type_flag == NOT_NULL_PYINT_FLAG: + self._read_int(buffer, len_, instance) + elif type_flag == NOT_NULL_PYBOOL_FLAG: + self._read_bool(buffer, len_, instance) + elif type_flag == NOT_NULL_PYFLOAT_FLAG: + self._read_float(buffer, len_, instance) else: - o = classinfo.serializer.read(buffer) - ref_resolver.set_read_object(ref_id, o) - instance.add(o) + if (collect_flag & COLLECTION_TRACKING_REF) == 0: + self._read_same_type_no_ref(buffer, len_, instance) + else: + self._read_same_type_ref(buffer, len_, instance) + else: + for i in range(len_): + ref_id = ref_resolver.try_preserve_ref_id(buffer) + if ref_id < NOT_NULL_VALUE_FLAG: + instance.add(ref_resolver.get_read_object()) + continue + # indicates that the object is first read. + classinfo = class_resolver.read_classinfo(buffer) + cls = classinfo.cls + if cls is str: + instance.add(buffer.read_string()) + elif cls is int: + instance.add(buffer.read_varint64()) + elif cls is bool: + instance.add(buffer.read_bool()) + elif cls is float: + instance.add(buffer.read_double()) + else: + o = classinfo.serializer.read(buffer) + ref_resolver.set_read_object(ref_id, o) + instance.add(o) return instance + cpdef inline _add_element(self, object collection_, int64_t index, object element): + collection_.add(element) + cpdef inline xread(self, Buffer buffer): cdef int32_t len_ = buffer.read_varint32() cdef set instance = set()