forked from feast-dev/feast
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Implement spark materialization engine (feast-dev#3184)
* implement spark materialization engine Signed-off-by: niklasvm <[email protected]> * remove redundant code Signed-off-by: niklasvm <[email protected]> * make function private Signed-off-by: niklasvm <[email protected]> * refactor serializing into a class Signed-off-by: niklasvm <[email protected]> * switch to using `foreachPartition` Signed-off-by: niklasvm <[email protected]> * remove batch_size parameter Signed-off-by: niklasvm <[email protected]> * add partitions parameter Signed-off-by: niklasvm <[email protected]> * linting Signed-off-by: niklasvm <[email protected]> * rename spark to spark.offline and spark.engine Signed-off-by: niklasvm <[email protected]> * fix to test Signed-off-by: niklasvm <[email protected]> * forgot to stage Signed-off-by: niklasvm <[email protected]> * revert spark.offline to spark to ensure backward compatibility Signed-off-by: niklasvm <[email protected]> * fix import Signed-off-by: niklasvm <[email protected]> * remove code from testing a large data set Signed-off-by: niklasvm <[email protected]> * linting Signed-off-by: niklasvm <[email protected]> * test without repartition Signed-off-by: niklasvm <[email protected]> * test alternate connection string Signed-off-by: niklasvm <[email protected]> * use redis online creator Signed-off-by: niklasvm <[email protected]> Signed-off-by: niklasvm <[email protected]>
- Loading branch information
Showing
3 changed files
with
343 additions
and
0 deletions.
There are no files selected for viewing
265 changes: 265 additions & 0 deletions
265
sdk/python/feast/infra/materialization/contrib/spark/spark_materialization_engine.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,265 @@ | ||
import tempfile | ||
from dataclasses import dataclass | ||
from datetime import datetime | ||
from typing import Callable, List, Literal, Optional, Sequence, Union | ||
|
||
import dill | ||
import pandas as pd | ||
import pyarrow | ||
from tqdm import tqdm | ||
|
||
from feast.batch_feature_view import BatchFeatureView | ||
from feast.entity import Entity | ||
from feast.feature_view import FeatureView | ||
from feast.infra.materialization.batch_materialization_engine import ( | ||
BatchMaterializationEngine, | ||
MaterializationJob, | ||
MaterializationJobStatus, | ||
MaterializationTask, | ||
) | ||
from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( | ||
SparkOfflineStore, | ||
SparkRetrievalJob, | ||
) | ||
from feast.infra.online_stores.online_store import OnlineStore | ||
from feast.infra.passthrough_provider import PassthroughProvider | ||
from feast.infra.registry.base_registry import BaseRegistry | ||
from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto | ||
from feast.repo_config import FeastConfigBaseModel, RepoConfig | ||
from feast.stream_feature_view import StreamFeatureView | ||
from feast.utils import ( | ||
_convert_arrow_to_proto, | ||
_get_column_names, | ||
_run_pyarrow_field_mapping, | ||
) | ||
|
||
|
||
class SparkMaterializationEngineConfig(FeastConfigBaseModel): | ||
"""Batch Materialization Engine config for spark engine""" | ||
|
||
type: Literal["spark.engine"] = "spark.engine" | ||
""" Type selector""" | ||
|
||
partitions: int = 0 | ||
"""Number of partitions to use when writing data to online store. If 0, no repartitioning is done""" | ||
|
||
|
||
@dataclass | ||
class SparkMaterializationJob(MaterializationJob): | ||
def __init__( | ||
self, | ||
job_id: str, | ||
status: MaterializationJobStatus, | ||
error: Optional[BaseException] = None, | ||
) -> None: | ||
super().__init__() | ||
self._job_id: str = job_id | ||
self._status: MaterializationJobStatus = status | ||
self._error: Optional[BaseException] = error | ||
|
||
def status(self) -> MaterializationJobStatus: | ||
return self._status | ||
|
||
def error(self) -> Optional[BaseException]: | ||
return self._error | ||
|
||
def should_be_retried(self) -> bool: | ||
return False | ||
|
||
def job_id(self) -> str: | ||
return self._job_id | ||
|
||
def url(self) -> Optional[str]: | ||
return None | ||
|
||
|
||
class SparkMaterializationEngine(BatchMaterializationEngine): | ||
def update( | ||
self, | ||
project: str, | ||
views_to_delete: Sequence[ | ||
Union[BatchFeatureView, StreamFeatureView, FeatureView] | ||
], | ||
views_to_keep: Sequence[ | ||
Union[BatchFeatureView, StreamFeatureView, FeatureView] | ||
], | ||
entities_to_delete: Sequence[Entity], | ||
entities_to_keep: Sequence[Entity], | ||
): | ||
# Nothing to set up. | ||
pass | ||
|
||
def teardown_infra( | ||
self, | ||
project: str, | ||
fvs: Sequence[Union[BatchFeatureView, StreamFeatureView, FeatureView]], | ||
entities: Sequence[Entity], | ||
): | ||
# Nothing to tear down. | ||
pass | ||
|
||
def __init__( | ||
self, | ||
*, | ||
repo_config: RepoConfig, | ||
offline_store: SparkOfflineStore, | ||
online_store: OnlineStore, | ||
**kwargs, | ||
): | ||
if not isinstance(offline_store, SparkOfflineStore): | ||
raise TypeError( | ||
"SparkMaterializationEngine is only compatible with the SparkOfflineStore" | ||
) | ||
super().__init__( | ||
repo_config=repo_config, | ||
offline_store=offline_store, | ||
online_store=online_store, | ||
**kwargs, | ||
) | ||
|
||
def materialize( | ||
self, registry, tasks: List[MaterializationTask] | ||
) -> List[MaterializationJob]: | ||
return [ | ||
self._materialize_one( | ||
registry, | ||
task.feature_view, | ||
task.start_time, | ||
task.end_time, | ||
task.project, | ||
task.tqdm_builder, | ||
) | ||
for task in tasks | ||
] | ||
|
||
def _materialize_one( | ||
self, | ||
registry: BaseRegistry, | ||
feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView], | ||
start_date: datetime, | ||
end_date: datetime, | ||
project: str, | ||
tqdm_builder: Callable[[int], tqdm], | ||
): | ||
entities = [] | ||
for entity_name in feature_view.entities: | ||
entities.append(registry.get_entity(entity_name, project)) | ||
|
||
( | ||
join_key_columns, | ||
feature_name_columns, | ||
timestamp_field, | ||
created_timestamp_column, | ||
) = _get_column_names(feature_view, entities) | ||
|
||
job_id = f"{feature_view.name}-{start_date}-{end_date}" | ||
|
||
try: | ||
offline_job: SparkRetrievalJob = ( | ||
self.offline_store.pull_latest_from_table_or_query( | ||
config=self.repo_config, | ||
data_source=feature_view.batch_source, | ||
join_key_columns=join_key_columns, | ||
feature_name_columns=feature_name_columns, | ||
timestamp_field=timestamp_field, | ||
created_timestamp_column=created_timestamp_column, | ||
start_date=start_date, | ||
end_date=end_date, | ||
) | ||
) | ||
|
||
spark_serialized_artifacts = _SparkSerializedArtifacts.serialize( | ||
feature_view=feature_view, repo_config=self.repo_config | ||
) | ||
|
||
spark_df = offline_job.to_spark_df() | ||
if self.repo_config.batch_engine.partitions != 0: | ||
spark_df = spark_df.repartition( | ||
self.repo_config.batch_engine.partitions | ||
) | ||
|
||
spark_df.foreachPartition( | ||
lambda x: _process_by_partition(x, spark_serialized_artifacts) | ||
) | ||
|
||
return SparkMaterializationJob( | ||
job_id=job_id, status=MaterializationJobStatus.SUCCEEDED | ||
) | ||
except BaseException as e: | ||
return SparkMaterializationJob( | ||
job_id=job_id, status=MaterializationJobStatus.ERROR, error=e | ||
) | ||
|
||
|
||
@dataclass | ||
class _SparkSerializedArtifacts: | ||
"""Class to assist with serializing unpicklable artifacts to the spark workers""" | ||
|
||
feature_view_proto: str | ||
repo_config_file: str | ||
|
||
@classmethod | ||
def serialize(cls, feature_view, repo_config): | ||
|
||
# serialize to proto | ||
feature_view_proto = feature_view.to_proto().SerializeToString() | ||
|
||
# serialize repo_config to disk. Will be used to instantiate the online store | ||
repo_config_file = tempfile.NamedTemporaryFile(delete=False).name | ||
with open(repo_config_file, "wb") as f: | ||
dill.dump(repo_config, f) | ||
|
||
return _SparkSerializedArtifacts( | ||
feature_view_proto=feature_view_proto, repo_config_file=repo_config_file | ||
) | ||
|
||
def unserialize(self): | ||
# unserialize | ||
proto = FeatureViewProto() | ||
proto.ParseFromString(self.feature_view_proto) | ||
feature_view = FeatureView.from_proto(proto) | ||
|
||
# load | ||
with open(self.repo_config_file, "rb") as f: | ||
repo_config = dill.load(f) | ||
|
||
provider = PassthroughProvider(repo_config) | ||
online_store = provider.online_store | ||
return feature_view, online_store, repo_config | ||
|
||
|
||
def _process_by_partition(rows, spark_serialized_artifacts: _SparkSerializedArtifacts): | ||
"""Load pandas df to online store""" | ||
|
||
# convert to pyarrow table | ||
dicts = [] | ||
for row in rows: | ||
dicts.append(row.asDict()) | ||
|
||
df = pd.DataFrame.from_records(dicts) | ||
if df.shape[0] == 0: | ||
print("Skipping") | ||
return | ||
|
||
table = pyarrow.Table.from_pandas(df) | ||
|
||
# unserialize artifacts | ||
feature_view, online_store, repo_config = spark_serialized_artifacts.unserialize() | ||
|
||
if feature_view.batch_source.field_mapping is not None: | ||
table = _run_pyarrow_field_mapping( | ||
table, feature_view.batch_source.field_mapping | ||
) | ||
|
||
join_key_to_value_type = { | ||
entity.name: entity.dtype.to_value_type() | ||
for entity in feature_view.entity_columns | ||
} | ||
|
||
rows_to_write = _convert_arrow_to_proto(table, feature_view, join_key_to_value_type) | ||
online_store.online_write_batch( | ||
repo_config, | ||
feature_view, | ||
rows_to_write, | ||
lambda x: None, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
77 changes: 77 additions & 0 deletions
77
sdk/python/tests/integration/materialization/contrib/spark/test_spark.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from datetime import timedelta | ||
|
||
import pytest | ||
|
||
from feast.entity import Entity | ||
from feast.feature_view import FeatureView | ||
from feast.field import Field | ||
from feast.infra.offline_stores.contrib.spark_offline_store.tests.data_source import ( | ||
SparkDataSourceCreator, | ||
) | ||
from feast.types import Float32 | ||
from tests.data.data_creator import create_basic_driver_dataset | ||
from tests.integration.feature_repos.integration_test_repo_config import ( | ||
IntegrationTestRepoConfig, | ||
) | ||
from tests.integration.feature_repos.repo_configuration import ( | ||
construct_test_environment, | ||
) | ||
from tests.integration.feature_repos.universal.online_store.redis import ( | ||
RedisOnlineStoreCreator, | ||
) | ||
from tests.utils.e2e_test_validation import validate_offline_online_store_consistency | ||
|
||
|
||
@pytest.mark.integration | ||
def test_spark_materialization_consistency(): | ||
spark_config = IntegrationTestRepoConfig( | ||
provider="local", | ||
online_store_creator=RedisOnlineStoreCreator, | ||
offline_store_creator=SparkDataSourceCreator, | ||
batch_engine={"type": "spark.engine", "partitions": 10}, | ||
) | ||
spark_environment = construct_test_environment( | ||
spark_config, None, entity_key_serialization_version=1 | ||
) | ||
|
||
df = create_basic_driver_dataset() | ||
|
||
ds = spark_environment.data_source_creator.create_data_source( | ||
df, | ||
spark_environment.feature_store.project, | ||
field_mapping={"ts_1": "ts"}, | ||
) | ||
|
||
fs = spark_environment.feature_store | ||
driver = Entity( | ||
name="driver_id", | ||
join_keys=["driver_id"], | ||
) | ||
|
||
driver_stats_fv = FeatureView( | ||
name="driver_hourly_stats", | ||
entities=[driver], | ||
ttl=timedelta(weeks=52), | ||
schema=[Field(name="value", dtype=Float32)], | ||
source=ds, | ||
) | ||
|
||
try: | ||
|
||
fs.apply([driver, driver_stats_fv]) | ||
|
||
print(df) | ||
|
||
# materialization is run in two steps and | ||
# we use timestamp from generated dataframe as a split point | ||
split_dt = df["ts_1"][4].to_pydatetime() - timedelta(seconds=1) | ||
|
||
print(f"Split datetime: {split_dt}") | ||
|
||
validate_offline_online_store_consistency(fs, driver_stats_fv, split_dt) | ||
finally: | ||
fs.teardown() | ||
|
||
|
||
if __name__ == "__main__": | ||
test_spark_materialization_consistency() |