diff --git a/sdk/python/feast/infra/aws.py b/sdk/python/feast/infra/aws.py index 14301faf19..4109856e60 100644 --- a/sdk/python/feast/infra/aws.py +++ b/sdk/python/feast/infra/aws.py @@ -106,6 +106,15 @@ def update_infra( self._deploy_feature_server(project, image_uri) + if self.batch_engine: + self.batch_engine.update( + project, + tables_to_delete, + tables_to_keep, + entities_to_delete, + entities_to_keep, + ) + def _deploy_feature_server(self, project: str, image_uri: str): _logger.info("Deploying feature server...") @@ -198,8 +207,7 @@ def _deploy_feature_server(self, project: str, image_uri: str): def teardown_infra( self, project: str, tables: Sequence[FeatureView], entities: Sequence[Entity], ) -> None: - if self.online_store: - self.online_store.teardown(self.repo_config, tables, entities) + super(AwsProvider, self).teardown_infra(project, tables, entities) if ( self.repo_config.feature_server is not None diff --git a/sdk/python/feast/infra/materialization/lambda/Dockerfile b/sdk/python/feast/infra/materialization/lambda/Dockerfile new file mode 100644 index 0000000000..bbdb74bdfe --- /dev/null +++ b/sdk/python/feast/infra/materialization/lambda/Dockerfile @@ -0,0 +1,25 @@ +FROM public.ecr.aws/lambda/python:3.9 + +RUN yum install -y git + + +# Copy app handler code +COPY sdk/python/feast/infra/materialization/lambda/app.py ${LAMBDA_TASK_ROOT} + +# Copy necessary parts of the Feast codebase +COPY sdk/python sdk/python +COPY protos protos +COPY go go +COPY setup.py setup.py +COPY pyproject.toml pyproject.toml +COPY README.md README.md + +# Install Feast for AWS with Lambda dependencies +# We need this mount thingy because setuptools_scm needs access to the +# git dir to infer the version of feast we're installing. +# https://github.com/pypa/setuptools_scm#usage-from-docker +# I think it also assumes that this dockerfile is being built from the root of the directory. +RUN --mount=source=.git,target=.git,type=bind pip3 install --no-cache-dir -e '.[aws,redis]' + +# Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) +CMD [ "app.handler" ] diff --git a/sdk/python/feast/infra/materialization/lambda/__init__.py b/sdk/python/feast/infra/materialization/lambda/__init__.py new file mode 100644 index 0000000000..d21505d91e --- /dev/null +++ b/sdk/python/feast/infra/materialization/lambda/__init__.py @@ -0,0 +1,11 @@ +from .lambda_engine import ( + LambdaMaterializationEngine, + LambdaMaterializationEngineConfig, + LambdaMaterializationJob, +) + +__all__ = [ + "LambdaMaterializationEngineConfig", + "LambdaMaterializationJob", + "LambdaMaterializationEngine", +] diff --git a/sdk/python/feast/infra/materialization/lambda/app.py b/sdk/python/feast/infra/materialization/lambda/app.py new file mode 100644 index 0000000000..ebed4c96e0 --- /dev/null +++ b/sdk/python/feast/infra/materialization/lambda/app.py @@ -0,0 +1,82 @@ +import base64 +import json +import sys +import tempfile +import traceback +from pathlib import Path + +import pyarrow.parquet as pq + +from feast import FeatureStore +from feast.constants import FEATURE_STORE_YAML_ENV_NAME +from feast.infra.materialization.local_engine import DEFAULT_BATCH_SIZE +from feast.utils import _convert_arrow_to_proto, _run_pyarrow_field_mapping + + +def handler(event, context): + """Provide an event that contains the following keys: + + - operation: one of the operations in the operations dict below + - tableName: required for operations that interact with DynamoDB + - payload: a parameter to pass to the operation being performed + """ + print("Received event: " + json.dumps(event, indent=2), flush=True) + + try: + + config_base64 = event[FEATURE_STORE_YAML_ENV_NAME] + + config_bytes = base64.b64decode(config_base64) + + # Create a new unique directory for writing feature_store.yaml + repo_path = Path(tempfile.mkdtemp()) + + with open(repo_path / "feature_store.yaml", "wb") as f: + f.write(config_bytes) + + # Initialize the feature store + store = FeatureStore(repo_path=str(repo_path.resolve())) + + view_name = event["view_name"] + view_type = event["view_type"] + path = event["path"] + + bucket = path[len("s3://") :].split("/", 1)[0] + key = path[len("s3://") :].split("/", 1)[1] + print(f"Inferred Bucket: `{bucket}` Key: `{key}`", flush=True) + + if view_type == "batch": + # TODO: This probably needs to be become `store.get_batch_feature_view` at some point. + feature_view = store.get_feature_view(view_name) + else: + feature_view = store.get_stream_feature_view(view_name) + + print(f"Got Feature View: `{feature_view}`", flush=True) + + table = pq.read_table(path) + if feature_view.batch_source.field_mapping is not None: + table = _run_pyarrow_field_mapping( + table, feature_view.batch_source.field_mapping + ) + + join_key_to_value_type = { + entity.name: entity.dtype.to_value_type() + for entity in feature_view.entity_columns + } + + written_rows = 0 + + for batch in table.to_batches(DEFAULT_BATCH_SIZE): + rows_to_write = _convert_arrow_to_proto( + batch, feature_view, join_key_to_value_type + ) + store._provider.online_write_batch( + store.config, feature_view, rows_to_write, lambda x: None, + ) + written_rows += len(rows_to_write) + return {"written_rows": written_rows} + except Exception as e: + print(f"Exception: {e}", flush=True) + print("Traceback:", flush=True) + print(traceback.format_exc(), flush=True) + sys.exit(1) diff --git a/sdk/python/feast/infra/materialization/lambda/lambda_engine.py b/sdk/python/feast/infra/materialization/lambda/lambda_engine.py new file mode 100644 index 0000000000..89a5f1a4f4 --- /dev/null +++ b/sdk/python/feast/infra/materialization/lambda/lambda_engine.py @@ -0,0 +1,238 @@ +import base64 +import json +import logging +from concurrent.futures import ThreadPoolExecutor, wait +from dataclasses import dataclass +from datetime import datetime +from typing import Callable, List, Literal, Optional, Sequence, Union + +import boto3 +from pydantic import StrictStr +from tqdm import tqdm + +from feast.batch_feature_view import BatchFeatureView +from feast.constants import FEATURE_STORE_YAML_ENV_NAME +from feast.entity import Entity +from feast.feature_view import FeatureView +from feast.infra.materialization.batch_materialization_engine import ( + BatchMaterializationEngine, + MaterializationJob, + MaterializationJobStatus, + MaterializationTask, +) +from feast.infra.offline_stores.offline_store import OfflineStore +from feast.infra.online_stores.online_store import OnlineStore +from feast.registry import BaseRegistry +from feast.repo_config import FeastConfigBaseModel, RepoConfig +from feast.stream_feature_view import StreamFeatureView +from feast.utils import _get_column_names +from feast.version import get_version + +DEFAULT_BATCH_SIZE = 10_000 + +logger = logging.getLogger(__name__) + + +class LambdaMaterializationEngineConfig(FeastConfigBaseModel): + """Batch Materialization Engine config for lambda based engine""" + + type: Literal["lambda"] = "lambda" + """ Type selector""" + + materialization_image: StrictStr + """ The URI of a container image in the Amazon ECR registry, which should be used for materialization. """ + + lambda_role: StrictStr + """ Role that should be used by the materialization lambda """ + + +@dataclass +class LambdaMaterializationJob(MaterializationJob): + def __init__(self, job_id: str, status: MaterializationJobStatus) -> None: + super().__init__() + self._job_id: str = job_id + self._status = status + self._error = None + + def status(self) -> MaterializationJobStatus: + return self._status + + def error(self) -> Optional[BaseException]: + return self._error + + def should_be_retried(self) -> bool: + return False + + def job_id(self) -> str: + return self._job_id + + def url(self) -> Optional[str]: + return None + + +class LambdaMaterializationEngine(BatchMaterializationEngine): + """ + WARNING: This engine should be considered "Alpha" functionality. + """ + + def update( + self, + project: str, + views_to_delete: Sequence[ + Union[BatchFeatureView, StreamFeatureView, FeatureView] + ], + views_to_keep: Sequence[ + Union[BatchFeatureView, StreamFeatureView, FeatureView] + ], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + ): + # This should be setting up the lambda function. + r = self.lambda_client.create_function( + FunctionName=self.lambda_name, + PackageType="Image", + Role=self.repo_config.batch_engine.lambda_role, + Code={"ImageUri": self.repo_config.batch_engine.materialization_image}, + Timeout=600, + Tags={ + "feast-owned": "True", + "project": project, + "feast-sdk-version": get_version(), + }, + ) + logger.info("Creating lambda function %s, %s", self.lambda_name, r) + + logger.info("Waiting for function %s to be active", self.lambda_name) + waiter = self.lambda_client.get_waiter("function_active") + waiter.wait(FunctionName=self.lambda_name) + + def teardown_infra( + self, + project: str, + fvs: Sequence[Union[BatchFeatureView, StreamFeatureView, FeatureView]], + entities: Sequence[Entity], + ): + # This should be tearing down the lambda function. + logger.info("Tearing down lambda %s", self.lambda_name) + r = self.lambda_client.delete_function(FunctionName=self.lambda_name) + logger.info("Finished tearing down lambda %s: %s", self.lambda_name, r) + + def __init__( + self, + *, + repo_config: RepoConfig, + offline_store: OfflineStore, + online_store: OnlineStore, + **kwargs, + ): + super().__init__( + repo_config=repo_config, + offline_store=offline_store, + online_store=online_store, + **kwargs, + ) + repo_path = self.repo_config.repo_path + assert repo_path + feature_store_path = repo_path / "feature_store.yaml" + self.feature_store_base64 = str( + base64.b64encode(bytes(feature_store_path.read_text(), "UTF-8")), "UTF-8" + ) + + self.lambda_name = f"feast-materialize-{self.repo_config.project}" + if len(self.lambda_name) > 64: + self.lambda_name = self.lambda_name[:64] + self.lambda_client = boto3.client("lambda") + + def materialize( + self, registry, tasks: List[MaterializationTask] + ) -> List[MaterializationJob]: + return [ + self._materialize_one( + registry, + task.feature_view, + task.start_time, + task.end_time, + task.project, + task.tqdm_builder, + ) + for task in tasks + ] + + def _materialize_one( + self, + registry: BaseRegistry, + feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView], + start_date: datetime, + end_date: datetime, + project: str, + tqdm_builder: Callable[[int], tqdm], + ): + entities = [] + for entity_name in feature_view.entities: + entities.append(registry.get_entity(entity_name, project)) + + ( + join_key_columns, + feature_name_columns, + timestamp_field, + created_timestamp_column, + ) = _get_column_names(feature_view, entities) + + job_id = f"{feature_view.name}-{start_date}-{end_date}" + + offline_job = self.offline_store.pull_latest_from_table_or_query( + config=self.repo_config, + data_source=feature_view.batch_source, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + start_date=start_date, + end_date=end_date, + ) + + paths = offline_job.to_remote_storage() + max_workers = len(paths) if len(paths) <= 20 else 20 + executor = ThreadPoolExecutor(max_workers=max_workers) + futures = [] + + for path in paths: + payload = { + FEATURE_STORE_YAML_ENV_NAME: self.feature_store_base64, + "view_name": feature_view.name, + "view_type": "batch", + "path": path, + } + # Invoke a lambda to materialize this file. + + logger.info("Invoking materialization for %s", path) + futures.append( + executor.submit( + self.lambda_client.invoke, + FunctionName=self.lambda_name, + InvocationType="RequestResponse", + Payload=json.dumps(payload), + ) + ) + + done, not_done = wait(futures) + logger.info("Done: %s Not Done: %s", done, not_done) + for f in done: + response = f.result() + output = json.loads(response["Payload"].read()) + + logger.info( + f"Ingested task; request id {response['ResponseMetadata']['RequestId']}, " + f"rows written: {output['written_rows']}" + ) + + for f in not_done: + response = f.result() + logger.error(f"Ingestion failed: {response}") + + return LambdaMaterializationJob( + job_id=job_id, + status=MaterializationJobStatus.SUCCEEDED + if not not_done + else MaterializationJobStatus.ERROR, + ) diff --git a/sdk/python/feast/infra/online_stores/dynamodb.py b/sdk/python/feast/infra/online_stores/dynamodb.py index 50709fa3d4..6919f2cc29 100644 --- a/sdk/python/feast/infra/online_stores/dynamodb.py +++ b/sdk/python/feast/infra/online_stores/dynamodb.py @@ -229,7 +229,8 @@ def online_read( break batch_entity_ids = { table_instance.name: { - "Keys": [{"entity_id": entity_id} for entity_id in batch] + "Keys": [{"entity_id": entity_id} for entity_id in batch], + "ConsistentRead": True, } } with tracing_span(name="remote_call"): diff --git a/sdk/python/feast/infra/passthrough_provider.py b/sdk/python/feast/infra/passthrough_provider.py index 181d46a5a8..e31eb1e177 100644 --- a/sdk/python/feast/infra/passthrough_provider.py +++ b/sdk/python/feast/infra/passthrough_provider.py @@ -22,7 +22,7 @@ from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.registry import BaseRegistry -from feast.repo_config import RepoConfig +from feast.repo_config import BATCH_ENGINE_CLASS_FOR_TYPE, RepoConfig from feast.saved_dataset import SavedDataset from feast.stream_feature_view import StreamFeatureView from feast.usage import RatioSampler, log_exceptions_and_usage, set_usage_attribute @@ -34,10 +34,6 @@ DEFAULT_BATCH_SIZE = 10_000 -BATCH_ENGINE_CLASS_FOR_TYPE = { - "local": "feast.infra.materialization.LocalMaterializationEngine", -} - class PassthroughProvider(Provider): """ @@ -73,7 +69,7 @@ def batch_engine(self) -> BatchMaterializationEngine: if self._batch_engine: return self._batch_engine else: - engine_config = self.repo_config.batch_engine_config + engine_config = self.repo_config._batch_engine_config config_is_dict = False if isinstance(engine_config, str): engine_config_type = engine_config @@ -129,6 +125,14 @@ def update_infra( entities_to_delete=entities_to_delete, partial=partial, ) + if self.batch_engine: + self.batch_engine.update( + project, + tables_to_delete, + tables_to_keep, + entities_to_delete, + entities_to_keep, + ) def teardown_infra( self, project: str, tables: Sequence[FeatureView], entities: Sequence[Entity], @@ -136,6 +140,8 @@ def teardown_infra( set_usage_attribute("provider", self.__class__.__name__) if self.online_store: self.online_store.teardown(self.repo_config, tables, entities) + if self.batch_engine: + self.batch_engine.teardown_infra(project, tables, entities) def online_write_batch( self, diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index f315023ee1..f7f564df6f 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -30,6 +30,12 @@ # These dict exists so that: # - existing values for the online store type in featurestore.yaml files continue to work in a backwards compatible way # - first party and third party implementations can use the same class loading code path. +BATCH_ENGINE_CLASS_FOR_TYPE = { + "local": "feast.infra.materialization.LocalMaterializationEngine", + "lambda": "feast.infra.materialization.lambda.lambda_engine.LambdaMaterializationEngine", +} + + ONLINE_STORE_CLASS_FOR_TYPE = { "sqlite": "feast.infra.online_stores.sqlite.SqliteOnlineStore", "datastore": "feast.infra.online_stores.datastore.DatastoreOnlineStore", @@ -120,7 +126,7 @@ class RepoConfig(FeastBaseModel): _offline_config: Any = Field(alias="offline_store") """ OfflineStoreConfig: Offline store configuration (optional depending on provider) """ - batch_engine_config: Any = Field(alias="batch_engine") + _batch_engine_config: Any = Field(alias="batch_engine") """ BatchMaterializationEngine: Batch materialization configuration (optional depending on provider)""" feature_server: Optional[Any] @@ -160,10 +166,12 @@ def __init__(self, **data: Any): self._batch_engine = None if "batch_engine" in data: - self.batch_engine_config = data["batch_engine"] + self._batch_engine_config = data["batch_engine"] + elif "batch_engine_config" in data: + self._batch_engine_config = data["batch_engine_config"] else: # Defaults to using local in-process materialization engine. - self.batch_engine_config = "local" + self._batch_engine_config = "local" if isinstance(self.feature_server, Dict): self.feature_server = get_feature_server_config_from_type( @@ -205,6 +213,22 @@ def online_store(self): return self._online_store + @property + def batch_engine(self): + if not self._batch_engine: + if isinstance(self._batch_engine_config, Dict): + self._batch_engine = get_batch_engine_config_from_type( + self._batch_engine_config["type"] + )(**self._batch_engine_config) + elif isinstance(self._batch_engine_config, str): + self._batch_engine = get_batch_engine_config_from_type( + self._batch_engine_config + )() + elif self._batch_engine_config: + self._batch_engine = self._batch_engine + + return self._batch_engine + @root_validator(pre=True) @log_exceptions def _validate_online_store_config(cls, values): @@ -382,6 +406,17 @@ def get_data_source_class_from_type(data_source_type: str): return import_class(module_name, config_class_name, "DataSource") +def get_batch_engine_config_from_type(batch_engine_type: str): + if batch_engine_type in BATCH_ENGINE_CLASS_FOR_TYPE: + batch_engine_type = BATCH_ENGINE_CLASS_FOR_TYPE[batch_engine_type] + else: + assert batch_engine_type.endswith("Engine") + module_name, batch_engine_class_type = batch_engine_type.rsplit(".", 1) + config_class_name = f"{batch_engine_class_type}Config" + + return import_class(module_name, config_class_name, config_class_name) + + def get_online_config_from_type(online_store_type: str): if online_store_type in ONLINE_STORE_CLASS_FOR_TYPE: online_store_type = ONLINE_STORE_CLASS_FOR_TYPE[online_store_type] diff --git a/sdk/python/tests/conftest.py b/sdk/python/tests/conftest.py index 1700970f1e..ac30149cfa 100644 --- a/sdk/python/tests/conftest.py +++ b/sdk/python/tests/conftest.py @@ -45,6 +45,22 @@ logger = logging.getLogger(__name__) +level = logging.INFO +logging.basicConfig( + format="%(asctime)s %(name)s %(levelname)s: %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", + level=level, +) +# Override the logging level for already created loggers (due to loggers being created at the import time) +# Note, that format & datefmt does not need to be set, because by default child loggers don't override them + +# Also note, that mypy complains that logging.root doesn't have "manager" because of the way it's written. +# So we have to put a type ignore hint for mypy. +for logger_name in logging.root.manager.loggerDict: # type: ignore + if "feast" in logger_name: + logger = logging.getLogger(logger_name) + logger.setLevel(level) + def pytest_configure(config): if platform in ["darwin", "windows"]: diff --git a/sdk/python/tests/integration/feature_repos/integration_test_repo_config.py b/sdk/python/tests/integration/feature_repos/integration_test_repo_config.py index 74ce37f17a..d2e0f70ba2 100644 --- a/sdk/python/tests/integration/feature_repos/integration_test_repo_config.py +++ b/sdk/python/tests/integration/feature_repos/integration_test_repo_config.py @@ -1,5 +1,6 @@ import hashlib from dataclasses import dataclass +from enum import Enum from typing import Dict, Optional, Type, Union from tests.integration.feature_repos.universal.data_source_creator import ( @@ -13,6 +14,11 @@ ) +class RegistryLocation(Enum): + Local = 1 + S3 = 2 + + @dataclass(frozen=False) class IntegrationTestRepoConfig: """ @@ -25,6 +31,9 @@ class IntegrationTestRepoConfig: offline_store_creator: Type[DataSourceCreator] = FileDataSourceCreator online_store_creator: Optional[Type[OnlineStoreCreator]] = None + batch_engine: Optional[Union[str, Dict]] = "local" + registry_location: RegistryLocation = RegistryLocation.Local + full_feature_names: bool = True infer_features: bool = False python_feature_server: bool = False diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index 4dc1db4a13..4300ca64b6 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -22,6 +22,7 @@ from feast.repo_config import RegistryConfig, RepoConfig from tests.integration.feature_repos.integration_test_repo_config import ( IntegrationTestRepoConfig, + RegistryLocation, ) from tests.integration.feature_repos.universal.data_source_creator import ( DataSourceCreator, @@ -381,8 +382,6 @@ def construct_test_environment( online_creator = None online_store = test_repo_config.online_store - repo_dir_name = tempfile.mkdtemp() - if test_repo_config.python_feature_server and test_repo_config.provider == "aws": from feast.infra.feature_servers.aws_lambda.config import ( AwsLambdaFeatureServerConfig, @@ -393,22 +392,30 @@ def construct_test_environment( execution_role_name="arn:aws:iam::402087665549:role/lambda_execution_role", ) - registry = ( - f"s3://feast-integration-tests/registries/{project}/registry.db" - ) # type: Union[str, RegistryConfig] else: feature_server = LocalFeatureServerConfig( feature_logging=FeatureLoggingConfig(enabled=True) ) + + repo_dir_name = tempfile.mkdtemp() + if ( + test_repo_config.python_feature_server and test_repo_config.provider == "aws" + ) or test_repo_config.registry_location == RegistryLocation.S3: + registry: Union[str, RegistryConfig] = ( + f"s3://feast-integration-tests/registries/{project}/registry.db" + ) + else: registry = RegistryConfig( path=str(Path(repo_dir_name) / "registry.db"), cache_ttl_seconds=1, ) + config = RepoConfig( registry=registry, project=project, provider=test_repo_config.provider, offline_store=offline_store_config, online_store=online_store, + batch_engine=test_repo_config.batch_engine, repo_path=repo_dir_name, feature_server=feature_server, go_feature_retrieval=test_repo_config.go_feature_retrieval, diff --git a/sdk/python/tests/integration/materialization/test_lambda.py b/sdk/python/tests/integration/materialization/test_lambda.py new file mode 100644 index 0000000000..66cd2c5eb9 --- /dev/null +++ b/sdk/python/tests/integration/materialization/test_lambda.py @@ -0,0 +1,200 @@ +import math +import time +from datetime import datetime, timedelta +from typing import Optional + +import pandas as pd +import pytest +from pytz import utc + +from feast import Entity, Feature, FeatureStore, FeatureView, ValueType +from tests.data.data_creator import create_basic_driver_dataset +from tests.integration.feature_repos.integration_test_repo_config import ( + IntegrationTestRepoConfig, + RegistryLocation, +) +from tests.integration.feature_repos.repo_configuration import ( + construct_test_environment, +) +from tests.integration.feature_repos.universal.data_sources.redshift import ( + RedshiftDataSourceCreator, +) + + +@pytest.mark.integration +def test_lambda_materialization(): + lambda_config = IntegrationTestRepoConfig( + provider="aws", + online_store={"type": "dynamodb", "region": "us-west-2"}, + offline_store_creator=RedshiftDataSourceCreator, + batch_engine={ + "type": "lambda", + "materialization_image": "402087665549.dkr.ecr.us-west-2.amazonaws.com/feast-lambda-consumer:v1", + "lambda_role": "arn:aws:iam::402087665549:role/lambda_execution_role", + }, + registry_location=RegistryLocation.S3, + ) + lambda_environment = construct_test_environment(lambda_config, None) + + df = create_basic_driver_dataset() + ds = lambda_environment.data_source_creator.create_data_source( + df, lambda_environment.feature_store.project, field_mapping={"ts_1": "ts"}, + ) + + fs = lambda_environment.feature_store + driver = Entity(name="driver_id", join_key="driver_id", value_type=ValueType.INT64,) + + driver_stats_fv = FeatureView( + name="driver_hourly_stats", + entities=["driver_id"], + ttl=timedelta(weeks=52), + features=[Feature(name="value", dtype=ValueType.FLOAT)], + batch_source=ds, + ) + + try: + + fs.apply([driver, driver_stats_fv]) + + print(df) + + # materialization is run in two steps and + # we use timestamp from generated dataframe as a split point + split_dt = df["ts_1"][4].to_pydatetime() - timedelta(seconds=1) + + print(f"Split datetime: {split_dt}") + + run_offline_online_store_consistency_test(fs, driver_stats_fv, split_dt) + finally: + fs.teardown() + + +def check_offline_and_online_features( + fs: FeatureStore, + fv: FeatureView, + driver_id: int, + event_timestamp: datetime, + expected_value: Optional[float], + full_feature_names: bool, + check_offline_store: bool = True, +) -> None: + # Check online store + response_dict = fs.get_online_features( + [f"{fv.name}:value"], + [{"driver_id": driver_id}], + full_feature_names=full_feature_names, + ).to_dict() + + if full_feature_names: + + if expected_value: + assert response_dict[f"{fv.name}__value"][0], f"Response: {response_dict}" + assert ( + abs(response_dict[f"{fv.name}__value"][0] - expected_value) < 1e-6 + ), f"Response: {response_dict}, Expected: {expected_value}" + else: + assert response_dict[f"{fv.name}__value"][0] is None + else: + if expected_value: + assert response_dict["value"][0], f"Response: {response_dict}" + assert ( + abs(response_dict["value"][0] - expected_value) < 1e-6 + ), f"Response: {response_dict}, Expected: {expected_value}" + else: + assert response_dict["value"][0] is None + + # Check offline store + if check_offline_store: + df = fs.get_historical_features( + entity_df=pd.DataFrame.from_dict( + {"driver_id": [driver_id], "event_timestamp": [event_timestamp]} + ), + features=[f"{fv.name}:value"], + full_feature_names=full_feature_names, + ).to_df() + + if full_feature_names: + if expected_value: + assert ( + abs( + df.to_dict(orient="list")[f"{fv.name}__value"][0] + - expected_value + ) + < 1e-6 + ) + else: + assert not df.to_dict(orient="list")[f"{fv.name}__value"] or math.isnan( + df.to_dict(orient="list")[f"{fv.name}__value"][0] + ) + else: + if expected_value: + assert ( + abs(df.to_dict(orient="list")["value"][0] - expected_value) < 1e-6 + ) + else: + assert not df.to_dict(orient="list")["value"] or math.isnan( + df.to_dict(orient="list")["value"][0] + ) + + +def run_offline_online_store_consistency_test( + fs: FeatureStore, fv: FeatureView, split_dt: datetime +) -> None: + now = datetime.utcnow() + + full_feature_names = True + check_offline_store: bool = True + + # Run materialize() + # use both tz-naive & tz-aware timestamps to test that they're both correctly handled + start_date = (now - timedelta(hours=5)).replace(tzinfo=utc) + end_date = split_dt + fs.materialize(feature_views=[fv.name], start_date=start_date, end_date=end_date) + + time.sleep(10) + + # check result of materialize() + check_offline_and_online_features( + fs=fs, + fv=fv, + driver_id=1, + event_timestamp=end_date, + expected_value=0.3, + full_feature_names=full_feature_names, + check_offline_store=check_offline_store, + ) + + check_offline_and_online_features( + fs=fs, + fv=fv, + driver_id=2, + event_timestamp=end_date, + expected_value=None, + full_feature_names=full_feature_names, + check_offline_store=check_offline_store, + ) + + # check prior value for materialize_incremental() + check_offline_and_online_features( + fs=fs, + fv=fv, + driver_id=3, + event_timestamp=end_date, + expected_value=4, + full_feature_names=full_feature_names, + check_offline_store=check_offline_store, + ) + + # run materialize_incremental() + fs.materialize_incremental(feature_views=[fv.name], end_date=now) + + # check result of materialize_incremental() + check_offline_and_online_features( + fs=fs, + fv=fv, + driver_id=3, + event_timestamp=now, + expected_value=5, + full_feature_names=full_feature_names, + check_offline_store=check_offline_store, + )