From d4b0b1a5045d04b4031f01c320d810b13180e64c Mon Sep 17 00:00:00 2001 From: Oleksii Moskalenko Date: Wed, 11 May 2022 11:15:32 -0700 Subject: [PATCH] chore: Generate environments for each individual test based on its markers/fixtures (#2648) * generating test environments bases on test markers and fixtures Signed-off-by: Oleksii Moskalenko * remove "universal" marker Signed-off-by: Oleksii Moskalenko --- Makefile | 10 +- .../contrib/contrib_repo_configuration.py | 13 +- .../contrib/postgres_repo_configuration.py | 13 +- .../feast/infra/online_stores/datastore.py | 22 +- ...st_benchmark_universal_online_retrieval.py | 1 + sdk/python/tests/conftest.py | 223 ++++++++++++------ .../integration/e2e/test_go_feature_server.py | 68 ++---- .../e2e/test_python_feature_server.py | 4 +- .../integration/e2e/test_universal_e2e.py | 2 +- .../tests/integration/e2e/test_validation.py | 3 +- .../integration_test_repo_config.py | 19 ++ .../feature_repos/repo_configuration.py | 123 +++------- .../feature_repos/universal/feature_views.py | 3 +- .../offline_store/test_feature_logging.py | 2 +- .../test_universal_historical_retrieval.py | 15 +- .../test_push_online_retrieval.py | 2 +- .../online_store/test_universal_online.py | 24 +- .../integration/registration/test_cli.py | 2 +- .../registration/test_inference.py | 1 - .../test_universal_odfv_feature_inference.py | 3 +- .../registration/test_universal_types.py | 183 +++++++------- 21 files changed, 368 insertions(+), 368 deletions(-) diff --git a/Makefile b/Makefile index 41e9166756..03a627edf5 100644 --- a/Makefile +++ b/Makefile @@ -76,7 +76,7 @@ test-python-universal-contrib: FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.offline_stores.contrib.contrib_repo_configuration \ PYTEST_PLUGINS=feast.infra.offline_stores.contrib.trino_offline_store.tests \ FEAST_USAGE=False IS_TEST=True \ - python -m pytest -n 8 --integration --universal \ + python -m pytest -n 8 --integration \ -k "not test_historical_retrieval_fails_on_validation and \ not test_historical_retrieval_with_validation and \ not test_historical_features_persisting and \ @@ -93,7 +93,7 @@ test-python-universal-postgres: PYTEST_PLUGINS=sdk.python.feast.infra.offline_stores.contrib.postgres_offline_store.tests \ FEAST_USAGE=False \ IS_TEST=True \ - python -m pytest -x --integration --universal \ + python -m pytest -x --integration \ -k "not test_historical_retrieval_fails_on_validation and \ not test_historical_retrieval_with_validation and \ not test_historical_features_persisting and \ @@ -105,10 +105,10 @@ test-python-universal-postgres: sdk/python/tests test-python-universal-local: - FEAST_USAGE=False IS_TEST=True FEAST_IS_LOCAL_TEST=True python -m pytest -n 8 --integration --universal sdk/python/tests + FEAST_USAGE=False IS_TEST=True FEAST_IS_LOCAL_TEST=True python -m pytest -n 8 --integration sdk/python/tests test-python-universal: - FEAST_USAGE=False IS_TEST=True python -m pytest -n 8 --integration --universal sdk/python/tests + FEAST_USAGE=False IS_TEST=True python -m pytest -n 8 --integration sdk/python/tests test-python-go-server: compile-go-lib FEAST_USAGE=False IS_TEST=True FEAST_GO_FEATURE_RETRIEVAL=True pytest --integration --goserver sdk/python/tests @@ -158,7 +158,7 @@ start-trino-locally: sleep 15 test-trino-plugin-locally: - cd ${ROOT_DIR}/sdk/python; FULL_REPO_CONFIGS_MODULE=feast.infra.offline_stores.contrib.trino_offline_store.test_config.manual_tests FEAST_USAGE=False IS_TEST=True python -m pytest --integration --universal tests/ + cd ${ROOT_DIR}/sdk/python; FULL_REPO_CONFIGS_MODULE=feast.infra.offline_stores.contrib.trino_offline_store.test_config.manual_tests FEAST_USAGE=False IS_TEST=True python -m pytest --integration tests/ kill-trino-locally: cd ${ROOT_DIR}; docker stop trino diff --git a/sdk/python/feast/infra/offline_stores/contrib/contrib_repo_configuration.py b/sdk/python/feast/infra/offline_stores/contrib/contrib_repo_configuration.py index 0346f6cf3a..083ec2b210 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/contrib_repo_configuration.py +++ b/sdk/python/feast/infra/offline_stores/contrib/contrib_repo_configuration.py @@ -4,11 +4,14 @@ from feast.infra.offline_stores.contrib.trino_offline_store.tests.data_source import ( TrinoSourceCreator, ) -from tests.integration.feature_repos.integration_test_repo_config import ( - IntegrationTestRepoConfig, +from tests.integration.feature_repos.repo_configuration import REDIS_CONFIG +from tests.integration.feature_repos.universal.online_store.redis import ( + RedisOnlineStoreCreator, ) -FULL_REPO_CONFIGS = [ - IntegrationTestRepoConfig(offline_store_creator=SparkDataSourceCreator), - IntegrationTestRepoConfig(offline_store_creator=TrinoSourceCreator), +AVAILABLE_OFFLINE_STORES = [ + ("local", SparkDataSourceCreator), + ("local", TrinoSourceCreator), ] + +AVAILABLE_ONLINE_STORES = {"redis": (REDIS_CONFIG, RedisOnlineStoreCreator)} diff --git a/sdk/python/feast/infra/offline_stores/contrib/postgres_repo_configuration.py b/sdk/python/feast/infra/offline_stores/contrib/postgres_repo_configuration.py index 288c0574b1..9b107aa7a3 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/postgres_repo_configuration.py +++ b/sdk/python/feast/infra/offline_stores/contrib/postgres_repo_configuration.py @@ -1,14 +1,7 @@ from feast.infra.offline_stores.contrib.postgres_offline_store.tests.data_source import ( PostgreSQLDataSourceCreator, ) -from tests.integration.feature_repos.integration_test_repo_config import ( - IntegrationTestRepoConfig, -) -FULL_REPO_CONFIGS = [ - IntegrationTestRepoConfig( - provider="local", - offline_store_creator=PostgreSQLDataSourceCreator, - online_store_creator=PostgreSQLDataSourceCreator, - ), -] +AVAILABLE_OFFLINE_STORES = [("local", PostgreSQLDataSourceCreator)] + +AVAILABLE_ONLINE_STORES = {"postgres": (None, PostgreSQLDataSourceCreator)} diff --git a/sdk/python/feast/infra/online_stores/datastore.py b/sdk/python/feast/infra/online_stores/datastore.py index e975ce138c..fc3659ea1a 100644 --- a/sdk/python/feast/infra/online_stores/datastore.py +++ b/sdk/python/feast/infra/online_stores/datastore.py @@ -15,7 +15,7 @@ import logging from datetime import datetime from multiprocessing.pool import ThreadPool -from queue import Queue +from queue import Empty, Queue from threading import Lock, Thread from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple @@ -292,22 +292,24 @@ def increment(self): def worker(shared_counter): while True: - client.delete_multi(deletion_queue.get()) + try: + job = deletion_queue.get(block=False) + except Empty: + return + + client.delete_multi(job) shared_counter.increment() LOGGER.debug( f"batch deletions completed: {shared_counter.value} ({shared_counter.value * BATCH_SIZE} total entries) & outstanding queue size: {deletion_queue.qsize()}" ) deletion_queue.task_done() - for _ in range(NUM_THREADS): - Thread(target=worker, args=(status_info_counter,), daemon=True).start() - query = client.query(kind="Row", ancestor=key) - while True: - entities = list(query.fetch(limit=BATCH_SIZE)) - if not entities: - break - deletion_queue.put([entity.key for entity in entities]) + for page in query.fetch().pages: + deletion_queue.put([entity.key for entity in page]) + + for _ in range(NUM_THREADS): + Thread(target=worker, args=(status_info_counter,)).start() deletion_queue.join() diff --git a/sdk/python/tests/benchmarks/test_benchmark_universal_online_retrieval.py b/sdk/python/tests/benchmarks/test_benchmark_universal_online_retrieval.py index a29383a5c9..6e22c93e5f 100644 --- a/sdk/python/tests/benchmarks/test_benchmark_universal_online_retrieval.py +++ b/sdk/python/tests/benchmarks/test_benchmark_universal_online_retrieval.py @@ -17,6 +17,7 @@ @pytest.mark.benchmark @pytest.mark.integration +@pytest.mark.universal_online_stores def test_online_retrieval(environment, universal_data_sources, benchmark): fs = environment.feature_store entities, datasets, data_sources = universal_data_sources diff --git a/sdk/python/tests/conftest.py b/sdk/python/tests/conftest.py index 7c934d620e..d492c7ba84 100644 --- a/sdk/python/tests/conftest.py +++ b/sdk/python/tests/conftest.py @@ -13,30 +13,34 @@ # limitations under the License. import logging import multiprocessing -import time +import socket +from contextlib import closing from datetime import datetime, timedelta from multiprocessing import Process from sys import platform -from typing import List +from typing import Any, Dict, List import pandas as pd import pytest from _pytest.nodes import Item from feast import FeatureStore +from feast.wait import wait_retry_backoff from tests.data.data_creator import create_dataset from tests.integration.feature_repos.integration_test_repo_config import ( IntegrationTestRepoConfig, ) from tests.integration.feature_repos.repo_configuration import ( - FULL_REPO_CONFIGS, - REDIS_CLUSTER_CONFIG, - REDIS_CONFIG, + AVAILABLE_OFFLINE_STORES, + AVAILABLE_ONLINE_STORES, Environment, TestData, construct_test_environment, construct_universal_test_data, ) +from tests.integration.feature_repos.universal.data_sources.file import ( + FileDataSourceCreator, +) logger = logging.getLogger(__name__) @@ -50,9 +54,6 @@ def pytest_configure(config): "markers", "integration: mark test that has external dependencies" ) config.addinivalue_line("markers", "benchmark: mark benchmarking tests") - config.addinivalue_line( - "markers", "universal: mark tests that use the universal feature repo" - ) config.addinivalue_line( "markers", "goserver: mark tests that use the go feature server" ) @@ -68,9 +69,6 @@ def pytest_addoption(parser): parser.addoption( "--benchmark", action="store_true", default=False, help="Run benchmark tests", ) - parser.addoption( - "--universal", action="store_true", default=False, help="Run universal tests", - ) parser.addoption( "--goserver", action="store_true", @@ -82,7 +80,6 @@ def pytest_addoption(parser): def pytest_collection_modifyitems(config, items: List[Item]): should_run_integration = config.getoption("--integration") is True should_run_benchmark = config.getoption("--benchmark") is True - should_run_universal = config.getoption("--universal") is True should_run_goserver = config.getoption("--goserver") is True integration_tests = [t for t in items if "integration" in t.keywords] @@ -103,12 +100,6 @@ def pytest_collection_modifyitems(config, items: List[Item]): for t in benchmark_tests: items.append(t) - universal_tests = [t for t in items if "universal" in t.keywords] - if should_run_universal: - items.clear() - for t in universal_tests: - items.append(t) - goserver_tests = [t for t in items if "goserver" in t.keywords] if should_run_goserver: items.clear() @@ -161,86 +152,168 @@ def start_test_local_server(repo_path: str, port: int): fs.serve("localhost", port, no_access_log=True) -@pytest.fixture( - params=FULL_REPO_CONFIGS, scope="session", ids=[str(c) for c in FULL_REPO_CONFIGS] -) -def environment(request, worker_id: str): +@pytest.fixture(scope="session") +def environment(request, worker_id): e = construct_test_environment( request.param, worker_id=worker_id, fixture_request=request ) + + yield e + + e.feature_store.teardown() + e.data_source_creator.teardown() + if e.online_store_creator: + e.online_store_creator.teardown() + + +_config_cache = {} + + +def pytest_generate_tests(metafunc: pytest.Metafunc): + """ + This function receives each test function (wrapped in Metafunc) + at the collection stage (before tests started). + Here we can access all fixture requests made by the test as well as its markers. + That allows us to dynamically parametrize the test based on markers and fixtures + by calling metafunc.parametrize(...). + + See more examples at https://docs.pytest.org/en/6.2.x/example/parametrize.html#paramexamples + + We also utilize indirect parametrization here. Since `environment` is a fixture, + when we call metafunc.parametrize("environment", ..., indirect=True) we actually + parametrizing this "environment" fixture and not the test itself. + Moreover, by utilizing `_config_cache` we are able to share `environment` fixture between different tests. + In order for pytest to group tests together (and share environment fixture) + parameter should point to the same Python object (hence, we use _config_cache dict to store those objects). + """ + if "environment" in metafunc.fixturenames: + markers = {m.name: m for m in metafunc.definition.own_markers} + + if "universal_offline_stores" in markers: + offline_stores = AVAILABLE_OFFLINE_STORES + else: + # default offline store for testing online store dimension + offline_stores = [("local", FileDataSourceCreator)] + + online_stores = None + if "universal_online_stores" in markers: + # Online stores are explicitly requested + if "only" in markers["universal_online_stores"].kwargs: + online_stores = [ + AVAILABLE_ONLINE_STORES.get(store_name) + for store_name in markers["universal_online_stores"].kwargs["only"] + if store_name in AVAILABLE_ONLINE_STORES + ] + else: + online_stores = AVAILABLE_ONLINE_STORES.values() + + if online_stores is None: + # No online stores requested -> setting the default or first available + online_stores = [ + AVAILABLE_ONLINE_STORES.get( + "redis", + AVAILABLE_ONLINE_STORES.get( + "sqlite", next(iter(AVAILABLE_ONLINE_STORES.values())) + ), + ) + ] + + extra_dimensions: List[Dict[str, Any]] = [{}] + + if "python_server" in metafunc.fixturenames: + extra_dimensions.extend( + [ + {"python_feature_server": True}, + {"python_feature_server": True, "provider": "aws"}, + ] + ) + + if "goserver" in markers: + extra_dimensions.append({"go_feature_retrieval": True}) + + configs = [] + for provider, offline_store_creator in offline_stores: + for online_store, online_store_creator in online_stores: + for dim in extra_dimensions: + config = { + "provider": provider, + "offline_store_creator": offline_store_creator, + "online_store": online_store, + "online_store_creator": online_store_creator, + **dim, + } + # temporary Go works only with redis + if config.get("go_feature_retrieval") and ( + not isinstance(online_store, dict) + or online_store["type"] != "redis" + ): + continue + + # aws lambda works only with dynamo + if ( + config.get("python_feature_server") + and config.get("provider") == "aws" + and ( + not isinstance(online_store, dict) + or online_store["type"] != "dynamodb" + ) + ): + continue + + c = IntegrationTestRepoConfig(**config) + + if c not in _config_cache: + _config_cache[c] = c + + configs.append(_config_cache[c]) + + metafunc.parametrize( + "environment", configs, indirect=True, ids=[str(c) for c in configs] + ) + + +@pytest.fixture(scope="session") +def python_server(environment): proc = Process( target=start_test_local_server, - args=(e.feature_store.repo_path, e.get_local_server_port()), + args=(environment.feature_store.repo_path, environment.get_local_server_port()), daemon=True, ) - if e.python_feature_server and e.test_repo_config.provider == "local": + if ( + environment.python_feature_server + and environment.test_repo_config.provider == "local" + ): proc.start() # Wait for server to start - time.sleep(3) + wait_retry_backoff( + lambda: ( + None, + _check_port_open("localhost", environment.get_local_server_port()), + ), + timeout_secs=10, + ) - def cleanup(): - e.feature_store.teardown() - if proc.is_alive(): - proc.kill() - if e.online_store_creator: - e.online_store_creator.teardown() + yield - request.addfinalizer(cleanup) + if proc.is_alive(): + proc.kill() - return e - -@pytest.fixture( - params=[REDIS_CONFIG, REDIS_CLUSTER_CONFIG], - scope="session", - ids=[str(c) for c in [REDIS_CONFIG, REDIS_CLUSTER_CONFIG]], -) -def local_redis_environment(request, worker_id): - e = construct_test_environment( - IntegrationTestRepoConfig(online_store=request.param), - worker_id=worker_id, - fixture_request=request, - ) - - def cleanup(): - e.feature_store.teardown() - - request.addfinalizer(cleanup) - return e +def _check_port_open(host, port) -> bool: + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: + return sock.connect_ex((host, port)) == 0 @pytest.fixture(scope="session") -def universal_data_sources(request, environment) -> TestData: - def cleanup(): - # logger.info("Running cleanup in %s, Request: %s", worker_id, request.param) - environment.data_source_creator.teardown() - - request.addfinalizer(cleanup) +def universal_data_sources(environment) -> TestData: return construct_universal_test_data(environment) @pytest.fixture(scope="session") -def redis_universal_data_sources(request, local_redis_environment): - def cleanup(): - # logger.info("Running cleanup in %s, Request: %s", worker_id, request.param) - local_redis_environment.data_source_creator.teardown() - - request.addfinalizer(cleanup) - return construct_universal_test_data(local_redis_environment) - - -@pytest.fixture(scope="session") -def e2e_data_sources(environment: Environment, request): +def e2e_data_sources(environment: Environment): df = create_dataset() data_source = environment.data_source_creator.create_data_source( df, environment.feature_store.project, field_mapping={"ts_1": "ts"}, ) - def cleanup(): - environment.data_source_creator.teardown() - if environment.online_store_creator: - environment.online_store_creator.teardown() - - request.addfinalizer(cleanup) - return df, data_source diff --git a/sdk/python/tests/integration/e2e/test_go_feature_server.py b/sdk/python/tests/integration/e2e/test_go_feature_server.py index 6d165b86a5..e469c90c11 100644 --- a/sdk/python/tests/integration/e2e/test_go_feature_server.py +++ b/sdk/python/tests/integration/e2e/test_go_feature_server.py @@ -24,16 +24,8 @@ from feast.protos.feast.types.Value_pb2 import RepeatedValue from feast.type_map import python_values_to_proto_values from feast.wait import wait_retry_backoff -from tests.integration.feature_repos.integration_test_repo_config import ( - IntegrationTestRepoConfig, -) from tests.integration.feature_repos.repo_configuration import ( - AVAILABLE_OFFLINE_STORES, - AVAILABLE_ONLINE_STORES, - REDIS_CONFIG, - construct_test_environment, construct_universal_feature_views, - construct_universal_test_data, ) from tests.integration.feature_repos.universal.entities import ( customer, @@ -41,53 +33,23 @@ location, ) -LOCAL_REPO_CONFIGS = [ - IntegrationTestRepoConfig(online_store=REDIS_CONFIG, go_feature_retrieval=True), -] -LOCAL_REPO_CONFIGS = [ - c - for c in LOCAL_REPO_CONFIGS - if c.offline_store_creator in AVAILABLE_OFFLINE_STORES - and c.online_store in AVAILABLE_ONLINE_STORES -] - NANOSECOND = 1 MILLISECOND = 1000_000 * NANOSECOND SECOND = 1000 * MILLISECOND -@pytest.fixture( - params=LOCAL_REPO_CONFIGS, - ids=[str(c) for c in LOCAL_REPO_CONFIGS], - scope="session", -) -def local_environment(request): - e = construct_test_environment(request.param, fixture_request=request) - - def cleanup(): - e.feature_store.teardown() - - request.addfinalizer(cleanup) - return e - - -@pytest.fixture(scope="session") -def test_data(local_environment): - return construct_universal_test_data(local_environment) - - @pytest.fixture(scope="session") -def initialized_registry(local_environment, test_data): - fs = local_environment.feature_store +def initialized_registry(environment, universal_data_sources): + fs = environment.feature_store - _, _, data_sources = test_data + _, _, data_sources = universal_data_sources feature_views = construct_universal_feature_views(data_sources) feature_service = FeatureService( name="driver_features", features=[feature_views.driver], logging_config=LoggingConfig( - destination=local_environment.data_source_creator.create_logged_features_destination(), + destination=environment.data_source_creator.create_logged_features_destination(), sample_rate=1.0, ), ) @@ -96,12 +58,15 @@ def initialized_registry(local_environment, test_data): feast_objects.extend([driver(), customer(), location()]) fs.apply(feast_objects) - fs.materialize(local_environment.start_date, local_environment.end_date) + fs.materialize(environment.start_date, environment.end_date) @pytest.fixture -def grpc_server_port(local_environment, initialized_registry): - fs = local_environment.feature_store +def grpc_server_port(environment, initialized_registry): + if not environment.test_repo_config.go_feature_retrieval: + pytest.skip("Only for Go path") + + fs = environment.feature_store embedded = EmbeddedOnlineFeatureServer( repo_path=str(fs.repo_path.absolute()), repo_config=fs.config, feature_store=fs, @@ -140,7 +105,7 @@ def grpc_client(grpc_server_port): @pytest.mark.integration -@pytest.mark.universal +@pytest.mark.goserver def test_go_grpc_server(grpc_client): resp: GetOnlineFeaturesResponse = grpc_client.GetOnlineFeatures( GetOnlineFeaturesRequest( @@ -166,10 +131,13 @@ def test_go_grpc_server(grpc_client): @pytest.mark.integration -@pytest.mark.universal +@pytest.mark.goserver +@pytest.mark.universal_offline_stores @pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v)) -def test_feature_logging(grpc_client, local_environment, test_data, full_feature_names): - fs = local_environment.feature_store +def test_feature_logging( + grpc_client, environment, universal_data_sources, full_feature_names +): + fs = environment.feature_store feature_service = fs.get_feature_service("driver_features") log_start_date = datetime.now().astimezone(pytz.UTC) driver_ids = list(range(5001, 5011)) @@ -192,7 +160,7 @@ def test_feature_logging(grpc_client, local_environment, test_data, full_feature # with some pause time.sleep(0.1) - _, datasets, _ = test_data + _, datasets, _ = universal_data_sources latest_rows = get_latest_rows(datasets.driver_df, "driver_id", driver_ids) features = [ feature.name diff --git a/sdk/python/tests/integration/e2e/test_python_feature_server.py b/sdk/python/tests/integration/e2e/test_python_feature_server.py index 36edc8df44..ea4c35a1ca 100644 --- a/sdk/python/tests/integration/e2e/test_python_feature_server.py +++ b/sdk/python/tests/integration/e2e/test_python_feature_server.py @@ -23,7 +23,7 @@ @pytest.mark.integration -@pytest.mark.universal +@pytest.mark.universal_online_stores def test_get_online_features(python_fs_client): request_data_dict = { "features": [ @@ -61,7 +61,7 @@ def test_get_online_features(python_fs_client): @pytest.mark.integration -@pytest.mark.universal +@pytest.mark.universal_online_stores def test_push(python_fs_client): initial_temp = get_temperatures(python_fs_client, location_ids=[1])[0] json_data = json.dumps( diff --git a/sdk/python/tests/integration/e2e/test_universal_e2e.py b/sdk/python/tests/integration/e2e/test_universal_e2e.py index 84c10bda84..a42a96e594 100644 --- a/sdk/python/tests/integration/e2e/test_universal_e2e.py +++ b/sdk/python/tests/integration/e2e/test_universal_e2e.py @@ -12,7 +12,7 @@ @pytest.mark.integration -@pytest.mark.universal +@pytest.mark.universal_online_stores @pytest.mark.parametrize("infer_features", [True, False]) def test_e2e_consistency(environment, e2e_data_sources, infer_features): fs = environment.feature_store diff --git a/sdk/python/tests/integration/e2e/test_validation.py b/sdk/python/tests/integration/e2e/test_validation.py index 76bbe152c5..e434f1a133 100644 --- a/sdk/python/tests/integration/e2e/test_validation.py +++ b/sdk/python/tests/integration/e2e/test_validation.py @@ -58,7 +58,7 @@ def profiler_with_unrealistic_expectations(dataset: PandasDataset) -> Expectatio @pytest.mark.integration -@pytest.mark.universal +@pytest.mark.universal_offline_stores def test_historical_retrieval_with_validation(environment, universal_data_sources): store = environment.feature_store (entities, datasets, data_sources) = universal_data_sources @@ -88,7 +88,6 @@ def test_historical_retrieval_with_validation(environment, universal_data_source @pytest.mark.integration -@pytest.mark.universal def test_historical_retrieval_fails_on_validation(environment, universal_data_sources): store = environment.feature_store diff --git a/sdk/python/tests/integration/feature_repos/integration_test_repo_config.py b/sdk/python/tests/integration/feature_repos/integration_test_repo_config.py index d014719bd0..61920bb03f 100644 --- a/sdk/python/tests/integration/feature_repos/integration_test_repo_config.py +++ b/sdk/python/tests/integration/feature_repos/integration_test_repo_config.py @@ -1,3 +1,4 @@ +import hashlib from dataclasses import dataclass from typing import Dict, Optional, Type, Union @@ -50,5 +51,23 @@ def __repr__(self) -> str: f"{self.provider.upper()}", f"{self.offline_store_creator.__name__.split('.')[-1].replace('DataSourceCreator', '')}", online_store_type, + f"python_fs={self.python_feature_server}", + f"go_fs={self.go_feature_retrieval}", ] ) + + def __hash__(self): + return int(hashlib.sha1(repr(self).encode()).hexdigest(), 16) + + def __eq__(self, other): + if not isinstance(other, IntegrationTestRepoConfig): + return False + + return ( + self.provider == other.provider + and self.online_store == other.online_store + and self.offline_store_creator == other.offline_store_creator + and self.online_store_creator == other.online_store_creator + and self.go_feature_retrieval == other.go_feature_retrieval + and self.python_feature_server == other.python_feature_server + ) diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index 7b596bd764..a4f740c734 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from datetime import datetime, timedelta from pathlib import Path -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union import pandas as pd import pytest @@ -73,114 +73,54 @@ "connection_string": "127.0.0.1:6001,127.0.0.1:6002,127.0.0.1:6003", } -AVAILABLE_OFFLINE_STORES: List[Any] = [ - FileDataSourceCreator, +AVAILABLE_OFFLINE_STORES: List[Tuple[str, Type[DataSourceCreator]]] = [ + ("local", FileDataSourceCreator), ] -AVAILABLE_ONLINE_STORES: List[Any] = [ - "sqlite", -] +AVAILABLE_ONLINE_STORES: Dict[ + str, Tuple[Union[str, Dict[str, str]], Optional[Type[OnlineStoreCreator]]] +] = { + "sqlite": ({"type": "sqlite"}, None), +} if os.getenv("FEAST_IS_LOCAL_TEST", "False") != "True": AVAILABLE_OFFLINE_STORES.extend( [ - BigQueryDataSourceCreator, - RedshiftDataSourceCreator, - SnowflakeDataSourceCreator, + ("gcp", BigQueryDataSourceCreator), + ("aws", RedshiftDataSourceCreator), + ("aws", SnowflakeDataSourceCreator), ] ) - AVAILABLE_ONLINE_STORES.extend([REDIS_CONFIG, DYNAMO_CONFIG]) - -# FULL_REPO_CONFIGS contains the repo configurations (e.g. provider, offline store, -# online store, test data, and more parameters) that most integration tests will test -# against. By default, FULL_REPO_CONFIGS uses the three providers (local, GCP, and AWS) -# with their default offline and online stores; it also tests the providers with the -# Redis online store. It can be overwritten by specifying a Python module through the -# FULL_REPO_CONFIGS_MODULE_ENV_NAME environment variable. In this case, that Python -# module will be imported and FULL_REPO_CONFIGS will be extracted from the file. -DEFAULT_FULL_REPO_CONFIGS: List[IntegrationTestRepoConfig] = [ - # Local configurations - IntegrationTestRepoConfig(), - IntegrationTestRepoConfig(python_feature_server=True), - IntegrationTestRepoConfig(online_store=REDIS_CONFIG), - # GCP configurations - IntegrationTestRepoConfig( - provider="gcp", - offline_store_creator=BigQueryDataSourceCreator, - online_store="datastore", - ), - IntegrationTestRepoConfig( - provider="gcp", - offline_store_creator=BigQueryDataSourceCreator, - online_store=REDIS_CONFIG, - ), - # AWS configurations - IntegrationTestRepoConfig( - provider="aws", - offline_store_creator=RedshiftDataSourceCreator, - online_store=DYNAMO_CONFIG, - python_feature_server=True, - ), - IntegrationTestRepoConfig( - provider="aws", - offline_store_creator=RedshiftDataSourceCreator, - online_store=REDIS_CONFIG, - ), - # Snowflake configurations - IntegrationTestRepoConfig( - provider="aws", # no list features, no feature server - offline_store_creator=SnowflakeDataSourceCreator, - online_store=REDIS_CONFIG, - ), - # Go implementation for online retrieval - IntegrationTestRepoConfig(online_store=REDIS_CONFIG, go_feature_retrieval=True,), - # TODO(felixwang9817): Enable this test once https://github.com/feast-dev/feast/issues/2544 is resolved. - # IntegrationTestRepoConfig( - # online_store=REDIS_CONFIG, - # python_feature_server=True, - # go_feature_retrieval=True, - # ), -] + AVAILABLE_ONLINE_STORES["redis"] = (REDIS_CONFIG, None) + AVAILABLE_ONLINE_STORES["dynamodb"] = (DYNAMO_CONFIG, None) + AVAILABLE_ONLINE_STORES["datastore"] = ("datastore", None) -DEFAULT_FULL_REPO_CONFIGS = [ - c - for c in DEFAULT_FULL_REPO_CONFIGS - if c.online_store in AVAILABLE_ONLINE_STORES - and c.offline_store_creator in AVAILABLE_OFFLINE_STORES -] -if os.getenv("FEAST_GO_FEATURE_RETRIEVAL", "False") == "True": - DEFAULT_FULL_REPO_CONFIGS = [ - IntegrationTestRepoConfig( - online_store=REDIS_CONFIG, go_feature_retrieval=True, - ), - ] full_repo_configs_module = os.environ.get(FULL_REPO_CONFIGS_MODULE_ENV_NAME) if full_repo_configs_module is not None: try: module = importlib.import_module(full_repo_configs_module) - FULL_REPO_CONFIGS = getattr(module, "FULL_REPO_CONFIGS") + AVAILABLE_ONLINE_STORES = getattr(module, "AVAILABLE_ONLINE_STORES") + AVAILABLE_OFFLINE_STORES = getattr(module, "AVAILABLE_OFFLINE_STORES") except Exception as e: raise FeastModuleImportError( "FULL_REPO_CONFIGS", full_repo_configs_module ) from e -else: - FULL_REPO_CONFIGS = DEFAULT_FULL_REPO_CONFIGS + if os.getenv("FEAST_LOCAL_ONLINE_CONTAINER", "False").lower() == "true": - replacements = {"datastore": DatastoreOnlineStoreCreator} - replacement_dicts = [ - (REDIS_CONFIG, RedisOnlineStoreCreator), - (DYNAMO_CONFIG, DynamoDBOnlineStoreCreator), - ] - for c in FULL_REPO_CONFIGS: - if isinstance(c.online_store, dict): - for _replacement in replacement_dicts: - if c.online_store == _replacement[0]: - c.online_store_creator = _replacement[1] - elif c.online_store in replacements: - c.online_store_creator = replacements[c.online_store] + replacements: Dict[ + str, Tuple[Union[str, Dict[str, str]], Optional[Type[OnlineStoreCreator]]] + ] = { + "redis": (REDIS_CONFIG, RedisOnlineStoreCreator), + "dynamodb": (DYNAMO_CONFIG, DynamoDBOnlineStoreCreator), + "datastore": ("datastore", DatastoreOnlineStoreCreator), + } + + for key, replacement in replacements.items(): + if key in AVAILABLE_ONLINE_STORES: + AVAILABLE_ONLINE_STORES[key] = replacement @dataclass @@ -364,10 +304,15 @@ class Environment: worker_id: str online_store_creator: Optional[OnlineStoreCreator] = None + next_id = 0 + def __post_init__(self): self.end_date = datetime.utcnow().replace(microsecond=0, second=0, minute=0) self.start_date: datetime = self.end_date - timedelta(days=3) + Environment.next_id += 1 + self.id = Environment.next_id + def get_feature_server_endpoint(self) -> str: if self.python_feature_server and self.test_repo_config.provider == "local": return f"http://localhost:{self.get_local_server_port()}" @@ -380,7 +325,7 @@ def get_local_server_port(self) -> int: worker_id_num = int(parsed_worker_id[0]) else: worker_id_num = 0 - return 6566 + worker_id_num + return 6000 + 100 * worker_id_num + self.id def table_name_from_data_source(ds: DataSource) -> Optional[str]: diff --git a/sdk/python/tests/integration/feature_repos/universal/feature_views.py b/sdk/python/tests/integration/feature_repos/universal/feature_views.py index 26c2513995..3e05f5d7e5 100644 --- a/sdk/python/tests/integration/feature_repos/universal/feature_views.py +++ b/sdk/python/tests/integration/feature_repos/universal/feature_views.py @@ -23,10 +23,11 @@ def driver_feature_view( name="test_correctness", infer_features: bool = False, dtype: FeastType = Float32, + entities: Optional[List[str]] = None, ) -> FeatureView: return FeatureView( name=name, - entities=["driver"], + entities=entities or ["driver"], schema=None if infer_features else [Field(name="value", dtype=dtype)], ttl=timedelta(days=5), source=data_source, diff --git a/sdk/python/tests/integration/offline_store/test_feature_logging.py b/sdk/python/tests/integration/offline_store/test_feature_logging.py index 6dda2e63a9..8e7e9d68be 100644 --- a/sdk/python/tests/integration/offline_store/test_feature_logging.py +++ b/sdk/python/tests/integration/offline_store/test_feature_logging.py @@ -31,7 +31,7 @@ @pytest.mark.integration -@pytest.mark.universal +@pytest.mark.universal_offline_stores @pytest.mark.parametrize("pass_as_path", [True, False], ids=lambda v: str(v)) def test_feature_service_logging(environment, universal_data_sources, pass_as_path): store = environment.feature_store diff --git a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py index b62f7cda24..1b7dab2110 100644 --- a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py +++ b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py @@ -280,7 +280,7 @@ def get_expected_training_df( @pytest.mark.integration -@pytest.mark.universal +@pytest.mark.universal_offline_stores @pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v)) def test_historical_features(environment, universal_data_sources, full_feature_names): store = environment.feature_store @@ -411,10 +411,9 @@ def test_historical_features(environment, universal_data_sources, full_feature_n @pytest.mark.integration -@pytest.mark.universal -@pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v)) +@pytest.mark.universal_offline_stores def test_historical_features_with_missing_request_data( - environment, universal_data_sources, full_feature_names + environment, universal_data_sources ): store = environment.feature_store @@ -437,12 +436,12 @@ def test_historical_features_with_missing_request_data( "global_stats:avg_ride_length", "field_mapping:feature_name", ], - full_feature_names=full_feature_names, + full_feature_names=True, ) @pytest.mark.integration -@pytest.mark.universal +@pytest.mark.universal_offline_stores @pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v)) def test_historical_features_with_entities_from_query( environment, universal_data_sources, full_feature_names @@ -542,7 +541,7 @@ def test_historical_features_with_entities_from_query( @pytest.mark.integration -@pytest.mark.universal +@pytest.mark.universal_offline_stores @pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v)) def test_historical_features_persisting( environment, universal_data_sources, full_feature_names @@ -621,7 +620,7 @@ def test_historical_features_persisting( @pytest.mark.integration -@pytest.mark.universal +@pytest.mark.universal_offline_stores def test_historical_features_from_bigquery_sources_containing_backfills(environment): store = environment.feature_store diff --git a/sdk/python/tests/integration/online_store/test_push_online_retrieval.py b/sdk/python/tests/integration/online_store/test_push_online_retrieval.py index 9e9ec953c7..aa7e3e7f53 100644 --- a/sdk/python/tests/integration/online_store/test_push_online_retrieval.py +++ b/sdk/python/tests/integration/online_store/test_push_online_retrieval.py @@ -14,7 +14,7 @@ @pytest.mark.integration -@pytest.mark.universal +@pytest.mark.universal_online_stores def test_push_features_and_read(environment, universal_data_sources): store = environment.feature_store diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index f4440dbfbc..ababb25c39 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -38,13 +38,14 @@ @pytest.mark.integration -def test_entity_ttl_online_store(local_redis_environment, redis_universal_data_sources): +@pytest.mark.universal_online_stores(only=["redis"]) +def test_entity_ttl_online_store(environment, universal_data_sources): if os.getenv("FEAST_IS_LOCAL_TEST", "False") == "True": return - fs = local_redis_environment.feature_store + fs = environment.feature_store # setting ttl setting in online store to 1 second fs.config.online_store.key_ttl_seconds = 1 - entities, datasets, data_sources = redis_universal_data_sources + entities, datasets, data_sources = universal_data_sources driver_hourly_stats = create_driver_hourly_stats_feature_view(data_sources.driver) driver_entity = driver() @@ -98,10 +99,11 @@ def test_entity_ttl_online_store(local_redis_environment, redis_universal_data_s # TODO: make this work with all universal (all online store types) @pytest.mark.integration -def test_write_to_online_store_event_check(local_redis_environment): +@pytest.mark.universal_online_stores(only=["redis"]) +def test_write_to_online_store_event_check(environment): if os.getenv("FEAST_IS_LOCAL_TEST", "False") == "True": return - fs = local_redis_environment.feature_store + fs = environment.feature_store # write same data points 3 with different timestamps now = pd.Timestamp(datetime.datetime.utcnow()).round("ms") @@ -198,7 +200,7 @@ def test_write_to_online_store_event_check(local_redis_environment): @pytest.mark.integration -@pytest.mark.universal +@pytest.mark.universal_online_stores def test_write_to_online_store(environment, universal_data_sources): fs = environment.feature_store entities, datasets, data_sources = universal_data_sources @@ -323,7 +325,7 @@ def get_online_features_dict( @pytest.mark.integration -@pytest.mark.universal +@pytest.mark.universal_online_stores @pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v)) def test_online_retrieval_with_event_timestamps( environment, universal_data_sources, full_feature_names @@ -387,10 +389,12 @@ def test_online_retrieval_with_event_timestamps( @pytest.mark.integration -@pytest.mark.universal +@pytest.mark.universal_online_stores @pytest.mark.goserver @pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v)) -def test_online_retrieval(environment, universal_data_sources, full_feature_names): +def test_online_retrieval( + environment, universal_data_sources, python_server, full_feature_names +): fs = environment.feature_store entities, datasets, data_sources = universal_data_sources feature_views = construct_universal_feature_views(data_sources) @@ -610,7 +614,7 @@ def test_online_retrieval(environment, universal_data_sources, full_feature_name @pytest.mark.integration -@pytest.mark.universal +@pytest.mark.universal_online_stores(only=["redis"]) def test_online_store_cleanup(environment, universal_data_sources): """ Some online store implementations (like Redis) keep features from different features views diff --git a/sdk/python/tests/integration/registration/test_cli.py b/sdk/python/tests/integration/registration/test_cli.py index 655e53e759..ce23ed66a6 100644 --- a/sdk/python/tests/integration/registration/test_cli.py +++ b/sdk/python/tests/integration/registration/test_cli.py @@ -32,7 +32,7 @@ @pytest.mark.integration -@pytest.mark.universal +@pytest.mark.universal_offline_stores def test_universal_cli(environment: Environment): project = f"test_universal_cli_{str(uuid.uuid4()).replace('-', '')[:8]}" runner = CliRunner() diff --git a/sdk/python/tests/integration/registration/test_inference.py b/sdk/python/tests/integration/registration/test_inference.py index b1fb509d75..6cf49c31db 100644 --- a/sdk/python/tests/integration/registration/test_inference.py +++ b/sdk/python/tests/integration/registration/test_inference.py @@ -145,7 +145,6 @@ def test_update_file_data_source_with_inferred_event_timestamp_col(simple_datase @pytest.mark.integration -@pytest.mark.universal def test_update_data_sources_with_inferred_event_timestamp_col(universal_data_sources): (_, _, data_sources) = universal_data_sources data_sources_copy = deepcopy(data_sources) diff --git a/sdk/python/tests/integration/registration/test_universal_odfv_feature_inference.py b/sdk/python/tests/integration/registration/test_universal_odfv_feature_inference.py index 04cb56367b..b7a9a571af 100644 --- a/sdk/python/tests/integration/registration/test_universal_odfv_feature_inference.py +++ b/sdk/python/tests/integration/registration/test_universal_odfv_feature_inference.py @@ -19,7 +19,7 @@ @pytest.mark.integration -@pytest.mark.universal +@pytest.mark.universal_offline_stores @pytest.mark.parametrize("infer_features", [True, False], ids=lambda v: str(v)) def test_infer_odfv_features(environment, universal_data_sources, infer_features): store = environment.feature_store @@ -72,7 +72,6 @@ def test_infer_odfv_list_features(environment, infer_features, tmp_path): @pytest.mark.integration -@pytest.mark.universal def test_infer_odfv_features_with_error(environment, universal_data_sources): store = environment.feature_store diff --git a/sdk/python/tests/integration/registration/test_universal_types.py b/sdk/python/tests/integration/registration/test_universal_types.py index b95745f997..6d016e3e85 100644 --- a/sdk/python/tests/integration/registration/test_universal_types.py +++ b/sdk/python/tests/integration/registration/test_universal_types.py @@ -8,16 +8,11 @@ import pyarrow as pa import pytest +from feast.entity import Entity from feast.infra.offline_stores.offline_store import RetrievalJob from feast.types import Array, Bool, Float32, Int32, Int64, UnixTimestamp from feast.value_type import ValueType from tests.data.data_creator import create_dataset -from tests.integration.feature_repos.repo_configuration import ( - FULL_REPO_CONFIGS, - REDIS_CONFIG, - IntegrationTestRepoConfig, - construct_test_environment, -) from tests.integration.feature_repos.universal.entities import driver from tests.integration.feature_repos.universal.feature_views import driver_feature_view @@ -25,46 +20,36 @@ def populate_test_configs(offline: bool): - entity_type_feature_dtypes = [ - (ValueType.INT32, "int32"), - (ValueType.INT64, "int64"), - (ValueType.STRING, "float"), - (ValueType.STRING, "bool"), - (ValueType.INT32, "datetime"), + feature_dtypes = [ + "int32", + "int64", + "float", + "bool", + "datetime", ] configs: List[TypeTestConfig] = [] - for test_repo_config in FULL_REPO_CONFIGS: - for entity_type, feature_dtype in entity_type_feature_dtypes: - for feature_is_list in [True, False]: - # Redshift doesn't support list features - if test_repo_config.provider == "aws" and feature_is_list is True: - continue - # For offline tests, don't need to vary for online store - if offline and test_repo_config.online_store == REDIS_CONFIG: + for feature_dtype in feature_dtypes: + for feature_is_list in [True, False]: + for has_empty_list in [True, False]: + # For non list features `has_empty_list` does nothing + if feature_is_list is False and has_empty_list is True: continue - for has_empty_list in [True, False]: - # For non list features `has_empty_list` does nothing - if feature_is_list is False and has_empty_list is True: - continue - configs.append( - TypeTestConfig( - entity_type=entity_type, - feature_dtype=feature_dtype, - feature_is_list=feature_is_list, - has_empty_list=has_empty_list, - test_repo_config=test_repo_config, - ) + + configs.append( + TypeTestConfig( + feature_dtype=feature_dtype, + feature_is_list=feature_is_list, + has_empty_list=has_empty_list, ) + ) return configs @dataclass(frozen=True, repr=True) class TypeTestConfig: - entity_type: ValueType feature_dtype: str feature_is_list: bool has_empty_list: bool - test_repo_config: IntegrationTestRepoConfig OFFLINE_TYPE_TEST_CONFIGS: List[TypeTestConfig] = populate_test_configs(offline=True) @@ -76,8 +61,15 @@ class TypeTestConfig: scope="session", ids=[str(c) for c in OFFLINE_TYPE_TEST_CONFIGS], ) -def offline_types_test_fixtures(request): - return get_fixtures(request) +def offline_types_test_fixtures(request, environment): + config: TypeTestConfig = request.param + if ( + environment.test_repo_config.provider == "aws" + and config.feature_is_list is True + ): + pytest.skip("Redshift doesn't support list features") + + return get_fixtures(request, environment) @pytest.fixture( @@ -85,87 +77,91 @@ def offline_types_test_fixtures(request): scope="session", ids=[str(c) for c in ONLINE_TYPE_TEST_CONFIGS], ) -def online_types_test_fixtures(request): - return get_fixtures(request) +def online_types_test_fixtures(request, environment): + return get_fixtures(request, environment) -def get_fixtures(request): +def get_fixtures(request, environment): config: TypeTestConfig = request.param # Lower case needed because Redshift lower-cases all table names - test_project_id = f"{config.entity_type}{config.feature_dtype}{config.feature_is_list}".replace( + destination_name = f"feature_type_{config.feature_dtype}{config.feature_is_list}".replace( ".", "" ).lower() - type_test_environment = construct_test_environment( - test_repo_config=config.test_repo_config, - test_suite_name=f"test_{test_project_id}", - fixture_request=request, - ) config = request.param df = create_dataset( - config.entity_type, + ValueType.INT64, config.feature_dtype, config.feature_is_list, config.has_empty_list, ) - data_source = type_test_environment.data_source_creator.create_data_source( - df, - destination_name=type_test_environment.feature_store.project, - field_mapping={"ts_1": "ts"}, + data_source = environment.data_source_creator.create_data_source( + df, destination_name=destination_name, field_mapping={"ts_1": "ts"}, ) fv = create_feature_view( - request.fixturename, + destination_name, config.feature_dtype, config.feature_is_list, config.has_empty_list, data_source, ) - def cleanup(): - try: - type_test_environment.data_source_creator.teardown() - except Exception: # noqa - logger.exception("DataSourceCreator teardown has failed") - - type_test_environment.feature_store.teardown() - - request.addfinalizer(cleanup) - - return type_test_environment, config, data_source, fv + return config, data_source, fv @pytest.mark.integration -@pytest.mark.universal -def test_entity_inference_types_match(offline_types_test_fixtures): - environment, config, data_source, fv = offline_types_test_fixtures +@pytest.mark.universal_offline_stores +@pytest.mark.parametrize( + "entity_type", [ValueType.INT32, ValueType.INT64, ValueType.STRING] +) +def test_entity_inference_types_match(environment, entity_type): fs = environment.feature_store # Don't specify value type in entity to force inference - entity = driver(value_type=ValueType.UNKNOWN) + entity = Entity( + name=f"driver_{entity_type.name.lower()}", + value_type=ValueType.UNKNOWN, + join_key="driver_id", + ) + df = create_dataset(entity_type, feature_dtype="int32",) + data_source = environment.data_source_creator.create_data_source( + df, + destination_name=f"entity_type_{entity_type.name.lower()}", + field_mapping={"ts_1": "ts"}, + ) + fv = create_feature_view( + f"fv_entity_type_{entity_type.name.lower()}", + feature_dtype="int32", + feature_is_list=False, + has_empty_list=False, + data_source=data_source, + entity=entity.name, + ) fs.apply([fv, entity]) - entities = fs.list_entities() + inferred_entity = fs.get_entity(entity.name) entity_type_to_expected_inferred_entity_type = { - ValueType.INT32: ValueType.INT64, - ValueType.INT64: ValueType.INT64, - ValueType.FLOAT: ValueType.DOUBLE, - ValueType.STRING: ValueType.STRING, + ValueType.INT32: {ValueType.INT32, ValueType.INT64}, + ValueType.INT64: {ValueType.INT32, ValueType.INT64}, + ValueType.FLOAT: {ValueType.DOUBLE}, + ValueType.STRING: {ValueType.STRING}, } - for entity in entities: - assert ( - entity.value_type - == entity_type_to_expected_inferred_entity_type[config.entity_type] - ) + assert ( + inferred_entity.value_type + in entity_type_to_expected_inferred_entity_type[entity_type] + ) @pytest.mark.integration -@pytest.mark.universal -def test_feature_get_historical_features_types_match(offline_types_test_fixtures): +@pytest.mark.universal_offline_stores +def test_feature_get_historical_features_types_match( + offline_types_test_fixtures, environment +): """ Note: to make sure this test works, we need to ensure that get_historical_features returns at least one non-null row to make sure type inferral works. This can only be achieved by carefully matching entity_df to the data fixtures. """ - environment, config, data_source, fv = offline_types_test_fixtures + config, data_source, fv = offline_types_test_fixtures fs = environment.feature_store entity = driver() fv = create_feature_view( @@ -178,9 +174,7 @@ def test_feature_get_historical_features_types_match(offline_types_test_fixtures fs.apply([fv, entity]) entity_df = pd.DataFrame() - entity_df["driver_id"] = ( - ["1", "3"] if config.entity_type == ValueType.STRING else [1, 3] - ) + entity_df["driver_id"] = [1, 3] ts = pd.Timestamp(datetime.utcnow()).round("ms") entity_df["ts"] = [ ts - timedelta(hours=4), @@ -214,9 +208,12 @@ def test_feature_get_historical_features_types_match(offline_types_test_fixtures @pytest.mark.integration -@pytest.mark.universal -def test_feature_get_online_features_types_match(online_types_test_fixtures): - environment, config, data_source, fv = online_types_test_fixtures +@pytest.mark.universal_online_stores(only=["sqlite"]) +def test_feature_get_online_features_types_match( + online_types_test_fixtures, environment +): + config, data_source, fv = online_types_test_fixtures + entity = driver() fv = create_feature_view( "get_online_features_types_match", config.feature_dtype, @@ -226,7 +223,6 @@ def test_feature_get_online_features_types_match(online_types_test_fixtures): ) fs = environment.feature_store features = [fv.name + ":value"] - entity = driver(value_type=config.entity_type) fs.apply([fv, entity]) fs.materialize( environment.start_date, @@ -235,9 +231,8 @@ def test_feature_get_online_features_types_match(online_types_test_fixtures): # we can successfully infer type even from all empty values ) - driver_id_value = "1" if config.entity_type == ValueType.STRING else 1 online_features = fs.get_online_features( - features=features, entity_rows=[{"driver_id": driver_id_value}], + features=features, entity_rows=[{"driver_id": 1}], ).to_dict() feature_list_dtype_to_expected_online_response_value_type = { @@ -268,7 +263,7 @@ def test_feature_get_online_features_types_match(online_types_test_fixtures): def create_feature_view( - name, feature_dtype, feature_is_list, has_empty_list, data_source + name, feature_dtype, feature_is_list, has_empty_list, data_source, entity="driver" ): if feature_is_list is True: if feature_dtype == "int32": @@ -293,7 +288,7 @@ def create_feature_view( elif feature_dtype == "datetime": dtype = UnixTimestamp - return driver_feature_view(data_source, name=name, dtype=dtype,) + return driver_feature_view(data_source, name=name, dtype=dtype, entities=[entity]) def assert_expected_historical_feature_types( @@ -302,7 +297,7 @@ def assert_expected_historical_feature_types( print("Asserting historical feature types") feature_dtype_to_expected_historical_feature_dtype = { "int32": (pd.api.types.is_integer_dtype,), - "int64": (pd.api.types.is_int64_dtype,), + "int64": (pd.api.types.is_integer_dtype,), "float": (pd.api.types.is_float_dtype,), "string": (pd.api.types.is_string_dtype,), "bool": (pd.api.types.is_bool_dtype, pd.api.types.is_object_dtype), @@ -311,7 +306,7 @@ def assert_expected_historical_feature_types( dtype_checkers = feature_dtype_to_expected_historical_feature_dtype[feature_dtype] assert any( check(historical_features_df.dtypes["value"]) for check in dtype_checkers - ) + ), f"Failed to match feature type {historical_features_df.dtypes['value']} with checkers {dtype_checkers}" def assert_feature_list_types( @@ -357,8 +352,8 @@ def assert_expected_arrow_types( historical_features_arrow = historical_features.to_arrow() print(historical_features_arrow) feature_list_dtype_to_expected_historical_feature_arrow_type = { - "int32": pa.types.is_int64, - "int64": pa.types.is_int64, + "int32": pa.types.is_signed_integer, # different offline stores could interpret integers differently + "int64": pa.types.is_signed_integer, # eg, Snowflake chooses the smallest possible (like int8) "float": pa.types.is_float64, "string": pa.types.is_string, "bool": pa.types.is_boolean,