Skip to content

Commit

Permalink
class registration part3
Browse files Browse the repository at this point in the history
  • Loading branch information
chaokunyang committed Dec 14, 2024
1 parent 34f4b52 commit 5aaaf52
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 267 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ private MetaStringBytes createSmallMetaStringBytes(int len, byte encoding, long
LittleEndian.putInt64(data, 0, v1);
LittleEndian.putInt64(data, 8, v2);
long hashCode = MurmurHash3.murmurhash3_x64_128(data, 0, len, 47)[0];
hashCode = ((hashCode) & 0xffffffffffffff00L) | encoding;
hashCode = (hashCode & 0xffffffffffffff00L) | encoding;
MetaStringBytes metaStringBytes = new MetaStringBytes(Arrays.copyOf(data, len), hashCode);
longLongMap.put(v1, v2, metaStringBytes);
return metaStringBytes;
Expand Down
39 changes: 6 additions & 33 deletions python/pyfury/_fury.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,37 +62,9 @@


DEFAULT_DYNAMIC_WRITE_STRING_ID = -1


MAGIC_NUMBER = 0x62D4


class MetaStringBytes:
__slots__ = (
"data",
"length",
"hashcode",
"encoding",
"dynamic_write_string_id",
)

def __init__(self, data, hashcode):
self.data = data
self.length = len(data)
self.hashcode = hashcode
self.encoding = Encoding(hashcode & 0xFF)
self.dynamic_write_string_id = DEFAULT_DYNAMIC_WRITE_STRING_ID

def __eq__(self, other):
return type(other) is MetaStringBytes and other.hashcode == self.hashcode

def __hash__(self):
return self.hashcode

def decode(self, decoder):
return decoder.decode(self.encoding)


class ClassInfo:
__slots__ = (
"cls",
Expand All @@ -108,8 +80,8 @@ def __init__(
cls: type = None,
class_id: int = NO_CLASS_ID,
serializer: Serializer = None,
namespace_bytes: MetaStringBytes = None,
typename_bytes: MetaStringBytes = None,
namespace_bytes = None,
typename_bytes = None,
dynamic_type: bool = False,
):
self.cls = cls
Expand Down Expand Up @@ -187,6 +159,7 @@ def __init__(
self.ref_resolver = MapRefResolver()
else:
self.ref_resolver = NoRefResolver()
self.metastring_resolver = MetaStringResolver(self)
self.class_resolver = ClassResolver(self)
self.class_resolver.initialize()
self.serialization_context = SerializationContext()
Expand Down Expand Up @@ -301,7 +274,7 @@ def serialize_ref(self, buffer, obj, classinfo=None):
if self.ref_resolver.write_ref_or_null(buffer, obj):
return
if classinfo is None:
classinfo = self.class_resolver.get_or_create_classinfo(cls)
classinfo = self.class_resolver.get_classinfo(cls)
self.class_resolver.write_classinfo(buffer, classinfo)
classinfo.serializer.write(buffer, obj)

Expand All @@ -320,7 +293,7 @@ def serialize_nonref(self, buffer, obj):
buffer.write_bool(obj)
return
else:
classinfo = self.class_resolver.get_or_create_classinfo(cls)
classinfo = self.class_resolver.get_classinfo(cls)
self.class_resolver.write_classinfo(buffer, classinfo)
classinfo.serializer.write(buffer, obj)

Expand Down Expand Up @@ -490,7 +463,7 @@ def write_ref_pyobject(self, buffer, value, classinfo=None):
if self.ref_resolver.write_ref_or_null(buffer, value):
return
if classinfo is None:
classinfo = self.class_resolver.get_or_create_classinfo(type(value))
classinfo = self.class_resolver.get_classinfo(type(value))
self.class_resolver.write_classinfo(buffer, classinfo)
classinfo.serializer.write(buffer, value)

Expand Down
146 changes: 57 additions & 89 deletions python/pyfury/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pyfury.serializer import (
Serializer,
NOT_SUPPORT_CROSS_LANGUAGE,
EnumSerializer,
PickleSerializer,
Numpy1DArraySerializer,
NDArraySerializer,
Expand All @@ -29,6 +30,7 @@
PickleCacheStub,
SMALL_STRING_THRESHOLD,
)
from pyfury._struct import ComplexObjectSerializer
from pyfury.buffer import Buffer
from pyfury.meta.metastring import Encoding, MetaStringEncoder, MetaStringDecoder
from pyfury.type import (
Expand All @@ -53,6 +55,22 @@

DEFAULT_DYNAMIC_WRITE_STRING_ID = -1
DYNAMIC_TYPE_ID = -1
USE_CLASSNAME = 0
USE_CLASS_ID = 1
# preserve 0 as flag for class id not set in ClassInfo`
NO_CLASS_ID = 0
PYINT_CLASS_ID = 1
PYFLOAT_CLASS_ID = 2
PYBOOL_CLASS_ID = 3
STRING_CLASS_ID = 4
PICKLE_CLASS_ID = 5
PICKLE_STRONG_CACHE_CLASS_ID = 6
PICKLE_CACHE_CLASS_ID = 7
# `NOT_NULL_VALUE_FLAG` + `CLASS_ID << 1` in little-endian order
NOT_NULL_PYINT_FLAG = NOT_NULL_VALUE_FLAG & 0b11111111 | (PYINT_CLASS_ID << 9)
NOT_NULL_PYFLOAT_FLAG = NOT_NULL_VALUE_FLAG & 0b11111111 | (PYFLOAT_CLASS_ID << 9)
NOT_NULL_PYBOOL_FLAG = NOT_NULL_VALUE_FLAG & 0b11111111 | (PYBOOL_CLASS_ID << 9)
NOT_NULL_STRING_FLAG = NOT_NULL_VALUE_FLANOT_NULL_STRING_FLAGG & 0b11111111 | (STRING_CLASS_ID << 9)
SMALL_STRING_THRESHOLD = 16


Expand All @@ -61,24 +79,28 @@ class ClassResolver:
"fury",
"_type_tag_to_class_x_lang_map",
"_metastr_to_str",
"_class_id_counter",
"_type_id_counter",
"_classes_info",
"_registered_id_to_class_info",
"_hash_to_metastring",
"_metastr_to_class",
"_hash_to_classinfo",
"_dynamic_id_to_classinfo_list",
"_dynamic_id_to_metastr_list",
"_serializer",
"_dynamic_write_string_id",
"_dynamic_written_metastr",
"_ns_type_to_classinfo",
"_namespace_encoder",
"_namespace_decoder",
"_typename_encoder",
"_typename_decoder",
"require_registration",
)

_classes_info: Dict[type, "ClassInfo"]

def __init__(self, fury):
self.fury = fury
self.metastring_resolver = fury.metastring_resolver
self.language = fury.language
self.require_registration = fury.require_class_registration
self._metastr_to_str = dict()
self._metastr_to_class = dict()
self._hash_to_metastring = dict()
Expand All @@ -87,7 +109,6 @@ def __init__(self, fury):
self._dynamic_written_metastr = []
self._type_id_to_classinfo = dict()
self._type_id_counter = PICKLE_CACHE_CLASS_ID + 1
self._registered_id_to_class_info = list()
self._dynamic_write_string_id = 0
self._classes_info = dict()
self._ns_type_to_classinfo = dict()
Expand All @@ -96,6 +117,9 @@ def __init__(self, fury):
self._typename_encoder = MetaStringEncoder("$", "_")
self._typename_decoder = MetaStringDecoder("$", "_")

from pyfury import MetaStringResolver
self._meta_string_resolver = MetaStringResolver()

def initialize(self):
if self.fury.language == Language.PYTHON:
self._initialize_py()
Expand Down Expand Up @@ -335,6 +359,10 @@ def __register_type(
serializer: Serializer = None,
internal: bool = False,
):
from pyfury import ClassInfo

if serializer is None:
serializer = self._create_serializer(cls)
if typename is None:
classinfo = ClassInfo(cls, type_id, serializer, None, None)
else:
Expand All @@ -348,11 +376,6 @@ def __register_type(
self._ns_type_to_classinfo[(ns_meta_bytes, type_meta_bytes)] = classinfo
self._classes_info[cls] = classinfo
if type_id > 0:
if len(self._registered_id_to_class_info) <= type_id:
self._registered_id_to_class_info.extend(
[None] * (type_id - len(self._registered_id_to_class_info) + 1)
)
self._registered_id_to_class_info[type_id] = classinfo
self._type_id_to_classinfo[type_id] = classinfo
self._classes_info[cls] = classinfo
return classinfo
Expand Down Expand Up @@ -380,25 +403,23 @@ def register_serializer(self, cls: Union[type, TypeVar], serializer):
type_id = classinfo.type_id & 0xFFFFFF00 | TypeId.EXT
self._type_id_to_classinfo[type_id] = classinfo

def get_serializer(self, cls: type = None):
def get_serializer(self, cls: type):
"""
Returns
-------
Returns or create serializer for the provided class
"""
class_info = self._classes_info.get(cls)
if class_info is None:
if self.language != Lanauage.PYTHON:
raise TypeUnregisteredError(f"{cls} not registered")
class_info = self.get_or_create_classinfo(cls)
return class_info.serializer
return self.get_classinfo(cls).serializer

def get_or_create_classinfo(self, cls):
def get_classinfo(self, cls):
class_info = self._classes_info.get(cls)
if class_info is not None:
if class_info.serializer is None:
class_info.serializer = self._create_serializer(cls)
return class_info
if self.language != Language.PYTHON or self.require_registration:
raise TypeError(f"{cls} not registered")
logger.info("Class %s not registered", cls)
serializer = self._create_serializer(cls)
type_id = (
NO_CLASS_ID if type(serializer) is not PickleSerializer else PICKLE_CLASS_ID
Expand All @@ -412,132 +433,79 @@ def get_or_create_classinfo(self, cls):
)

def _create_serializer(self, cls):
if self.language != Language.PYTHON:
raise
mro = cls.__mro__
classinfo_ = self._classes_info.get(cls)
for clz in mro:
for clz in cls.__mro__:
class_info = self._classes_info.get(clz)
if (
class_info
and class_info.serializer
and class_info.serializer.support_subclass()
):
if classinfo_ is None or classinfo_.class_id == NO_CLASS_ID:
logger.info("Class %s not registered", cls)
serializer = type(class_info.serializer)(self.fury, cls)
break
else:
if dataclasses.is_dataclass(cls):
if classinfo_ is None or classinfo_.class_id == NO_CLASS_ID:
logger.info("Class %s not registered", cls)
logger.info("Class %s not registered", cls)
from pyfury import DataClassSerializer

serializer = DataClassSerializer(self.fury, cls)
elif issubclass(cls, enum.Enum):
serializer = EnumSerializer(self.fury, cls)
else:
serializer = PickleSerializer(self.fury, cls)
return serializer

def write_classinfo(self, buffer: Buffer, classinfo: ClassInfo):
def write_classinfo(self, buffer: Buffer, classinfo):
class_id = classinfo.class_id
if class_id != NO_CLASS_ID:
buffer.write_varuint32(class_id << 1)
return
buffer.write_varuint32(1)
self.write_meta_string_bytes(buffer, classinfo.namespace_bytes)
self.write_meta_string_bytes(buffer, classinfo.typename_bytes)
self._meta_string_resolver.write_meta_string_bytes(buffer, classinfo.namespace_bytes)
self._meta_string_resolver.write_meta_string_bytes(buffer, classinfo.typename_bytes)

def read_classinfo(self, buffer):
header = buffer.read_varuint32()
if header & 0b1 == 0:
class_id = header >> 1
classinfo = self._registered_id_to_class_info[class_id]
classinfo = self._type_id_to_classinfo[class_id]
if classinfo.serializer is None:
classinfo.serializer = self._create_serializer(classinfo.cls)
return classinfo
ns_metabytes = self.read_meta_string_bytes(buffer)
type_metabytes = self.read_meta_string_bytes(buffer)
ns_metabytes = self._meta_string_resolver.read_meta_string_bytes(buffer)
type_metabytes = self._meta_string_resolver.read_meta_string_bytes(buffer)
typeinfo = self._ns_type_to_classinfo.get((ns_metabytes, type_metabytes))
if typeinfo is None:
ns = ns_metabytes.decode(self._namespace_decoder)
typename = type_metabytes.decode(self._typename_decoder)
cls = load_class(ns + "#" + typename)
classinfo = self.get_or_create_classinfo(cls)
classinfo = self.get_classinfo(cls)
return classinfo

def xwrite_typeinfo(self, buffer, classinfo):
type_id = classinfo.type_id
internal_type_id = type_id & 0xFF
buffer.write_varuint32(type_id)
if TypeId.is_namespaced_type(type_id):
self.write_meta_string_bytes(buffer, classinfo.namespace_bytes)
self.write_meta_string_bytes(buffer, classinfo.typename_bytes)
if TypeId.is_namespaced_type(internal_type_id):
self._meta_string_resolver.write_meta_string_bytes(buffer, classinfo.namespace_bytes)
self._meta_string_resolver.write_meta_string_bytes(buffer, classinfo.typename_bytes)

def xread_typeinfo(self, buffer):
type_id = buffer.read_varuint32()
internal_type_id = type_id & 0xFF
if TypeId.is_namespaced_type(internal_type_id):
ns_metabytes = self.read_meta_string_bytes(buffer)
type_metabytes = self.read_meta_string_bytes(buffer)
ns_metabytes = self._meta_string_resolver.read_meta_string_bytes(buffer)
type_metabytes = self._meta_string_resolver.read_meta_string_bytes(buffer)
typeinfo = self._ns_type_to_classinfo.get((ns_metabytes, type_metabytes))
if typeinfo is None:
ns = ns_metabytes.decode(self._namespace_decoder)
typename = type_metabytes.decode(self._typename_decoder)
# TODO(chaokunyang) generate a dynamic class and serializer
# when meta share is enabled.
raise TypeUnregisteredError(f"{ns}.{typename} not registered")
return typeinfo
else:
return self._type_id_to_classinfo[type_id]

def write_meta_string_bytes(self, buffer: Buffer, metastr_bytes: MetaStringBytes):
dynamic_write_string_id = metastr_bytes.dynamic_write_string_id
if dynamic_write_string_id == DEFAULT_DYNAMIC_WRITE_STRING_ID:
dynamic_write_string_id = self._dynamic_write_string_id
metastr_bytes.dynamic_write_string_id = dynamic_write_string_id
self._dynamic_write_string_id += 1
self._dynamic_written_metastr.append(metastr_bytes)
buffer.write_varint32(metastr_bytes.length << 1)
if metastr_bytes.length <= SMALL_STRING_THRESHOLD:
# TODO(chaokunyang) support meta string encoding
buffer.write_int8(Encoding.UTF_8.value)
else:
buffer.write_int64(metastr_bytes.hashcode)
buffer.write_bytes(metastr_bytes.data)
else:
buffer.write_varint32(((dynamic_write_string_id + 1) << 1) | 1)

def read_meta_string_bytes(self, buffer: Buffer) -> MetaStringBytes:
header = buffer.read_varint32()
length = header >> 1
if header & 0b1 != 0:
return self._dynamic_id_to_metastr_list[length - 1]
if length <= SMALL_STRING_THRESHOLD:
buffer.read_int8()
if length <= 8:
v1 = buffer.read_bytes_as_int64(length)
v2 = 0
else:
v1 = buffer.read_int64()
v2 = buffer.read_bytes_as_int64(length - 8)
hashcode = v1 * 31 + v2
metastr = self._hash_to_metastring.get(hashcode)
if metastr is None:
str_bytes = buffer.get_bytes(buffer.reader_index - length, length)
metastr = MetaStringBytes(str_bytes, hashcode=hashcode)
self._hash_to_metastring[hashcode] = metastr
else:
hashcode = buffer.read_int64()
reader_index = buffer.reader_index
buffer.check_bound(reader_index, length)
buffer.reader_index = reader_index + length
metastr = self._hash_to_metastring.get(hashcode)
if metastr is None:
str_bytes = buffer.get_bytes(reader_index, length)
metastr = MetaStringBytes(str_bytes, hashcode=hashcode)
self._hash_to_metastring[hashcode] = metastr
self._dynamic_id_to_metastr_list.append(metastr)
return metastr

def reset(self):
self.reset_write()
self.reset_read()
Expand Down
Loading

0 comments on commit 5aaaf52

Please sign in to comment.