Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement spark materialization engine #3184

Merged
merged 19 commits into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
import tempfile
from dataclasses import dataclass
from datetime import datetime
from typing import Callable, List, Literal, Optional, Sequence, Union

import dill
import pandas as pd
import pyarrow
from tqdm import tqdm

from feast.batch_feature_view import BatchFeatureView
from feast.entity import Entity
from feast.feature_view import FeatureView
from feast.infra.materialization.batch_materialization_engine import (
BatchMaterializationEngine,
MaterializationJob,
MaterializationJobStatus,
MaterializationTask,
)
from feast.infra.offline_stores.contrib.spark_offline_store.spark import (
SparkOfflineStore,
SparkRetrievalJob,
)
from feast.infra.online_stores.online_store import OnlineStore
from feast.infra.passthrough_provider import PassthroughProvider
from feast.infra.registry.base_registry import BaseRegistry
from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.stream_feature_view import StreamFeatureView
from feast.utils import (
_convert_arrow_to_proto,
_get_column_names,
_run_pyarrow_field_mapping,
)


class SparkMaterializationEngineConfig(FeastConfigBaseModel):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably makes sense to throw an error somewhere if the offline store is not the SparkOfflineStore?

"""Batch Materialization Engine config for spark engine"""

type: Literal["spark.engine"] = "spark.engine"
""" Type selector"""

partitions: int = 0
"""Number of partitions to use when writing data to online store. If 0, no repartitioning is done"""


@dataclass
class SparkMaterializationJob(MaterializationJob):
def __init__(
self,
job_id: str,
status: MaterializationJobStatus,
error: Optional[BaseException] = None,
) -> None:
super().__init__()
self._job_id: str = job_id
self._status: MaterializationJobStatus = status
self._error: Optional[BaseException] = error

def status(self) -> MaterializationJobStatus:
return self._status

def error(self) -> Optional[BaseException]:
return self._error

def should_be_retried(self) -> bool:
return False

def job_id(self) -> str:
return self._job_id

def url(self) -> Optional[str]:
return None


class SparkMaterializationEngine(BatchMaterializationEngine):
def update(
self,
project: str,
views_to_delete: Sequence[
Union[BatchFeatureView, StreamFeatureView, FeatureView]
],
views_to_keep: Sequence[
Union[BatchFeatureView, StreamFeatureView, FeatureView]
],
entities_to_delete: Sequence[Entity],
entities_to_keep: Sequence[Entity],
):
# Nothing to set up.
pass

def teardown_infra(
self,
project: str,
fvs: Sequence[Union[BatchFeatureView, StreamFeatureView, FeatureView]],
entities: Sequence[Entity],
):
# Nothing to tear down.
pass

def __init__(
self,
*,
repo_config: RepoConfig,
offline_store: SparkOfflineStore,
online_store: OnlineStore,
**kwargs,
):
if not isinstance(offline_store, SparkOfflineStore):
raise TypeError(
"SparkMaterializationEngine is only compatible with the SparkOfflineStore"
)
super().__init__(
repo_config=repo_config,
offline_store=offline_store,
online_store=online_store,
**kwargs,
)

def materialize(
self, registry, tasks: List[MaterializationTask]
) -> List[MaterializationJob]:
return [
self._materialize_one(
registry,
task.feature_view,
task.start_time,
task.end_time,
task.project,
task.tqdm_builder,
)
for task in tasks
]

def _materialize_one(
self,
registry: BaseRegistry,
feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView],
start_date: datetime,
end_date: datetime,
project: str,
tqdm_builder: Callable[[int], tqdm],
):
entities = []
for entity_name in feature_view.entities:
entities.append(registry.get_entity(entity_name, project))

(
join_key_columns,
feature_name_columns,
timestamp_field,
created_timestamp_column,
) = _get_column_names(feature_view, entities)

job_id = f"{feature_view.name}-{start_date}-{end_date}"

try:
offline_job: SparkRetrievalJob = (
self.offline_store.pull_latest_from_table_or_query(
config=self.repo_config,
data_source=feature_view.batch_source,
join_key_columns=join_key_columns,
feature_name_columns=feature_name_columns,
timestamp_field=timestamp_field,
created_timestamp_column=created_timestamp_column,
start_date=start_date,
end_date=end_date,
)
)

spark_serialized_artifacts = _SparkSerializedArtifacts.serialize(
feature_view=feature_view, repo_config=self.repo_config
)

spark_df = offline_job.to_spark_df()
if self.repo_config.batch_engine.partitions != 0:
spark_df = spark_df.repartition(
self.repo_config.batch_engine.partitions
)

spark_df.foreachPartition(
lambda x: _process_by_partition(x, spark_serialized_artifacts)
)
Comment on lines +175 to +183
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯


return SparkMaterializationJob(
job_id=job_id, status=MaterializationJobStatus.SUCCEEDED
)
except BaseException as e:
return SparkMaterializationJob(
job_id=job_id, status=MaterializationJobStatus.ERROR, error=e
)


@dataclass
class _SparkSerializedArtifacts:
"""Class to assist with serializing unpicklable artifacts to the spark workers"""

feature_view_proto: str
repo_config_file: str

@classmethod
def serialize(cls, feature_view, repo_config):

# serialize to proto
feature_view_proto = feature_view.to_proto().SerializeToString()

# serialize repo_config to disk. Will be used to instantiate the online store
repo_config_file = tempfile.NamedTemporaryFile(delete=False).name
with open(repo_config_file, "wb") as f:
dill.dump(repo_config, f)

return _SparkSerializedArtifacts(
feature_view_proto=feature_view_proto, repo_config_file=repo_config_file
)

def unserialize(self):
# unserialize
proto = FeatureViewProto()
proto.ParseFromString(self.feature_view_proto)
feature_view = FeatureView.from_proto(proto)

# load
with open(self.repo_config_file, "rb") as f:
repo_config = dill.load(f)

provider = PassthroughProvider(repo_config)
online_store = provider.online_store
return feature_view, online_store, repo_config


def _process_by_partition(rows, spark_serialized_artifacts: _SparkSerializedArtifacts):
"""Load pandas df to online store"""

# convert to pyarrow table
dicts = []
for row in rows:
dicts.append(row.asDict())

df = pd.DataFrame.from_records(dicts)
if df.shape[0] == 0:
print("Skipping")
return

table = pyarrow.Table.from_pandas(df)

# unserialize artifacts
feature_view, online_store, repo_config = spark_serialized_artifacts.unserialize()

if feature_view.batch_source.field_mapping is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since lines 249 to 257 are also used in feature_store.write_to_online_store, maybe it makes sense to refactor this into a util method?

table = _run_pyarrow_field_mapping(
table, feature_view.batch_source.field_mapping
)

join_key_to_value_type = {
entity.name: entity.dtype.to_value_type()
for entity in feature_view.entity_columns
}

rows_to_write = _convert_arrow_to_proto(table, feature_view, join_key_to_value_type)
online_store.online_write_batch(
repo_config,
feature_view,
rows_to_write,
lambda x: None,
)
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@


class SparkOfflineStoreConfig(FeastConfigBaseModel):
type: StrictStr = "spark"
type: StrictStr = "spark.offline"
niklasvm marked this conversation as resolved.
Show resolved Hide resolved
""" Offline store type selector"""

spark_conf: Optional[Dict[str, str]] = None
Expand All @@ -49,7 +49,7 @@ class SparkOfflineStoreConfig(FeastConfigBaseModel):

class SparkOfflineStore(OfflineStore):
@staticmethod
@log_exceptions_and_usage(offline_store="spark")
@log_exceptions_and_usage(offline_store="spark.offline")
def pull_latest_from_table_or_query(
config: RepoConfig,
data_source: DataSource,
Expand Down Expand Up @@ -247,7 +247,7 @@ def offline_write_batch(
)

@staticmethod
@log_exceptions_and_usage(offline_store="spark")
@log_exceptions_and_usage(offline_store="spark.offline")
def pull_all_from_table_or_query(
config: RepoConfig,
data_source: DataSource,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def teardown(self):

def create_offline_store_config(self):
self.spark_offline_store_config = SparkOfflineStoreConfig()
self.spark_offline_store_config.type = "spark"
self.spark_offline_store_config.type = "spark.offline"
self.spark_offline_store_config.spark_conf = self.spark_conf
return self.spark_offline_store_config

Expand Down
3 changes: 2 additions & 1 deletion sdk/python/feast/repo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"snowflake.engine": "feast.infra.materialization.snowflake_engine.SnowflakeMaterializationEngine",
"lambda": "feast.infra.materialization.aws_lambda.lambda_engine.LambdaMaterializationEngine",
"bytewax": "feast.infra.materialization.contrib.bytewax.bytewax_materialization_engine.BytewaxMaterializationEngine",
"spark.engine": "feast.infra.materialization.contrib.spark.spark_materialization_engine.SparkMaterializationEngine",
}

ONLINE_STORE_CLASS_FOR_TYPE = {
Expand All @@ -57,7 +58,7 @@
"bigquery": "feast.infra.offline_stores.bigquery.BigQueryOfflineStore",
"redshift": "feast.infra.offline_stores.redshift.RedshiftOfflineStore",
"snowflake.offline": "feast.infra.offline_stores.snowflake.SnowflakeOfflineStore",
"spark": "feast.infra.offline_stores.contrib.spark_offline_store.spark.SparkOfflineStore",
"spark.offline": "feast.infra.offline_stores.contrib.spark_offline_store.spark.SparkOfflineStore",
"trino": "feast.infra.offline_stores.contrib.trino_offline_store.trino.TrinoOfflineStore",
"postgres": "feast.infra.offline_stores.contrib.postgres_offline_store.postgres.PostgreSQLOfflineStore",
"athena": "feast.infra.offline_stores.contrib.athena_offline_store.athena.AthenaOfflineStore",
Expand Down
Loading