diff --git a/sdk/python/feast/feature_server.py b/sdk/python/feast/feature_server.py index ab88eca89f..a9f5c09404 100644 --- a/sdk/python/feast/feature_server.py +++ b/sdk/python/feast/feature_server.py @@ -165,7 +165,7 @@ async def get_online_features(body=Depends(get_body)): ) @app.post("/push", dependencies=[Depends(inject_user_details)]) - def push(body=Depends(get_body)): + async def push(body=Depends(get_body)): request = PushFeaturesRequest(**json.loads(body)) df = pd.DataFrame(request.df) actions = [] @@ -201,13 +201,22 @@ def push(body=Depends(get_body)): for feature_view in fvs_with_push_sources: assert_permissions(resource=feature_view, actions=actions) - store.push( + push_params = dict( push_source_name=request.push_source_name, df=df, allow_registry_cache=request.allow_registry_cache, to=to, ) + should_push_async = ( + store._get_provider().async_supported.online.write + and to in [PushMode.ONLINE, PushMode.ONLINE_AND_OFFLINE] + ) + if should_push_async: + await store.push_async(**push_params) + else: + store.push(**push_params) + @app.post("/write-to-online-store", dependencies=[Depends(inject_user_details)]) def write_to_online_store(body=Depends(get_body)): request = WriteToFeatureStoreRequest(**json.loads(body)) diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 9b1a35303e..f9fa0a7881 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import itertools import os import warnings @@ -33,6 +34,7 @@ import pandas as pd import pyarrow as pa from colorama import Fore, Style +from fastapi.concurrency import run_in_threadpool from google.protobuf.timestamp_pb2 import Timestamp from tqdm import tqdm @@ -1423,26 +1425,13 @@ def tqdm_builder(length): end_date, ) - def push( - self, - push_source_name: str, - df: pd.DataFrame, - allow_registry_cache: bool = True, - to: PushMode = PushMode.ONLINE, - ): - """ - Push features to a push source. This updates all the feature views that have the push source as stream source. - - Args: - push_source_name: The name of the push source we want to push data to. - df: The data being pushed. - allow_registry_cache: Whether to allow cached versions of the registry. - to: Whether to push to online or offline store. Defaults to online store only. - """ + def _fvs_for_push_source_or_raise( + self, push_source_name: str, allow_cache: bool + ) -> set[FeatureView]: from feast.data_source import PushSource - all_fvs = self.list_feature_views(allow_cache=allow_registry_cache) - all_fvs += self.list_stream_feature_views(allow_cache=allow_registry_cache) + all_fvs = self.list_feature_views(allow_cache=allow_cache) + all_fvs += self.list_stream_feature_views(allow_cache=allow_cache) fvs_with_push_sources = { fv @@ -1457,7 +1446,27 @@ def push( if not fvs_with_push_sources: raise PushSourceNotFoundException(push_source_name) - for fv in fvs_with_push_sources: + return fvs_with_push_sources + + def push( + self, + push_source_name: str, + df: pd.DataFrame, + allow_registry_cache: bool = True, + to: PushMode = PushMode.ONLINE, + ): + """ + Push features to a push source. This updates all the feature views that have the push source as stream source. + + Args: + push_source_name: The name of the push source we want to push data to. + df: The data being pushed. + allow_registry_cache: Whether to allow cached versions of the registry. + to: Whether to push to online or offline store. Defaults to online store only. + """ + for fv in self._fvs_for_push_source_or_raise( + push_source_name, allow_registry_cache + ): if to == PushMode.ONLINE or to == PushMode.ONLINE_AND_OFFLINE: self.write_to_online_store( fv.name, df, allow_registry_cache=allow_registry_cache @@ -1467,22 +1476,42 @@ def push( fv.name, df, allow_registry_cache=allow_registry_cache ) - def write_to_online_store( + async def push_async( + self, + push_source_name: str, + df: pd.DataFrame, + allow_registry_cache: bool = True, + to: PushMode = PushMode.ONLINE, + ): + fvs = self._fvs_for_push_source_or_raise(push_source_name, allow_registry_cache) + + if to == PushMode.ONLINE or to == PushMode.ONLINE_AND_OFFLINE: + _ = await asyncio.gather( + *[ + self.write_to_online_store_async( + fv.name, df, allow_registry_cache=allow_registry_cache + ) + for fv in fvs + ] + ) + + if to == PushMode.OFFLINE or to == PushMode.ONLINE_AND_OFFLINE: + + def _offline_write(): + for fv in fvs: + self.write_to_offline_store( + fv.name, df, allow_registry_cache=allow_registry_cache + ) + + await run_in_threadpool(_offline_write) + + def _get_feature_view_and_df_for_online_write( self, feature_view_name: str, df: Optional[pd.DataFrame] = None, inputs: Optional[Union[Dict[str, List[Any]], pd.DataFrame]] = None, allow_registry_cache: bool = True, ): - """ - Persists a dataframe to the online store. - - Args: - feature_view_name: The feature view to which the dataframe corresponds. - df: The dataframe to be persisted. - inputs: Optional the dictionary object to be written - allow_registry_cache (optional): Whether to allow retrieving feature views from a cached registry. - """ feature_view_dict = { fv_proto.name: fv_proto for fv_proto in self.list_all_feature_views(allow_registry_cache) @@ -1509,10 +1538,60 @@ def write_to_online_store( df = pd.DataFrame(df) except Exception as _: raise DataFrameSerializationError(df) + return feature_view, df + + def write_to_online_store( + self, + feature_view_name: str, + df: Optional[pd.DataFrame] = None, + inputs: Optional[Union[Dict[str, List[Any]], pd.DataFrame]] = None, + allow_registry_cache: bool = True, + ): + """ + Persists a dataframe to the online store. + Args: + feature_view_name: The feature view to which the dataframe corresponds. + df: The dataframe to be persisted. + inputs: Optional the dictionary object to be written + allow_registry_cache (optional): Whether to allow retrieving feature views from a cached registry. + """ + + feature_view, df = self._get_feature_view_and_df_for_online_write( + feature_view_name=feature_view_name, + df=df, + inputs=inputs, + allow_registry_cache=allow_registry_cache, + ) provider = self._get_provider() provider.ingest_df(feature_view, df) + async def write_to_online_store_async( + self, + feature_view_name: str, + df: Optional[pd.DataFrame] = None, + inputs: Optional[Union[Dict[str, List[Any]], pd.DataFrame]] = None, + allow_registry_cache: bool = True, + ): + """ + Persists a dataframe to the online store asynchronously. + + Args: + feature_view_name: The feature view to which the dataframe corresponds. + df: The dataframe to be persisted. + inputs: Optional the dictionary object to be written + allow_registry_cache (optional): Whether to allow retrieving feature views from a cached registry. + """ + + feature_view, df = self._get_feature_view_and_df_for_online_write( + feature_view_name=feature_view_name, + df=df, + inputs=inputs, + allow_registry_cache=allow_registry_cache, + ) + provider = self._get_provider() + await provider.ingest_df_async(feature_view, df) + def write_to_offline_store( self, feature_view_name: str, diff --git a/sdk/python/feast/infra/online_stores/online_store.py b/sdk/python/feast/infra/online_stores/online_store.py index be2065040b..15dd843ba8 100644 --- a/sdk/python/feast/infra/online_stores/online_store.py +++ b/sdk/python/feast/infra/online_stores/online_store.py @@ -67,6 +67,33 @@ def online_write_batch( """ pass + async def online_write_batch_async( + self, + config: RepoConfig, + table: FeatureView, + data: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ], + progress: Optional[Callable[[int], Any]], + ) -> None: + """ + Writes a batch of feature rows to the online store asynchronously. + + If a tz-naive timestamp is passed to this method, it is assumed to be UTC. + + Args: + config: The config for the current feature store. + table: Feature view to which these feature rows correspond. + data: A list of quadruplets containing feature data. Each quadruplet contains an entity + key, a dict containing feature values, an event timestamp for the row, and the created + timestamp for the row if it exists. + progress: Function to be called once a batch of rows is written to the online store, used + to show progress. + """ + raise NotImplementedError( + f"Online store {self.__class__.__name__} does not support online write batch async" + ) + @abstractmethod def online_read( self, diff --git a/sdk/python/feast/infra/passthrough_provider.py b/sdk/python/feast/infra/passthrough_provider.py index ea75cf5ff2..9482b808a9 100644 --- a/sdk/python/feast/infra/passthrough_provider.py +++ b/sdk/python/feast/infra/passthrough_provider.py @@ -188,6 +188,20 @@ def online_write_batch( if self.online_store: self.online_store.online_write_batch(config, table, data, progress) + async def online_write_batch_async( + self, + config: RepoConfig, + table: Union[FeatureView, BaseFeatureView, OnDemandFeatureView], + data: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ], + progress: Optional[Callable[[int], Any]], + ) -> None: + if self.online_store: + await self.online_store.online_write_batch_async( + config, table, data, progress + ) + def offline_write_batch( self, config: RepoConfig, @@ -291,8 +305,8 @@ def retrieve_online_documents( ) return result - def ingest_df( - self, + @staticmethod + def _prep_rows_to_write_for_ingestion( feature_view: Union[BaseFeatureView, FeatureView, OnDemandFeatureView], df: pd.DataFrame, field_mapping: Optional[Dict] = None, @@ -307,10 +321,6 @@ def ingest_df( for entity in feature_view.entity_columns } rows_to_write = _convert_arrow_to_proto(table, feature_view, join_keys) - - self.online_write_batch( - self.repo_config, feature_view, rows_to_write, progress=None - ) else: if hasattr(feature_view, "entity_columns"): join_keys = { @@ -336,9 +346,37 @@ def ingest_df( join_keys[entity.name] = entity.dtype.to_value_type() rows_to_write = _convert_arrow_to_proto(table, feature_view, join_keys) - self.online_write_batch( - self.repo_config, feature_view, rows_to_write, progress=None - ) + return rows_to_write + + def ingest_df( + self, + feature_view: Union[BaseFeatureView, FeatureView, OnDemandFeatureView], + df: pd.DataFrame, + field_mapping: Optional[Dict] = None, + ): + rows_to_write = self._prep_rows_to_write_for_ingestion( + feature_view=feature_view, + df=df, + field_mapping=field_mapping, + ) + self.online_write_batch( + self.repo_config, feature_view, rows_to_write, progress=None + ) + + async def ingest_df_async( + self, + feature_view: Union[BaseFeatureView, FeatureView, OnDemandFeatureView], + df: pd.DataFrame, + field_mapping: Optional[Dict] = None, + ): + rows_to_write = self._prep_rows_to_write_for_ingestion( + feature_view=feature_view, + df=df, + field_mapping=field_mapping, + ) + await self.online_write_batch_async( + self.repo_config, feature_view, rows_to_write, progress=None + ) def ingest_df_to_offline_store(self, feature_view: FeatureView, table: pa.Table): if feature_view.batch_source.field_mapping is not None: diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index fb483d194e..47b7c65ef0 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -141,6 +141,32 @@ def online_write_batch( """ pass + @abstractmethod + async def online_write_batch_async( + self, + config: RepoConfig, + table: FeatureView, + data: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ], + progress: Optional[Callable[[int], Any]], + ) -> None: + """ + Writes a batch of feature rows to the online store asynchronously. + + If a tz-naive timestamp is passed to this method, it is assumed to be UTC. + + Args: + config: The config for the current feature store. + table: Feature view to which these feature rows correspond. + data: A list of quadruplets containing feature data. Each quadruplet contains an entity + key, a dict containing feature values, an event timestamp for the row, and the created + timestamp for the row if it exists. + progress: Function to be called once a batch of rows is written to the online store, used + to show progress. + """ + pass + def ingest_df( self, feature_view: Union[BaseFeatureView, FeatureView, OnDemandFeatureView], @@ -157,6 +183,22 @@ def ingest_df( """ pass + async def ingest_df_async( + self, + feature_view: Union[BaseFeatureView, FeatureView, OnDemandFeatureView], + df: pd.DataFrame, + field_mapping: Optional[Dict] = None, + ): + """ + Persists a dataframe to the online store asynchronously. + + Args: + feature_view: The feature view to which the dataframe corresponds. + df: The dataframe to be persisted. + field_mapping: A dictionary mapping dataframe column names to feature names. + """ + pass + def ingest_df_to_offline_store( self, feature_view: FeatureView, diff --git a/sdk/python/tests/foo_provider.py b/sdk/python/tests/foo_provider.py index 6fe6d15150..bc30d3ef88 100644 --- a/sdk/python/tests/foo_provider.py +++ b/sdk/python/tests/foo_provider.py @@ -22,6 +22,10 @@ from feast.infra.offline_stores.offline_store import RetrievalJob from feast.infra.provider import Provider from feast.infra.registry.base_registry import BaseRegistry +from feast.infra.supported_async_methods import ( + ProviderAsyncMethods, + SupportedAsyncMethods, +) from feast.online_response import OnlineResponse from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import RepeatedValue @@ -30,6 +34,20 @@ class FooProvider(Provider): + @staticmethod + def with_async_support(online_read=False, online_write=False): + class _FooProvider(FooProvider): + @property + def async_supported(self): + return ProviderAsyncMethods( + online=SupportedAsyncMethods( + read=online_read, + write=online_write, + ) + ) + + return _FooProvider(None) + def __init__(self, config: RepoConfig): pass @@ -184,3 +202,14 @@ async def get_online_features_async( full_feature_names: bool = False, ) -> OnlineResponse: pass + + async def online_write_batch_async( + self, + config: RepoConfig, + table: FeatureView, + data: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ], + progress: Optional[Callable[[int], Any]], + ) -> None: + pass diff --git a/sdk/python/tests/unit/test_feature_server.py b/sdk/python/tests/unit/test_feature_server.py new file mode 100644 index 0000000000..34c3fc4068 --- /dev/null +++ b/sdk/python/tests/unit/test_feature_server.py @@ -0,0 +1,47 @@ +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi.testclient import TestClient + +from feast import FeatureStore +from feast.data_source import PushMode +from feast.feature_server import get_app +from feast.utils import _utc_now +from tests.foo_provider import FooProvider + + +@pytest.mark.parametrize( + "online_write,push_mode,async_count", + [ + (True, PushMode.ONLINE_AND_OFFLINE, 1), + (True, PushMode.OFFLINE, 0), + (True, PushMode.ONLINE, 1), + (False, PushMode.ONLINE_AND_OFFLINE, 0), + (False, PushMode.OFFLINE, 0), + (False, PushMode.ONLINE, 0), + ], +) +def test_push_online_async_supported(online_write, push_mode, async_count, environment): + push_payload = json.dumps( + { + "push_source_name": "location_stats_push_source", + "df": { + "location_id": [1], + "temperature": [100], + "event_timestamp": [str(_utc_now())], + "created": [str(_utc_now())], + }, + "to": push_mode.name.lower(), + } + ) + + provider = FooProvider.with_async_support(online_write=online_write) + with patch.object(FeatureStore, "_get_provider", return_value=provider): + fs = environment.feature_store + fs.push = MagicMock() + fs.push_async = AsyncMock() + client = TestClient(get_app(fs)) + client.post("/push", data=push_payload) + assert fs.push.call_count == 1 - async_count + assert fs.push_async.await_count == async_count