Skip to content

Commit

Permalink
feat: Entity key deserialization (feast-dev#4284)
Browse files Browse the repository at this point in the history
* Add new version of serialization and desrialization

Signed-off-by: cmuhao <[email protected]>

* Add new version of serialization and desrialization

Signed-off-by: cmuhao <[email protected]>

* fix test

Signed-off-by: cmuhao <[email protected]>

* fix test

Signed-off-by: cmuhao <[email protected]>

* add test

Signed-off-by: cmuhao <[email protected]>

* add test

Signed-off-by: cmuhao <[email protected]>

* update doc

Signed-off-by: cmuhao <[email protected]>

---------

Signed-off-by: cmuhao <[email protected]>
  • Loading branch information
HaoXuAI authored Jun 18, 2024
1 parent df46cae commit 83fad15
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 5 deletions.
80 changes: 79 additions & 1 deletion sdk/python/feast/infra/key_encoding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,23 @@ def _serialize_val(
return struct.pack("<l", v.int64_val), ValueType.INT64
return struct.pack("<q", v.int64_val), ValueType.INT64
else:
raise ValueError(f"Value type not supported for Firestore: {v}")
raise ValueError(f"Value type not supported for feast feature store: {v}")


def _deserialize_value(value_type, value_bytes) -> ValueProto:
if value_type == ValueType.INT64:
value = struct.unpack("<q", value_bytes)[0]
return ValueProto(int64_val=value)
if value_type == ValueType.INT32:
value = struct.unpack("<i", value_bytes)[0]
return ValueProto(int32_val=value)
elif value_type == ValueType.STRING:
value = value_bytes.decode("utf-8")
return ValueProto(string_val=value)
elif value_type == ValueType.BYTES:
return ValueProto(bytes_val=value_bytes)
else:
raise ValueError(f"Unsupported value type: {value_type}")


def serialize_entity_key_prefix(entity_keys: List[str]) -> bytes:
Expand Down Expand Up @@ -50,6 +66,15 @@ def serialize_entity_key(
serialize to the same byte string[1].
[1] https://developers.google.com/protocol-buffers/docs/encoding
Args:
entity_key_serialization_version: version of the entity key serialization
version 1: int64 values are serialized as 4 bytes
version 2: int64 values are serialized as 8 bytes
version 3: entity_key size is added to the serialization for deserialization purposes
entity_key: EntityKeyProto
Returns: bytes of the serialized entity key
"""
sorted_keys, sorted_values = zip(
*sorted(zip(entity_key.join_keys, entity_key.entity_values))
Expand All @@ -58,6 +83,8 @@ def serialize_entity_key(
output: List[bytes] = []
for k in sorted_keys:
output.append(struct.pack("<I", ValueType.STRING))
if entity_key_serialization_version > 2:
output.append(struct.pack("<I", len(k)))
output.append(k.encode("utf8"))
for v in sorted_values:
val_bytes, value_type = _serialize_val(
Expand All @@ -74,6 +101,57 @@ def serialize_entity_key(
return b"".join(output)


def deserialize_entity_key(
serialized_entity_key: bytes, entity_key_serialization_version=3
) -> EntityKeyProto:
"""
Deserialize entity key from a bytestring. This function can only be used with entity_key_serialization_version > 2.
Args:
entity_key_serialization_version: version of the entity key serialization
serialized_entity_key: serialized entity key bytes
Returns: EntityKeyProto
"""
if entity_key_serialization_version <= 2:
raise ValueError(
"Deserialization of entity key with version <= 2 is not supported. Please use version > 2 by setting entity_key_serialization_version=3"
)
offset = 0
keys = []
values = []
while offset < len(serialized_entity_key):
key_type = struct.unpack_from("<I", serialized_entity_key, offset)[0]
offset += 4

# Read the length of the key
key_length = struct.unpack_from("<I", serialized_entity_key, offset)[0]
offset += 4

if key_type == ValueType.STRING:
key = struct.unpack_from(f"<{key_length}s", serialized_entity_key, offset)[
0
]
keys.append(key.decode("utf-8").rstrip("\x00"))
offset += key_length
else:
raise ValueError(f"Unsupported key type: {key_type}")

(value_type,) = struct.unpack_from("<I", serialized_entity_key, offset)
offset += 4

(value_length,) = struct.unpack_from("<I", serialized_entity_key, offset)
offset += 4

# Read the value based on its type and length
value_bytes = serialized_entity_key[offset : offset + value_length]
value = _deserialize_value(value_type, value_bytes)
values.append(value)
offset += value_length

return EntityKeyProto(join_keys=keys, entity_values=values)


def get_list_val_str(val):
accept_value_types = [
"float_list_val",
Expand Down
8 changes: 5 additions & 3 deletions sdk/python/feast/repo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,12 @@ class RepoConfig(FeastBaseModel):
used when writing data to the online store.
A value <= 1 uses the serialization scheme used by feast up to Feast 0.22.
A value of 2 uses a newer serialization scheme, supported as of Feast 0.23.
The main difference between the two scheme is that the serialization scheme v1 stored `long` values as `int`s,
which would result in errors trying to serialize a range of values.
v2 fixes this error, but v1 is kept around to ensure backwards compatibility - specifically the ability to read
A value of 3 uses the latest serialization scheme, supported as of Feast 0.38.
The main difference between the three schema is that
v1: the serialization scheme v1 stored `long` values as `int`s, which would result in errors trying to serialize a range of values.
v2: fixes this error, but v1 is kept around to ensure backwards compatibility - specifically the ability to read
feature values for entities that have already been written into the online store.
v3: add entity_key value length to serialized bytes to enable deserialization, which can be used in retrieval of entity_key in document retrieval.
"""

coerce_tz_aware: Optional[bool] = True
Expand Down
71 changes: 70 additions & 1 deletion sdk/python/tests/unit/infra/test_key_encoding_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import pytest

from feast.infra.key_encoding_utils import serialize_entity_key
from feast.infra.key_encoding_utils import (
_deserialize_value,
_serialize_val,
deserialize_entity_key,
serialize_entity_key,
)
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.protos.feast.types.Value_pb2 import ValueType


def test_serialize_entity_key():
Expand All @@ -28,3 +34,66 @@ def test_serialize_entity_key():
join_keys=["user"], entity_values=[ValueProto(int64_val=int(2**31))]
),
)


def test_deserialize_entity_key():
serialized_entity_key = serialize_entity_key(
EntityKeyProto(
join_keys=["user"], entity_values=[ValueProto(int64_val=int(2**15))]
),
entity_key_serialization_version=3,
)

deserialized_entity_key = deserialize_entity_key(
serialized_entity_key, entity_key_serialization_version=3
)
assert deserialized_entity_key == EntityKeyProto(
join_keys=["user"], entity_values=[ValueProto(int64_val=int(2**15))]
)


def test_serialize_value():
v, t = _serialize_val("string_val", ValueProto(string_val="test"))
assert t == ValueType.STRING
assert v == b"test"

v, t = _serialize_val("bytes_val", ValueProto(bytes_val=b"test"))
assert t == ValueType.BYTES
assert v == b"test"

v, t = _serialize_val("int32_val", ValueProto(int32_val=1))
assert t == ValueType.INT32
assert v == b"\x01\x00\x00\x00"

# default entity_key_serialization_version is 1, so the result should be 4 bytes
v, t = _serialize_val("int64_val", ValueProto(int64_val=1))
assert t == ValueType.INT64
assert v == b"\x01\x00\x00\x00"

# current entity_key_serialization_version is 2, so the result should be 8 bytes
v, t = _serialize_val(
"int64_val", ValueProto(int64_val=1), entity_key_serialization_version=2
)
assert t == ValueType.INT64
assert v == b"\x01\x00\x00\x00\x00\x00\x00\x00"

# new entity_key_serialization_version is 3, the result should be same as version 2
v, t = _serialize_val(
"int64_val", ValueProto(int64_val=1), entity_key_serialization_version=3
)
assert t == ValueType.INT64
assert v == b"\x01\x00\x00\x00\x00\x00\x00\x00"


def test_deserialize_value():
v = _deserialize_value(ValueType.STRING, b"test")
assert v.string_val == "test"

v = _deserialize_value(ValueType.BYTES, b"test")
assert v.bytes_val == b"test"

v = _deserialize_value(ValueType.INT32, b"\x01\x00\x00\x00")
assert v.int32_val == 1

v = _deserialize_value(ValueType.INT64, b"\x01\x00\x00\x00\x00\x00\x00\x00")
assert v.int64_val == 1

0 comments on commit 83fad15

Please sign in to comment.