Skip to content

Commit

Permalink
feat(python): Implement collection serialization protocol (#1942)
Browse files Browse the repository at this point in the history
## What does this PR do?

Implement a new format for collection serialization in pyfury.

## Related issues

## Does this PR introduce any user-facing change?

- [ ] Does this PR introduce any public API change?
- [ ] Does this PR introduce any binary protocol compatibility change?

## Benchmark

```
fury_tuple: Mean +- std dev: [base] 259 us +- 6 us -> [collection] 256 us +- 5 us: 1.01x faster
fury_large_tuple: Mean +- std dev: [base] 92.7 ms +- 5.5 ms -> [collection] 63.7 ms +- 4.8 ms: 1.46x faster
fury_list: Mean +- std dev: [base] 277 us +- 6 us -> [collection] 267 us +- 3 us: 1.04x faster
fury_large_list: Mean +- std dev: [base] 92.8 ms +- 5.3 ms -> [collection] 66.5 ms +- 3.0 ms: 1.40x faster

Geometric mean: 1.21x faster
```
  • Loading branch information
penguin-wwy authored Nov 18, 2024
1 parent 0189c93 commit fb2172b
Showing 1 changed file with 253 additions and 48 deletions.
301 changes: 253 additions & 48 deletions python/pyfury/_serialization.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ logger = logging.getLogger(__name__)
ENABLE_FURY_CYTHON_SERIALIZATION = os.environ.get(
"ENABLE_FURY_CYTHON_SERIALIZATION", "True").lower() in ("true", "1")

cdef extern from *:
"""
#define int2obj(obj_addr) ((PyObject *)(obj_addr))
#define obj2int(obj_ref) (Py_INCREF(obj_ref), ((int64_t)(obj_ref)))
"""
object int2obj(int64_t obj_addr)
int64_t obj2int(object obj_ref)


cdef int8_t NULL_FLAG = -3
# This flag indicates that object is a not-null value.
Expand Down Expand Up @@ -1630,6 +1638,18 @@ cdef class BytesSerializer(CrossLanguageCompatibleSerializer):
return fury_buf.to_pybytes()


"""
Collection serialization format:
https://fury.apache.org/docs/specification/fury_xlang_serialization_spec/#list
Has the following changes:
* None has an independent NonType type, so COLLECTION_NOT_SAME_TYPE can also cover the concept of being nullable.
* No flag is needed to indicate that the element type is not the declared type.
"""
cdef int8_t COLLECTION_DEFAULT_FLAG = 0b0
cdef int8_t COLLECTION_TRACKING_REF = 0b1
cdef int8_t COLLECTION_NOT_SAME_TYPE = 0b1000


cdef class CollectionSerializer(Serializer):
cdef ClassResolver class_resolver
cdef MapRefResolver ref_resolver
Expand All @@ -1644,29 +1664,143 @@ cdef class CollectionSerializer(Serializer):
cpdef int16_t get_xtype_id(self):
return -FuryType.LIST.value

cdef pair[int8_t, int64_t] write_header(self, Buffer buffer, value):
cdef int8_t collect_flag = COLLECTION_DEFAULT_FLAG
elem_type = type(next(iter(value)))
for s in value:
if type(s) is not elem_type:
collect_flag |= COLLECTION_NOT_SAME_TYPE
break
if self.fury.ref_tracking:
collect_flag |= COLLECTION_TRACKING_REF
buffer.write_varint64((len(value) << 4) | collect_flag)
return pair[int8_t, int64_t](collect_flag, obj2int(elem_type))

cpdef write(self, Buffer buffer, value):
buffer.write_varint32(len(value))
if len(value) == 0:
buffer.write_varint64(0)
return
cdef pair[int8_t, int64_t] header_pair = self.write_header(buffer, value)
cdef int8_t collect_flag = header_pair.first
cdef int64_t elem_type_ptr = header_pair.second
cdef elem_type = <type>int2obj(elem_type_ptr)
cdef MapRefResolver ref_resolver = self.ref_resolver
cdef ClassResolver class_resolver = self.class_resolver
if (collect_flag & COLLECTION_NOT_SAME_TYPE) == 0:
if elem_type is str:
self._write_string(buffer, value)
elif elem_type is int:
self._write_int(buffer, value)
elif elem_type is bool:
self._write_bool(buffer, value)
elif elem_type is float:
self._write_float(buffer, value)
else:
if (collect_flag & COLLECTION_TRACKING_REF) == 0:
self._write_same_type_no_ref(buffer, value, elem_type)
else:
self._write_same_type_ref(buffer, value, elem_type)
else:
for s in value:
cls = type(s)
if cls is str:
buffer.write_int16(NOT_NULL_STRING_FLAG)
buffer.write_string(s)
elif cls is int:
buffer.write_int16(NOT_NULL_PYINT_FLAG)
buffer.write_varint64(s)
elif cls is bool:
buffer.write_int16(NOT_NULL_PYBOOL_FLAG)
buffer.write_bool(s)
elif cls is float:
buffer.write_int16(NOT_NULL_PYFLOAT_FLAG)
buffer.write_double(s)
else:
if not ref_resolver.write_ref_or_null(buffer, s):
classinfo = class_resolver.get_or_create_classinfo(cls)
class_resolver.write_classinfo(buffer, classinfo)
classinfo.serializer.write(buffer, s)

cdef inline _write_string(self, Buffer buffer, value):
buffer.write_int16(NOT_NULL_STRING_FLAG)
for s in value:
cls = type(s)
if cls is str:
buffer.write_int16(NOT_NULL_STRING_FLAG)
buffer.write_string(s)
elif cls is int:
buffer.write_int16(NOT_NULL_PYINT_FLAG)
buffer.write_varint64(s)
elif cls is bool:
buffer.write_int16(NOT_NULL_PYBOOL_FLAG)
buffer.write_bool(s)
elif cls is float:
buffer.write_int16(NOT_NULL_PYFLOAT_FLAG)
buffer.write_double(s)
buffer.write_string(s)

cdef inline _read_string(self, Buffer buffer, int64_t len_, object collection_):
assert buffer.read_int16() == NOT_NULL_STRING_FLAG
for i in range(len_):
self._add_element(collection_, i, buffer.read_string())

cdef inline _write_int(self, Buffer buffer, value):
buffer.write_int16(NOT_NULL_PYINT_FLAG)
for s in value:
buffer.write_varint64(s)

cdef inline _read_int(self, Buffer buffer, int64_t len_, object collection_):
assert buffer.read_int16() == NOT_NULL_PYINT_FLAG
for i in range(len_):
self._add_element(collection_, i, buffer.read_varint64())

cdef inline _write_bool(self, Buffer buffer, value):
buffer.write_int16(NOT_NULL_PYBOOL_FLAG)
for s in value:
buffer.write_bool(s)

cdef inline _read_bool(self, Buffer buffer, int64_t len_, object collection_):
assert buffer.read_int16() == NOT_NULL_PYBOOL_FLAG
for i in range(len_):
self._add_element(collection_, i, buffer.read_bool())

cdef inline _write_float(self, Buffer buffer, value):
buffer.write_int16(NOT_NULL_PYFLOAT_FLAG)
for s in value:
buffer.write_double(s)

cdef inline _read_float(self, Buffer buffer, int64_t len_, object collection_):
assert buffer.read_int16() == NOT_NULL_PYFLOAT_FLAG
for i in range(len_):
self._add_element(collection_, i, buffer.read_double())

cpdef _write_same_type_no_ref(self, Buffer buffer, value, elem_type):
cdef MapRefResolver ref_resolver = self.ref_resolver
cdef ClassResolver class_resolver = self.class_resolver
classinfo = class_resolver.get_or_create_classinfo(elem_type)
class_resolver.write_classinfo(buffer, classinfo)
for s in value:
classinfo.serializer.write(buffer, s)

cpdef _read_same_type_no_ref(self, Buffer buffer, int64_t len_, object collection_):
cdef MapRefResolver ref_resolver = self.ref_resolver
cdef ClassResolver class_resolver = self.class_resolver
classinfo = class_resolver.read_classinfo(buffer)
for i in range(len_):
obj = classinfo.serializer.read(buffer)
self._add_element(collection_, i, obj)

cpdef _write_same_type_ref(self, Buffer buffer, value, elem_type):
cdef MapRefResolver ref_resolver = self.ref_resolver
cdef ClassResolver class_resolver = self.class_resolver
classinfo = class_resolver.get_or_create_classinfo(elem_type)
class_resolver.write_classinfo(buffer, classinfo)
for s in value:
if not ref_resolver.write_ref_or_null(buffer, s):
classinfo.serializer.write(buffer, s)

cpdef _read_same_type_ref(self, Buffer buffer, int64_t len_, object collection_):
cdef MapRefResolver ref_resolver = self.ref_resolver
cdef ClassResolver class_resolver = self.class_resolver
classinfo = class_resolver.read_classinfo(buffer)
for i in range(len_):
ref_id = ref_resolver.try_preserve_ref_id(buffer)
if ref_id < NOT_NULL_VALUE_FLAG:
obj = ref_resolver.get_read_object()
else:
if not ref_resolver.write_ref_or_null(buffer, s):
classinfo = class_resolver.get_or_create_classinfo(cls)
class_resolver.write_classinfo(buffer, classinfo)
classinfo.serializer.write(buffer, s)
obj = classinfo.serializer.read(buffer)
ref_resolver.set_read_object(ref_id, obj)
self._add_element(collection_, i, obj)

cpdef _add_element(self, object collection_, int64_t index, object element):
raise NotImplementedError

cpdef xwrite(self, Buffer buffer, value):
cdef int32_t len_ = 0
Expand All @@ -1690,15 +1824,39 @@ cdef class ListSerializer(CollectionSerializer):
cpdef read(self, Buffer buffer):
cdef MapRefResolver ref_resolver = self.fury.ref_resolver
cdef ClassResolver class_resolver = self.fury.class_resolver
cdef int32_t len_ = buffer.read_varint32()
cdef int64_t len_and_flag = buffer.read_varint64()
cdef int64_t len_ = len_and_flag >> 4
cdef int8_t collect_flag = <int8_t>(len_and_flag & 0xF)
cdef list list_ = PyList_New(len_)
ref_resolver.reference(list_)
for i in range(len_):
elem = get_next_elenment(buffer, ref_resolver, class_resolver)
Py_INCREF(elem)
PyList_SET_ITEM(list_, i, elem)
if len_ == 0:
return list_
if (collect_flag & COLLECTION_NOT_SAME_TYPE) == 0:
type_flag = buffer.get_int16(buffer.reader_index)
if type_flag == NOT_NULL_STRING_FLAG:
self._read_string(buffer, len_, list_)
elif type_flag == NOT_NULL_PYINT_FLAG:
self._read_int(buffer, len_, list_)
elif type_flag == NOT_NULL_PYBOOL_FLAG:
self._read_bool(buffer, len_, list_)
elif type_flag == NOT_NULL_PYFLOAT_FLAG:
self._read_float(buffer, len_, list_)
else:
if (collect_flag & COLLECTION_TRACKING_REF) == 0:
self._read_same_type_no_ref(buffer, len_, list_)
else:
self._read_same_type_ref(buffer, len_, list_)
else:
for i in range(len_):
elem = get_next_elenment(buffer, ref_resolver, class_resolver)
Py_INCREF(elem)
PyList_SET_ITEM(list_, i, elem)
return list_

cpdef _add_element(self, object collection_, int64_t index, object element):
Py_INCREF(element)
PyList_SET_ITEM(collection_, index, element)

cpdef xread(self, Buffer buffer):
cdef int32_t len_ = buffer.read_varint32()
cdef list collection_ = PyList_New(len_)
Expand Down Expand Up @@ -1746,14 +1904,38 @@ cdef class TupleSerializer(CollectionSerializer):
cpdef inline read(self, Buffer buffer):
cdef MapRefResolver ref_resolver = self.fury.ref_resolver
cdef ClassResolver class_resolver = self.fury.class_resolver
cdef int32_t len_ = buffer.read_varint32()
cdef int64_t len_and_flag = buffer.read_varint64()
cdef int64_t len_ = len_and_flag >> 4
cdef int8_t collect_flag = <int8_t>(len_and_flag & 0xF)
cdef tuple tuple_ = PyTuple_New(len_)
for i in range(len_):
elem = get_next_elenment(buffer, ref_resolver, class_resolver)
Py_INCREF(elem)
PyTuple_SET_ITEM(tuple_, i, elem)
if len_ == 0:
return tuple_
if (collect_flag & COLLECTION_NOT_SAME_TYPE) == 0:
type_flag = buffer.get_int16(buffer.reader_index)
if type_flag == NOT_NULL_STRING_FLAG:
self._read_string(buffer, len_, tuple_)
elif type_flag == NOT_NULL_PYINT_FLAG:
self._read_int(buffer, len_, tuple_)
elif type_flag == NOT_NULL_PYBOOL_FLAG:
self._read_bool(buffer, len_, tuple_)
elif type_flag == NOT_NULL_PYFLOAT_FLAG:
self._read_float(buffer, len_, tuple_)
else:
if (collect_flag & COLLECTION_TRACKING_REF) == 0:
self._read_same_type_no_ref(buffer, len_, tuple_)
else:
self._read_same_type_ref(buffer, len_, tuple_)
else:
for i in range(len_):
elem = get_next_elenment(buffer, ref_resolver, class_resolver)
Py_INCREF(elem)
PyTuple_SET_ITEM(tuple_, i, elem)
return tuple_

cpdef inline _add_element(self, object collection_, int64_t index, object element):
Py_INCREF(element)
PyTuple_SET_ITEM(collection_, index, element)

cpdef inline xread(self, Buffer buffer):
cdef int32_t len_ = buffer.read_varint32()
cdef tuple tuple_ = PyTuple_New(len_)
Expand Down Expand Up @@ -1785,31 +1967,54 @@ cdef class SetSerializer(CollectionSerializer):
cdef ClassResolver class_resolver = self.fury.class_resolver
cdef set instance = set()
ref_resolver.reference(instance)
cdef int32_t len_ = buffer.read_varint32()
cdef int64_t len_and_flag = buffer.read_varint64()
cdef int64_t len_ = len_and_flag >> 4
cdef int8_t collect_flag = <int8_t>(len_and_flag & 0xF)
cdef int32_t ref_id
cdef ClassInfo classinfo
for i in range(len_):
ref_id = ref_resolver.try_preserve_ref_id(buffer)
if ref_id < NOT_NULL_VALUE_FLAG:
instance.add(ref_resolver.get_read_object())
continue
# indicates that the object is first read.
classinfo = class_resolver.read_classinfo(buffer)
cls = classinfo.cls
if cls is str:
instance.add(buffer.read_string())
elif cls is int:
instance.add(buffer.read_varint64())
elif cls is bool:
instance.add(buffer.read_bool())
elif cls is float:
instance.add(buffer.read_double())
if len_ == 0:
return instance
if (collect_flag & COLLECTION_NOT_SAME_TYPE) == 0:
type_flag = buffer.get_int16(buffer.reader_index)
if type_flag == NOT_NULL_STRING_FLAG:
self._read_string(buffer, len_, instance)
elif type_flag == NOT_NULL_PYINT_FLAG:
self._read_int(buffer, len_, instance)
elif type_flag == NOT_NULL_PYBOOL_FLAG:
self._read_bool(buffer, len_, instance)
elif type_flag == NOT_NULL_PYFLOAT_FLAG:
self._read_float(buffer, len_, instance)
else:
o = classinfo.serializer.read(buffer)
ref_resolver.set_read_object(ref_id, o)
instance.add(o)
if (collect_flag & COLLECTION_TRACKING_REF) == 0:
self._read_same_type_no_ref(buffer, len_, instance)
else:
self._read_same_type_ref(buffer, len_, instance)
else:
for i in range(len_):
ref_id = ref_resolver.try_preserve_ref_id(buffer)
if ref_id < NOT_NULL_VALUE_FLAG:
instance.add(ref_resolver.get_read_object())
continue
# indicates that the object is first read.
classinfo = class_resolver.read_classinfo(buffer)
cls = classinfo.cls
if cls is str:
instance.add(buffer.read_string())
elif cls is int:
instance.add(buffer.read_varint64())
elif cls is bool:
instance.add(buffer.read_bool())
elif cls is float:
instance.add(buffer.read_double())
else:
o = classinfo.serializer.read(buffer)
ref_resolver.set_read_object(ref_id, o)
instance.add(o)
return instance

cpdef inline _add_element(self, object collection_, int64_t index, object element):
collection_.add(element)

cpdef inline xread(self, Buffer buffer):
cdef int32_t len_ = buffer.read_varint32()
cdef set instance = set()
Expand Down

0 comments on commit fb2172b

Please sign in to comment.