Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Implement collection serialization protocol #1942

Merged
merged 3 commits into from
Nov 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 253 additions & 48 deletions python/pyfury/_serialization.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ logger = logging.getLogger(__name__)
ENABLE_FURY_CYTHON_SERIALIZATION = os.environ.get(
"ENABLE_FURY_CYTHON_SERIALIZATION", "True").lower() in ("true", "1")

cdef extern from *:
"""
#define int2obj(obj_addr) ((PyObject *)(obj_addr))
#define obj2int(obj_ref) (Py_INCREF(obj_ref), ((int64_t)(obj_ref)))
"""
object int2obj(int64_t obj_addr)
int64_t obj2int(object obj_ref)


cdef int8_t NULL_FLAG = -3
# This flag indicates that object is a not-null value.
Expand Down Expand Up @@ -1630,6 +1638,18 @@ cdef class BytesSerializer(CrossLanguageCompatibleSerializer):
return fury_buf.to_pybytes()


"""
Collection serialization format:
https://fury.apache.org/docs/specification/fury_xlang_serialization_spec/#list
Has the following changes:
* None has an independent NonType type, so COLLECTION_NOT_SAME_TYPE can also cover the concept of being nullable.
* No flag is needed to indicate that the element type is not the declared type.
"""
cdef int8_t COLLECTION_DEFAULT_FLAG = 0b0
cdef int8_t COLLECTION_TRACKING_REF = 0b1
cdef int8_t COLLECTION_NOT_SAME_TYPE = 0b1000


cdef class CollectionSerializer(Serializer):
cdef ClassResolver class_resolver
cdef MapRefResolver ref_resolver
Expand All @@ -1644,29 +1664,143 @@ cdef class CollectionSerializer(Serializer):
cpdef int16_t get_xtype_id(self):
return -FuryType.LIST.value

cdef pair[int8_t, int64_t] write_header(self, Buffer buffer, value):
cdef int8_t collect_flag = COLLECTION_DEFAULT_FLAG
elem_type = type(next(iter(value)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This create iterator object, maybe we could just init it as None here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to extract the type of the first element and avoid checking elem_type is None in the loop. Since the collection might be a set, an iterator retrieves it.

for s in value:
if type(s) is not elem_type:
collect_flag |= COLLECTION_NOT_SAME_TYPE
break
if self.fury.ref_tracking:
collect_flag |= COLLECTION_TRACKING_REF
buffer.write_varint64((len(value) << 4) | collect_flag)
return pair[int8_t, int64_t](collect_flag, obj2int(elem_type))

cpdef write(self, Buffer buffer, value):
buffer.write_varint32(len(value))
if len(value) == 0:
buffer.write_varint64(0)
return
cdef pair[int8_t, int64_t] header_pair = self.write_header(buffer, value)
cdef int8_t collect_flag = header_pair.first
cdef int64_t elem_type_ptr = header_pair.second
cdef elem_type = <type>int2obj(elem_type_ptr)
cdef MapRefResolver ref_resolver = self.ref_resolver
cdef ClassResolver class_resolver = self.class_resolver
if (collect_flag & COLLECTION_NOT_SAME_TYPE) == 0:
if elem_type is str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you still write null flag and type here, this should be skipped for same type and not-null elements. could you take org.apache.fury.serializer.collection.AbstractCollectionSerializer#writeSameTypeElements as a reference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Python, the type of None is NoneType, so if all elements are of the same type, either none of them are None or all of them are None.

>>> type(None)
<class 'NoneType'>
>>> type(object()) is type(None)
False

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In such cases, maybe we could write elements like this:

elem_type_ptr = xxx
for elem in list_value:
    if elem_type_ptr == str_type_ptr:
         buffer.write_string(elem)
    elif elem_type_ptr == int_type_ptr:
         buffer.write_varint64(elem)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, write the type flag only once for primitive types

self._write_string(buffer, value)
elif elem_type is int:
self._write_int(buffer, value)
elif elem_type is bool:
self._write_bool(buffer, value)
elif elem_type is float:
self._write_float(buffer, value)
else:
if (collect_flag & COLLECTION_TRACKING_REF) == 0:
self._write_same_type_no_ref(buffer, value, elem_type)
else:
self._write_same_type_ref(buffer, value, elem_type)
else:
for s in value:
cls = type(s)
if cls is str:
buffer.write_int16(NOT_NULL_STRING_FLAG)
buffer.write_string(s)
elif cls is int:
buffer.write_int16(NOT_NULL_PYINT_FLAG)
buffer.write_varint64(s)
elif cls is bool:
buffer.write_int16(NOT_NULL_PYBOOL_FLAG)
buffer.write_bool(s)
elif cls is float:
buffer.write_int16(NOT_NULL_PYFLOAT_FLAG)
buffer.write_double(s)
else:
if not ref_resolver.write_ref_or_null(buffer, s):
classinfo = class_resolver.get_or_create_classinfo(cls)
class_resolver.write_classinfo(buffer, classinfo)
classinfo.serializer.write(buffer, s)

cdef inline _write_string(self, Buffer buffer, value):
buffer.write_int16(NOT_NULL_STRING_FLAG)
for s in value:
cls = type(s)
if cls is str:
buffer.write_int16(NOT_NULL_STRING_FLAG)
buffer.write_string(s)
elif cls is int:
buffer.write_int16(NOT_NULL_PYINT_FLAG)
buffer.write_varint64(s)
elif cls is bool:
buffer.write_int16(NOT_NULL_PYBOOL_FLAG)
buffer.write_bool(s)
elif cls is float:
buffer.write_int16(NOT_NULL_PYFLOAT_FLAG)
buffer.write_double(s)
buffer.write_string(s)

cdef inline _read_string(self, Buffer buffer, int64_t len_, object collection_):
assert buffer.read_int16() == NOT_NULL_STRING_FLAG
for i in range(len_):
self._add_element(collection_, i, buffer.read_string())

cdef inline _write_int(self, Buffer buffer, value):
buffer.write_int16(NOT_NULL_PYINT_FLAG)
for s in value:
buffer.write_varint64(s)

cdef inline _read_int(self, Buffer buffer, int64_t len_, object collection_):
assert buffer.read_int16() == NOT_NULL_PYINT_FLAG
for i in range(len_):
self._add_element(collection_, i, buffer.read_varint64())

cdef inline _write_bool(self, Buffer buffer, value):
buffer.write_int16(NOT_NULL_PYBOOL_FLAG)
for s in value:
buffer.write_bool(s)

cdef inline _read_bool(self, Buffer buffer, int64_t len_, object collection_):
assert buffer.read_int16() == NOT_NULL_PYBOOL_FLAG
for i in range(len_):
self._add_element(collection_, i, buffer.read_bool())

cdef inline _write_float(self, Buffer buffer, value):
buffer.write_int16(NOT_NULL_PYFLOAT_FLAG)
for s in value:
buffer.write_double(s)

cdef inline _read_float(self, Buffer buffer, int64_t len_, object collection_):
assert buffer.read_int16() == NOT_NULL_PYFLOAT_FLAG
for i in range(len_):
self._add_element(collection_, i, buffer.read_double())

cpdef _write_same_type_no_ref(self, Buffer buffer, value, elem_type):
cdef MapRefResolver ref_resolver = self.ref_resolver
cdef ClassResolver class_resolver = self.class_resolver
classinfo = class_resolver.get_or_create_classinfo(elem_type)
class_resolver.write_classinfo(buffer, classinfo)
for s in value:
classinfo.serializer.write(buffer, s)

cpdef _read_same_type_no_ref(self, Buffer buffer, int64_t len_, object collection_):
cdef MapRefResolver ref_resolver = self.ref_resolver
cdef ClassResolver class_resolver = self.class_resolver
classinfo = class_resolver.read_classinfo(buffer)
for i in range(len_):
obj = classinfo.serializer.read(buffer)
self._add_element(collection_, i, obj)

cpdef _write_same_type_ref(self, Buffer buffer, value, elem_type):
cdef MapRefResolver ref_resolver = self.ref_resolver
cdef ClassResolver class_resolver = self.class_resolver
classinfo = class_resolver.get_or_create_classinfo(elem_type)
class_resolver.write_classinfo(buffer, classinfo)
for s in value:
if not ref_resolver.write_ref_or_null(buffer, s):
classinfo.serializer.write(buffer, s)

cpdef _read_same_type_ref(self, Buffer buffer, int64_t len_, object collection_):
cdef MapRefResolver ref_resolver = self.ref_resolver
cdef ClassResolver class_resolver = self.class_resolver
classinfo = class_resolver.read_classinfo(buffer)
for i in range(len_):
ref_id = ref_resolver.try_preserve_ref_id(buffer)
if ref_id < NOT_NULL_VALUE_FLAG:
obj = ref_resolver.get_read_object()
else:
if not ref_resolver.write_ref_or_null(buffer, s):
classinfo = class_resolver.get_or_create_classinfo(cls)
class_resolver.write_classinfo(buffer, classinfo)
classinfo.serializer.write(buffer, s)
obj = classinfo.serializer.read(buffer)
ref_resolver.set_read_object(ref_id, obj)
self._add_element(collection_, i, obj)

cpdef _add_element(self, object collection_, int64_t index, object element):
raise NotImplementedError

cpdef xwrite(self, Buffer buffer, value):
cdef int32_t len_ = 0
Expand All @@ -1690,15 +1824,39 @@ cdef class ListSerializer(CollectionSerializer):
cpdef read(self, Buffer buffer):
cdef MapRefResolver ref_resolver = self.fury.ref_resolver
cdef ClassResolver class_resolver = self.fury.class_resolver
cdef int32_t len_ = buffer.read_varint32()
cdef int64_t len_and_flag = buffer.read_varint64()
cdef int64_t len_ = len_and_flag >> 4
cdef int8_t collect_flag = <int8_t>(len_and_flag & 0xF)
cdef list list_ = PyList_New(len_)
ref_resolver.reference(list_)
for i in range(len_):
elem = get_next_elenment(buffer, ref_resolver, class_resolver)
Py_INCREF(elem)
PyList_SET_ITEM(list_, i, elem)
if len_ == 0:
return list_
if (collect_flag & COLLECTION_NOT_SAME_TYPE) == 0:
type_flag = buffer.get_int16(buffer.reader_index)
if type_flag == NOT_NULL_STRING_FLAG:
self._read_string(buffer, len_, list_)
elif type_flag == NOT_NULL_PYINT_FLAG:
self._read_int(buffer, len_, list_)
elif type_flag == NOT_NULL_PYBOOL_FLAG:
self._read_bool(buffer, len_, list_)
elif type_flag == NOT_NULL_PYFLOAT_FLAG:
self._read_float(buffer, len_, list_)
else:
if (collect_flag & COLLECTION_TRACKING_REF) == 0:
self._read_same_type_no_ref(buffer, len_, list_)
else:
self._read_same_type_ref(buffer, len_, list_)
else:
for i in range(len_):
elem = get_next_elenment(buffer, ref_resolver, class_resolver)
Py_INCREF(elem)
PyList_SET_ITEM(list_, i, elem)
return list_

cpdef _add_element(self, object collection_, int64_t index, object element):
Py_INCREF(element)
PyList_SET_ITEM(collection_, index, element)

cpdef xread(self, Buffer buffer):
cdef int32_t len_ = buffer.read_varint32()
cdef list collection_ = PyList_New(len_)
Expand Down Expand Up @@ -1746,14 +1904,38 @@ cdef class TupleSerializer(CollectionSerializer):
cpdef inline read(self, Buffer buffer):
cdef MapRefResolver ref_resolver = self.fury.ref_resolver
cdef ClassResolver class_resolver = self.fury.class_resolver
cdef int32_t len_ = buffer.read_varint32()
cdef int64_t len_and_flag = buffer.read_varint64()
cdef int64_t len_ = len_and_flag >> 4
cdef int8_t collect_flag = <int8_t>(len_and_flag & 0xF)
cdef tuple tuple_ = PyTuple_New(len_)
for i in range(len_):
elem = get_next_elenment(buffer, ref_resolver, class_resolver)
Py_INCREF(elem)
PyTuple_SET_ITEM(tuple_, i, elem)
if len_ == 0:
return tuple_
if (collect_flag & COLLECTION_NOT_SAME_TYPE) == 0:
type_flag = buffer.get_int16(buffer.reader_index)
if type_flag == NOT_NULL_STRING_FLAG:
self._read_string(buffer, len_, tuple_)
elif type_flag == NOT_NULL_PYINT_FLAG:
self._read_int(buffer, len_, tuple_)
elif type_flag == NOT_NULL_PYBOOL_FLAG:
self._read_bool(buffer, len_, tuple_)
elif type_flag == NOT_NULL_PYFLOAT_FLAG:
self._read_float(buffer, len_, tuple_)
else:
if (collect_flag & COLLECTION_TRACKING_REF) == 0:
self._read_same_type_no_ref(buffer, len_, tuple_)
else:
self._read_same_type_ref(buffer, len_, tuple_)
else:
for i in range(len_):
elem = get_next_elenment(buffer, ref_resolver, class_resolver)
Py_INCREF(elem)
PyTuple_SET_ITEM(tuple_, i, elem)
return tuple_

cpdef inline _add_element(self, object collection_, int64_t index, object element):
Py_INCREF(element)
PyTuple_SET_ITEM(collection_, index, element)

cpdef inline xread(self, Buffer buffer):
cdef int32_t len_ = buffer.read_varint32()
cdef tuple tuple_ = PyTuple_New(len_)
Expand Down Expand Up @@ -1785,31 +1967,54 @@ cdef class SetSerializer(CollectionSerializer):
cdef ClassResolver class_resolver = self.fury.class_resolver
cdef set instance = set()
ref_resolver.reference(instance)
cdef int32_t len_ = buffer.read_varint32()
cdef int64_t len_and_flag = buffer.read_varint64()
cdef int64_t len_ = len_and_flag >> 4
cdef int8_t collect_flag = <int8_t>(len_and_flag & 0xF)
cdef int32_t ref_id
cdef ClassInfo classinfo
for i in range(len_):
ref_id = ref_resolver.try_preserve_ref_id(buffer)
if ref_id < NOT_NULL_VALUE_FLAG:
instance.add(ref_resolver.get_read_object())
continue
# indicates that the object is first read.
classinfo = class_resolver.read_classinfo(buffer)
cls = classinfo.cls
if cls is str:
instance.add(buffer.read_string())
elif cls is int:
instance.add(buffer.read_varint64())
elif cls is bool:
instance.add(buffer.read_bool())
elif cls is float:
instance.add(buffer.read_double())
if len_ == 0:
return instance
if (collect_flag & COLLECTION_NOT_SAME_TYPE) == 0:
type_flag = buffer.get_int16(buffer.reader_index)
if type_flag == NOT_NULL_STRING_FLAG:
self._read_string(buffer, len_, instance)
elif type_flag == NOT_NULL_PYINT_FLAG:
self._read_int(buffer, len_, instance)
elif type_flag == NOT_NULL_PYBOOL_FLAG:
self._read_bool(buffer, len_, instance)
elif type_flag == NOT_NULL_PYFLOAT_FLAG:
self._read_float(buffer, len_, instance)
else:
o = classinfo.serializer.read(buffer)
ref_resolver.set_read_object(ref_id, o)
instance.add(o)
if (collect_flag & COLLECTION_TRACKING_REF) == 0:
self._read_same_type_no_ref(buffer, len_, instance)
else:
self._read_same_type_ref(buffer, len_, instance)
else:
for i in range(len_):
ref_id = ref_resolver.try_preserve_ref_id(buffer)
if ref_id < NOT_NULL_VALUE_FLAG:
instance.add(ref_resolver.get_read_object())
continue
# indicates that the object is first read.
classinfo = class_resolver.read_classinfo(buffer)
cls = classinfo.cls
if cls is str:
instance.add(buffer.read_string())
elif cls is int:
instance.add(buffer.read_varint64())
elif cls is bool:
instance.add(buffer.read_bool())
elif cls is float:
instance.add(buffer.read_double())
else:
o = classinfo.serializer.read(buffer)
ref_resolver.set_read_object(ref_id, o)
instance.add(o)
return instance

cpdef inline _add_element(self, object collection_, int64_t index, object element):
collection_.add(element)

cpdef inline xread(self, Buffer buffer):
cdef int32_t len_ = buffer.read_varint32()
cdef set instance = set()
Expand Down
Loading