Skip to content

Commit

Permalink
feat(python): Implement collection serialization protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
penguin-wwy committed Nov 13, 2024
1 parent 0189c93 commit b517b62
Showing 1 changed file with 245 additions and 48 deletions.
293 changes: 245 additions & 48 deletions python/pyfury/_serialization.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1630,6 +1630,18 @@ cdef class BytesSerializer(CrossLanguageCompatibleSerializer):
return fury_buf.to_pybytes()


"""
Collection serialization format:
https://fury.apache.org/docs/specification/fury_xlang_serialization_spec/#list
Has the following changes:
* None has an independent NonType type, so COLLECTION_NOT_SAME_TYPE can also cover the concept of being nullable.
* No flag is needed to indicate that the element type is not the declared type.
"""
cdef int8_t COLLECTION_DEFAULT_FLAG = 0b0
cdef int8_t COLLECTION_TRACKING_REF = 0b1
cdef int8_t COLLECTION_NOT_SAME_TYPE = 0b1000


cdef class CollectionSerializer(Serializer):
cdef ClassResolver class_resolver
cdef MapRefResolver ref_resolver
Expand All @@ -1644,29 +1656,143 @@ cdef class CollectionSerializer(Serializer):
cpdef int16_t get_xtype_id(self):
return -FuryType.LIST.value

cpdef int8_t write_header(self, Buffer buffer, value):
cdef int8_t collect_flag = COLLECTION_DEFAULT_FLAG
elem_type = type(next(iter(value)))
for s in value:
if type(s) is not elem_type:
collect_flag |= COLLECTION_NOT_SAME_TYPE
break
if self.fury.ref_tracking:
collect_flag |= COLLECTION_TRACKING_REF
buffer.write_varint64((len(value) << 4) | collect_flag)
return collect_flag

cpdef write(self, Buffer buffer, value):
buffer.write_varint32(len(value))
if len(value) == 0:
buffer.write_varint64(0)
return
cdef int8_t collect_flag = self.write_header(buffer, value)
cdef MapRefResolver ref_resolver = self.ref_resolver
cdef ClassResolver class_resolver = self.class_resolver
if (collect_flag & COLLECTION_NOT_SAME_TYPE) == 0:
elem_type = type(next(iter(value)))
if elem_type is str:
self._write_string(buffer, value)
elif elem_type is int:
self._write_int(buffer, value)
elif elem_type is bool:
self._write_bool(buffer, value)
elif elem_type is float:
self._write_float(buffer, value)
else:
if (collect_flag & COLLECTION_TRACKING_REF) == 0:
self._write_same_type_no_ref(buffer, value)
else:
self._write_same_type_ref(buffer, value)
else:
for s in value:
cls = type(s)
if cls is str:
buffer.write_int16(NOT_NULL_STRING_FLAG)
buffer.write_string(s)
elif cls is int:
buffer.write_int16(NOT_NULL_PYINT_FLAG)
buffer.write_varint64(s)
elif cls is bool:
buffer.write_int16(NOT_NULL_PYBOOL_FLAG)
buffer.write_bool(s)
elif cls is float:
buffer.write_int16(NOT_NULL_PYFLOAT_FLAG)
buffer.write_double(s)
else:
if not ref_resolver.write_ref_or_null(buffer, s):
classinfo = class_resolver.get_or_create_classinfo(cls)
class_resolver.write_classinfo(buffer, classinfo)
classinfo.serializer.write(buffer, s)

cdef inline _write_string(self, Buffer buffer, value):
for s in value:
cls = type(s)
if cls is str:
buffer.write_int16(NOT_NULL_STRING_FLAG)
buffer.write_string(s)
elif cls is int:
buffer.write_int16(NOT_NULL_PYINT_FLAG)
buffer.write_varint64(s)
elif cls is bool:
buffer.write_int16(NOT_NULL_PYBOOL_FLAG)
buffer.write_bool(s)
elif cls is float:
buffer.write_int16(NOT_NULL_PYFLOAT_FLAG)
buffer.write_double(s)
buffer.write_int16(NOT_NULL_STRING_FLAG)
buffer.write_string(s)

cdef inline _read_string(self, Buffer buffer, int64_t len_, object collection_):
for i in range(len_):
assert buffer.read_int16() == NOT_NULL_STRING_FLAG
self._add_element(collection_, i, buffer.read_string())

cdef inline _write_int(self, Buffer buffer, value):
for s in value:
buffer.write_int16(NOT_NULL_PYINT_FLAG)
buffer.write_varint64(s)

cdef inline _read_int(self, Buffer buffer, int64_t len_, object collection_):
for i in range(len_):
assert buffer.read_int16() == NOT_NULL_PYINT_FLAG
self._add_element(collection_, i, buffer.read_varint64())

cdef inline _write_bool(self, Buffer buffer, value):
for s in value:
buffer.write_int16(NOT_NULL_PYBOOL_FLAG)
buffer.write_bool(s)

cdef inline _read_bool(self, Buffer buffer, int64_t len_, object collection_):
for i in range(len_):
assert buffer.read_int16() == NOT_NULL_PYBOOL_FLAG
self._add_element(collection_, i, buffer.read_bool())

cdef inline _write_float(self, Buffer buffer, value):
for s in value:
buffer.write_int16(NOT_NULL_PYFLOAT_FLAG)
buffer.write_double(s)

cdef inline _read_float(self, Buffer buffer, int64_t len_, object collection_):
for i in range(len_):
assert buffer.read_int16() == NOT_NULL_PYFLOAT_FLAG
self._add_element(collection_, i, buffer.read_double())

cpdef _write_same_type_no_ref(self, Buffer buffer, value):
cdef MapRefResolver ref_resolver = self.ref_resolver
cdef ClassResolver class_resolver = self.class_resolver
elem_type = type(next(iter(value)))
classinfo = class_resolver.get_or_create_classinfo(elem_type)
for s in value:
class_resolver.write_classinfo(buffer, classinfo)
classinfo.serializer.write(buffer, s)

cpdef _read_same_type_no_ref(self, Buffer buffer, int64_t len_, object collection_):
cdef MapRefResolver ref_resolver = self.ref_resolver
cdef ClassResolver class_resolver = self.class_resolver
classinfo = class_resolver.read_classinfo(buffer)
for i in range(len_):
obj = classinfo.serializer.read(buffer)
self._add_element(collection_, i, obj)

cpdef _write_same_type_ref(self, Buffer buffer, value):
cdef MapRefResolver ref_resolver = self.ref_resolver
cdef ClassResolver class_resolver = self.class_resolver
elem_type = type(next(iter(value)))
classinfo = class_resolver.get_or_create_classinfo(elem_type)
for s in value:
if not ref_resolver.write_ref_or_null(buffer, s):
class_resolver.write_classinfo(buffer, classinfo)
classinfo.serializer.write(buffer, s)

cpdef _read_same_type_ref(self, Buffer buffer, int64_t len_, object collection_):
cdef MapRefResolver ref_resolver = self.ref_resolver
cdef ClassResolver class_resolver = self.class_resolver
for i in range(len_):
ref_id = ref_resolver.try_preserve_ref_id(buffer)
if ref_id < NOT_NULL_VALUE_FLAG:
obj = ref_resolver.get_read_object()
else:
if not ref_resolver.write_ref_or_null(buffer, s):
classinfo = class_resolver.get_or_create_classinfo(cls)
class_resolver.write_classinfo(buffer, classinfo)
classinfo.serializer.write(buffer, s)
classinfo = class_resolver.read_classinfo(buffer)
obj = classinfo.serializer.read(buffer)
ref_resolver.set_read_object(ref_id, obj)
self._add_element(collection_, i, obj)

cpdef _add_element(self, object collection_, int64_t index, object element):
raise NotImplementedError

cpdef xwrite(self, Buffer buffer, value):
cdef int32_t len_ = 0
Expand All @@ -1690,15 +1816,39 @@ cdef class ListSerializer(CollectionSerializer):
cpdef read(self, Buffer buffer):
cdef MapRefResolver ref_resolver = self.fury.ref_resolver
cdef ClassResolver class_resolver = self.fury.class_resolver
cdef int32_t len_ = buffer.read_varint32()
cdef int64_t len_and_flag = buffer.read_varint64()
cdef int64_t len_ = len_and_flag >> 4
cdef int8_t collect_flag = <int8_t>(len_and_flag & 0xF)
cdef list list_ = PyList_New(len_)
ref_resolver.reference(list_)
for i in range(len_):
elem = get_next_elenment(buffer, ref_resolver, class_resolver)
Py_INCREF(elem)
PyList_SET_ITEM(list_, i, elem)
if len_ == 0:
return list_
if (collect_flag & COLLECTION_NOT_SAME_TYPE) == 0:
type_flag = buffer.get_int16(buffer.reader_index)
if type_flag == NOT_NULL_STRING_FLAG:
self._read_string(buffer, len_, list_)
elif type_flag == NOT_NULL_PYINT_FLAG:
self._read_int(buffer, len_, list_)
elif type_flag == NOT_NULL_PYBOOL_FLAG:
self._read_bool(buffer, len_, list_)
elif type_flag == NOT_NULL_PYFLOAT_FLAG:
self._read_float(buffer, len_, list_)
else:
if (collect_flag & COLLECTION_TRACKING_REF) == 0:
self._read_same_type_no_ref(buffer, len_, list_)
else:
self._read_same_type_ref(buffer, len_, list_)
else:
for i in range(len_):
elem = get_next_elenment(buffer, ref_resolver, class_resolver)
Py_INCREF(elem)
PyList_SET_ITEM(list_, i, elem)
return list_

cpdef _add_element(self, object collection_, int64_t index, object element):
Py_INCREF(element)
PyList_SET_ITEM(collection_, index, element)

cpdef xread(self, Buffer buffer):
cdef int32_t len_ = buffer.read_varint32()
cdef list collection_ = PyList_New(len_)
Expand Down Expand Up @@ -1746,14 +1896,38 @@ cdef class TupleSerializer(CollectionSerializer):
cpdef inline read(self, Buffer buffer):
cdef MapRefResolver ref_resolver = self.fury.ref_resolver
cdef ClassResolver class_resolver = self.fury.class_resolver
cdef int32_t len_ = buffer.read_varint32()
cdef int64_t len_and_flag = buffer.read_varint64()
cdef int64_t len_ = len_and_flag >> 4
cdef int8_t collect_flag = <int8_t>(len_and_flag & 0xF)
cdef tuple tuple_ = PyTuple_New(len_)
for i in range(len_):
elem = get_next_elenment(buffer, ref_resolver, class_resolver)
Py_INCREF(elem)
PyTuple_SET_ITEM(tuple_, i, elem)
if len_ == 0:
return tuple_
if (collect_flag & COLLECTION_NOT_SAME_TYPE) == 0:
type_flag = buffer.get_int16(buffer.reader_index)
if type_flag == NOT_NULL_STRING_FLAG:
self._read_string(buffer, len_, tuple_)
elif type_flag == NOT_NULL_PYINT_FLAG:
self._read_int(buffer, len_, tuple_)
elif type_flag == NOT_NULL_PYBOOL_FLAG:
self._read_bool(buffer, len_, tuple_)
elif type_flag == NOT_NULL_PYFLOAT_FLAG:
self._read_float(buffer, len_, tuple_)
else:
if (collect_flag & COLLECTION_TRACKING_REF) == 0:
self._read_same_type_no_ref(buffer, len_, tuple_)
else:
self._read_same_type_ref(buffer, len_, tuple_)
else:
for i in range(len_):
elem = get_next_elenment(buffer, ref_resolver, class_resolver)
Py_INCREF(elem)
PyTuple_SET_ITEM(tuple_, i, elem)
return tuple_

cpdef inline _add_element(self, object collection_, int64_t index, object element):
Py_INCREF(element)
PyTuple_SET_ITEM(collection_, index, element)

cpdef inline xread(self, Buffer buffer):
cdef int32_t len_ = buffer.read_varint32()
cdef tuple tuple_ = PyTuple_New(len_)
Expand Down Expand Up @@ -1785,31 +1959,54 @@ cdef class SetSerializer(CollectionSerializer):
cdef ClassResolver class_resolver = self.fury.class_resolver
cdef set instance = set()
ref_resolver.reference(instance)
cdef int32_t len_ = buffer.read_varint32()
cdef int64_t len_and_flag = buffer.read_varint64()
cdef int64_t len_ = len_and_flag >> 4
cdef int8_t collect_flag = <int8_t>(len_and_flag & 0xF)
cdef int32_t ref_id
cdef ClassInfo classinfo
for i in range(len_):
ref_id = ref_resolver.try_preserve_ref_id(buffer)
if ref_id < NOT_NULL_VALUE_FLAG:
instance.add(ref_resolver.get_read_object())
continue
# indicates that the object is first read.
classinfo = class_resolver.read_classinfo(buffer)
cls = classinfo.cls
if cls is str:
instance.add(buffer.read_string())
elif cls is int:
instance.add(buffer.read_varint64())
elif cls is bool:
instance.add(buffer.read_bool())
elif cls is float:
instance.add(buffer.read_double())
if len_ == 0:
return instance
if (collect_flag & COLLECTION_NOT_SAME_TYPE) == 0:
type_flag = buffer.get_int16(buffer.reader_index)
if type_flag == NOT_NULL_STRING_FLAG:
self._read_string(buffer, len_, instance)
elif type_flag == NOT_NULL_PYINT_FLAG:
self._read_int(buffer, len_, instance)
elif type_flag == NOT_NULL_PYBOOL_FLAG:
self._read_bool(buffer, len_, instance)
elif type_flag == NOT_NULL_PYFLOAT_FLAG:
self._read_float(buffer, len_, instance)
else:
o = classinfo.serializer.read(buffer)
ref_resolver.set_read_object(ref_id, o)
instance.add(o)
if (collect_flag & COLLECTION_TRACKING_REF) == 0:
self._read_same_type_no_ref(buffer, len_, instance)
else:
self._read_same_type_ref(buffer, len_, instance)
else:
for i in range(len_):
ref_id = ref_resolver.try_preserve_ref_id(buffer)
if ref_id < NOT_NULL_VALUE_FLAG:
instance.add(ref_resolver.get_read_object())
continue
# indicates that the object is first read.
classinfo = class_resolver.read_classinfo(buffer)
cls = classinfo.cls
if cls is str:
instance.add(buffer.read_string())
elif cls is int:
instance.add(buffer.read_varint64())
elif cls is bool:
instance.add(buffer.read_bool())
elif cls is float:
instance.add(buffer.read_double())
else:
o = classinfo.serializer.read(buffer)
ref_resolver.set_read_object(ref_id, o)
instance.add(o)
return instance

cpdef inline _add_element(self, object collection_, int64_t index, object element):
collection_.add(element)

cpdef inline xread(self, Buffer buffer):
cdef int32_t len_ = buffer.read_varint32()
cdef set instance = set()
Expand Down

0 comments on commit b517b62

Please sign in to comment.