Skip to content

Commit

Permalink
feat: Return entity key in the retrieval document api (feast-dev#4511)
Browse files Browse the repository at this point in the history
* update entity retrieval and add duckdb

Signed-off-by: cmuhao <[email protected]>

* lint

Signed-off-by: cmuhao <[email protected]>

* fix lint

Signed-off-by: cmuhao <[email protected]>

* fix lint

Signed-off-by: cmuhao <[email protected]>

* fix lint

Signed-off-by: cmuhao <[email protected]>

* fix lint

Signed-off-by: cmuhao <[email protected]>

* fix lint

Signed-off-by: cmuhao <[email protected]>

* fix lint

Signed-off-by: cmuhao <[email protected]>

* fix lint

Signed-off-by: cmuhao <[email protected]>

* fix lint

Signed-off-by: cmuhao <[email protected]>

* fix typo

Signed-off-by: cmuhao <[email protected]>

* fix typo

Signed-off-by: cmuhao <[email protected]>

* fix typo

Signed-off-by: cmuhao <[email protected]>

* fix typo

Signed-off-by: cmuhao <[email protected]>

* fix typo

Signed-off-by: cmuhao <[email protected]>

* fix typo

Signed-off-by: cmuhao <[email protected]>

* fix typo

Signed-off-by: cmuhao <[email protected]>

* fix typo

Signed-off-by: cmuhao <[email protected]>

* fix lint

Signed-off-by: cmuhao <[email protected]>

* fix test

Signed-off-by: cmuhao <[email protected]>

* fix test

Signed-off-by: cmuhao <[email protected]>

* fix test

Signed-off-by: cmuhao <[email protected]>

* fix test

Signed-off-by: cmuhao <[email protected]>

* fix test

Signed-off-by: cmuhao <[email protected]>

* fix test

Signed-off-by: cmuhao <[email protected]>

* fix test

Signed-off-by: cmuhao <[email protected]>

---------

Signed-off-by: cmuhao <[email protected]>
  • Loading branch information
HaoXuAI authored Sep 20, 2024
1 parent 4e2eacc commit 5f5caf0
Show file tree
Hide file tree
Showing 9 changed files with 149 additions and 61 deletions.
46 changes: 34 additions & 12 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@
FieldStatus,
GetOnlineFeaturesResponse,
)
from feast.protos.feast.types.EntityKey_pb2 import EntityKey
from feast.protos.feast.types.Value_pb2 import RepeatedValue, Value
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.repo_config import RepoConfig, load_repo_config
from feast.repo_contents import RepoContents
from feast.saved_dataset import SavedDataset, SavedDatasetStorage, ValidationReference
Expand Down Expand Up @@ -1666,20 +1668,29 @@ def retrieve_online_documents(
distance_metric,
)

# TODO Refactor to better way of populating result
# TODO populate entity in the response after returning entity in document_features is supported
# TODO currently not return the vector value since it is same as feature value, if embedding is supported,
# the feature value can be raw text before embedded
document_feature_vals = [feature[2] for feature in document_features]
document_feature_distance_vals = [feature[4] for feature in document_features]
entity_key_vals = [feature[1] for feature in document_features]
join_key_values: Dict[str, List[ValueProto]] = {}
for entity_key_val in entity_key_vals:
if entity_key_val is not None:
for join_key, entity_value in zip(
entity_key_val.join_keys, entity_key_val.entity_values
):
if join_key not in join_key_values:
join_key_values[join_key] = []
join_key_values[join_key].append(entity_value)

document_feature_vals = [feature[4] for feature in document_features]
document_feature_distance_vals = [feature[5] for feature in document_features]
online_features_response = GetOnlineFeaturesResponse(results=[])
utils._populate_result_rows_from_columnar(
online_features_response=online_features_response,
data={requested_feature: document_feature_vals},
)
utils._populate_result_rows_from_columnar(
online_features_response=online_features_response,
data={"distance": document_feature_distance_vals},
data={
**join_key_values,
requested_feature: document_feature_vals,
"distance": document_feature_distance_vals,
},
)
return OnlineResponse(online_features_response)

Expand All @@ -1691,7 +1702,11 @@ def _retrieve_from_online_store(
query: List[float],
top_k: int,
distance_metric: Optional[str],
) -> List[Tuple[Timestamp, "FieldStatus.ValueType", Value, Value, Value]]:
) -> List[
Tuple[
Timestamp, Optional[EntityKey], "FieldStatus.ValueType", Value, Value, Value
]
]:
"""
Search and return document features from the online document store.
"""
Expand All @@ -1707,7 +1722,7 @@ def _retrieve_from_online_store(
read_row_protos = []
row_ts_proto = Timestamp()

for row_ts, feature_val, vector_value, distance_val in documents:
for row_ts, entity_key, feature_val, vector_value, distance_val in documents:
# Reset timestamp to default or update if row_ts is not None
if row_ts is not None:
row_ts_proto.FromDatetime(row_ts)
Expand All @@ -1721,7 +1736,14 @@ def _retrieve_from_online_store(
status = FieldStatus.PRESENT

read_row_protos.append(
(row_ts_proto, status, feature_val, vector_value, distance_val)
(
row_ts_proto,
entity_key,
status,
feature_val,
vector_value,
distance_val,
)
)
return read_row_protos

Expand Down
25 changes: 14 additions & 11 deletions sdk/python/feast/infra/online_stores/contrib/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@
from elasticsearch import Elasticsearch, helpers

from feast import Entity, FeatureView, RepoConfig
from feast.infra.key_encoding_utils import get_list_val_str, serialize_entity_key
from feast.infra.key_encoding_utils import (
get_list_val_str,
serialize_entity_key,
)
from feast.infra.online_stores.online_store import OnlineStore
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.repo_config import FeastConfigBaseModel
from feast.utils import to_naive_utc
from feast.utils import _build_retrieve_online_document_record, to_naive_utc


class ElasticSearchOnlineStoreConfig(FeastConfigBaseModel):
Expand Down Expand Up @@ -224,6 +227,7 @@ def retrieve_online_documents(
) -> List[
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
Expand All @@ -232,6 +236,7 @@ def retrieve_online_documents(
result: List[
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
Expand All @@ -247,23 +252,21 @@ def retrieve_online_documents(
)
rows = response["hits"]["hits"][0:top_k]
for row in rows:
entity_key = row["_source"]["entity_key"]
feature_value = row["_source"]["feature_value"]
vector_value = row["_source"]["vector_value"]
timestamp = row["_source"]["timestamp"]
distance = row["_score"]
timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f")

feature_value_proto = ValueProto()
feature_value_proto.ParseFromString(base64.b64decode(feature_value))

vector_value_proto = ValueProto(string_val=str(vector_value))
distance_value_proto = ValueProto(float_val=distance)
result.append(
(
_build_retrieve_online_document_record(
entity_key,
base64.b64decode(feature_value),
str(vector_value),
distance,
timestamp,
feature_value_proto,
vector_value_proto,
distance_value_proto,
config.entity_key_serialization_version,
)
)
return result
37 changes: 15 additions & 22 deletions sdk/python/feast/infra/online_stores/contrib/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.repo_config import RepoConfig
from feast.utils import _build_retrieve_online_document_record

SUPPORTED_DISTANCE_METRICS_DICT = {
"cosine": "<=>",
Expand Down Expand Up @@ -360,6 +361,7 @@ def retrieve_online_documents(
) -> List[
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
Expand Down Expand Up @@ -391,12 +393,11 @@ def retrieve_online_documents(
)

distance_metric_sql = SUPPORTED_DISTANCE_METRICS_DICT[distance_metric]
# Convert the embedding to a string to be used in postgres vector search
query_embedding_str = f"[{','.join(str(el) for el in embedding)}]"

result: List[
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
Expand All @@ -415,45 +416,37 @@ def retrieve_online_documents(
feature_name,
value,
vector_value,
vector_value {distance_metric_sql} %s as distance,
vector_value {distance_metric_sql} %s::vector as distance,
event_ts FROM {table_name}
WHERE feature_name = {feature_name}
ORDER BY distance
LIMIT {top_k};
"""
).format(
distance_metric_sql=distance_metric_sql,
distance_metric_sql=sql.SQL(distance_metric_sql),
table_name=sql.Identifier(table_name),
feature_name=sql.Literal(requested_feature),
top_k=sql.Literal(top_k),
),
(query_embedding_str,),
(embedding,),
)
rows = cur.fetchall()

for (
entity_key,
feature_name,
value,
_,
feature_val,
vector_value,
distance,
distance_val,
event_ts,
) in rows:
# TODO Deserialize entity_key to return the entity in response
# entity_key_proto = EntityKeyProto()
# entity_key_proto_bin = bytes(entity_key)

feature_value_proto = ValueProto()
feature_value_proto.ParseFromString(bytes(value))

vector_value_proto = ValueProto(string_val=vector_value)
distance_value_proto = ValueProto(float_val=distance)
result.append(
(
_build_retrieve_online_document_record(
entity_key,
feature_val,
vector_value,
distance_val,
event_ts,
feature_value_proto,
vector_value_proto,
distance_value_proto,
config.entity_key_serialization_version,
)
)

Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/infra/online_stores/online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def retrieve_online_documents(
) -> List[
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
Expand Down
22 changes: 9 additions & 13 deletions sdk/python/feast/infra/online_stores/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@
from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto
from feast.protos.feast.core.SqliteTable_pb2 import SqliteTable as SqliteTableProto
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import FloatList as FloatListProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.utils import to_naive_utc
from feast.utils import _build_retrieve_online_document_record, to_naive_utc


class SqliteOnlineStoreConfig(FeastConfigBaseModel):
Expand Down Expand Up @@ -303,6 +302,7 @@ def retrieve_online_documents(
) -> List[
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
Expand Down Expand Up @@ -385,26 +385,22 @@ def retrieve_online_documents(
result: List[
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
]
] = []

for entity_key, _, string_value, distance, event_ts in rows:
feature_value_proto = ValueProto()
feature_value_proto.ParseFromString(string_value if string_value else b"")
vector_value_proto = ValueProto(
float_list_val=FloatListProto(val=embedding)
)
distance_value_proto = ValueProto(float_val=distance)

result.append(
(
_build_retrieve_online_document_record(
entity_key,
string_value if string_value else b"",
embedding,
distance,
event_ts,
feature_value_proto,
vector_value_proto,
distance_value_proto,
config.entity_key_serialization_version,
)
)

Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/infra/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def retrieve_online_documents(
) -> List[
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
Expand Down
51 changes: 50 additions & 1 deletion sdk/python/feast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@
FeatureViewNotFoundException,
RequestDataNotFoundInEntityRowsException,
)
from feast.infra.key_encoding_utils import deserialize_entity_key
from feast.protos.feast.serving.ServingService_pb2 import (
FieldStatus,
GetOnlineFeaturesResponse,
)
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import FloatList as FloatListProto
from feast.protos.feast.types.Value_pb2 import RepeatedValue as RepeatedValueProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.type_map import python_values_to_proto_values
Expand All @@ -49,7 +51,6 @@
from feast.feature_view import FeatureView
from feast.on_demand_feature_view import OnDemandFeatureView


APPLICATION_NAME = "feast-dev/feast"
USER_AGENT = "{}/{}".format(APPLICATION_NAME, get_version())

Expand Down Expand Up @@ -1050,3 +1051,51 @@ def tags_str_to_dict(tags: str = "") -> dict[str, str]:

def _utc_now() -> datetime:
return datetime.now(tz=timezone.utc)


def _build_retrieve_online_document_record(
entity_key: Union[str, bytes],
feature_value: Union[str, bytes],
vector_value: Union[str, List[float]],
distance_value: float,
event_timestamp: datetime,
entity_key_serialization_version: int,
) -> Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[ValueProto],
Optional[ValueProto],
Optional[ValueProto],
]:
if entity_key_serialization_version < 3:
entity_key_proto = None
else:
if isinstance(entity_key, str):
entity_key_proto_bin = entity_key.encode("utf-8")
else:
entity_key_proto_bin = entity_key
entity_key_proto = deserialize_entity_key(
entity_key_proto_bin,
entity_key_serialization_version=entity_key_serialization_version,
)

feature_value_proto = ValueProto()

if isinstance(feature_value, str):
feature_value_proto.ParseFromString(feature_value.encode("utf-8"))
else:
feature_value_proto.ParseFromString(feature_value)

if isinstance(vector_value, str):
vector_value_proto = ValueProto(string_val=vector_value)
else:
vector_value_proto = ValueProto(float_list_val=FloatListProto(val=vector_value))

distance_value_proto = ValueProto(float_val=distance_value)
return (
event_timestamp,
entity_key_proto,
feature_value_proto,
vector_value_proto,
distance_value_proto,
)
Loading

0 comments on commit 5f5caf0

Please sign in to comment.