Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Add support for push sources in feature views #2452

Merged
merged 6 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go/internal/feast/featurestore.go
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ func groupFeatureRefs(requestedFeatureViews []*featureViewAndRefs,
joinKeys := make([]string, 0)
fv := featuresAndView.view
featureNames := featuresAndView.featureRefs
for entity, _ := range fv.entities {
for entity := range fv.entities {
joinKeys = append(joinKeys, entityNameToJoinKeyMap[entity])
}

Expand Down
3 changes: 2 additions & 1 deletion sdk/python/feast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from feast.infra.offline_stores.redshift_source import RedshiftSource
from feast.infra.offline_stores.snowflake_source import SnowflakeSource

from .data_source import KafkaSource, KinesisSource, SourceType
from .data_source import KafkaSource, KinesisSource, PushSource, SourceType
from .entity import Entity
from .feature import Feature
from .feature_service import FeatureService
Expand Down Expand Up @@ -47,4 +47,5 @@
"RedshiftSource",
"RequestFeatureView",
"SnowflakeSource",
"PushSource",
]
55 changes: 37 additions & 18 deletions sdk/python/feast/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,8 @@ def get_table_column_names_and_types(
def from_proto(data_source: DataSourceProto):
schema_pb = data_source.request_data_options.schema
schema = {}
for key in schema_pb.keys():
schema[key] = ValueType(schema_pb.get(key))
for key, val in schema_pb.items():
schema[key] = ValueType(val)
return RequestDataSource(name=data_source.name, schema=schema)

def to_proto(self) -> DataSourceProto:
Expand Down Expand Up @@ -510,27 +510,41 @@ def to_proto(self) -> DataSourceProto:

class PushSource(DataSource):
"""
PushSource that can be used to ingest features on request

Args:
name: Name of the push source
schema: Schema mapping from the input feature name to a ValueType
A source that can be used to ingest features on request
"""

name: str
schema: Dict[str, ValueType]
batch_source: Optional[DataSource]
batch_source: DataSource
event_timestamp_column: str

def __init__(
self,
name: str,
schema: Dict[str, ValueType],
batch_source: Optional[DataSource] = None,
batch_source: DataSource,
event_timestamp_column="timestamp",
achals marked this conversation as resolved.
Show resolved Hide resolved
):
"""Creates a PushSource object."""
"""
Creates a PushSource object.
Args:
name: Name of the push source
schema: Schema mapping from the input feature name to a ValueType
batch_source: The batch source that backs this push source. It's used when materializing from the offline
store to the online store, and when retrieving historical features.
event_timestamp_column (optional): Event timestamp column used for point in time
joins of feature values.
"""
super().__init__(name)
self.schema = schema
self.batch_source = batch_source
if not self.batch_source:
raise ValueError(f"batch_source is needed for push source {self.name}")
self.event_timestamp_column = event_timestamp_column
if not self.event_timestamp_column:
raise ValueError(
f"event_timestamp_column is needed for push source {self.name}"
)

def validate(self, config: RepoConfig):
pass
Expand All @@ -544,21 +558,23 @@ def get_table_column_names_and_types(
def from_proto(data_source: DataSourceProto):
schema_pb = data_source.push_options.schema
schema = {}
for key, value in schema_pb.items():
schema[key] = value
for key, val in schema_pb.items():
schema[key] = ValueType(val)

batch_source = None
if data_source.push_options.HasField("batch_source"):
batch_source = DataSource.from_proto(data_source.push_options.batch_source)
assert data_source.push_options.HasField("batch_source")
batch_source = DataSource.from_proto(data_source.push_options.batch_source)

return PushSource(
name=data_source.name, schema=schema, batch_source=batch_source
name=data_source.name,
schema=schema,
batch_source=batch_source,
event_timestamp_column=data_source.event_timestamp_column,
)

def to_proto(self) -> DataSourceProto:
schema_pb = {}
for key, value in self.schema.items():
schema_pb[key] = value
schema_pb[key] = value.value
batch_source_proto = None
if self.batch_source:
batch_source_proto = self.batch_source.to_proto()
Expand All @@ -567,7 +583,10 @@ def to_proto(self) -> DataSourceProto:
schema=schema_pb, batch_source=batch_source_proto
)
data_source_proto = DataSourceProto(
name=self.name, type=DataSourceProto.PUSH_SOURCE, push_options=options,
name=self.name,
type=DataSourceProto.PUSH_SOURCE,
push_options=options,
event_timestamp_column=self.event_timestamp_column,
)

return data_source_proto
Expand Down
43 changes: 34 additions & 9 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,20 +282,23 @@ def list_data_sources(self, allow_cache: bool = False) -> List[DataSource]:
return self._registry.list_data_sources(self.project, allow_cache=allow_cache)

@log_exceptions_and_usage
def get_entity(self, name: str) -> Entity:
def get_entity(self, name: str, allow_registry_cache: bool = False) -> Entity:
"""
Retrieves an entity.

Args:
name: Name of entity.
allow_registry_cache: (Optional) Whether to allow returning this entity from a cached registry

Returns:
The specified entity.

Raises:
EntityNotFoundException: The entity could not be found.
"""
return self._registry.get_entity(name, self.project)
return self._registry.get_entity(
name, self.project, allow_cache=allow_registry_cache
)

@log_exceptions_and_usage
def get_feature_service(
Expand All @@ -317,25 +320,33 @@ def get_feature_service(
return self._registry.get_feature_service(name, self.project, allow_cache)

@log_exceptions_and_usage
def get_feature_view(self, name: str) -> FeatureView:
def get_feature_view(
self, name: str, allow_registry_cache: bool = False
) -> FeatureView:
"""
Retrieves a feature view.

Args:
name: Name of feature view.
allow_registry_cache: (Optional) Whether to allow returning this entity from a cached registry

Returns:
The specified feature view.

Raises:
FeatureViewNotFoundException: The feature view could not be found.
"""
return self._get_feature_view(name)
return self._get_feature_view(name, allow_registry_cache=allow_registry_cache)

def _get_feature_view(
self, name: str, hide_dummy_entity: bool = True
self,
name: str,
hide_dummy_entity: bool = True,
allow_registry_cache: bool = False,
) -> FeatureView:
feature_view = self._registry.get_feature_view(name, self.project)
feature_view = self._registry.get_feature_view(
name, self.project, allow_cache=allow_registry_cache
)
if hide_dummy_entity and feature_view.entities[0] == DUMMY_ENTITY_NAME:
feature_view.entities = []
return feature_view
Expand Down Expand Up @@ -1144,6 +1155,18 @@ def tqdm_builder(length):
feature_view, self.project, start_date, end_date,
)

@log_exceptions_and_usage
def push(self, push_source_name: str, df: pd.DataFrame):
felixwang9817 marked this conversation as resolved.
Show resolved Hide resolved
push_source = self.get_data_source(push_source_name)

all_fvs = self.list_feature_views(allow_cache=True)
fvs_with_push_sources = {
fv for fv in all_fvs if fv.stream_source in push_source
}

for fv in fvs_with_push_sources:
self.write_to_online_store(fv.name, df, allow_registry_cache=True)

@log_exceptions_and_usage
def write_to_online_store(
self,
Expand All @@ -1155,12 +1178,14 @@ def write_to_online_store(
ingests data directly into the Online store
"""
# TODO: restrict this to work with online StreamFeatureViews and validate the FeatureView type
feature_view = self._registry.get_feature_view(
feature_view_name, self.project, allow_cache=allow_registry_cache
feature_view = self.get_feature_view(
feature_view_name, allow_registry_cache=allow_registry_cache
)
entities = []
for entity_name in feature_view.entities:
entities.append(self._registry.get_entity(entity_name, self.project))
entities.append(
self.get_entity(entity_name, allow_registry_cache=allow_registry_cache)
)
provider = self._get_provider()
provider.ingest_df(feature_view, entities, df)

Expand Down
34 changes: 25 additions & 9 deletions sdk/python/feast/feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from feast import utils
from feast.base_feature_view import BaseFeatureView
from feast.data_source import DataSource
from feast.data_source import DataSource, PushSource
from feast.entity import Entity
from feast.feature import Feature
from feast.feature_view_projection import FeatureViewProjection
Expand Down Expand Up @@ -58,7 +58,9 @@ class FeatureView(BaseFeatureView):
ttl: The amount of time this group of features lives. A ttl of 0 indicates that
this group of features lives forever. Note that large ttl's or a ttl of 0
can result in extremely computationally intensive queries.
batch_source: The batch source of data where this group of features is stored.
batch_source (optional): The batch source of data where this group of features is stored.
This is optional ONLY a push source is specified as the stream_source, since push sources
contain their own batch sources.
stream_source (optional): The stream source of data where this group of features
is stored.
features: The list of features defined as part of this feature view.
Expand Down Expand Up @@ -88,7 +90,7 @@ def __init__(
name: str,
entities: List[str],
ttl: Union[Duration, timedelta],
batch_source: DataSource,
batch_source: Optional[DataSource] = None,
achals marked this conversation as resolved.
Show resolved Hide resolved
stream_source: Optional[DataSource] = None,
features: Optional[List[Feature]] = None,
online: bool = True,
Expand Down Expand Up @@ -121,15 +123,30 @@ def __init__(
"""
_features = features or []

if stream_source is not None and isinstance(stream_source, PushSource):
if stream_source.batch_source is None or not isinstance(
stream_source.batch_source, DataSource
):
raise ValueError(
f"A batch_source needs to be specified for feature view `{name}`"
)
self.batch_source = stream_source.batch_source
else:
if batch_source is None:
raise ValueError(
f"A batch_source needs to be specified for feature view `{name}`"
)
self.batch_source = batch_source

cols = [entity for entity in entities] + [feat.name for feat in _features]
for col in cols:
if (
batch_source.field_mapping is not None
and col in batch_source.field_mapping.keys()
self.batch_source.field_mapping is not None
and col in self.batch_source.field_mapping.keys()
):
raise ValueError(
f"The field {col} is mapped to {batch_source.field_mapping[col]} for this data source. "
f"Please either remove this field mapping or use {batch_source.field_mapping[col]} as the "
f"The field {col} is mapped to {self.batch_source.field_mapping[col]} for this data source. "
f"Please either remove this field mapping or use {self.batch_source.field_mapping[col]} as the "
f"Entity or Feature name."
)

Expand All @@ -149,9 +166,8 @@ def __init__(
else:
self.ttl = ttl

self.batch_source = batch_source
self.stream_source = stream_source
self.online = online
self.stream_source = stream_source
self.materialization_intervals = []

# Note: Python requires redefining hash in child classes that override __eq__
Expand Down
3 changes: 3 additions & 0 deletions sdk/python/feast/repo_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import click
from click.exceptions import BadParameter

from feast import PushSource
from feast.data_source import DataSource
from feast.diff.registry_diff import extract_objects_for_keep_delete_update_add
from feast.entity import Entity
Expand Down Expand Up @@ -112,6 +113,8 @@ def parse_repo(repo_root: Path) -> RepoContents:
res.data_sources.add(obj)
if isinstance(obj, FeatureView):
res.feature_views.add(obj)
if isinstance(obj.stream_source, PushSource):
achals marked this conversation as resolved.
Show resolved Hide resolved
res.data_sources.add(obj.stream_source.batch_source)
elif isinstance(obj, Entity):
res.entities.add(obj)
elif isinstance(obj, FeatureService):
Expand Down
24 changes: 24 additions & 0 deletions sdk/python/tests/example_repos/example_feature_repo_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Feature,
FeatureService,
FeatureView,
PushSource,
ValueType,
)

Expand All @@ -26,6 +27,16 @@
event_timestamp_column="event_timestamp",
)

driver_locations_push_source = PushSource(
name="driver_locations_push",
schema={
"driver_id": ValueType.STRING,
"driver_lat": ValueType.FLOAT,
"driver_long": ValueType.STRING,
},
batch_source=driver_locations_source,
)

driver = Entity(
name="driver", # The name is derived from this argument, not object name.
join_key="driver_id",
Expand Down Expand Up @@ -53,6 +64,19 @@
tags={},
)

pushed_driver_locations = FeatureView(
name="pushed_driver_locations",
entities=["driver"],
ttl=timedelta(days=1),
features=[
Feature(name="driver_lat", dtype=ValueType.FLOAT),
Feature(name="driver_long", dtype=ValueType.STRING),
],
online=True,
stream_source=driver_locations_push_source,
tags={},
)

customer_profile = FeatureView(
name="customer_profile",
entities=["customer"],
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/tests/integration/registration/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ def test_universal_cli(environment: Environment):
cwd=repo_path,
)
assertpy.assert_that(result.returncode).is_equal_to(0)
assertpy.assert_that(fs.list_feature_views()).is_length(3)
assertpy.assert_that(fs.list_feature_views()).is_length(4)
result = runner.run(
["data-sources", "describe", "customer_profile_source"], cwd=repo_path,
)
assertpy.assert_that(result.returncode).is_equal_to(0)
assertpy.assert_that(fs.list_data_sources()).is_length(3)
assertpy.assert_that(fs.list_data_sources()).is_length(4)

# entity & feature view describe commands should fail when objects don't exist
result = runner.run(["entities", "describe", "foo"], cwd=repo_path)
Expand Down
Loading