From 269055e25178956715c163547c9f0a33a5892a75 Mon Sep 17 00:00:00 2001 From: Achal Shah Date: Tue, 7 Jun 2022 09:21:25 -0700 Subject: [PATCH] fix: Add columns for user metadata in the tables (#2760) * fix: Add columns for user metadata in the tables Signed-off-by: Achal Shah * registry -> base registry Signed-off-by: Achal Shah * metadata methods Signed-off-by: Achal Shah * metadata methods Signed-off-by: Achal Shah * tests Signed-off-by: Achal Shah * one more test assert Signed-off-by: Achal Shah * cr update Signed-off-by: Achal Shah --- sdk/python/feast/diff/registry_diff.py | 11 +- sdk/python/feast/feature_logging.py | 8 +- sdk/python/feast/feature_store.py | 19 +- .../feast/infra/offline_stores/bigquery.py | 6 +- sdk/python/feast/infra/offline_stores/file.py | 6 +- .../infra/offline_stores/offline_store.py | 6 +- .../infra/offline_stores/offline_utils.py | 8 +- .../feast/infra/offline_stores/redshift.py | 6 +- .../feast/infra/offline_stores/snowflake.py | 6 +- .../feast/infra/passthrough_provider.py | 10 +- sdk/python/feast/infra/provider.py | 10 +- sdk/python/feast/infra/registry_stores/sql.py | 233 +++++++++++++++--- sdk/python/feast/registry.py | 170 +++++++++++-- sdk/python/tests/foo_provider.py | 10 +- .../registration/test_sql_registry.py | 10 + 15 files changed, 400 insertions(+), 119 deletions(-) diff --git a/sdk/python/feast/diff/registry_diff.py b/sdk/python/feast/diff/registry_diff.py index 6b38a190fe..95b106867c 100644 --- a/sdk/python/feast/diff/registry_diff.py +++ b/sdk/python/feast/diff/registry_diff.py @@ -26,7 +26,7 @@ from feast.protos.feast.core.ValidationProfile_pb2 import ( ValidationReference as ValidationReferenceProto, ) -from feast.registry import FEAST_OBJECT_TYPES, FeastObjectType, Registry +from feast.registry import FEAST_OBJECT_TYPES, BaseRegistry, FeastObjectType from feast.repo_contents import RepoContents @@ -161,7 +161,7 @@ def diff_registry_objects( def extract_objects_for_keep_delete_update_add( - registry: Registry, current_project: str, desired_repo_contents: RepoContents, + registry: BaseRegistry, current_project: str, desired_repo_contents: RepoContents, ) -> Tuple[ Dict[FeastObjectType, Set[FeastObject]], Dict[FeastObjectType, Set[FeastObject]], @@ -208,7 +208,7 @@ def extract_objects_for_keep_delete_update_add( def diff_between( - registry: Registry, current_project: str, desired_repo_contents: RepoContents, + registry: BaseRegistry, current_project: str, desired_repo_contents: RepoContents, ) -> RegistryDiff: """ Returns the difference between the current and desired repo states. @@ -267,7 +267,10 @@ def diff_between( def apply_diff_to_registry( - registry: Registry, registry_diff: RegistryDiff, project: str, commit: bool = True + registry: BaseRegistry, + registry_diff: RegistryDiff, + project: str, + commit: bool = True, ): """ Applies the given diff to the given Feast project in the registry. diff --git a/sdk/python/feast/feature_logging.py b/sdk/python/feast/feature_logging.py index 04f30ab81a..275bde72ec 100644 --- a/sdk/python/feast/feature_logging.py +++ b/sdk/python/feast/feature_logging.py @@ -17,8 +17,8 @@ ) if TYPE_CHECKING: - from feast import FeatureService - from feast.registry import Registry + from feast.feature_service import FeatureService + from feast.registry import BaseRegistry REQUEST_ID_FIELD = "__request_id" @@ -33,7 +33,7 @@ class LoggingSource: """ @abc.abstractmethod - def get_schema(self, registry: "Registry") -> pa.Schema: + def get_schema(self, registry: "BaseRegistry") -> pa.Schema: """ Generate schema for logs destination. """ raise NotImplementedError @@ -48,7 +48,7 @@ def __init__(self, feature_service: "FeatureService", project: str): self._feature_service = feature_service self._project = project - def get_schema(self, registry: "Registry") -> pa.Schema: + def get_schema(self, registry: "BaseRegistry") -> pa.Schema: fields: Dict[str, pa.DataType] = {} for projection in self._feature_service.feature_view_projections: diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 7824db4a39..7a5a8299eb 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -74,14 +74,13 @@ from feast.infra.registry_stores.sql import SqlRegistry from feast.on_demand_feature_view import OnDemandFeatureView from feast.online_response import OnlineResponse -from feast.protos.feast.core.InfraObject_pb2 import Infra as InfraProto from feast.protos.feast.serving.ServingService_pb2 import ( FieldStatus, GetOnlineFeaturesResponse, ) from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import RepeatedValue, Value -from feast.registry import Registry +from feast.registry import BaseRegistry, Registry from feast.repo_config import RepoConfig, load_repo_config from feast.repo_contents import RepoContents from feast.request_feature_view import RequestFeatureView @@ -113,7 +112,7 @@ class FeatureStore: config: RepoConfig repo_path: Path - _registry: Registry + _registry: BaseRegistry _provider: Provider _go_server: "EmbeddedOnlineFeatureServer" @@ -142,8 +141,9 @@ def __init__( if registry_config.registry_type == "sql": self._registry = SqlRegistry(registry_config, None) else: - self._registry = Registry(registry_config, repo_path=self.repo_path) - self._registry._initialize_registry() + r = Registry(registry_config, repo_path=self.repo_path) + r._initialize_registry() + self._registry = r self._provider = get_provider(self.config, self.repo_path) self._go_server = None @@ -153,7 +153,7 @@ def version(self) -> str: return get_version() @property - def registry(self) -> Registry: + def registry(self) -> BaseRegistry: """Gets the registry of this feature store.""" return self._registry @@ -644,12 +644,7 @@ def _plan( # Compute the desired difference between the current infra, as stored in the registry, # and the desired infra. self._registry.refresh() - current_infra_proto = ( - self._registry.cached_registry_proto.infra.__deepcopy__() - if hasattr(self._registry, "cached_registry_proto") - and self._registry.cached_registry_proto - else InfraProto() - ) + current_infra_proto = self._registry.proto().infra.__deepcopy__() desired_registry_proto = desired_repo_contents.to_registry_proto() new_infra = self._provider.plan_infra(self.config, desired_registry_proto) new_infra_proto = new_infra.to_proto() diff --git a/sdk/python/feast/infra/offline_stores/bigquery.py b/sdk/python/feast/infra/offline_stores/bigquery.py index e9d8bdccbf..f095caef9b 100644 --- a/sdk/python/feast/infra/offline_stores/bigquery.py +++ b/sdk/python/feast/infra/offline_stores/bigquery.py @@ -39,7 +39,7 @@ RetrievalMetadata, ) from feast.on_demand_feature_view import OnDemandFeatureView -from feast.registry import Registry +from feast.registry import BaseRegistry from feast.repo_config import FeastConfigBaseModel, RepoConfig from ...saved_dataset import SavedDatasetStorage @@ -169,7 +169,7 @@ def get_historical_features( feature_views: List[FeatureView], feature_refs: List[str], entity_df: Union[pd.DataFrame, str], - registry: Registry, + registry: BaseRegistry, project: str, full_feature_names: bool = False, ) -> RetrievalJob: @@ -262,7 +262,7 @@ def write_logged_features( data: Union[pyarrow.Table, Path], source: LoggingSource, logging_config: LoggingConfig, - registry: Registry, + registry: BaseRegistry, ): destination = logging_config.destination assert isinstance(destination, BigQueryLoggingDestination) diff --git a/sdk/python/feast/infra/offline_stores/file.py b/sdk/python/feast/infra/offline_stores/file.py index c675751739..809dbc12a8 100644 --- a/sdk/python/feast/infra/offline_stores/file.py +++ b/sdk/python/feast/infra/offline_stores/file.py @@ -32,7 +32,7 @@ _get_requested_feature_views_to_features_dict, _run_dask_field_mapping, ) -from feast.registry import Registry +from feast.registry import BaseRegistry from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.saved_dataset import SavedDatasetStorage from feast.usage import log_exceptions_and_usage @@ -113,7 +113,7 @@ def get_historical_features( feature_views: List[FeatureView], feature_refs: List[str], entity_df: Union[pd.DataFrame, str], - registry: Registry, + registry: BaseRegistry, project: str, full_feature_names: bool = False, ) -> RetrievalJob: @@ -380,7 +380,7 @@ def write_logged_features( data: Union[pyarrow.Table, Path], source: LoggingSource, logging_config: LoggingConfig, - registry: Registry, + registry: BaseRegistry, ): destination = logging_config.destination assert isinstance(destination, FileLoggingDestination) diff --git a/sdk/python/feast/infra/offline_stores/offline_store.py b/sdk/python/feast/infra/offline_stores/offline_store.py index 2996a1ed59..cc06ad54c1 100644 --- a/sdk/python/feast/infra/offline_stores/offline_store.py +++ b/sdk/python/feast/infra/offline_stores/offline_store.py @@ -25,7 +25,7 @@ from feast.feature_logging import LoggingConfig, LoggingSource from feast.feature_view import FeatureView from feast.on_demand_feature_view import OnDemandFeatureView -from feast.registry import Registry +from feast.registry import BaseRegistry from feast.repo_config import RepoConfig from feast.saved_dataset import SavedDatasetStorage @@ -211,7 +211,7 @@ def get_historical_features( feature_views: List[FeatureView], feature_refs: List[str], entity_df: Union[pd.DataFrame, str], - registry: Registry, + registry: BaseRegistry, project: str, full_feature_names: bool = False, ) -> RetrievalJob: @@ -252,7 +252,7 @@ def write_logged_features( data: Union[pyarrow.Table, Path], source: LoggingSource, logging_config: LoggingConfig, - registry: Registry, + registry: BaseRegistry, ): """ Write logged features to a specified destination (taken from logging_config) in the offline store. diff --git a/sdk/python/feast/infra/offline_stores/offline_utils.py b/sdk/python/feast/infra/offline_stores/offline_utils.py index dad0ca5b78..893180f19f 100644 --- a/sdk/python/feast/infra/offline_stores/offline_utils.py +++ b/sdk/python/feast/infra/offline_stores/offline_utils.py @@ -8,7 +8,6 @@ from jinja2 import BaseLoader, Environment from pandas import Timestamp -import feast from feast.errors import ( EntityTimestampInferenceException, FeastEntityDFMissingColumnsError, @@ -17,7 +16,7 @@ from feast.importer import import_class from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.provider import _get_requested_feature_views_to_features_dict -from feast.registry import Registry +from feast.registry import BaseRegistry from feast.utils import to_naive_utc DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL = "event_timestamp" @@ -55,8 +54,9 @@ def assert_expected_columns_in_entity_df( raise FeastEntityDFMissingColumnsError(expected_columns, missing_keys) +# TODO: Remove project and registry from the interface and call sites. def get_expected_join_keys( - project: str, feature_views: List["feast.FeatureView"], registry: Registry + project: str, feature_views: List[FeatureView], registry: BaseRegistry ) -> Set[str]: join_keys = set() for feature_view in feature_views: @@ -95,7 +95,7 @@ class FeatureViewQueryContext: def get_feature_view_query_context( feature_refs: List[str], feature_views: List[FeatureView], - registry: Registry, + registry: BaseRegistry, project: str, entity_df_timestamp_range: Tuple[datetime, datetime], ) -> List[FeatureViewQueryContext]: diff --git a/sdk/python/feast/infra/offline_stores/redshift.py b/sdk/python/feast/infra/offline_stores/redshift.py index 74ba83cb00..a5483e8140 100644 --- a/sdk/python/feast/infra/offline_stores/redshift.py +++ b/sdk/python/feast/infra/offline_stores/redshift.py @@ -38,7 +38,7 @@ SavedDatasetRedshiftStorage, ) from feast.infra.utils import aws_utils -from feast.registry import Registry +from feast.registry import BaseRegistry from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.saved_dataset import SavedDatasetStorage from feast.usage import log_exceptions_and_usage @@ -176,7 +176,7 @@ def get_historical_features( feature_views: List[FeatureView], feature_refs: List[str], entity_df: Union[pd.DataFrame, str], - registry: Registry, + registry: BaseRegistry, project: str, full_feature_names: bool = False, ) -> RetrievalJob: @@ -269,7 +269,7 @@ def write_logged_features( data: Union[pyarrow.Table, Path], source: LoggingSource, logging_config: LoggingConfig, - registry: Registry, + registry: BaseRegistry, ): destination = logging_config.destination assert isinstance(destination, RedshiftLoggingDestination) diff --git a/sdk/python/feast/infra/offline_stores/snowflake.py b/sdk/python/feast/infra/offline_stores/snowflake.py index d39acc9f08..73c785eecf 100644 --- a/sdk/python/feast/infra/offline_stores/snowflake.py +++ b/sdk/python/feast/infra/offline_stores/snowflake.py @@ -44,7 +44,7 @@ write_pandas, write_parquet, ) -from feast.registry import Registry +from feast.registry import BaseRegistry from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.saved_dataset import SavedDatasetStorage from feast.usage import log_exceptions_and_usage @@ -206,7 +206,7 @@ def get_historical_features( feature_views: List[FeatureView], feature_refs: List[str], entity_df: Union[pd.DataFrame, str], - registry: Registry, + registry: BaseRegistry, project: str, full_feature_names: bool = False, ) -> RetrievalJob: @@ -284,7 +284,7 @@ def write_logged_features( data: Union[pyarrow.Table, Path], source: LoggingSource, logging_config: LoggingConfig, - registry: Registry, + registry: BaseRegistry, ): assert isinstance(logging_config.destination, SnowflakeLoggingDestination) diff --git a/sdk/python/feast/infra/passthrough_provider.py b/sdk/python/feast/infra/passthrough_provider.py index a53788dc85..f04d03eb99 100644 --- a/sdk/python/feast/infra/passthrough_provider.py +++ b/sdk/python/feast/infra/passthrough_provider.py @@ -21,7 +21,7 @@ ) from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto -from feast.registry import Registry +from feast.registry import BaseRegistry from feast.repo_config import RepoConfig from feast.saved_dataset import SavedDataset from feast.usage import RatioSampler, log_exceptions_and_usage, set_usage_attribute @@ -138,7 +138,7 @@ def materialize_single_feature_view( feature_view: FeatureView, start_date: datetime, end_date: datetime, - registry: Registry, + registry: BaseRegistry, project: str, tqdm_builder: Callable[[int], tqdm], ) -> None: @@ -194,7 +194,7 @@ def get_historical_features( feature_views: List[FeatureView], feature_refs: List[str], entity_df: Union[pandas.DataFrame, str], - registry: Registry, + registry: BaseRegistry, project: str, full_feature_names: bool, ) -> RetrievalJob: @@ -240,7 +240,7 @@ def write_feature_service_logs( feature_service: FeatureService, logs: Union[pyarrow.Table, str], config: RepoConfig, - registry: Registry, + registry: BaseRegistry, ): assert ( feature_service.logging_config is not None @@ -260,7 +260,7 @@ def retrieve_feature_service_logs( start_date: datetime, end_date: datetime, config: RepoConfig, - registry: Registry, + registry: BaseRegistry, ) -> RetrievalJob: assert ( feature_service.logging_config is not None diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index cd82b7d416..e6c3da86a5 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -19,7 +19,7 @@ from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto -from feast.registry import Registry +from feast.registry import BaseRegistry from feast.repo_config import RepoConfig from feast.saved_dataset import SavedDataset from feast.type_map import python_values_to_proto_values @@ -133,7 +133,7 @@ def materialize_single_feature_view( feature_view: FeatureView, start_date: datetime, end_date: datetime, - registry: Registry, + registry: BaseRegistry, project: str, tqdm_builder: Callable[[int], tqdm], ) -> None: @@ -146,7 +146,7 @@ def get_historical_features( feature_views: List[FeatureView], feature_refs: List[str], entity_df: Union[pandas.DataFrame, str], - registry: Registry, + registry: BaseRegistry, project: str, full_feature_names: bool, ) -> RetrievalJob: @@ -192,7 +192,7 @@ def write_feature_service_logs( feature_service: FeatureService, logs: Union[pyarrow.Table, Path], config: RepoConfig, - registry: Registry, + registry: BaseRegistry, ): """ Write features and entities logged by a feature server to an offline store. @@ -211,7 +211,7 @@ def retrieve_feature_service_logs( start_date: datetime, end_date: datetime, config: RepoConfig, - registry: Registry, + registry: BaseRegistry, ) -> RetrievalJob: """ Read logged features from an offline store for a given time window [from, to). diff --git a/sdk/python/feast/infra/registry_stores/sql.py b/sdk/python/feast/infra/registry_stores/sql.py index 1a45dec68a..d34bb2fa8b 100644 --- a/sdk/python/feast/infra/registry_stores/sql.py +++ b/sdk/python/feast/infra/registry_stores/sql.py @@ -1,7 +1,7 @@ from datetime import datetime from pathlib import Path from threading import Lock -from typing import List, Optional +from typing import Any, List, Optional from sqlalchemy import ( # type: ignore BigInteger, @@ -31,6 +31,7 @@ ) from feast.feature_service import FeatureService from feast.feature_view import FeatureView +from feast.infra.infra_object import Infra from feast.on_demand_feature_view import OnDemandFeatureView from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto from feast.protos.feast.core.Entity_pb2 import Entity as EntityProto @@ -41,10 +42,14 @@ from feast.protos.feast.core.OnDemandFeatureView_pb2 import ( OnDemandFeatureView as OnDemandFeatureViewProto, ) +from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto from feast.protos.feast.core.RequestFeatureView_pb2 import ( RequestFeatureView as RequestFeatureViewProto, ) from feast.protos.feast.core.SavedDataset_pb2 import SavedDataset as SavedDatasetProto +from feast.protos.feast.core.StreamFeatureView_pb2 import ( + StreamFeatureView as StreamFeatureViewProto, +) from feast.protos.feast.core.ValidationProfile_pb2 import ( ValidationReference as ValidationReferenceProto, ) @@ -79,6 +84,7 @@ Column("last_updated_timestamp", BigInteger, nullable=False), Column("materialized_intervals", LargeBinary, nullable=True), Column("feature_view_proto", LargeBinary, nullable=False), + Column("user_metadata", LargeBinary, nullable=True), ) request_feature_views = Table( @@ -87,6 +93,16 @@ Column("feature_view_name", String(50), primary_key=True), Column("last_updated_timestamp", BigInteger, nullable=False), Column("feature_view_proto", LargeBinary, nullable=False), + Column("user_metadata", LargeBinary, nullable=True), +) + +stream_feature_views = Table( + "stream_feature_views", + metadata, + Column("feature_view_name", String(50), primary_key=True), + Column("last_updated_timestamp", BigInteger, nullable=False), + Column("feature_view_proto", LargeBinary, nullable=False), + Column("user_metadata", LargeBinary, nullable=True), ) on_demand_feature_views = Table( @@ -95,6 +111,7 @@ Column("feature_view_name", String(50), primary_key=True), Column("last_updated_timestamp", BigInteger, nullable=False), Column("feature_view_proto", LargeBinary, nullable=False), + Column("user_metadata", LargeBinary, nullable=True), ) feature_services = Table( @@ -153,10 +170,29 @@ def teardown(self): def refresh(self): pass + def get_stream_feature_view( + self, name: str, project: str, allow_cache: bool = False + ): + return self._get_object( + stream_feature_views, + name, + project, + StreamFeatureViewProto, + StreamFeatureView, + "feature_view_name", + "feature_view_proto", + FeatureViewNotFoundException, + ) + def list_stream_feature_views( self, project: str, allow_cache: bool = False ) -> List[StreamFeatureView]: - return [] + return self._list_objects( + stream_feature_views, + StreamFeatureViewProto, + StreamFeatureView, + "feature_view_proto", + ) def apply_entity(self, entity: Entity, project: str, commit: bool = True): return self._apply_object(entities, "entity_name", entity, "entity_proto") @@ -201,6 +237,18 @@ def get_on_demand_feature_view( FeatureViewNotFoundException, ) + def get_request_feature_view(self, name: str, project: str): + return self._get_object( + request_feature_views, + name, + project, + RequestFeatureViewProto, + RequestFeatureView, + "feature_view_name", + "feature_view_proto", + FeatureViewNotFoundException, + ) + def get_feature_service( self, name: str, project: str, allow_cache: bool = False ) -> FeatureService: @@ -247,41 +295,46 @@ def list_entities(self, project: str, allow_cache: bool = False) -> List[Entity] return self._list_objects(entities, EntityProto, Entity, "entity_proto") def delete_entity(self, name: str, project: str, commit: bool = True): - with self.engine.connect() as conn: - stmt = delete(entities).where(entities.c.entity_name == name) - rows = conn.execute(stmt) - if rows.rowcount < 1: - raise EntityNotFoundException(name, project) + return self._delete_object( + entities, name, project, "entity_name", EntityNotFoundException + ) def delete_feature_view(self, name: str, project: str, commit: bool = True): deleted_count = 0 - for table in {feature_views, request_feature_views, on_demand_feature_views}: - with self.engine.connect() as conn: - stmt = delete(table).where(table.c.feature_view_name == name) - rows = conn.execute(stmt) - deleted_count += rows.rowcount + for table in { + feature_views, + request_feature_views, + on_demand_feature_views, + stream_feature_views, + }: + deleted_count += self._delete_object( + table, name, project, "feature_view_name", None + ) if deleted_count == 0: raise FeatureViewNotFoundException(name, project) def delete_feature_service(self, name: str, project: str, commit: bool = True): - with self.engine.connect() as conn: - stmt = delete(feature_services).where( - feature_services.c.feature_service_name == name - ) - rows = conn.execute(stmt) - if rows.rowcount < 1: - raise FeatureServiceNotFoundException(name, project) + return self._delete_object( + feature_services, + name, + project, + "feature_service_name", + FeatureServiceNotFoundException, + ) def get_data_source( self, name: str, project: str, allow_cache: bool = False ) -> DataSource: - with self.engine.connect() as conn: - stmt = select(data_sources).where(data_sources.c.entity_name == name) - row = conn.execute(stmt).first() - if row: - ds_proto = DataSourceProto.FromString(row["data_source_proto"]) - return DataSource.from_proto(ds_proto) - raise DataSourceObjectNotFoundException(name, project=project) + return self._get_object( + data_sources, + name, + project, + DataSourceProto, + DataSource, + "data_source_name", + "data_source_proto", + DataSourceObjectNotFoundException, + ) def list_data_sources( self, project: str, allow_cache: bool = False @@ -300,8 +353,9 @@ def apply_data_source( def apply_feature_view( self, feature_view: BaseFeatureView, project: str, commit: bool = True ): - # TODO(achals): Stream feature views need to be supported. - if isinstance(feature_view, FeatureView): + if isinstance(feature_view, StreamFeatureView): + fv_table = stream_feature_views + elif isinstance(feature_view, FeatureView): fv_table = feature_views elif isinstance(feature_view, OnDemandFeatureView): fv_table = on_demand_feature_views @@ -406,8 +460,100 @@ def apply_materialization( pass def delete_validation_reference(self, name: str, project: str, commit: bool = True): + self._delete_object( + validation_references, + name, + project, + "validation_reference_name", + ValidationReferenceNotFound, + ) + + def update_infra(self, infra: Infra, project: str, commit: bool = True): + pass + + def get_infra(self, project: str, allow_cache: bool = False) -> Infra: pass + def apply_user_metadata( + self, + project: str, + feature_view: BaseFeatureView, + metadata_bytes: Optional[bytes], + ): + if isinstance(feature_view, StreamFeatureView): + table = stream_feature_views + elif isinstance(feature_view, FeatureView): + table = feature_views + elif isinstance(feature_view, OnDemandFeatureView): + table = on_demand_feature_views + elif isinstance(feature_view, RequestFeatureView): + table = request_feature_views + else: + raise ValueError(f"Unexpected feature view type: {type(feature_view)}") + + name = feature_view.name + with self.engine.connect() as conn: + stmt = select(table).where(getattr(table.c, "feature_view_name") == name) + row = conn.execute(stmt).first() + update_datetime = datetime.utcnow() + update_time = int(update_datetime.timestamp()) + if row: + values = { + "user_metadata": metadata_bytes, + "last_updated_timestamp": update_time, + } + update_stmt = ( + update(table) + .where(getattr(table.c, "feature_view_name") == name) + .values(values,) + ) + conn.execute(update_stmt) + else: + raise FeatureViewNotFoundException(feature_view.name, project=project) + + def get_user_metadata( + self, project: str, feature_view: BaseFeatureView + ) -> Optional[bytes]: + if isinstance(feature_view, StreamFeatureView): + table = stream_feature_views + elif isinstance(feature_view, FeatureView): + table = feature_views + elif isinstance(feature_view, OnDemandFeatureView): + table = on_demand_feature_views + elif isinstance(feature_view, RequestFeatureView): + table = request_feature_views + else: + raise ValueError(f"Unexpected feature view type: {type(feature_view)}") + + name = feature_view.name + with self.engine.connect() as conn: + stmt = select(table).where(getattr(table.c, "feature_view_name") == name) + row = conn.execute(stmt).first() + if row: + return row["user_metadata"] + else: + raise FeatureViewNotFoundException(feature_view.name, project=project) + + def proto(self) -> RegistryProto: + r = RegistryProto() + project = "" + # TODO(achal): Support Infra object, and last_updated_timestamp. + for lister, registry_proto_field in [ + (self.list_entities, r.entities), + (self.list_feature_views, r.feature_views), + (self.list_data_sources, r.data_sources), + (self.list_on_demand_feature_views, r.on_demand_feature_views), + (self.list_request_feature_views, r.request_feature_views), + (self.list_stream_feature_views, r.stream_feature_views), + (self.list_feature_services, r.feature_services), + (self.list_saved_datasets, r.saved_datasets), + (self.list_validation_references, r.validation_references), + ]: + objs: List[Any] = lister(project) # type: ignore + registry_proto_field.extend([obj.to_proto() for obj in objs]) + + return r + def commit(self): pass @@ -422,6 +568,7 @@ def _apply_object( update_time = int(update_datetime.timestamp()) if hasattr(obj, "last_updated_timestamp"): obj.last_updated_timestamp = update_datetime + if row: values = { proto_field_name: obj.to_proto().SerializeToString(), @@ -442,18 +589,13 @@ def _apply_object( insert_stmt = insert(table).values(values,) conn.execute(insert_stmt) - def _list_objects(self, table, proto_class, python_class, proto_field_name): + def _delete_object(self, table, name, project, id_field_name, not_found_exception): with self.engine.connect() as conn: - stmt = select(table) - rows = conn.execute(stmt).all() - if rows: - return [ - python_class.from_proto( - proto_class.FromString(row[proto_field_name]) - ) - for row in rows - ] - return [] + stmt = delete(table).where(getattr(table.c, id_field_name) == name) + rows = conn.execute(stmt) + if rows.rowcount < 1 and not_found_exception: + raise not_found_exception(name, project) + return rows.rowcount def _get_object( self, @@ -473,3 +615,16 @@ def _get_object( _proto = proto_class.FromString(row[proto_field_name]) return python_class.from_proto(_proto) raise not_found_exception(name, project) + + def _list_objects(self, table, proto_class, python_class, proto_field_name): + with self.engine.connect() as conn: + stmt = select(table) + rows = conn.execute(stmt).all() + if rows: + return [ + python_class.from_proto( + proto_class.FromString(row[proto_field_name]) + ) + for row in rows + ] + return [] diff --git a/sdk/python/feast/registry.py b/sdk/python/feast/registry.py index fe37aa8dc2..e993533c8b 100644 --- a/sdk/python/feast/registry.py +++ b/sdk/python/feast/registry.py @@ -84,7 +84,7 @@ class FeastObjectType(Enum): @staticmethod def get_objects_from_registry( - registry: "Registry", project: str + registry: "BaseRegistry", project: str ) -> Dict["FeastObjectType", List[Any]]: return { FeastObjectType.DATA_SOURCE: registry.list_data_sources(project=project), @@ -337,9 +337,22 @@ def delete_feature_view(self, name: str, project: str, commit: bool = True): """ # stream feature view operations - # TODO: Needs to be implemented. - # def get_stream_feature_view(self): - # ... + @abstractmethod + def get_stream_feature_view( + self, name: str, project: str, allow_cache: bool = False + ): + """ + Retrieves a stream feature view. + + Args: + name: Name of stream feature view + project: Feast project that this feature view belongs to + allow_cache: Allow returning feature view from the cached registry + + Returns: + Returns either the specified feature view, or raises an exception if + none is found + """ @abstractmethod def list_stream_feature_views( @@ -423,10 +436,20 @@ def list_feature_views( """ # request feature view operations - # TODO: Needs to be implemented. - # @abstractmethod - # def get_request_feature_view(self, name: str, project: str): - # ... + @abstractmethod + def get_request_feature_view(self, name: str, project: str) -> RequestFeatureView: + """ + Retrieves a request feature view. + + Args: + name: Name of request feature view + project: Feast project that this feature view belongs to + allow_cache: Allow returning feature view from the cached registry + + Returns: + Returns either the specified feature view, or raises an exception if + none is found + """ @abstractmethod def list_request_feature_views( @@ -494,22 +517,19 @@ def get_saved_dataset( none is found """ - # TODO: Needs to be implemented. - # def delete_saved_dataset( - # self, name: str, project: str, allow_cache: bool = False - # ): - # """ - # Retrieves a saved dataset. - # - # Args: - # name: Name of dataset - # project: Feast project that this dataset belongs to - # allow_cache: Whether to allow returning this dataset from a cached registry - # - # Returns: - # Returns either the specified SavedDataset, or raises an exception if - # none is found - # """ + def delete_saved_dataset(self, name: str, project: str, allow_cache: bool = False): + """ + Delete a saved dataset. + + Args: + name: Name of dataset + project: Feast project that this dataset belongs to + allow_cache: Whether to allow returning this dataset from a cached registry + + Returns: + Returns either the specified SavedDataset, or raises an exception if + none is found + """ @abstractmethod def list_saved_datasets( @@ -572,8 +592,68 @@ def get_validation_reference( """ # TODO: Needs to be implemented. - # def list_validation_references(self): - # ... + def list_validation_references( + self, project: str, allow_cache: bool = False + ) -> List[ValidationReference]: + + """ + Retrieve a list of validation references from the registry + + Args: + allow_cache: Allow returning feature views from the cached registry + project: Filter feature views based on project name + + Returns: + List of request feature views + """ + + @abstractmethod + def update_infra(self, infra: Infra, project: str, commit: bool = True): + """ + Updates the stored Infra object. + + Args: + infra: The new Infra object to be stored. + project: Feast project that the Infra object refers to + commit: Whether the change should be persisted immediately + """ + + @abstractmethod + def get_infra(self, project: str, allow_cache: bool = False) -> Infra: + """ + Retrieves the stored Infra object. + + Args: + project: Feast project that the Infra object refers to + allow_cache: Whether to allow returning this entity from a cached registry + + Returns: + The stored Infra object. + """ + + @abstractmethod + def apply_user_metadata( + self, + project: str, + feature_view: BaseFeatureView, + metadata_bytes: Optional[bytes], + ): + ... + + @abstractmethod + def get_user_metadata( + self, project: str, feature_view: BaseFeatureView + ) -> Optional[bytes]: + ... + + @abstractmethod + def proto(self) -> RegistryProto: + """ + Retrieves a proto version of the registry. + + Returns: + The registry proto object. + """ @abstractmethod def commit(self): @@ -589,6 +669,19 @@ class Registry(BaseRegistry): Registry: A registry allows for the management and persistence of feature definitions and related metadata. """ + def apply_user_metadata( + self, + project: str, + feature_view: BaseFeatureView, + metadata_bytes: Optional[bytes], + ): + pass + + def get_user_metadata( + self, project: str, feature_view: BaseFeatureView + ) -> Optional[bytes]: + pass + # The cached_registry_proto object is used for both reads and writes. In particular, # all write operations refresh the cache and modify it in memory; the write must # then be persisted to the underlying RegistryStore with a call to commit(). @@ -1115,6 +1208,28 @@ def list_feature_views( feature_views.append(FeatureView.from_proto(feature_view_proto)) return feature_views + def get_request_feature_view(self, name: str, project: str): + """ + Retrieves a feature view. + + Args: + name: Name of feature view + project: Feast project that this feature view belongs to + allow_cache: Allow returning feature view from the cached registry + + Returns: + Returns either the specified feature view, or raises an exception if + none is found + """ + registry_proto = self._get_registry_proto(allow_cache=False) + for feature_view_proto in registry_proto.feature_views: + if ( + feature_view_proto.spec.name == name + and feature_view_proto.spec.project == project + ): + return RequestFeatureView.from_proto(feature_view_proto) + raise FeatureViewNotFoundException(name, project) + def list_request_feature_views( self, project: str, allow_cache: bool = False ) -> List[RequestFeatureView]: @@ -1469,6 +1584,9 @@ def teardown(self): """Tears down (removes) the registry.""" self._registry_store.teardown() + def proto(self) -> RegistryProto: + return self.cached_registry_proto or RegistryProto() + def to_dict(self, project: str) -> Dict[str, List[Any]]: """Returns a dictionary representation of the registry contents for the specified project. diff --git a/sdk/python/tests/foo_provider.py b/sdk/python/tests/foo_provider.py index 2d61c36273..bd6f9811e8 100644 --- a/sdk/python/tests/foo_provider.py +++ b/sdk/python/tests/foo_provider.py @@ -11,7 +11,7 @@ from feast.infra.provider import Provider from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto -from feast.registry import Registry +from feast.registry import BaseRegistry from feast.saved_dataset import SavedDataset @@ -52,7 +52,7 @@ def materialize_single_feature_view( feature_view: FeatureView, start_date: datetime, end_date: datetime, - registry: Registry, + registry: BaseRegistry, project: str, tqdm_builder: Callable[[int], tqdm], ) -> None: @@ -64,7 +64,7 @@ def get_historical_features( feature_views: List[FeatureView], feature_refs: List[str], entity_df: Union[pandas.DataFrame, str], - registry: Registry, + registry: BaseRegistry, project: str, full_feature_names: bool = False, ) -> RetrievalJob: @@ -87,7 +87,7 @@ def write_feature_service_logs( feature_service: FeatureService, logs: Union[pyarrow.Table, Path], config: RepoConfig, - registry: Registry, + registry: BaseRegistry, ): pass @@ -97,6 +97,6 @@ def retrieve_feature_service_logs( start_date: datetime, end_date: datetime, config: RepoConfig, - registry: Registry, + registry: BaseRegistry, ) -> RetrievalJob: pass diff --git a/sdk/python/tests/integration/registration/test_sql_registry.py b/sdk/python/tests/integration/registration/test_sql_registry.py index efad9f2c81..c96d83ce0a 100644 --- a/sdk/python/tests/integration/registration/test_sql_registry.py +++ b/sdk/python/tests/integration/registration/test_sql_registry.py @@ -24,6 +24,7 @@ from feast import Feature, FileSource, RequestSource from feast.data_format import ParquetFormat from feast.entity import Entity +from feast.errors import FeatureViewNotFoundException from feast.feature_view import FeatureView from feast.field import Field from feast.infra.registry_stores.sql import SqlRegistry @@ -259,9 +260,18 @@ def location_features_from_push(inputs: pd.DataFrame) -> pd.DataFrame: project = "project" + with pytest.raises(FeatureViewNotFoundException): + sql_registry.get_user_metadata(project, location_features_from_push) + # Register Feature View sql_registry.apply_feature_view(location_features_from_push, project) + assert not sql_registry.get_user_metadata(project, location_features_from_push) + + b = "metadata".encode("utf-8") + sql_registry.apply_user_metadata(project, location_features_from_push, b) + assert sql_registry.get_user_metadata(project, location_features_from_push) == b + feature_views = sql_registry.list_on_demand_feature_views(project) # List Feature Views