diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 6112279a02..9a652fd060 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -602,16 +602,23 @@ def _make_inferences( # New feature views may reference previously applied entities. entities = self._list_entities() + provider = self._get_provider() update_feature_views_with_inferred_features_and_entities( - views_to_update, entities + entities_to_update, self.config + provider, + views_to_update, + entities + entities_to_update, + self.config, ) update_feature_views_with_inferred_features_and_entities( - sfvs_to_update, entities + entities_to_update, self.config + provider, + sfvs_to_update, + entities + entities_to_update, + self.config, ) # We need to attach the time stamp fields to the underlying data sources # and cascade the dependencies update_feature_views_with_inferred_features_and_entities( - odfvs_to_update, entities + entities_to_update, self.config + provider, odfvs_to_update, entities + entities_to_update, self.config ) # TODO(kevjumba): Update schema inference for sfv in sfvs_to_update: @@ -1529,9 +1536,12 @@ def write_to_offline_store( feature_view_name, allow_registry_cache=allow_registry_cache ) + provider = self._get_provider() # Get columns of the batch source and the input dataframe. column_names_and_types = ( - feature_view.batch_source.get_table_column_names_and_types(self.config) + provider.get_table_column_names_and_types_from_data_source( + self.config, feature_view.batch_source + ) ) source_columns = [column for column, _ in column_names_and_types] input_columns = df.columns.values.tolist() @@ -1545,7 +1555,6 @@ def write_to_offline_store( df = df.reindex(columns=source_columns) table = pa.Table.from_pandas(df) - provider = self._get_provider() provider.ingest_df_to_offline_store(feature_view, table) def get_online_features( diff --git a/sdk/python/feast/inference.py b/sdk/python/feast/inference.py index 39782e1e31..f2a2ee637f 100644 --- a/sdk/python/feast/inference.py +++ b/sdk/python/feast/inference.py @@ -13,6 +13,7 @@ from feast.infra.offline_stores.file_source import FileSource from feast.infra.offline_stores.redshift_source import RedshiftSource from feast.infra.offline_stores.snowflake_source import SnowflakeSource +from feast.infra.provider import Provider from feast.on_demand_feature_view import OnDemandFeatureView from feast.repo_config import RepoConfig from feast.stream_feature_view import StreamFeatureView @@ -95,6 +96,7 @@ def update_data_sources_with_inferred_event_timestamp_col( def update_feature_views_with_inferred_features_and_entities( + provider: Provider, fvs: Union[List[FeatureView], List[StreamFeatureView], List[OnDemandFeatureView]], entities: List[Entity], config: RepoConfig, @@ -176,6 +178,7 @@ def update_feature_views_with_inferred_features_and_entities( if run_inference_for_entities or run_inference_for_features: _infer_features_and_entities( + provider, fv, join_keys, run_inference_for_features, @@ -193,6 +196,7 @@ def update_feature_views_with_inferred_features_and_entities( def _infer_features_and_entities( + provider: Provider, fv: Union[FeatureView, OnDemandFeatureView], join_keys: Set[Optional[str]], run_inference_for_features, @@ -222,8 +226,10 @@ def _infer_features_and_entities( columns_to_exclude.remove(mapped_col) columns_to_exclude.add(original_col) - table_column_names_and_types = fv.batch_source.get_table_column_names_and_types( - config + table_column_names_and_types = ( + provider.get_table_column_names_and_types_from_data_source( + config, fv.batch_source + ) ) for col_name, col_datatype in table_column_names_and_types: diff --git a/sdk/python/feast/infra/offline_stores/offline_store.py b/sdk/python/feast/infra/offline_stores/offline_store.py index d973844531..69d6bb278b 100644 --- a/sdk/python/feast/infra/offline_stores/offline_store.py +++ b/sdk/python/feast/infra/offline_stores/offline_store.py @@ -15,7 +15,16 @@ from abc import ABC from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterable, + List, + Optional, + Tuple, + Union, +) import pandas as pd import pyarrow @@ -352,8 +361,8 @@ def offline_write_batch( """ raise NotImplementedError - @staticmethod def validate_data_source( + self, config: RepoConfig, data_source: DataSource, ): @@ -365,3 +374,17 @@ def validate_data_source( data_source: DataSource object that needs to be validated """ data_source.validate(config=config) + + def get_table_column_names_and_types_from_data_source( + self, + config: RepoConfig, + data_source: DataSource, + ) -> Iterable[Tuple[str, str]]: + """ + Returns the list of column names and raw column types for a DataSource. + + Args: + config: Configuration object used to configure a feature store. + data_source: DataSource object + """ + return data_source.get_table_column_names_and_types(config=config) diff --git a/sdk/python/feast/infra/offline_stores/remote.py b/sdk/python/feast/infra/offline_stores/remote.py index 8154f75f87..7ee018ac6d 100644 --- a/sdk/python/feast/infra/offline_stores/remote.py +++ b/sdk/python/feast/infra/offline_stores/remote.py @@ -3,7 +3,7 @@ import uuid from datetime import datetime from pathlib import Path -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union import numpy as np import pandas as pd @@ -328,6 +328,57 @@ def offline_write_batch( entity_df=None, ) + def validate_data_source( + self, + config: RepoConfig, + data_source: DataSource, + ): + assert isinstance(config.offline_store, RemoteOfflineStoreConfig) + + client = build_arrow_flight_client( + config.offline_store.host, config.offline_store.port, config.auth_config + ) + + api_parameters = { + "data_source_proto": str(data_source), + } + logger.debug(f"validating DataSource {data_source.name}") + _call_put( + api=OfflineStore.validate_data_source.__name__, + api_parameters=api_parameters, + client=client, + table=None, + entity_df=None, + ) + + def get_table_column_names_and_types_from_data_source( + self, config: RepoConfig, data_source: DataSource + ) -> Iterable[Tuple[str, str]]: + assert isinstance(config.offline_store, RemoteOfflineStoreConfig) + + client = build_arrow_flight_client( + config.offline_store.host, config.offline_store.port, config.auth_config + ) + + api_parameters = { + "data_source_proto": str(data_source), + } + logger.debug( + f"Calling {OfflineStore.get_table_column_names_and_types_from_data_source.__name__} with {api_parameters}" + ) + table = _send_retrieve_remote( + api=OfflineStore.get_table_column_names_and_types_from_data_source.__name__, + api_parameters=api_parameters, + client=client, + table=None, + entity_df=None, + ) + + logger.debug( + f"get_table_column_names_and_types_from_data_source for {data_source.name}: {table}" + ) + return zip(table.column("name").to_pylist(), table.column("type").to_pylist()) + def _create_retrieval_metadata(feature_refs: List[str], entity_df: pd.DataFrame): entity_schema = _get_entity_schema( diff --git a/sdk/python/feast/infra/passthrough_provider.py b/sdk/python/feast/infra/passthrough_provider.py index 11ccbb2450..5acfc0d6f3 100644 --- a/sdk/python/feast/infra/passthrough_provider.py +++ b/sdk/python/feast/infra/passthrough_provider.py @@ -1,5 +1,16 @@ from datetime import datetime, timedelta -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) import pandas as pd import pyarrow as pa @@ -455,3 +466,10 @@ def validate_data_source( data_source: DataSource, ): self.offline_store.validate_data_source(config=config, data_source=data_source) + + def get_table_column_names_and_types_from_data_source( + self, config: RepoConfig, data_source: DataSource + ) -> Iterable[Tuple[str, str]]: + return self.offline_store.get_table_column_names_and_types_from_data_source( + config=config, data_source=data_source + ) diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index 22c2088861..0723e0513f 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -1,7 +1,18 @@ from abc import ABC, abstractmethod from datetime import datetime from pathlib import Path -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) import pandas as pd import pyarrow @@ -405,6 +416,19 @@ def validate_data_source( """ pass + @abstractmethod + def get_table_column_names_and_types_from_data_source( + self, config: RepoConfig, data_source: DataSource + ) -> Iterable[Tuple[str, str]]: + """ + Returns the list of column names and raw column types for a DataSource. + + Args: + config: Configuration object used to configure a feature store. + data_source: DataSource object + """ + pass + def get_provider(config: RepoConfig) -> Provider: if "." not in config.provider: diff --git a/sdk/python/feast/infra/registry/remote.py b/sdk/python/feast/infra/registry/remote.py index cdb45f0363..424c59c57d 100644 --- a/sdk/python/feast/infra/registry/remote.py +++ b/sdk/python/feast/infra/registry/remote.py @@ -15,7 +15,6 @@ from feast.infra.infra_object import Infra from feast.infra.registry.base_registry import BaseRegistry from feast.on_demand_feature_view import OnDemandFeatureView -from feast.permissions.auth.auth_type import AuthType from feast.permissions.auth_model import AuthConfig, NoAuthConfig from feast.permissions.client.grpc_client_auth_interceptor import ( GrpcClientAuthHeaderInterceptor, @@ -67,9 +66,8 @@ def __init__( ): self.auth_config = auth_config self.channel = grpc.insecure_channel(registry_config.path) - if self.auth_config.type != AuthType.NONE.value: - auth_header_interceptor = GrpcClientAuthHeaderInterceptor(auth_config) - self.channel = grpc.intercept_channel(self.channel, auth_header_interceptor) + auth_header_interceptor = GrpcClientAuthHeaderInterceptor(auth_config) + self.channel = grpc.intercept_channel(self.channel, auth_header_interceptor) self.stub = RegistryServer_pb2_grpc.RegistryServerStub(self.channel) def close(self): diff --git a/sdk/python/feast/offline_server.py b/sdk/python/feast/offline_server.py index ff3db579d0..0cb40ad934 100644 --- a/sdk/python/feast/offline_server.py +++ b/sdk/python/feast/offline_server.py @@ -7,9 +7,11 @@ import pyarrow as pa import pyarrow.flight as fl +from google.protobuf.json_format import Parse from feast import FeatureStore, FeatureView, utils from feast.arrow_error_handler import arrow_server_error_handling_decorator +from feast.data_source import DataSource from feast.feature_logging import FeatureServiceLoggingSource from feast.feature_view import DUMMY_ENTITY_NAME from feast.infra.offline_stores.offline_utils import get_offline_store_from_config @@ -26,6 +28,7 @@ init_security_manager, str_to_auth_manager_type, ) +from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto from feast.saved_dataset import SavedDatasetStorage logger = logging.getLogger(__name__) @@ -138,6 +141,9 @@ def _call_api(self, api: str, command: dict, key: str): elif api == OfflineServer.persist.__name__: self.persist(command, key) remove_data = True + elif api == OfflineServer.validate_data_source.__name__: + self.validate_data_source(command) + remove_data = True except Exception as e: remove_data = True logger.exception(e) @@ -224,6 +230,11 @@ def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket): table = self.pull_all_from_table_or_query(command).to_arrow() elif api == OfflineServer.pull_latest_from_table_or_query.__name__: table = self.pull_latest_from_table_or_query(command).to_arrow() + elif ( + api + == OfflineServer.get_table_column_names_and_types_from_data_source.__name__ + ): + table = self.get_table_column_names_and_types_from_data_source(command) else: raise NotImplementedError except Exception as e: @@ -457,6 +468,41 @@ def persist(self, command: dict, key: str): traceback.print_exc() raise e + @staticmethod + def _extract_data_source_from_command(command) -> DataSource: + data_source_proto_str = command["data_source_proto"] + logger.debug(f"Extracted data_source_proto {data_source_proto_str}") + data_source_proto = DataSourceProto() + Parse(data_source_proto_str, data_source_proto) + data_source = DataSource.from_proto(data_source_proto) + logger.debug(f"Converted to DataSource {data_source}") + return data_source + + def validate_data_source(self, command: dict): + data_source = OfflineServer._extract_data_source_from_command(command) + logger.debug(f"Validating data source {data_source.name}") + assert_permissions(data_source, actions=[AuthzedAction.READ_OFFLINE]) + + self.offline_store.validate_data_source( + config=self.store.config, + data_source=data_source, + ) + + def get_table_column_names_and_types_from_data_source(self, command: dict): + data_source = OfflineServer._extract_data_source_from_command(command) + logger.debug(f"Fetching table columns metadata from {data_source.name}") + assert_permissions(data_source, actions=[AuthzedAction.READ_OFFLINE]) + + column_names_and_types = data_source.get_table_column_names_and_types( + self.store.config + ) + + column_names, types = zip(*column_names_and_types) + logger.debug( + f"DataSource {data_source.name} has columns {column_names} with types {types}" + ) + return pa.table({"name": column_names, "type": types}) + def remove_dummies(fv: FeatureView) -> FeatureView: """ diff --git a/sdk/python/feast/permissions/client/grpc_client_auth_interceptor.py b/sdk/python/feast/permissions/client/grpc_client_auth_interceptor.py index 121735e351..9a6bef2c07 100644 --- a/sdk/python/feast/permissions/client/grpc_client_auth_interceptor.py +++ b/sdk/python/feast/permissions/client/grpc_client_auth_interceptor.py @@ -3,6 +3,7 @@ import grpc from feast.errors import FeastError +from feast.permissions.auth.auth_type import AuthType from feast.permissions.auth_model import AuthConfig from feast.permissions.client.client_auth_token import get_auth_token @@ -15,8 +16,8 @@ class GrpcClientAuthHeaderInterceptor( grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor, ): - def __init__(self, auth_type: AuthConfig): - self._auth_type = auth_type + def __init__(self, auth_config: AuthConfig): + self._auth_config = auth_config def intercept_unary_unary( self, continuation, client_call_details, request_iterator @@ -39,7 +40,8 @@ def intercept_stream_stream( return self._handle_call(continuation, client_call_details, request_iterator) def _handle_call(self, continuation, client_call_details, request_iterator): - client_call_details = self._append_auth_header_metadata(client_call_details) + if self._auth_config.type != AuthType.NONE.value: + client_call_details = self._append_auth_header_metadata(client_call_details) result = continuation(client_call_details, request_iterator) if result.exception() is not None: mapped_error = FeastError.from_error_detail(result.exception().details()) @@ -52,7 +54,7 @@ def _append_auth_header_metadata(self, client_call_details): "Intercepted the grpc api method call to inject Authorization header " ) metadata = client_call_details.metadata or [] - access_token = get_auth_token(self._auth_type) + access_token = get_auth_token(self._auth_config) metadata.append((b"authorization", b"Bearer " + access_token.encode("utf-8"))) client_call_details = client_call_details._replace(metadata=metadata) return client_call_details diff --git a/sdk/python/tests/foo_provider.py b/sdk/python/tests/foo_provider.py index 8e8f54db24..6fe6d15150 100644 --- a/sdk/python/tests/foo_provider.py +++ b/sdk/python/tests/foo_provider.py @@ -1,6 +1,17 @@ from datetime import datetime from pathlib import Path -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) import pandas import pyarrow @@ -141,6 +152,11 @@ def validate_data_source( ): pass + def get_table_column_names_and_types_from_data_source( + self, config: RepoConfig, data_source: DataSource + ) -> Iterable[Tuple[str, str]]: + return [] + def get_online_features( self, config: RepoConfig, diff --git a/sdk/python/tests/unit/infra/test_inference_unit_tests.py b/sdk/python/tests/unit/infra/test_inference_unit_tests.py index 3d8fe8c967..54488d4321 100644 --- a/sdk/python/tests/unit/infra/test_inference_unit_tests.py +++ b/sdk/python/tests/unit/infra/test_inference_unit_tests.py @@ -6,7 +6,10 @@ from feast import BigQuerySource, FileSource, RedshiftSource, SnowflakeSource from feast.data_source import RequestSource from feast.entity import Entity -from feast.errors import DataSourceNoNameException, SpecifiedFeaturesNotPresentError +from feast.errors import ( + DataSourceNoNameException, + SpecifiedFeaturesNotPresentError, +) from feast.feature_service import FeatureService from feast.feature_view import FeatureView from feast.field import Field @@ -14,6 +17,7 @@ from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import ( SparkSource, ) +from feast.infra.provider import get_provider from feast.on_demand_feature_view import on_demand_feature_view from feast.repo_config import RepoConfig from feast.types import Float32, Float64, Int64, String, UnixTimestamp @@ -253,15 +257,18 @@ def test_feature_view_inference_respects_basic_inference(): assert len(feature_view_1.features) == 1 assert len(feature_view_1.entity_columns) == 1 + config = RepoConfig( + provider="local", + project="test", + entity_key_serialization_version=2, + registry="dummy_registry.pb", + ) + provider = get_provider(config) update_feature_views_with_inferred_features_and_entities( + provider, [feature_view_1], [entity1], - RepoConfig( - provider="local", - project="test", - entity_key_serialization_version=2, - registry="dummy_registry.pb", - ), + config, ) assert len(feature_view_1.schema) == 2 assert len(feature_view_1.features) == 1 @@ -271,15 +278,18 @@ def test_feature_view_inference_respects_basic_inference(): assert len(feature_view_2.features) == 1 assert len(feature_view_2.entity_columns) == 2 + config = RepoConfig( + provider="local", + project="test", + entity_key_serialization_version=2, + registry="dummy_registry.pb", + ) + provider = get_provider(config) update_feature_views_with_inferred_features_and_entities( + provider, [feature_view_2], [entity1, entity2], - RepoConfig( - provider="local", - project="test", - entity_key_serialization_version=2, - registry="dummy_registry.pb", - ), + config, ) assert len(feature_view_2.schema) == 3 assert len(feature_view_2.features) == 1 @@ -305,15 +315,18 @@ def test_feature_view_inference_on_entity_value_types(): assert len(feature_view_1.features) == 1 assert len(feature_view_1.entity_columns) == 0 + config = RepoConfig( + provider="local", + project="test", + entity_key_serialization_version=2, + registry="dummy_registry.pb", + ) + provider = get_provider(config) update_feature_views_with_inferred_features_and_entities( + provider, [feature_view_1], [entity1], - RepoConfig( - provider="local", - project="test", - entity_key_serialization_version=2, - registry="dummy_registry.pb", - ), + config, ) # The schema must be entity and features @@ -378,15 +391,18 @@ def test_feature_view_inference_on_entity_columns(simple_dataset_1): assert len(feature_view_1.features) == 1 assert len(feature_view_1.entity_columns) == 0 + config = RepoConfig( + provider="local", + project="test", + entity_key_serialization_version=2, + registry="dummy_registry.pb", + ) + provider = get_provider(config) update_feature_views_with_inferred_features_and_entities( + provider, [feature_view_1], [entity1], - RepoConfig( - provider="local", - project="test", - entity_key_serialization_version=2, - registry="dummy_registry.pb", - ), + config, ) # Since there is already a feature specified, additional features are not inferred. @@ -416,15 +432,18 @@ def test_feature_view_inference_on_feature_columns(simple_dataset_1): assert len(feature_view_1.features) == 0 assert len(feature_view_1.entity_columns) == 1 + config = RepoConfig( + provider="local", + project="test", + entity_key_serialization_version=2, + registry="dummy_registry.pb", + ) + provider = get_provider(config) update_feature_views_with_inferred_features_and_entities( + provider, [feature_view_1], [entity1], - RepoConfig( - provider="local", - project="test", - entity_key_serialization_version=2, - registry="dummy_registry.pb", - ), + config, ) # The schema is a property concatenating features and entity_columns @@ -471,15 +490,18 @@ def test_update_feature_services_with_inferred_features(simple_dataset_1): assert len(feature_service.feature_view_projections[1].features) == 0 assert len(feature_service.feature_view_projections[1].desired_features) == 0 + config = RepoConfig( + provider="local", + project="test", + entity_key_serialization_version=2, + registry="dummy_registry.pb", + ) + provider = get_provider(config) update_feature_views_with_inferred_features_and_entities( + provider, [feature_view_1, feature_view_2], [entity1], - RepoConfig( - provider="local", - project="test", - entity_key_serialization_version=2, - registry="dummy_registry.pb", - ), + config, ) feature_service.infer_features( fvs_to_update={ @@ -531,15 +553,18 @@ def test_update_feature_services_with_specified_features(simple_dataset_1): assert len(feature_service.feature_view_projections[1].features) == 1 assert len(feature_service.feature_view_projections[1].desired_features) == 0 + config = RepoConfig( + provider="local", + project="test", + entity_key_serialization_version=2, + registry="dummy_registry.pb", + ) + provider = get_provider(config) update_feature_views_with_inferred_features_and_entities( + provider, [feature_view_1, feature_view_2], [entity1], - RepoConfig( - provider="local", - project="test", - entity_key_serialization_version=2, - registry="dummy_registry.pb", - ), + config, ) assert len(feature_view_1.features) == 1 assert len(feature_view_2.features) == 1