From fd1243b769e910c161b56b4650116925788180a5 Mon Sep 17 00:00:00 2001 From: lokeshrangineni <19699092+lokeshrangineni@users.noreply.github.com> Date: Sat, 10 Aug 2024 23:30:39 -0400 Subject: [PATCH] Added the arrow flight interceptor to inject the auth header. (#68) * * Added the arrow flight interceptor to inject the auth header. * Injecting grpc interceptor if it is needed when auth type is not NO_AUTH. Signed-off-by: Lokesh Rangineni <19699092+lokeshrangineni@users.noreply.github.com> * Fixing the failing integration test cases by setting the header in binary format. Signed-off-by: Lokesh Rangineni <19699092+lokeshrangineni@users.noreply.github.com> * Refactored method and moved to factory class to incorporate code review comment. Fixed lint error by removing the type of port. and other minor changes. Signed-off-by: Lokesh Rangineni <19699092+lokeshrangineni@users.noreply.github.com> * Incorproating code review comments from Daniel. Signed-off-by: Lokesh Rangineni <19699092+lokeshrangineni@users.noreply.github.com> --------- Signed-off-by: Lokesh Rangineni <19699092+lokeshrangineni@users.noreply.github.com> Signed-off-by: Abdul Hameed --- .../feast/infra/offline_stores/remote.py | 61 +++++++------------ sdk/python/feast/infra/registry/remote.py | 10 +-- .../client/arrow_flight_auth_interceptor.py | 38 ++++++++++++ .../permissions/client/auth_client_manager.py | 30 --------- .../client/auth_client_manager_factory.py | 30 +++++++++ .../client/grpc_client_auth_interceptor.py | 22 +++---- .../client/http_auth_requests_wrapper.py | 5 +- sdk/python/feast/permissions/client/utils.py | 21 ------- .../offline_stores/test_offline_store.py | 2 - 9 files changed, 106 insertions(+), 113 deletions(-) create mode 100644 sdk/python/feast/permissions/client/arrow_flight_auth_interceptor.py create mode 100644 sdk/python/feast/permissions/client/auth_client_manager_factory.py delete mode 100644 sdk/python/feast/permissions/client/utils.py diff --git a/sdk/python/feast/infra/offline_stores/remote.py b/sdk/python/feast/infra/offline_stores/remote.py index 2545d47734..40239c8950 100644 --- a/sdk/python/feast/infra/offline_stores/remote.py +++ b/sdk/python/feast/infra/offline_stores/remote.py @@ -27,7 +27,9 @@ RetrievalMetadata, ) from feast.infra.registry.base_registry import BaseRegistry -from feast.permissions.client.utils import create_flight_call_options +from feast.permissions.client.arrow_flight_auth_interceptor import ( + build_arrow_flight_client, +) from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.saved_dataset import SavedDatasetStorage @@ -47,7 +49,6 @@ class RemoteRetrievalJob(RetrievalJob): def __init__( self, client: fl.FlightClient, - options: pa.flight.FlightCallOptions, api: str, api_parameters: Dict[str, Any], entity_df: Union[pd.DataFrame, str] = None, @@ -56,7 +57,6 @@ def __init__( ): # Initialize the client connection self.client = client - self.options = options self.api = api self.api_parameters = api_parameters self.entity_df = entity_df @@ -77,7 +77,6 @@ def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table: self.entity_df, self.table, self.client, - self.options, ) @property @@ -118,7 +117,6 @@ def persist( api=RemoteRetrievalJob.persist.__name__, api_parameters=api_parameters, client=self.client, - options=self.options, table=self.table, entity_df=self.entity_df, ) @@ -137,9 +135,9 @@ def get_historical_features( ) -> RemoteRetrievalJob: assert isinstance(config.offline_store, RemoteOfflineStoreConfig) - # Initialize the client connection - client = RemoteOfflineStore.init_client(config) - options = create_flight_call_options(config.auth_config) + client = build_arrow_flight_client( + config.offline_store.host, config.offline_store.port, config.auth_config + ) feature_view_names = [fv.name for fv in feature_views] name_aliases = [fv.projection.name_alias for fv in feature_views] @@ -154,7 +152,6 @@ def get_historical_features( return RemoteRetrievalJob( client=client, - options=options, api=OfflineStore.get_historical_features.__name__, api_parameters=api_parameters, entity_df=entity_df, @@ -174,8 +171,9 @@ def pull_all_from_table_or_query( assert isinstance(config.offline_store, RemoteOfflineStoreConfig) # Initialize the client connection - client = RemoteOfflineStore.init_client(config) - options = create_flight_call_options(config.auth_config) + client = build_arrow_flight_client( + config.offline_store.host, config.offline_store.port, config.auth_config + ) api_parameters = { "data_source_name": data_source.name, @@ -188,7 +186,6 @@ def pull_all_from_table_or_query( return RemoteRetrievalJob( client=client, - options=options, api=OfflineStore.pull_all_from_table_or_query.__name__, api_parameters=api_parameters, ) @@ -207,8 +204,9 @@ def pull_latest_from_table_or_query( assert isinstance(config.offline_store, RemoteOfflineStoreConfig) # Initialize the client connection - client = RemoteOfflineStore.init_client(config) - options = create_flight_call_options(config.auth_config) + client = build_arrow_flight_client( + config.offline_store.host, config.offline_store.port, config.auth_config + ) api_parameters = { "data_source_name": data_source.name, @@ -222,7 +220,6 @@ def pull_latest_from_table_or_query( return RemoteRetrievalJob( client=client, - options=options, api=OfflineStore.pull_latest_from_table_or_query.__name__, api_parameters=api_parameters, ) @@ -242,8 +239,9 @@ def write_logged_features( data = pyarrow.parquet.read_table(data, use_threads=False, pre_buffer=False) # Initialize the client connection - client = RemoteOfflineStore.init_client(config) - options = create_flight_call_options(config.auth_config) + client = build_arrow_flight_client( + config.offline_store.host, config.offline_store.port, config.auth_config + ) api_parameters = { "feature_service_name": source._feature_service.name, @@ -253,7 +251,6 @@ def write_logged_features( api=OfflineStore.write_logged_features.__name__, api_parameters=api_parameters, client=client, - options=options, table=data, entity_df=None, ) @@ -268,8 +265,9 @@ def offline_write_batch( assert isinstance(config.offline_store, RemoteOfflineStoreConfig) # Initialize the client connection - client = RemoteOfflineStore.init_client(config) - options = create_flight_call_options(config.auth_config) + client = build_arrow_flight_client( + config.offline_store.host, config.offline_store.port, config.auth_config + ) feature_view_names = [feature_view.name] name_aliases = [feature_view.projection.name_alias] @@ -284,18 +282,10 @@ def offline_write_batch( api=OfflineStore.offline_write_batch.__name__, api_parameters=api_parameters, client=client, - options=options, table=table, entity_df=None, ) - @staticmethod - def init_client(config): - location = f"grpc://{config.offline_store.host}:{config.offline_store.port}" - client = fl.connect(location=location) - logger.info(f"Connecting FlightClient at {location}") - return client - def _create_retrieval_metadata(feature_refs: List[str], entity_df: pd.DataFrame): entity_schema = _get_entity_schema( @@ -349,27 +339,24 @@ def _send_retrieve_remote( entity_df: Union[pd.DataFrame, str], table: pa.Table, client: fl.FlightClient, - options: pa.flight.FlightCallOptions, ): command_descriptor = _call_put( api, api_parameters, client, - options, entity_df, table, ) - return _call_get(client, options, command_descriptor) + return _call_get(client, command_descriptor) def _call_get( client: fl.FlightClient, - options: pa.flight.FlightCallOptions, command_descriptor: fl.FlightDescriptor, ): - flight = client.get_flight_info(command_descriptor, options) + flight = client.get_flight_info(command_descriptor) ticket = flight.endpoints[0].ticket - reader = client.do_get(ticket, options) + reader = client.do_get(ticket) return reader.read_all() @@ -377,7 +364,6 @@ def _call_put( api: str, api_parameters: Dict[str, Any], client: fl.FlightClient, - options: pa.flight.FlightCallOptions, entity_df: Union[pd.DataFrame, str], table: pa.Table, ): @@ -397,7 +383,7 @@ def _call_put( ) ) - _put_parameters(command_descriptor, entity_df, table, client, options) + _put_parameters(command_descriptor, entity_df, table, client) return command_descriptor @@ -406,7 +392,6 @@ def _put_parameters( entity_df: Union[pd.DataFrame, str], table: pa.Table, client: fl.FlightClient, - options: pa.flight.FlightCallOptions, ): updatedTable: pa.Table @@ -417,7 +402,7 @@ def _put_parameters( else: updatedTable = _create_empty_table() - writer, _ = client.do_put(command_descriptor, updatedTable.schema, options) + writer, _ = client.do_put(command_descriptor, updatedTable.schema) writer.write_table(updatedTable) writer.close() diff --git a/sdk/python/feast/infra/registry/remote.py b/sdk/python/feast/infra/registry/remote.py index 9a4a8eadb8..b8665d0cf6 100644 --- a/sdk/python/feast/infra/registry/remote.py +++ b/sdk/python/feast/infra/registry/remote.py @@ -15,6 +15,7 @@ from feast.infra.infra_object import Infra from feast.infra.registry.base_registry import BaseRegistry from feast.on_demand_feature_view import OnDemandFeatureView +from feast.permissions.auth.auth_type import AuthType from feast.permissions.auth_model import ( AuthConfig, NoAuthConfig, @@ -48,13 +49,12 @@ def __init__( repo_path: Optional[Path], auth_config: AuthConfig = NoAuthConfig(), ): - auth_header_interceptor = GrpcClientAuthHeaderInterceptor(auth_config) self.auth_config = auth_config channel = grpc.insecure_channel(registry_config.path) - self.intercepted_channel = grpc.intercept_channel( - channel, auth_header_interceptor - ) - self.stub = RegistryServer_pb2_grpc.RegistryServerStub(self.intercepted_channel) + if self.auth_config.type != AuthType.NONE.value: + auth_header_interceptor = GrpcClientAuthHeaderInterceptor(auth_config) + channel = grpc.intercept_channel(channel, auth_header_interceptor) + self.stub = RegistryServer_pb2_grpc.RegistryServerStub(channel) def apply_entity(self, entity: Entity, project: str, commit: bool = True): request = RegistryServer_pb2.ApplyEntityRequest( diff --git a/sdk/python/feast/permissions/client/arrow_flight_auth_interceptor.py b/sdk/python/feast/permissions/client/arrow_flight_auth_interceptor.py new file mode 100644 index 0000000000..724c7df5ca --- /dev/null +++ b/sdk/python/feast/permissions/client/arrow_flight_auth_interceptor.py @@ -0,0 +1,38 @@ +import pyarrow.flight as fl + +from feast.permissions.auth.auth_type import AuthType +from feast.permissions.auth_model import AuthConfig +from feast.permissions.client.auth_client_manager_factory import get_auth_token + + +class FlightBearerTokenInterceptor(fl.ClientMiddleware): + def __init__(self, auth_config: AuthConfig): + super().__init__() + self.auth_config = auth_config + + def call_completed(self, exception): + pass + + def received_headers(self, headers): + pass + + def sending_headers(self): + access_token = get_auth_token(self.auth_config) + return {b"authorization": b"Bearer " + access_token.encode("utf-8")} + + +class FlightAuthInterceptorFactory(fl.ClientMiddlewareFactory): + def __init__(self, auth_config: AuthConfig): + super().__init__() + self.auth_config = auth_config + + def start_call(self, info): + return FlightBearerTokenInterceptor(self.auth_config) + + +def build_arrow_flight_client(host: str, port, auth_config: AuthConfig): + if auth_config.type != AuthType.NONE.value: + middleware_factory = FlightAuthInterceptorFactory(auth_config) + return fl.FlightClient(f"grpc://{host}:{port}", middleware=[middleware_factory]) + else: + return fl.FlightClient(f"grpc://{host}:{port}") diff --git a/sdk/python/feast/permissions/client/auth_client_manager.py b/sdk/python/feast/permissions/client/auth_client_manager.py index ef8a9715ec..82f9b7433e 100644 --- a/sdk/python/feast/permissions/client/auth_client_manager.py +++ b/sdk/python/feast/permissions/client/auth_client_manager.py @@ -1,38 +1,8 @@ from abc import ABC, abstractmethod -from feast.permissions.auth.auth_type import AuthType -from feast.permissions.auth_model import ( - AuthConfig, - KubernetesAuthConfig, - OidcAuthConfig, -) - class AuthenticationClientManager(ABC): @abstractmethod def get_token(self) -> str: """Retrieves the token based on the authentication type configuration""" pass - - -def get_auth_client_manager(auth_config: AuthConfig) -> AuthenticationClientManager: - if auth_config.type == AuthType.OIDC.value: - assert isinstance(auth_config, OidcAuthConfig) - - from feast.permissions.client.oidc_authentication_client_manager import ( - OidcAuthClientManager, - ) - - return OidcAuthClientManager(auth_config) - elif auth_config.type == AuthType.KUBERNETES.value: - assert isinstance(auth_config, KubernetesAuthConfig) - - from feast.permissions.client.kubernetes_auth_client_manager import ( - KubernetesAuthClientManager, - ) - - return KubernetesAuthClientManager(auth_config) - else: - raise RuntimeError( - f"No Auth client manager implemented for the auth type:${auth_config.type}" - ) diff --git a/sdk/python/feast/permissions/client/auth_client_manager_factory.py b/sdk/python/feast/permissions/client/auth_client_manager_factory.py new file mode 100644 index 0000000000..4e49802047 --- /dev/null +++ b/sdk/python/feast/permissions/client/auth_client_manager_factory.py @@ -0,0 +1,30 @@ +from feast.permissions.auth.auth_type import AuthType +from feast.permissions.auth_model import ( + AuthConfig, + KubernetesAuthConfig, + OidcAuthConfig, +) +from feast.permissions.client.auth_client_manager import AuthenticationClientManager +from feast.permissions.client.kubernetes_auth_client_manager import ( + KubernetesAuthClientManager, +) +from feast.permissions.client.oidc_authentication_client_manager import ( + OidcAuthClientManager, +) + + +def get_auth_client_manager(auth_config: AuthConfig) -> AuthenticationClientManager: + if auth_config.type == AuthType.OIDC.value: + assert isinstance(auth_config, OidcAuthConfig) + return OidcAuthClientManager(auth_config) + elif auth_config.type == AuthType.KUBERNETES.value: + assert isinstance(auth_config, KubernetesAuthConfig) + return KubernetesAuthClientManager(auth_config) + else: + raise RuntimeError( + f"No Auth client manager implemented for the auth type:${auth_config.type}" + ) + + +def get_auth_token(auth_config: AuthConfig) -> str: + return get_auth_client_manager(auth_config).get_token() diff --git a/sdk/python/feast/permissions/client/grpc_client_auth_interceptor.py b/sdk/python/feast/permissions/client/grpc_client_auth_interceptor.py index f10b8d911a..98cc445c7b 100644 --- a/sdk/python/feast/permissions/client/grpc_client_auth_interceptor.py +++ b/sdk/python/feast/permissions/client/grpc_client_auth_interceptor.py @@ -2,9 +2,8 @@ import grpc -from feast.permissions.auth.auth_type import AuthType from feast.permissions.auth_model import AuthConfig -from feast.permissions.client.auth_client_manager import get_auth_client_manager +from feast.permissions.client.auth_client_manager_factory import get_auth_token logger = logging.getLogger(__name__) @@ -43,16 +42,11 @@ def intercept_stream_stream( return continuation(client_call_details, request_iterator) def _append_auth_header_metadata(self, client_call_details): - if self._auth_type.type is not AuthType.NONE.value: - logger.info( - f"Intercepted the grpc api method {client_call_details.method} call to inject Authorization header " - f"token. " - ) - metadata = client_call_details.metadata or [] - auth_client_manager = get_auth_client_manager(self._auth_type) - access_token = auth_client_manager.get_token() - metadata.append( - (b"authorization", b"Bearer " + access_token.encode("utf-8")) - ) - client_call_details = client_call_details._replace(metadata=metadata) + logger.debug( + "Intercepted the grpc api method call to inject Authorization header " + ) + metadata = client_call_details.metadata or [] + access_token = get_auth_token(self._auth_type) + metadata.append((b"authorization", b"Bearer " + access_token.encode("utf-8"))) + client_call_details = client_call_details._replace(metadata=metadata) return client_call_details diff --git a/sdk/python/feast/permissions/client/http_auth_requests_wrapper.py b/sdk/python/feast/permissions/client/http_auth_requests_wrapper.py index eba9af5ead..3232e25025 100644 --- a/sdk/python/feast/permissions/client/http_auth_requests_wrapper.py +++ b/sdk/python/feast/permissions/client/http_auth_requests_wrapper.py @@ -5,7 +5,7 @@ from feast.permissions.auth_model import ( AuthConfig, ) -from feast.permissions.client.auth_client_manager import get_auth_client_manager +from feast.permissions.client.auth_client_manager_factory import get_auth_token class AuthenticatedRequestsSession(Session): @@ -18,6 +18,5 @@ def get_http_auth_requests_session(auth_config: AuthConfig) -> Session: if auth_config.type == AuthType.NONE.value: request_session = requests.session() else: - auth_client_manager = get_auth_client_manager(auth_config) - request_session = AuthenticatedRequestsSession(auth_client_manager.get_token()) + request_session = AuthenticatedRequestsSession(get_auth_token(auth_config)) return request_session diff --git a/sdk/python/feast/permissions/client/utils.py b/sdk/python/feast/permissions/client/utils.py deleted file mode 100644 index 02c69c9d6e..0000000000 --- a/sdk/python/feast/permissions/client/utils.py +++ /dev/null @@ -1,21 +0,0 @@ -import pyarrow.flight as fl - -from feast.permissions.auth.auth_type import AuthType -from feast.permissions.auth_model import AuthConfig -from feast.permissions.client.auth_client_manager import get_auth_client_manager - - -def create_auth_header( - auth_config: AuthConfig, -) -> list[tuple[bytes, bytes]]: - auth_client_manager = get_auth_client_manager(auth_config) - token = auth_client_manager.get_token() - - return [(b"authorization", b"Bearer " + token.encode("utf-8"))] - - -def create_flight_call_options(auth_config: AuthConfig) -> fl.FlightCallOptions: - if auth_config.type != AuthType.NONE.value: - headers = create_auth_header(auth_config) - return fl.FlightCallOptions(headers=headers) - return fl.FlightCallOptions() diff --git a/sdk/python/tests/unit/infra/offline_stores/test_offline_store.py b/sdk/python/tests/unit/infra/offline_stores/test_offline_store.py index dea5304ff4..6d5eeb90c7 100644 --- a/sdk/python/tests/unit/infra/offline_stores/test_offline_store.py +++ b/sdk/python/tests/unit/infra/offline_stores/test_offline_store.py @@ -3,7 +3,6 @@ import pandas as pd import pyarrow -import pyarrow.flight as fl import pytest from feast.infra.offline_stores.contrib.athena_offline_store.athena import ( @@ -216,7 +215,6 @@ def retrieval_job(request, environment): return RemoteRetrievalJob( client=MagicMock(), - options=fl.FlightCallOptions(), api_parameters={ "str": "str", },