Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enable Vector database and retrieve_online_documents API #4061

Merged
merged 41 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
713768e
feat: add document store
HaoXuAI Mar 31, 2024
58d5d94
feat: add document store
HaoXuAI Mar 31, 2024
2cd73d1
feat: add document store
HaoXuAI Mar 31, 2024
d2e0a59
feat: add document store
HaoXuAI Mar 31, 2024
7079e7f
remove DocumentStore
HaoXuAI Apr 9, 2024
8c9ee97
format
HaoXuAI Apr 9, 2024
513dd39
Merge branch 'master' into feat-documentstore
HaoXuAI Apr 9, 2024
29d98cd
format
HaoXuAI Apr 9, 2024
11eb97f
format
HaoXuAI Apr 9, 2024
865baf2
format
HaoXuAI Apr 9, 2024
47cd117
format
HaoXuAI Apr 9, 2024
3f9f59f
format
HaoXuAI Apr 9, 2024
7935071
remove unused vars
HaoXuAI Apr 9, 2024
ba39f93
add test
HaoXuAI Apr 11, 2024
cf53c71
add test
HaoXuAI Apr 11, 2024
92046af
format
HaoXuAI Apr 11, 2024
d0acd2d
format
HaoXuAI Apr 11, 2024
cc45f73
format
HaoXuAI Apr 11, 2024
006b5c6
format
HaoXuAI Apr 11, 2024
6e0ba03
format
HaoXuAI Apr 11, 2024
a2302be
fix not implemented issue
HaoXuAI Apr 11, 2024
2e6fc55
fix not implemented issue
HaoXuAI Apr 11, 2024
3cbbf21
fix test
HaoXuAI Apr 11, 2024
ec32764
format
HaoXuAI Apr 11, 2024
e2d8008
format
HaoXuAI Apr 12, 2024
523d20f
format
HaoXuAI Apr 12, 2024
5cd085d
format
HaoXuAI Apr 12, 2024
795699e
format
HaoXuAI Apr 12, 2024
67b007f
format
HaoXuAI Apr 12, 2024
33b46bd
update testcontainer
HaoXuAI Apr 12, 2024
82fe5f1
format
HaoXuAI Apr 12, 2024
0618378
fix postgres integration test
HaoXuAI Apr 12, 2024
7de2016
format
HaoXuAI Apr 12, 2024
92fed1d
fix postgres test
HaoXuAI Apr 14, 2024
d4f2639
fix postgres test
HaoXuAI Apr 14, 2024
396d7de
fix postgres test
HaoXuAI Apr 14, 2024
6c38b92
fix postgres test
HaoXuAI Apr 14, 2024
f763dc9
fix postgres test
HaoXuAI Apr 14, 2024
818c055
format
HaoXuAI Apr 14, 2024
a51b555
format
HaoXuAI Apr 15, 2024
2624b22
format
HaoXuAI Apr 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ test-python-universal-postgres-offline:
test-python-universal-postgres-online:
PYTHONPATH='.' \
FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.contrib.postgres_repo_configuration \
PYTEST_PLUGINS=sdk.python.feast.infra.offline_stores.contrib.postgres_offline_store.tests \
PYTEST_PLUGINS=sdk.python.tests.integration.feature_repos.universal.online_store.postgres \
python -m pytest -n 8 --integration \
-k "not test_universal_cli and \
not test_go_feature_server and \
Expand Down
103 changes: 103 additions & 0 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1690,6 +1690,72 @@ def _get_online_features(
)
return OnlineResponse(online_features_response)

@log_exceptions_and_usage
def retrieve_online_documents(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's probably something to be said about having a configurable distance metric to let the user choose which way to get the top_k

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, there are a bunch of different algorithms/configs for Postgresql to retrieve the documents. We can support it in the future after this PR

self,
feature: str,
query: Union[str, List[float]],
top_k: int,
) -> OnlineResponse:
"""
Retrieves the top k closest document features. Note, embeddings are a subset of features.

Args:
feature: The list of document features that should be retrieved from the online document store. These features can be
specified either as a list of string document feature references or as a feature service. String feature
references must have format "feature_view:feature", e.g, "document_fv:document_embeddings".
query: The query to retrieve the closest document features for.
top_k: The number of closest document features to retrieve.
"""
return self._retrieve_online_documents(
feature=feature,
query=query,
top_k=top_k,
)

def _retrieve_online_documents(
self,
feature: str,
query: Union[str, List[float]],
top_k: int,
):
if isinstance(query, str):
raise ValueError(
"Using embedding functionality is not supported for document retrieval. Please embed the query before calling retrieve_online_documents."
)
(
requested_feature_views,
_,
) = self._get_feature_views_to_use(
features=[feature], allow_cache=True, hide_dummy_entity=False
)
requested_feature = (
feature.split(":")[1] if isinstance(feature, str) else feature
)
provider = self._get_provider()
document_features = self._retrieve_from_online_store(
provider,
requested_feature_views[0],
requested_feature,
query,
top_k,
)
document_feature_vals = [feature[2] for feature in document_features]
document_feature_distance_vals = [feature[3] for feature in document_features]
online_features_response = GetOnlineFeaturesResponse(results=[])

# TODO Refactor to better way of populating result
# TODO populate entity in the response after returning entity in document_features is supported
self._populate_result_rows_from_columnar(
online_features_response=online_features_response,
data={requested_feature: document_feature_vals},
)
self._populate_result_rows_from_columnar(
online_features_response=online_features_response,
data={"distance": document_feature_distance_vals},
)
return OnlineResponse(online_features_response)

@staticmethod
def _get_columnar_entity_values(
rowise: Optional[List[Dict[str, Any]]], columnar: Optional[Dict[str, List[Any]]]
Expand Down Expand Up @@ -1906,6 +1972,43 @@ def _read_from_online_store(
read_row_protos.append((event_timestamps, statuses, values))
return read_row_protos

def _retrieve_from_online_store(
self,
provider: Provider,
table: FeatureView,
requested_feature: str,
query: List[float],
top_k: int,
) -> List[Tuple[Timestamp, "FieldStatus.ValueType", Value, Value]]:
"""
Search and return document features from the online document store.
"""
documents = provider.retrieve_online_documents(
config=self.config,
table=table,
requested_feature=requested_feature,
query=query,
top_k=top_k,
)

read_row_protos = []
row_ts_proto = Timestamp()

for row_ts, feature_val, distance_val in documents:
# Reset timestamp to default or update if row_ts is not None
if row_ts is not None:
row_ts_proto.FromDatetime(row_ts)

if feature_val is None or distance_val is None:
feature_val = Value()
distance_val = Value()
status = FieldStatus.NOT_FOUND
else:
status = FieldStatus.PRESENT

read_row_protos.append((row_ts_proto, status, feature_val, distance_val))
return read_row_protos

@staticmethod
def _populate_response_from_feature_data(
feature_data: Iterable[
Expand Down
8 changes: 8 additions & 0 deletions sdk/python/feast/infra/key_encoding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,11 @@ def serialize_entity_key(
output.append(val_bytes)

return b"".join(output)


def get_val_str(val):
accept_value_types = ["float_list_val", "double_list_val", "int_list_val"]
for accept_type in accept_value_types:
if val.HasField(accept_type):
return str(getattr(val, accept_type).val)
return None
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from testcontainers.core.waiting_utils import wait_for_logs

from feast.data_source import DataSource
from feast.feature_logging import LoggingDestination
from feast.infra.offline_stores.contrib.postgres_offline_store.postgres import (
PostgreSQLOfflineStoreConfig,
PostgreSQLSource,
Expand Down Expand Up @@ -57,6 +58,9 @@ def postgres_container():


class PostgreSQLDataSourceCreator(DataSourceCreator, OnlineStoreCreator):
def create_logged_features_destination(self) -> LoggingDestination:
return None # type: ignore

def __init__(
self, project_name: str, fixture_request: pytest.FixtureRequest, **kwargs
):
Expand Down
97 changes: 93 additions & 4 deletions sdk/python/feast/infra/online_stores/contrib/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from collections import defaultdict
from datetime import datetime
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union

import psycopg2
import pytz
Expand All @@ -12,7 +12,7 @@

from feast import Entity
from feast.feature_view import FeatureView
from feast.infra.key_encoding_utils import serialize_entity_key
from feast.infra.key_encoding_utils import get_val_str, serialize_entity_key
from feast.infra.online_stores.online_store import OnlineStore
from feast.infra.utils.postgres.connection_utils import _get_conn, _get_connection_pool
from feast.infra.utils.postgres.postgres_config import ConnectionType, PostgreSQLConfig
Expand All @@ -25,6 +25,12 @@
class PostgreSQLOnlineStoreConfig(PostgreSQLConfig):
type: Literal["postgres"] = "postgres"

# Whether to enable the pgvector extension for vector similarity search
pgvector_enabled: Optional[bool] = False

# If pgvector is enabled, the length of the vector field
vector_len: Optional[int] = 512


class PostgreSQLOnlineStore(OnlineStore):
_conn: Optional[psycopg2._psycopg.connection] = None
Expand Down Expand Up @@ -68,11 +74,19 @@ def online_write_batch(
created_ts = _to_naive_utc(created_ts)

for feature_name, val in values.items():
val_str: Union[str, bytes]
if (
"pgvector_enabled" in config.online_config
and config.online_config["pgvector_enabled"]
):
val_str = get_val_str(val)
else:
val_str = val.SerializeToString()
insert_values.append(
(
entity_key_bin,
feature_name,
val.SerializeToString(),
val_str,
timestamp,
created_ts,
)
Expand Down Expand Up @@ -212,14 +226,20 @@ def update(

for table in tables_to_keep:
table_name = _table_id(project, table)
value_type = "BYTEA"
if (
"pgvector_enabled" in config.online_config
and config.online_config["pgvector_enabled"]
):
value_type = f'vector({config.online_config["vector_len"]})'
cur.execute(
sql.SQL(
"""
CREATE TABLE IF NOT EXISTS {}
(
entity_key BYTEA,
feature_name TEXT,
value BYTEA,
value {},
event_ts TIMESTAMPTZ,
created_ts TIMESTAMPTZ,
PRIMARY KEY(entity_key, feature_name)
Expand All @@ -228,6 +248,7 @@ def update(
"""
).format(
sql.Identifier(table_name),
sql.SQL(value_type),
sql.Identifier(f"{table_name}_ek"),
sql.Identifier(table_name),
)
Expand All @@ -251,6 +272,74 @@ def teardown(
logging.exception("Teardown failed")
raise

def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: str,
embedding: List[float],
top_k: int,
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
"""

Args:
config: Feast configuration object
table: FeatureView object as the table to search
requested_feature: The requested feature as the column to search
embedding: The query embedding to search for
top_k: The number of items to return
Returns:
List of tuples containing the event timestamp and the document feature

"""
project = config.project

# Convert the embedding to a string to be used in postgres vector search
query_embedding_str = f"[{','.join(str(el) for el in embedding)}]"

result: List[
Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]
] = []
with self._get_conn(config) as conn, conn.cursor() as cur:
table_name = _table_id(project, table)

# Search query template to find the top k items that are closest to the given embedding
# SELECT * FROM items ORDER BY embedding <-> '[3,1,2]' LIMIT 5;
cur.execute(
sql.SQL(
"""
SELECT
entity_key,
feature_name,
value,
value <-> %s as distance,
event_ts FROM {table_name}
WHERE feature_name = {feature_name}
ORDER BY distance
LIMIT {top_k};
"""
).format(
table_name=sql.Identifier(table_name),
feature_name=sql.Literal(requested_feature),
top_k=sql.Literal(top_k),
),
(query_embedding_str,),
)
rows = cur.fetchall()

for entity_key, feature_name, value, distance, event_ts in rows:
# TODO Deserialize entity_key to return the entity in response
# entity_key_proto = EntityKeyProto()
# entity_key_proto_bin = bytes(entity_key)

# TODO Convert to List[float] for value type proto
feature_value_proto = ValueProto(string_val=value)

distance_value_proto = ValueProto(float_val=distance)
result.append((event_ts, feature_value_proto, distance_value_proto))

return result


def _table_id(project: str, table: FeatureView) -> str:
return f"{project}_{table.name}"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from feast.infra.offline_stores.contrib.postgres_offline_store.tests.data_source import (
PostgreSQLDataSourceCreator,
)
from tests.integration.feature_repos.integration_test_repo_config import (
IntegrationTestRepoConfig,
)
from tests.integration.feature_repos.universal.online_store.postgres import (
PGVectorOnlineStoreCreator,
PostgresOnlineStoreCreator,
)

FULL_REPO_CONFIGS = [
IntegrationTestRepoConfig(online_store_creator=PostgreSQLDataSourceCreator),
IntegrationTestRepoConfig(
online_store="postgres", online_store_creator=PostgresOnlineStoreCreator
),
IntegrationTestRepoConfig(
online_store="pgvector", online_store_creator=PGVectorOnlineStoreCreator
),
]

AVAILABLE_ONLINE_STORES = {"pgvector": PGVectorOnlineStoreCreator}
27 changes: 27 additions & 0 deletions sdk/python/feast/infra/online_stores/online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,30 @@ def teardown(
entities: Entities whose corresponding infrastructure should be deleted.
"""
pass

def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: str,
embedding: List[float],
top_k: int,
) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]:
"""
Retrieves online feature values for the specified embeddings.

Args:
config: The config for the current feature store.
table: The feature view whose feature values should be read.
requested_feature: The name of the feature whose embeddings should be used for retrieval.
embedding: The embeddings to use for retrieval.
top_k: The number of nearest neighbors to retrieve.

Returns:
object: A list of top k closest documents to the specified embedding. Each item in the list is a tuple
where the first item is the event timestamp for the row, and the second item is a dict of feature
name to embeddings.
"""
raise NotImplementedError(
f"Online store {self.__class__.__name__} does not support online retrieval"
)
17 changes: 17 additions & 0 deletions sdk/python/feast/infra/passthrough_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,23 @@ def online_read(
)
return result

@log_exceptions_and_usage(sampler=RatioSampler(ratio=0.001))
def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: str,
query: List[float],
top_k: int,
) -> List:
set_usage_attribute("provider", self.__class__.__name__)
result = []
if self.online_store:
result = self.online_store.retrieve_online_documents(
config, table, requested_feature, query, top_k
)
return result

def ingest_df(
self,
feature_view: FeatureView,
Expand Down
Loading
Loading