Skip to content

Commit

Permalink
* Adding TLS support for offline server.
Browse files Browse the repository at this point in the history
* Added test cases for the TLS offline server by creating RemoteOfflineTlsStoreDataSourceCreator

Signed-off-by: lrangine <[email protected]>
  • Loading branch information
lokeshrangineni committed Nov 6, 2024
1 parent 4a89252 commit b090e86
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 20 deletions.
26 changes: 25 additions & 1 deletion sdk/python/feast/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,16 +1114,40 @@ def serve_registry_command(
default=DEFAULT_OFFLINE_SERVER_PORT,
help="Specify a port for the server",
)
@click.option(
"--key",
"-k",
"tls_key_path",
type=click.STRING,
default="",
show_default=False,
help="path to TLS certificate private key. You need to pass --cert as well to start server in TLS mode",
)
@click.option(
"--cert",
"-c",
"tls_cert_path",
type=click.STRING,
default="",
show_default=False,
help="path to TLS certificate public key. You need to pass --key as well to start server in TLS mode",
)
@click.pass_context
def serve_offline_command(
ctx: click.Context,
host: str,
port: int,
tls_key_path: str,
tls_cert_path: str,
):
"""Start a remote server locally on a given host, port."""
if (tls_key_path and not tls_cert_path) or (not tls_key_path and tls_cert_path):
raise click.BadParameter(
"Please pass --cert and --key args to start the offline server in TLS mode."
)
store = create_feature_store(ctx)

store.serve_offline(host, port)
store.serve_offline(host, port, tls_key_path, tls_cert_path)


@cli.command("validate")
Expand Down
10 changes: 8 additions & 2 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1963,11 +1963,17 @@ def serve_registry(
self, port=port, tls_key_path=tls_key_path, tls_cert_path=tls_cert_path
)

def serve_offline(self, host: str, port: int) -> None:
def serve_offline(
self,
host: str,
port: int,
tls_key_path: str = "",
tls_cert_path: str = "",
) -> None:
"""Start offline server locally on a given port."""
from feast import offline_server

offline_server.start_server(self, host, port)
offline_server.start_server(self, host, port, tls_key_path, tls_cert_path)

def serve_transformations(self, port: int) -> None:
"""Start the feature transformation server locally on a given port."""
Expand Down
27 changes: 23 additions & 4 deletions sdk/python/feast/infra/offline_stores/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,38 @@ def list_actions(self, options: FlightCallOptions = None):
return super().list_actions(options)


def build_arrow_flight_client(host: str, port, auth_config: AuthConfig):
def build_arrow_flight_client(scheme: str, host: str, port, auth_config: AuthConfig):
arrow_scheme = "grpc+tcp"
if scheme == "https":
logger.info(
"Scheme is https so going to connect offline server in SSL(TLS) mode."
)
arrow_scheme = "grpc+tls"

if auth_config.type != AuthType.NONE.value:
middlewares = [FlightAuthInterceptorFactory(auth_config)]
return FeastFlightClient(f"grpc://{host}:{port}", middleware=middlewares)
return FeastFlightClient(
f"{arrow_scheme}://{host}:{port}", middleware=middlewares
)

return FeastFlightClient(f"grpc://{host}:{port}")
return FeastFlightClient(f"{arrow_scheme}://{host}:{port}")


class RemoteOfflineStoreConfig(FeastConfigBaseModel):
type: Literal["remote"] = "remote"

scheme: Literal["http", "https"] = "http"

host: StrictStr
""" str: remote offline store server port, e.g. the host URL for offline store of arrow flight server. """

port: Optional[StrictInt] = None
""" str: remote offline store server port."""

cert: StrictStr = ""
""" str: Path to the public certificate when the offline server starts in TLS(SSL) mode. This may be needed if the offline server started with a self-signed certificate, typically this file ends with `*.crt`, `*.cer`, or `*.pem`.
If type is 'remote', then this configuration is needed to connect to remote offline server in TLS mode. """


class RemoteRetrievalJob(RetrievalJob):
def __init__(
Expand Down Expand Up @@ -178,7 +194,10 @@ def get_historical_features(
assert isinstance(config.offline_store, RemoteOfflineStoreConfig)

client = build_arrow_flight_client(
config.offline_store.host, config.offline_store.port, config.auth_config
scheme=config.offline_store.scheme,
host=config.offline_store.host,
port=config.offline_store.port,
auth_config=config.auth_config,
)

feature_view_names = [fv.name for fv in feature_views]
Expand Down
51 changes: 45 additions & 6 deletions sdk/python/feast/offline_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,31 @@


class OfflineServer(fl.FlightServerBase):
def __init__(self, store: FeatureStore, location: str, **kwargs):
def __init__(
self,
store: FeatureStore,
location: str,
host: str = "localhost",
tls_certificates: [] = None,
verify_client=False,
**kwargs,
):
super(OfflineServer, self).__init__(
location,
location=location,
middleware=self.arrow_flight_auth_middleware(
str_to_auth_manager_type(store.config.auth_config.type)
),
tls_certificates=tls_certificates,
verify_client=verify_client,
**kwargs,
)
self._location = location
# A dictionary of configured flights, e.g. API calls received and not yet served
self.flights: Dict[str, Any] = {}
self.store = store
self.offline_store = get_offline_store_from_config(store.config.offline_store)
self.host = host
self.tls_certificates = tls_certificates

def arrow_flight_auth_middleware(
self,
Expand Down Expand Up @@ -81,8 +93,13 @@ def descriptor_to_key(self, descriptor: fl.FlightDescriptor):
)

def _make_flight_info(self, key: Any, descriptor: fl.FlightDescriptor):
endpoints = [fl.FlightEndpoint(repr(key), [self._location])]
# TODO calculate actual schema from the given features
if len(self.tls_certificates) != 0:
location = fl.Location.for_grpc_tls(self.host, self.port)
else:
location = fl.Location.for_grpc_tcp(self.host, self.port)
endpoints = [
fl.FlightEndpoint(repr(key), [location]),
]
schema = pa.schema([])

return fl.FlightInfo(schema, descriptor, endpoints, -1, -1)
Expand Down Expand Up @@ -549,11 +566,33 @@ def start_server(
store: FeatureStore,
host: str,
port: int,
tls_key_path: str = "",
tls_cert_path: str = "",
):
_init_auth_manager(store)

location = "grpc+tcp://{}:{}".format(host, port)
server = OfflineServer(store, location)
tls_certificates = []
scheme = "grpc+tcp"
if tls_key_path and tls_cert_path:
logger.info(
"Found SSL certificates in the args so going to start offline server in TLS(SSL) mode."
)
scheme = "grpc+tls"
with open(tls_cert_path, "rb") as cert_file:
tls_cert_chain = cert_file.read()
with open(tls_key_path, "rb") as key_file:
tls_private_key = key_file.read()
tls_certificates.append((tls_cert_chain, tls_private_key))

location = "{}://{}:{}".format(scheme, host, port)
server = OfflineServer(
store,
location=location,
host=host,
port=port,
tls_certificates=tls_certificates,
verify_client=True,
)
try:
logger.info(f"Offline store server serving at: {location}")
server.serve()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
DuckDBDeltaDataSourceCreator,
FileDataSourceCreator,
RemoteOfflineOidcAuthStoreDataSourceCreator,
RemoteOfflineStoreDataSourceCreator,
RemoteOfflineStoreDataSourceCreator, RemoteOfflineTlsStoreDataSourceCreator,
)
from tests.integration.feature_repos.universal.data_sources.redshift import (
RedshiftDataSourceCreator,
Expand Down Expand Up @@ -131,6 +131,7 @@
("local", DuckDBDeltaDataSourceCreator),
("local", RemoteOfflineStoreDataSourceCreator),
("local", RemoteOfflineOidcAuthStoreDataSourceCreator),
("local", RemoteOfflineTlsStoreDataSourceCreator),
]

if os.getenv("FEAST_IS_LOCAL_TEST", "False") == "True":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
DataSourceCreator,
)
from tests.utils.auth_permissions_util import include_auth_config
from tests.utils.generate_self_signed_certifcate_util import generate_self_signed_cert
from tests.utils.http_server import check_port_open, free_port # noqa: E402

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -410,11 +411,67 @@ def setup(self, registry: RegistryConfig):
)
return "grpc+tcp://{}:{}".format(host, self.server_port)

class RemoteOfflineTlsStoreDataSourceCreator(FileDataSourceCreator):
def __init__(self, project_name: str, *args, **kwargs):
super().__init__(project_name)
self.server_port: int = 0
self.proc: Optional[Popen[bytes]] = None

def setup(self, registry: RegistryConfig):
parent_offline_config = super().create_offline_store_config()
config = RepoConfig(
project=self.project_name,
provider="local",
offline_store=parent_offline_config,
registry=registry.path,
entity_key_serialization_version=2,
)

certificates_path = tempfile.mkdtemp()
tls_key_path = os.path.join(certificates_path, "key.pem")
self.tls_cert_path = os.path.join(certificates_path, "cert.pem")
generate_self_signed_cert(cert_path=self.tls_cert_path, key_path=tls_key_path)


repo_path = Path(tempfile.mkdtemp())
with open(repo_path / "feature_store.yaml", "w") as outfile:
yaml.dump(config.model_dump(by_alias=True), outfile)
repo_path = repo_path.resolve()

self.server_port = free_port()
host = "0.0.0.0"
cmd = [
"feast",
"-c" + str(repo_path),
"serve_offline",
"--host",
host,
"--port",
str(self.server_port),
"--key",
str(tls_key_path),
"--cert",
str(self.tls_cert_path)
]
self.proc = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL
)

_time_out_sec: int = 60
# Wait for server to start
wait_retry_backoff(
lambda: (None, check_port_open(host, self.server_port)),
timeout_secs=_time_out_sec,
timeout_msg=f"Unable to start the feast remote offline server in {_time_out_sec} seconds at port={self.server_port}",
)
return "grpc+tls://{}:{}".format(host, self.server_port)


def create_offline_store_config(self) -> FeastConfigBaseModel:
self.remote_offline_store_config = RemoteOfflineStoreConfig(
type="remote", host="0.0.0.0", port=self.server_port
remote_offline_store_config = RemoteOfflineStoreConfig(
type="remote", host="0.0.0.0", port=self.server_port, scheme="https", cert=self.tls_cert_path
)
return self.remote_offline_store_config
return remote_offline_store_config

def teardown(self):
super().teardown()
Expand Down Expand Up @@ -499,10 +556,10 @@ def setup(self, registry: RegistryConfig):
return "grpc+tcp://{}:{}".format(host, self.server_port)

def create_offline_store_config(self) -> FeastConfigBaseModel:
self.remote_offline_store_config = RemoteOfflineStoreConfig(
remote_offline_store_config = RemoteOfflineStoreConfig(
type="remote", host="0.0.0.0", port=self.server_port
)
return self.remote_offline_store_config
return remote_offline_store_config

def get_keycloak_url(self):
return self.keycloak_url
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from tests.integration.feature_repos.universal.data_sources.file import (
RemoteOfflineOidcAuthStoreDataSourceCreator,
RemoteOfflineStoreDataSourceCreator,
RemoteOfflineStoreDataSourceCreator, RemoteOfflineTlsStoreDataSourceCreator,
)
from tests.integration.feature_repos.universal.data_sources.snowflake import (
SnowflakeDataSourceCreator,
Expand Down Expand Up @@ -166,6 +166,7 @@ def test_historical_features_main(
environment.data_source_creator,
(
RemoteOfflineStoreDataSourceCreator,
RemoteOfflineTlsStoreDataSourceCreator,
RemoteOfflineOidcAuthStoreDataSourceCreator,
),
):
Expand Down

0 comments on commit b090e86

Please sign in to comment.