Skip to content

Commit

Permalink
fix: Fix __hash__ methods (#2556)
Browse files Browse the repository at this point in the history
* Fix __hash__ method for Entity

Signed-off-by: Felix Wang <[email protected]>

* Fix __hash__ method for FeatureService

Signed-off-by: Felix Wang <[email protected]>

* Remove references to PrimitiveFeastType

Signed-off-by: Felix Wang <[email protected]>

* Fix __hash__ method for DataSource and PushSource

Signed-off-by: Felix Wang <[email protected]>

* Fix __hash__ method for FeatureView and OnDemandFeatureView

Signed-off-by: Felix Wang <[email protected]>

* Fix __hash__ method for SavedDataset

Signed-off-by: Felix Wang <[email protected]>

* Fix bugs

Signed-off-by: Felix Wang <[email protected]>

* Fix push merge

Signed-off-by: Danny Chiao <[email protected]>

Co-authored-by: Danny Chiao <[email protected]>
  • Loading branch information
felixwang9817 and adchia authored Apr 19, 2022
1 parent 7076fe0 commit ebb7dfe
Show file tree
Hide file tree
Showing 20 changed files with 394 additions and 95 deletions.
3 changes: 2 additions & 1 deletion sdk/python/feast/base_feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __str__(self):
return str(MessageToJson(self.to_proto()))

def __hash__(self):
return hash((id(self), self.name))
return hash((self.name))

def __getitem__(self, item):
assert isinstance(item, list)
Expand All @@ -134,6 +134,7 @@ def __eq__(self, other):
if (
self.name != other.name
or sorted(self.features) != sorted(other.features)
or self.projection != other.projection
or self.description != other.description
or self.tags != other.tags
or self.owner != other.owner
Expand Down
46 changes: 33 additions & 13 deletions sdk/python/feast/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def __init__(
self.owner = owner or ""

def __hash__(self):
return hash((id(self), self.name))
return hash((self.name, self.timestamp_field))

def __str__(self):
return str(MessageToJson(self.to_proto()))
Expand All @@ -263,9 +263,9 @@ def __eq__(self, other):
or self.created_timestamp_column != other.created_timestamp_column
or self.field_mapping != other.field_mapping
or self.date_partition_column != other.date_partition_column
or self.description != other.description
or self.tags != other.tags
or self.owner != other.owner
or self.description != other.description
):
return False

Expand Down Expand Up @@ -392,6 +392,9 @@ def __eq__(self, other):
"Comparisons should only involve KafkaSource class objects."
)

if not super().__eq__(other):
return False

if (
self.kafka_options.bootstrap_servers
!= other.kafka_options.bootstrap_servers
Expand All @@ -402,6 +405,9 @@ def __eq__(self, other):

return True

def __hash__(self):
return super().__hash__()

@staticmethod
def from_proto(data_source: DataSourceProto):
return KafkaSource(
Expand Down Expand Up @@ -507,13 +513,10 @@ def __eq__(self, other):
raise TypeError(
"Comparisons should only involve RequestSource class objects."
)
if (
self.name != other.name
or self.description != other.description
or self.owner != other.owner
or self.tags != other.tags
):

if not super().__eq__(other):
return False

if isinstance(self.schema, List) and isinstance(other.schema, List):
for field1, field2 in zip(self.schema, other.schema):
if field1 != field2:
Expand Down Expand Up @@ -671,24 +674,26 @@ def __init__(
)

def __eq__(self, other):
if other is None:
return False

if not isinstance(other, KinesisSource):
raise TypeError(
"Comparisons should only involve KinesisSource class objects."
)

if not super().__eq__(other):
return False

if (
self.name != other.name
or self.kinesis_options.record_format != other.kinesis_options.record_format
self.kinesis_options.record_format != other.kinesis_options.record_format
or self.kinesis_options.region != other.kinesis_options.region
or self.kinesis_options.stream_name != other.kinesis_options.stream_name
):
return False

return True

def __hash__(self):
return super().__hash__()

def to_proto(self) -> DataSourceProto:
data_source_proto = DataSourceProto(
name=self.name,
Expand Down Expand Up @@ -744,6 +749,21 @@ def __init__(
if not self.batch_source:
raise ValueError(f"batch_source is needed for push source {self.name}")

def __eq__(self, other):
if not isinstance(other, PushSource):
raise TypeError("Comparisons should only involve PushSource class objects.")

if not super().__eq__(other):
return False

if self.batch_source != other.batch_source:
return False

return True

def __hash__(self):
return super().__hash__()

def validate(self, config: RepoConfig):
pass

Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/diff/registry_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def extract_objects_for_keep_delete_update_add(
FeastObjectType, List[Any]
] = FeastObjectType.get_objects_from_registry(registry, current_project)
registry_object_type_to_repo_contents: Dict[
FeastObjectType, Set[Any]
FeastObjectType, List[Any]
] = FeastObjectType.get_objects_from_repo_contents(desired_repo_contents)

for object_type in FEAST_OBJECT_TYPES:
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(
self.last_updated_timestamp = None

def __hash__(self) -> int:
return hash((id(self), self.name))
return hash((self.name, self.join_key))

def __eq__(self, other):
if not isinstance(other, Entity):
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/feature_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __str__(self):
return str(MessageToJson(self.to_proto()))

def __hash__(self):
return hash((id(self), self.name))
return hash((self.name))

def __eq__(self, other):
if not isinstance(other, FeatureService):
Expand Down
28 changes: 14 additions & 14 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,25 +533,25 @@ def _plan(
... batch_source=driver_hourly_stats,
... )
>>> registry_diff, infra_diff, new_infra = fs._plan(RepoContents(
... data_sources={driver_hourly_stats},
... feature_views={driver_hourly_stats_view},
... on_demand_feature_views=set(),
... request_feature_views=set(),
... entities={driver},
... feature_services=set())) # register entity and feature view
... data_sources=[driver_hourly_stats],
... feature_views=[driver_hourly_stats_view],
... on_demand_feature_views=list(),
... request_feature_views=list(),
... entities=[driver],
... feature_services=list())) # register entity and feature view
"""
# Validate and run inference on all the objects to be registered.
self._validate_all_feature_views(
list(desired_repo_contents.feature_views),
list(desired_repo_contents.on_demand_feature_views),
list(desired_repo_contents.request_feature_views),
desired_repo_contents.feature_views,
desired_repo_contents.on_demand_feature_views,
desired_repo_contents.request_feature_views,
)
_validate_data_sources(list(desired_repo_contents.data_sources))
_validate_data_sources(desired_repo_contents.data_sources)
self._make_inferences(
list(desired_repo_contents.data_sources),
list(desired_repo_contents.entities),
list(desired_repo_contents.feature_views),
list(desired_repo_contents.on_demand_feature_views),
desired_repo_contents.data_sources,
desired_repo_contents.entities,
desired_repo_contents.feature_views,
desired_repo_contents.on_demand_feature_views,
)

# Compute the desired difference between the current objects in the registry and
Expand Down
13 changes: 4 additions & 9 deletions sdk/python/feast/feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,6 @@ def _initialize_sources(self, name, batch_source, stream_source, source):
self.batch_source = batch_source
self.source = source

# Note: Python requires redefining hash in child classes that override __eq__
def __hash__(self):
return super().__hash__()

Expand Down Expand Up @@ -298,19 +297,15 @@ def __eq__(self, other):
return False

if (
self.tags != other.tags
sorted(self.entities) != sorted(other.entities)
or self.ttl != other.ttl
or self.online != other.online
or self.batch_source != other.batch_source
or self.stream_source != other.stream_source
or self.schema != other.schema
):
return False

if sorted(self.entities) != sorted(other.entities):
return False
if self.batch_source != other.batch_source:
return False
if self.stream_source != other.stream_source:
return False

return True

def ensure_valid(self):
Expand Down
13 changes: 9 additions & 4 deletions sdk/python/feast/on_demand_feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,19 @@ def __copy__(self):
return fv

def __eq__(self, other):
if not isinstance(other, OnDemandFeatureView):
raise TypeError(
"Comparisons should only involve OnDemandFeatureView class objects."
)

if not super().__eq__(other):
return False

if (
not self.source_feature_view_projections
== other.source_feature_view_projections
or not self.source_request_sources == other.source_request_sources
or not self.udf.__code__.co_code == other.udf.__code__.co_code
self.source_feature_view_projections
!= other.source_feature_view_projections
or self.source_request_sources != other.source_request_sources
or self.udf.__code__.co_code != other.udf.__code__.co_code
):
return False

Expand Down
4 changes: 2 additions & 2 deletions sdk/python/feast/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from enum import Enum
from pathlib import Path
from threading import Lock
from typing import Any, Dict, List, Optional, Set
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse

import dill
Expand Down Expand Up @@ -98,7 +98,7 @@ def get_objects_from_registry(
@staticmethod
def get_objects_from_repo_contents(
repo_contents: RepoContents,
) -> Dict["FeastObjectType", Set[Any]]:
) -> Dict["FeastObjectType", List[Any]]:
return {
FeastObjectType.DATA_SOURCE: repo_contents.data_sources,
FeastObjectType.ENTITY: repo_contents.entities,
Expand Down
14 changes: 7 additions & 7 deletions sdk/python/feast/repo_contents.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import NamedTuple, Set
from typing import List, NamedTuple

from feast.data_source import DataSource
from feast.entity import Entity
Expand All @@ -27,12 +27,12 @@ class RepoContents(NamedTuple):
Represents the objects in a Feast feature repo.
"""

data_sources: Set[DataSource]
feature_views: Set[FeatureView]
on_demand_feature_views: Set[OnDemandFeatureView]
request_feature_views: Set[RequestFeatureView]
entities: Set[Entity]
feature_services: Set[FeatureService]
data_sources: List[DataSource]
feature_views: List[FeatureView]
on_demand_feature_views: List[OnDemandFeatureView]
request_feature_views: List[RequestFeatureView]
entities: List[Entity]
feature_services: List[FeatureService]

def to_registry_proto(self) -> RegistryProto:
registry_proto = RegistryProto()
Expand Down
64 changes: 42 additions & 22 deletions sdk/python/feast/repo_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,36 +94,56 @@ def get_repo_files(repo_root: Path) -> List[Path]:


def parse_repo(repo_root: Path) -> RepoContents:
"""Collect feature table definitions from feature repo"""
"""
Collects unique Feast object definitions from the given feature repo.
Specifically, if an object foo has already been added, bar will still be added if
(bar == foo), but not if (bar is foo). This ensures that import statements will
not result in duplicates, but defining two equal objects will.
"""
res = RepoContents(
data_sources=set(),
entities=set(),
feature_views=set(),
feature_services=set(),
on_demand_feature_views=set(),
request_feature_views=set(),
data_sources=[],
entities=[],
feature_views=[],
feature_services=[],
on_demand_feature_views=[],
request_feature_views=[],
)

for repo_file in get_repo_files(repo_root):
module_path = py_path_to_module(repo_file, repo_root)
module = importlib.import_module(module_path)
for attr_name in dir(module):
obj = getattr(module, attr_name)
if isinstance(obj, DataSource):
res.data_sources.add(obj)
if isinstance(obj, FeatureView):
res.feature_views.add(obj)
if isinstance(obj.stream_source, PushSource):
res.data_sources.add(obj.stream_source.batch_source)
elif isinstance(obj, Entity):
res.entities.add(obj)
elif isinstance(obj, FeatureService):
res.feature_services.add(obj)
elif isinstance(obj, OnDemandFeatureView):
res.on_demand_feature_views.add(obj)
elif isinstance(obj, RequestFeatureView):
res.request_feature_views.add(obj)
res.entities.add(DUMMY_ENTITY)
if isinstance(obj, DataSource) and not any(
(obj is ds) for ds in res.data_sources
):
res.data_sources.append(obj)
if isinstance(obj, FeatureView) and not any(
(obj is fv) for fv in res.feature_views
):
res.feature_views.append(obj)
if isinstance(obj.stream_source, PushSource) and not any(
(obj is ds) for ds in res.data_sources
):
res.data_sources.append(obj.stream_source.batch_source)
elif isinstance(obj, Entity) and not any(
(obj is entity) for entity in res.entities
):
res.entities.append(obj)
elif isinstance(obj, FeatureService) and not any(
(obj is fs) for fs in res.feature_services
):
res.feature_services.append(obj)
elif isinstance(obj, OnDemandFeatureView) and not any(
(obj is odfv) for odfv in res.on_demand_feature_views
):
res.on_demand_feature_views.append(obj)
elif isinstance(obj, RequestFeatureView) and not any(
(obj is rfv) for rfv in res.request_feature_views
):
res.request_feature_views.append(obj)
res.entities.append(DUMMY_ENTITY)
return res


Expand Down
Loading

0 comments on commit ebb7dfe

Please sign in to comment.