Skip to content

Commit

Permalink
Uncomment
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Zhang <[email protected]>
  • Loading branch information
kevjumba committed Jun 22, 2022
1 parent b48d377 commit 2e1ddb1
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 143 deletions.
148 changes: 74 additions & 74 deletions sdk/python/tests/integration/online_store/test_universal_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,80 +441,80 @@ def test_online_retrieval_with_event_timestamps(
)


# @pytest.mark.integration
# @pytest.mark.universal_online_stores
# # @pytest.mark.goserver Disabling because the go fs tests are flaking in CI. TODO(achals): uncomment after fixed.
# @pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v))
# def test_stream_feature_view_online_retrieval(
# environment, universal_data_sources, feature_server_endpoint, full_feature_names
# ):
# """
# Tests materialization and online retrieval for stream feature views.

# This test is separate from test_online_retrieval since combining feature views and
# stream feature views into a single test resulted in test flakiness. This is tech
# debt that should be resolved soon.
# """
# # Set up feature store.
# fs = environment.feature_store
# entities, datasets, data_sources = universal_data_sources
# feature_views = construct_universal_feature_views(data_sources)
# pushable_feature_view = feature_views.pushed_locations
# fs.apply([location(), pushable_feature_view])

# # Materialize.
# fs.materialize(
# environment.start_date - timedelta(days=1),
# environment.end_date + timedelta(days=1),
# )

# # Get online features by randomly sampling 10 entities that exist in the batch source.
# sample_locations = datasets.location_df.sample(10)["location_id"]
# entity_rows = [
# {"location_id": sample_location} for sample_location in sample_locations
# ]

# feature_refs = [
# "pushable_location_stats:temperature",
# ]
# unprefixed_feature_refs = [f.rsplit(":", 1)[-1] for f in feature_refs if ":" in f]

# online_features_dict = get_online_features_dict(
# environment=environment,
# endpoint=feature_server_endpoint,
# features=feature_refs,
# entity_rows=entity_rows,
# full_feature_names=full_feature_names,
# )

# # Check that the response has the expected set of keys.
# keys = set(online_features_dict.keys())
# expected_keys = set(
# f.replace(":", "__") if full_feature_names else f.split(":")[-1]
# for f in feature_refs
# ) | {"location_id"}
# assert (
# keys == expected_keys
# ), f"Response keys are different from expected: {keys - expected_keys} (extra) and {expected_keys - keys} (missing)"

# # Check that the feature values match.
# tc = unittest.TestCase()
# for i, entity_row in enumerate(entity_rows):
# df_features = get_latest_feature_values_from_location_df(
# entity_row, datasets.location_df
# )

# assert df_features["location_id"] == online_features_dict["location_id"][i]
# for unprefixed_feature_ref in unprefixed_feature_refs:
# tc.assertAlmostEqual(
# df_features[unprefixed_feature_ref],
# online_features_dict[
# response_feature_name(
# unprefixed_feature_ref, feature_refs, full_feature_names
# )
# ][i],
# delta=0.0001,
# )
@pytest.mark.integration
@pytest.mark.universal_online_stores
# @pytest.mark.goserver Disabling because the go fs tests are flaking in CI. TODO(achals): uncomment after fixed.
@pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v))
def test_stream_feature_view_online_retrieval(
environment, universal_data_sources, feature_server_endpoint, full_feature_names
):
"""
Tests materialization and online retrieval for stream feature views.
This test is separate from test_online_retrieval since combining feature views and
stream feature views into a single test resulted in test flakiness. This is tech
debt that should be resolved soon.
"""
# Set up feature store.
fs = environment.feature_store
entities, datasets, data_sources = universal_data_sources
feature_views = construct_universal_feature_views(data_sources)
pushable_feature_view = feature_views.pushed_locations
fs.apply([location(), pushable_feature_view])

# Materialize.
fs.materialize(
environment.start_date - timedelta(days=1),
environment.end_date + timedelta(days=1),
)

# Get online features by randomly sampling 10 entities that exist in the batch source.
sample_locations = datasets.location_df.sample(10)["location_id"]
entity_rows = [
{"location_id": sample_location} for sample_location in sample_locations
]

feature_refs = [
"pushable_location_stats:temperature",
]
unprefixed_feature_refs = [f.rsplit(":", 1)[-1] for f in feature_refs if ":" in f]

online_features_dict = get_online_features_dict(
environment=environment,
endpoint=feature_server_endpoint,
features=feature_refs,
entity_rows=entity_rows,
full_feature_names=full_feature_names,
)

# Check that the response has the expected set of keys.
keys = set(online_features_dict.keys())
expected_keys = set(
f.replace(":", "__") if full_feature_names else f.split(":")[-1]
for f in feature_refs
) | {"location_id"}
assert (
keys == expected_keys
), f"Response keys are different from expected: {keys - expected_keys} (extra) and {expected_keys - keys} (missing)"

# Check that the feature values match.
tc = unittest.TestCase()
for i, entity_row in enumerate(entity_rows):
df_features = get_latest_feature_values_from_location_df(
entity_row, datasets.location_df
)

assert df_features["location_id"] == online_features_dict["location_id"][i]
for unprefixed_feature_ref in unprefixed_feature_refs:
tc.assertAlmostEqual(
df_features[unprefixed_feature_ref],
online_features_dict[
response_feature_name(
unprefixed_feature_ref, feature_refs, full_feature_names
)
][i],
delta=0.0001,
)


@pytest.mark.integration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,72 +77,72 @@ def simple_sfv(df):
assert features["dummy_field"] == [None]


# @pytest.mark.integration
# def test_stream_feature_view_udf(simple_dataset_1) -> None:
# """
# Test apply of StreamFeatureView udfs are serialized correctly and usable.
# """
# runner = CliRunner()
# with runner.local_repo(
# get_example_repo("example_feature_repo_1.py"), "bigquery"
# ) as fs, prep_file_source(
# df=simple_dataset_1, timestamp_field="ts_1"
# ) as file_source:
# entity = Entity(name="driver_entity", join_keys=["test_key"])

# stream_source = KafkaSource(
# name="kafka",
# timestamp_field="event_timestamp",
# kafka_bootstrap_servers="",
# message_format=AvroFormat(""),
# topic="topic",
# batch_source=file_source,
# watermark_delay_threshold=timedelta(days=1),
# )

# @stream_feature_view(
# entities=[entity],
# ttl=timedelta(days=30),
# owner="[email protected]",
# online=True,
# schema=[Field(name="dummy_field", dtype=Float32)],
# description="desc",
# aggregations=[
# Aggregation(
# column="dummy_field", function="max", time_window=timedelta(days=1),
# ),
# Aggregation(
# column="dummy_field2",
# function="count",
# time_window=timedelta(days=24),
# ),
# ],
# timestamp_field="event_timestamp",
# mode="spark",
# source=stream_source,
# tags={},
# )
# def pandas_view(pandas_df):
# import pandas as pd

# assert type(pandas_df) == pd.DataFrame
# df = pandas_df.transform(lambda x: x + 10, axis=1)
# df.insert(2, "C", [20.2, 230.0, 34.0], True)
# return df

# import pandas as pd

# fs.apply([entity, pandas_view])

# stream_feature_views = fs.list_stream_feature_views()
# assert len(stream_feature_views) == 1
# assert stream_feature_views[0] == pandas_view

# sfv = stream_feature_views[0]

# df = pd.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]})
# new_df = sfv.udf(df)
# expected_df = pd.DataFrame(
# {"A": [11, 12, 13], "B": [20, 30, 40], "C": [20.2, 230.0, 34.0]}
# )
# assert new_df.equals(expected_df)
@pytest.mark.integration
def test_stream_feature_view_udf(simple_dataset_1) -> None:
"""
Test apply of StreamFeatureView udfs are serialized correctly and usable.
"""
runner = CliRunner()
with runner.local_repo(
get_example_repo("example_feature_repo_1.py"), "bigquery"
) as fs, prep_file_source(
df=simple_dataset_1, timestamp_field="ts_1"
) as file_source:
entity = Entity(name="driver_entity", join_keys=["test_key"])

stream_source = KafkaSource(
name="kafka",
timestamp_field="event_timestamp",
kafka_bootstrap_servers="",
message_format=AvroFormat(""),
topic="topic",
batch_source=file_source,
watermark_delay_threshold=timedelta(days=1),
)

@stream_feature_view(
entities=[entity],
ttl=timedelta(days=30),
owner="[email protected]",
online=True,
schema=[Field(name="dummy_field", dtype=Float32)],
description="desc",
aggregations=[
Aggregation(
column="dummy_field", function="max", time_window=timedelta(days=1),
),
Aggregation(
column="dummy_field2",
function="count",
time_window=timedelta(days=24),
),
],
timestamp_field="event_timestamp",
mode="spark",
source=stream_source,
tags={},
)
def pandas_view(pandas_df):
import pandas as pd

assert type(pandas_df) == pd.DataFrame
df = pandas_df.transform(lambda x: x + 10, axis=1)
df.insert(2, "C", [20.2, 230.0, 34.0], True)
return df

import pandas as pd

fs.apply([entity, pandas_view])

stream_feature_views = fs.list_stream_feature_views()
assert len(stream_feature_views) == 1
assert stream_feature_views[0] == pandas_view

sfv = stream_feature_views[0]

df = pd.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]})
new_df = sfv.udf(df)
expected_df = pd.DataFrame(
{"A": [11, 12, 13], "B": [20, 30, 40], "C": [20.2, 230.0, 34.0]}
)
assert new_df.equals(expected_df)

0 comments on commit 2e1ddb1

Please sign in to comment.