Skip to content

Commit

Permalink
Refactor Environment class and DataSourceCreator API, and use fixture…
Browse files Browse the repository at this point in the history
…s for datasets and data sources (#1790)

* Fix API cruft from DataSourceCreator

Signed-off-by: Achal Shah <[email protected]>

* Remove the need for get_prefixed_table_name

Signed-off-by: Achal Shah <[email protected]>

* major refactor

Signed-off-by: Achal Shah <[email protected]>

* move start time

Signed-off-by: Achal Shah <[email protected]>

* Remove one dimension of variation to be added in later

Signed-off-by: Achal Shah <[email protected]>

* Fix default

Signed-off-by: Achal Shah <[email protected]>

* Fixups

Signed-off-by: Achal Shah <[email protected]>

* Fixups

Signed-off-by: Achal Shah <[email protected]>

* Fix up tests

Signed-off-by: Achal Shah <[email protected]>

* Add retries to execute_redshift_statement_async

Signed-off-by: Achal Shah <[email protected]>

* Add retries to execute_redshift_statement_async

Signed-off-by: Achal Shah <[email protected]>

* refactoooor

Signed-off-by: Achal Shah <[email protected]>

* remove retries

Signed-off-by: Achal Shah <[email protected]>

* Remove provider variation since they don't really play a big role

Signed-off-by: Achal Shah <[email protected]>

* Session scoped cache for test datasets and skipping older tests whose functionality is present in other universal tests

Signed-off-by: Achal Shah <[email protected]>

* make format

Signed-off-by: Achal Shah <[email protected]>

* make format

Signed-off-by: Achal Shah <[email protected]>

* remove import

Signed-off-by: Achal Shah <[email protected]>

* fix merge

Signed-off-by: Achal Shah <[email protected]>

* Use an enum for the stopping procedure instead of the bools

Signed-off-by: Achal Shah <[email protected]>

* Fix refs

Signed-off-by: Achal Shah <[email protected]>

* fix step

Signed-off-by: Achal Shah <[email protected]>

* WIP fixes

Signed-off-by: Achal Shah <[email protected]>

* Fix for feature inferencing

Signed-off-by: Achal Shah <[email protected]>

* C901 '_python_value_to_proto_value' is too complex :(

Signed-off-by: Achal Shah <[email protected]>

* Split out construct_test_repo and construct_universal_test_repo

Signed-off-by: Achal Shah <[email protected]>

* remove import

Signed-off-by: Achal Shah <[email protected]>

* add unsafe_hash

Signed-off-by: Achal Shah <[email protected]>

* Update testrepoconfig

Signed-off-by: Achal Shah <[email protected]>

* Update testrepoconfig

Signed-off-by: Achal Shah <[email protected]>

* Remove kwargs from construct_universal_test_environment

Signed-off-by: Achal Shah <[email protected]>

* Remove unneeded method

Signed-off-by: Achal Shah <[email protected]>

* Docs

Signed-off-by: Achal Shah <[email protected]>

* Kill skipped tests

Signed-off-by: Achal Shah <[email protected]>

* reorder

Signed-off-by: Achal Shah <[email protected]>

* add todo

Signed-off-by: Achal Shah <[email protected]>

* Split universal vs non data_source_cache

Signed-off-by: Achal Shah <[email protected]>

* make format

Signed-off-by: Achal Shah <[email protected]>

* WIP fixtures

Signed-off-by: Achal Shah <[email protected]>

* WIP Trying fixtures more effectively

Signed-off-by: Achal Shah <[email protected]>

* fix refs

Signed-off-by: Achal Shah <[email protected]>

* Fix refs

Signed-off-by: Achal Shah <[email protected]>

* Fix refs

Signed-off-by: Achal Shah <[email protected]>

* Fix refs

Signed-off-by: Achal Shah <[email protected]>

* fix historical tests

Signed-off-by: Achal Shah <[email protected]>

* renames

Signed-off-by: Achal Shah <[email protected]>

* CR updates

Signed-off-by: Achal Shah <[email protected]>

* use the actual ref to data source creators

Signed-off-by: Achal Shah <[email protected]>

* format

Signed-off-by: Achal Shah <[email protected]>

* unused imports'

Signed-off-by: Achal Shah <[email protected]>

* Add ids for pytest params

Signed-off-by: Achal Shah <[email protected]>
  • Loading branch information
achals authored Sep 1, 2021
1 parent ef7200a commit 66cf6a4
Show file tree
Hide file tree
Showing 20 changed files with 629 additions and 1,195 deletions.
24 changes: 23 additions & 1 deletion sdk/python/feast/feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,28 @@ def infer_features_from_batch_source(self, config: RepoConfig):
self.batch_source.created_timestamp_column,
} | set(self.entities)

if (
self.batch_source.event_timestamp_column
in self.batch_source.field_mapping
):
columns_to_exclude.add(
self.batch_source.field_mapping[
self.batch_source.event_timestamp_column
]
)
if (
self.batch_source.created_timestamp_column
in self.batch_source.field_mapping
):
columns_to_exclude.add(
self.batch_source.field_mapping[
self.batch_source.created_timestamp_column
]
)
for e in self.entities:
if e in self.batch_source.field_mapping:
columns_to_exclude.add(self.batch_source.field_mapping[e])

for (
col_name,
col_datatype,
Expand All @@ -335,7 +357,7 @@ def infer_features_from_batch_source(self, config: RepoConfig):
):
feature_name = (
self.batch_source.field_mapping[col_name]
if col_name in self.batch_source.field_mapping.keys()
if col_name in self.batch_source.field_mapping
else col_name
)
self.features.append(
Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/infra/offline_stores/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def pull_latest_from_table_or_query(
)
WHERE _feast_row = 1
"""

return BigQueryRetrievalJob(query=query, client=client, config=config)

@staticmethod
Expand Down
6 changes: 6 additions & 0 deletions sdk/python/feast/infra/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,12 @@ def _get_column_names(
reverse_field_mapping[col] if col in reverse_field_mapping.keys() else col
for col in feature_names
]

# We need to exclude join keys and timestamp columns from the list of features, after they are mapped to
# their final column names via the `field_mapping` field of the source.
_feature_names = set(feature_names) - set(join_keys)
_feature_names = _feature_names - {event_timestamp_column, created_timestamp_column}
feature_names = list(_feature_names)
return (
join_keys,
feature_names,
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/infra/utils/aws_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class RedshiftStatementNotFinishedError(Exception):


@retry(
wait=wait_exponential(multiplier=0.1, max=30),
wait=wait_exponential(multiplier=1, max=30),
retry=retry_if_exception_type(RedshiftStatementNotFinishedError),
)
def wait_for_redshift_statement(redshift_data_client, statement: dict) -> None:
Expand Down
11 changes: 10 additions & 1 deletion sdk/python/feast/type_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.

import re
from datetime import datetime
from typing import Any, Dict, Union

import numpy as np
import pandas as pd
from google.protobuf.json_format import MessageToDict
from google.protobuf.timestamp_pb2 import Timestamp

from feast.protos.feast.types.Value_pb2 import (
BoolList,
Expand Down Expand Up @@ -104,6 +106,8 @@ def python_type_to_feast_value_type(
"int8": ValueType.INT32,
"bool": ValueType.BOOL,
"timedelta": ValueType.UNIX_TIMESTAMP,
"Timestamp": ValueType.UNIX_TIMESTAMP,
"datetime": ValueType.UNIX_TIMESTAMP,
"datetime64[ns]": ValueType.UNIX_TIMESTAMP,
"datetime64[ns, tz]": ValueType.UNIX_TIMESTAMP,
"category": ValueType.STRING,
Expand Down Expand Up @@ -160,7 +164,8 @@ def _type_err(item, dtype):
raise ValueError(f'Value "{item}" is of type {type(item)} not of type {dtype}')


def _python_value_to_proto_value(feast_value_type, value) -> ProtoValue:
# TODO(achals): Simplify this method and remove the noqa.
def _python_value_to_proto_value(feast_value_type, value) -> ProtoValue: # noqa: C901
"""
Converts a Python (native, pandas) value to a Feast Proto Value based
on a provided value type
Expand Down Expand Up @@ -281,6 +286,10 @@ def _python_value_to_proto_value(feast_value_type, value) -> ProtoValue:
elif feast_value_type == ValueType.INT64:
return ProtoValue(int64_val=int(value))
elif feast_value_type == ValueType.UNIX_TIMESTAMP:
if isinstance(value, datetime):
return ProtoValue(int64_val=int(value.timestamp()))
elif isinstance(value, Timestamp):
return ProtoValue(int64_val=int(value.ToSeconds()))
return ProtoValue(int64_val=int(value))
elif feast_value_type == ValueType.FLOAT:
return ProtoValue(float_val=float(value))
Expand Down
51 changes: 51 additions & 0 deletions sdk/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@
import pandas as pd
import pytest

from tests.data.data_creator import create_dataset
from tests.integration.feature_repos.repo_configuration import (
FULL_REPO_CONFIGS,
Environment,
construct_test_environment,
construct_universal_data_sources,
construct_universal_datasets,
construct_universal_entities,
)


def pytest_configure(config):
if platform in ["darwin", "windows"]:
Expand Down Expand Up @@ -87,3 +97,44 @@ def simple_dataset_2() -> pd.DataFrame:
],
}
return pd.DataFrame.from_dict(data)


@pytest.fixture(
params=FULL_REPO_CONFIGS, scope="session", ids=[str(c) for c in FULL_REPO_CONFIGS]
)
def environment(request):
with construct_test_environment(request.param) as e:
yield e


@pytest.fixture(scope="session")
def universal_data_sources(environment):
entities = construct_universal_entities()
datasets = construct_universal_datasets(
entities, environment.start_date, environment.end_date
)
datasources = construct_universal_data_sources(
datasets, environment.data_source_creator
)

yield entities, datasets, datasources

environment.data_source_creator.teardown()


@pytest.fixture(scope="session")
def e2e_data_sources(environment: Environment):
df = create_dataset()
data_source = environment.data_source_creator.create_data_source(
df, environment.feature_store.project, field_mapping={"ts_1": "ts"},
)

yield df, data_source

environment.data_source_creator.teardown()


@pytest.fixture(params=FULL_REPO_CONFIGS, scope="session")
def type_test_environment(request):
with construct_test_environment(request.param) as e:
yield e
18 changes: 8 additions & 10 deletions sdk/python/tests/integration/e2e/test_universal_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,21 @@
from typing import Optional

import pandas as pd
import pytest
from pytz import utc

from feast import FeatureStore, FeatureView
from tests.integration.feature_repos.test_repo_configuration import (
Environment,
parametrize_e2e_test,
)
from tests.integration.feature_repos.universal.entities import driver
from tests.integration.feature_repos.universal.feature_views import driver_feature_view


@parametrize_e2e_test
def test_e2e_consistency(test_environment: Environment):
fs, fv = (
test_environment.feature_store,
driver_feature_view(test_environment.data_source),
)
@pytest.mark.integration
@pytest.mark.parametrize("infer_features", [True, False])
def test_e2e_consistency(environment, e2e_data_sources, infer_features):
fs = environment.feature_store
df, data_source = e2e_data_sources
fv = driver_feature_view(data_source=data_source, infer_features=infer_features)

entity = driver()
fs.apply([fv, entity])

Expand Down
192 changes: 192 additions & 0 deletions sdk/python/tests/integration/feature_repos/repo_configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import tempfile
import uuid
from contextlib import contextmanager
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, Union

import pandas as pd

from feast import FeatureStore, FeatureView, RepoConfig, driver_test_data
from feast.data_source import DataSource
from tests.integration.feature_repos.universal.data_source_creator import (
DataSourceCreator,
)
from tests.integration.feature_repos.universal.data_sources.bigquery import (
BigQueryDataSourceCreator,
)
from tests.integration.feature_repos.universal.data_sources.file import (
FileDataSourceCreator,
)
from tests.integration.feature_repos.universal.data_sources.redshift import (
RedshiftDataSourceCreator,
)
from tests.integration.feature_repos.universal.feature_views import (
create_customer_daily_profile_feature_view,
create_driver_hourly_stats_feature_view,
)


@dataclass(frozen=True, repr=True)
class IntegrationTestRepoConfig:
"""
This class should hold all possible parameters that may need to be varied by individual tests.
"""

provider: str = "local"
online_store: Union[str, Dict] = "sqlite"

offline_store_creator: Type[DataSourceCreator] = FileDataSourceCreator

full_feature_names: bool = True
infer_event_timestamp_col: bool = True
infer_features: bool = False


DYNAMO_CONFIG = {"type": "dynamodb", "region": "us-west-2"}
REDIS_CONFIG = {"type": "redis", "connection_string": "localhost:6379,db=0"}
FULL_REPO_CONFIGS: List[IntegrationTestRepoConfig] = [
# Local configurations
IntegrationTestRepoConfig(),
IntegrationTestRepoConfig(online_store=REDIS_CONFIG),
# GCP configurations
IntegrationTestRepoConfig(
provider="gcp",
offline_store_creator=BigQueryDataSourceCreator,
online_store="datastore",
),
IntegrationTestRepoConfig(
provider="gcp",
offline_store_creator=BigQueryDataSourceCreator,
online_store=REDIS_CONFIG,
),
# AWS configurations
IntegrationTestRepoConfig(
provider="aws",
offline_store_creator=RedshiftDataSourceCreator,
online_store=DYNAMO_CONFIG,
),
IntegrationTestRepoConfig(
provider="aws",
offline_store_creator=RedshiftDataSourceCreator,
online_store=REDIS_CONFIG,
),
]


def construct_universal_entities() -> Dict[str, List[Any]]:
return {"customer": list(range(1001, 1110)), "driver": list(range(5001, 5110))}


def construct_universal_datasets(
entities: Dict[str, List[Any]], start_time: datetime, end_time: datetime
) -> Dict[str, pd.DataFrame]:
customer_df = driver_test_data.create_customer_daily_profile_df(
entities["customer"], start_time, end_time
)
driver_df = driver_test_data.create_driver_hourly_stats_df(
entities["driver"], start_time, end_time
)
orders_df = driver_test_data.create_orders_df(
customers=entities["customer"],
drivers=entities["driver"],
start_date=end_time - timedelta(days=365),
end_date=end_time + timedelta(days=365),
order_count=1000,
)

return {"customer": customer_df, "driver": driver_df, "orders": orders_df}


def construct_universal_data_sources(
datasets: Dict[str, pd.DataFrame], data_source_creator: DataSourceCreator
) -> Dict[str, DataSource]:
customer_ds = data_source_creator.create_data_source(
datasets["customer"],
destination_name="customer_profile",
event_timestamp_column="event_timestamp",
created_timestamp_column="created",
)
driver_ds = data_source_creator.create_data_source(
datasets["driver"],
destination_name="driver_hourly",
event_timestamp_column="event_timestamp",
created_timestamp_column="created",
)
orders_ds = data_source_creator.create_data_source(
datasets["orders"],
destination_name="orders",
event_timestamp_column="event_timestamp",
created_timestamp_column="created",
)
return {"customer": customer_ds, "driver": driver_ds, "orders": orders_ds}


def construct_universal_feature_views(
data_sources: Dict[str, DataSource],
) -> Dict[str, FeatureView]:
return {
"customer": create_customer_daily_profile_feature_view(
data_sources["customer"]
),
"driver": create_driver_hourly_stats_feature_view(data_sources["driver"]),
}


@dataclass
class Environment:
name: str
test_repo_config: IntegrationTestRepoConfig
feature_store: FeatureStore
data_source_creator: DataSourceCreator

end_date: datetime = field(
default=datetime.now().replace(microsecond=0, second=0, minute=0)
)

def __post_init__(self):
self.start_date: datetime = self.end_date - timedelta(days=7)


def table_name_from_data_source(ds: DataSource) -> Optional[str]:
if hasattr(ds, "table_ref"):
return ds.table_ref
elif hasattr(ds, "table"):
return ds.table
return None


@contextmanager
def construct_test_environment(
test_repo_config: IntegrationTestRepoConfig,
test_suite_name: str = "integration_test",
) -> Environment:
project = f"{test_suite_name}_{str(uuid.uuid4()).replace('-', '')[:8]}"

offline_creator: DataSourceCreator = test_repo_config.offline_store_creator(project)

offline_store_config = offline_creator.create_offline_store_config()
online_store = test_repo_config.online_store

with tempfile.TemporaryDirectory() as repo_dir_name:
config = RepoConfig(
registry=str(Path(repo_dir_name) / "registry.db"),
project=project,
provider=test_repo_config.provider,
offline_store=offline_store_config,
online_store=online_store,
repo_path=repo_dir_name,
)
fs = FeatureStore(config=config)
environment = Environment(
name=project,
test_repo_config=test_repo_config,
feature_store=fs,
data_source_creator=offline_creator,
)

try:
yield environment
finally:
fs.teardown()
Loading

0 comments on commit 66cf6a4

Please sign in to comment.