From a59c33ac10760b4029fadd8e377eb36a2c458583 Mon Sep 17 00:00:00 2001 From: Niklas von Maltzahn Date: Thu, 15 Sep 2022 16:45:22 +0200 Subject: [PATCH] feat: Implement spark materialization engine (#3184) * implement spark materialization engine Signed-off-by: niklasvm * remove redundant code Signed-off-by: niklasvm * make function private Signed-off-by: niklasvm * refactor serializing into a class Signed-off-by: niklasvm * switch to using `foreachPartition` Signed-off-by: niklasvm * remove batch_size parameter Signed-off-by: niklasvm * add partitions parameter Signed-off-by: niklasvm * linting Signed-off-by: niklasvm * rename spark to spark.offline and spark.engine Signed-off-by: niklasvm * fix to test Signed-off-by: niklasvm * forgot to stage Signed-off-by: niklasvm * revert spark.offline to spark to ensure backward compatibility Signed-off-by: niklasvm * fix import Signed-off-by: niklasvm * remove code from testing a large data set Signed-off-by: niklasvm * linting Signed-off-by: niklasvm * test without repartition Signed-off-by: niklasvm * test alternate connection string Signed-off-by: niklasvm * use redis online creator Signed-off-by: niklasvm Signed-off-by: niklasvm --- .../spark/spark_materialization_engine.py | 265 ++++++++++++++++++ sdk/python/feast/repo_config.py | 1 + .../contrib/spark/test_spark.py | 77 +++++ 3 files changed, 343 insertions(+) create mode 100644 sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py create mode 100644 sdk/python/tests/integration/materialization/contrib/spark/test_spark.py diff --git a/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py b/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py new file mode 100644 index 0000000000..66eb97bca7 --- /dev/null +++ b/sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py @@ -0,0 +1,265 @@ +import tempfile +from dataclasses import dataclass +from datetime import datetime +from typing import Callable, List, Literal, Optional, Sequence, Union + +import dill +import pandas as pd +import pyarrow +from tqdm import tqdm + +from feast.batch_feature_view import BatchFeatureView +from feast.entity import Entity +from feast.feature_view import FeatureView +from feast.infra.materialization.batch_materialization_engine import ( + BatchMaterializationEngine, + MaterializationJob, + MaterializationJobStatus, + MaterializationTask, +) +from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( + SparkOfflineStore, + SparkRetrievalJob, +) +from feast.infra.online_stores.online_store import OnlineStore +from feast.infra.passthrough_provider import PassthroughProvider +from feast.infra.registry.base_registry import BaseRegistry +from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto +from feast.repo_config import FeastConfigBaseModel, RepoConfig +from feast.stream_feature_view import StreamFeatureView +from feast.utils import ( + _convert_arrow_to_proto, + _get_column_names, + _run_pyarrow_field_mapping, +) + + +class SparkMaterializationEngineConfig(FeastConfigBaseModel): + """Batch Materialization Engine config for spark engine""" + + type: Literal["spark.engine"] = "spark.engine" + """ Type selector""" + + partitions: int = 0 + """Number of partitions to use when writing data to online store. If 0, no repartitioning is done""" + + +@dataclass +class SparkMaterializationJob(MaterializationJob): + def __init__( + self, + job_id: str, + status: MaterializationJobStatus, + error: Optional[BaseException] = None, + ) -> None: + super().__init__() + self._job_id: str = job_id + self._status: MaterializationJobStatus = status + self._error: Optional[BaseException] = error + + def status(self) -> MaterializationJobStatus: + return self._status + + def error(self) -> Optional[BaseException]: + return self._error + + def should_be_retried(self) -> bool: + return False + + def job_id(self) -> str: + return self._job_id + + def url(self) -> Optional[str]: + return None + + +class SparkMaterializationEngine(BatchMaterializationEngine): + def update( + self, + project: str, + views_to_delete: Sequence[ + Union[BatchFeatureView, StreamFeatureView, FeatureView] + ], + views_to_keep: Sequence[ + Union[BatchFeatureView, StreamFeatureView, FeatureView] + ], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + ): + # Nothing to set up. + pass + + def teardown_infra( + self, + project: str, + fvs: Sequence[Union[BatchFeatureView, StreamFeatureView, FeatureView]], + entities: Sequence[Entity], + ): + # Nothing to tear down. + pass + + def __init__( + self, + *, + repo_config: RepoConfig, + offline_store: SparkOfflineStore, + online_store: OnlineStore, + **kwargs, + ): + if not isinstance(offline_store, SparkOfflineStore): + raise TypeError( + "SparkMaterializationEngine is only compatible with the SparkOfflineStore" + ) + super().__init__( + repo_config=repo_config, + offline_store=offline_store, + online_store=online_store, + **kwargs, + ) + + def materialize( + self, registry, tasks: List[MaterializationTask] + ) -> List[MaterializationJob]: + return [ + self._materialize_one( + registry, + task.feature_view, + task.start_time, + task.end_time, + task.project, + task.tqdm_builder, + ) + for task in tasks + ] + + def _materialize_one( + self, + registry: BaseRegistry, + feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView], + start_date: datetime, + end_date: datetime, + project: str, + tqdm_builder: Callable[[int], tqdm], + ): + entities = [] + for entity_name in feature_view.entities: + entities.append(registry.get_entity(entity_name, project)) + + ( + join_key_columns, + feature_name_columns, + timestamp_field, + created_timestamp_column, + ) = _get_column_names(feature_view, entities) + + job_id = f"{feature_view.name}-{start_date}-{end_date}" + + try: + offline_job: SparkRetrievalJob = ( + self.offline_store.pull_latest_from_table_or_query( + config=self.repo_config, + data_source=feature_view.batch_source, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + start_date=start_date, + end_date=end_date, + ) + ) + + spark_serialized_artifacts = _SparkSerializedArtifacts.serialize( + feature_view=feature_view, repo_config=self.repo_config + ) + + spark_df = offline_job.to_spark_df() + if self.repo_config.batch_engine.partitions != 0: + spark_df = spark_df.repartition( + self.repo_config.batch_engine.partitions + ) + + spark_df.foreachPartition( + lambda x: _process_by_partition(x, spark_serialized_artifacts) + ) + + return SparkMaterializationJob( + job_id=job_id, status=MaterializationJobStatus.SUCCEEDED + ) + except BaseException as e: + return SparkMaterializationJob( + job_id=job_id, status=MaterializationJobStatus.ERROR, error=e + ) + + +@dataclass +class _SparkSerializedArtifacts: + """Class to assist with serializing unpicklable artifacts to the spark workers""" + + feature_view_proto: str + repo_config_file: str + + @classmethod + def serialize(cls, feature_view, repo_config): + + # serialize to proto + feature_view_proto = feature_view.to_proto().SerializeToString() + + # serialize repo_config to disk. Will be used to instantiate the online store + repo_config_file = tempfile.NamedTemporaryFile(delete=False).name + with open(repo_config_file, "wb") as f: + dill.dump(repo_config, f) + + return _SparkSerializedArtifacts( + feature_view_proto=feature_view_proto, repo_config_file=repo_config_file + ) + + def unserialize(self): + # unserialize + proto = FeatureViewProto() + proto.ParseFromString(self.feature_view_proto) + feature_view = FeatureView.from_proto(proto) + + # load + with open(self.repo_config_file, "rb") as f: + repo_config = dill.load(f) + + provider = PassthroughProvider(repo_config) + online_store = provider.online_store + return feature_view, online_store, repo_config + + +def _process_by_partition(rows, spark_serialized_artifacts: _SparkSerializedArtifacts): + """Load pandas df to online store""" + + # convert to pyarrow table + dicts = [] + for row in rows: + dicts.append(row.asDict()) + + df = pd.DataFrame.from_records(dicts) + if df.shape[0] == 0: + print("Skipping") + return + + table = pyarrow.Table.from_pandas(df) + + # unserialize artifacts + feature_view, online_store, repo_config = spark_serialized_artifacts.unserialize() + + if feature_view.batch_source.field_mapping is not None: + table = _run_pyarrow_field_mapping( + table, feature_view.batch_source.field_mapping + ) + + join_key_to_value_type = { + entity.name: entity.dtype.to_value_type() + for entity in feature_view.entity_columns + } + + rows_to_write = _convert_arrow_to_proto(table, feature_view, join_key_to_value_type) + online_store.online_write_batch( + repo_config, + feature_view, + rows_to_write, + lambda x: None, + ) diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 47a5ae321d..d5f68a8db6 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -39,6 +39,7 @@ "snowflake.engine": "feast.infra.materialization.snowflake_engine.SnowflakeMaterializationEngine", "lambda": "feast.infra.materialization.aws_lambda.lambda_engine.LambdaMaterializationEngine", "bytewax": "feast.infra.materialization.contrib.bytewax.bytewax_materialization_engine.BytewaxMaterializationEngine", + "spark.engine": "feast.infra.materialization.contrib.spark.spark_materialization_engine.SparkMaterializationEngine", } ONLINE_STORE_CLASS_FOR_TYPE = { diff --git a/sdk/python/tests/integration/materialization/contrib/spark/test_spark.py b/sdk/python/tests/integration/materialization/contrib/spark/test_spark.py new file mode 100644 index 0000000000..c7028a09ef --- /dev/null +++ b/sdk/python/tests/integration/materialization/contrib/spark/test_spark.py @@ -0,0 +1,77 @@ +from datetime import timedelta + +import pytest + +from feast.entity import Entity +from feast.feature_view import FeatureView +from feast.field import Field +from feast.infra.offline_stores.contrib.spark_offline_store.tests.data_source import ( + SparkDataSourceCreator, +) +from feast.types import Float32 +from tests.data.data_creator import create_basic_driver_dataset +from tests.integration.feature_repos.integration_test_repo_config import ( + IntegrationTestRepoConfig, +) +from tests.integration.feature_repos.repo_configuration import ( + construct_test_environment, +) +from tests.integration.feature_repos.universal.online_store.redis import ( + RedisOnlineStoreCreator, +) +from tests.utils.e2e_test_validation import validate_offline_online_store_consistency + + +@pytest.mark.integration +def test_spark_materialization_consistency(): + spark_config = IntegrationTestRepoConfig( + provider="local", + online_store_creator=RedisOnlineStoreCreator, + offline_store_creator=SparkDataSourceCreator, + batch_engine={"type": "spark.engine", "partitions": 10}, + ) + spark_environment = construct_test_environment( + spark_config, None, entity_key_serialization_version=1 + ) + + df = create_basic_driver_dataset() + + ds = spark_environment.data_source_creator.create_data_source( + df, + spark_environment.feature_store.project, + field_mapping={"ts_1": "ts"}, + ) + + fs = spark_environment.feature_store + driver = Entity( + name="driver_id", + join_keys=["driver_id"], + ) + + driver_stats_fv = FeatureView( + name="driver_hourly_stats", + entities=[driver], + ttl=timedelta(weeks=52), + schema=[Field(name="value", dtype=Float32)], + source=ds, + ) + + try: + + fs.apply([driver, driver_stats_fv]) + + print(df) + + # materialization is run in two steps and + # we use timestamp from generated dataframe as a split point + split_dt = df["ts_1"][4].to_pydatetime() - timedelta(seconds=1) + + print(f"Split datetime: {split_dt}") + + validate_offline_online_store_consistency(fs, driver_stats_fv, split_dt) + finally: + fs.teardown() + + +if __name__ == "__main__": + test_spark_materialization_consistency()