Skip to content

Commit

Permalink
added RemoteOfflineStoreDataSourceCreator,
Browse files Browse the repository at this point in the history
use feature_view_names to transfer feature views and remove dummies
  • Loading branch information
dmartinol committed May 27, 2024
1 parent d901ac2 commit a1f660e
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 49 deletions.
39 changes: 27 additions & 12 deletions sdk/python/feast/infra/offline_stores/remote.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import logging
import uuid
from datetime import datetime
from pathlib import Path
Expand All @@ -20,7 +21,8 @@
)
from feast.infra.registry.base_registry import BaseRegistry
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.usage import log_exceptions_and_usage

logger = logging.getLogger(__name__)


class RemoteOfflineStoreConfig(FeastConfigBaseModel):
Expand All @@ -36,14 +38,23 @@ class RemoteRetrievalJob(RetrievalJob):
def __init__(
self,
client: fl.FlightClient,
feature_view_names: List[str],
feature_refs: List[str],
entity_df: Union[pd.DataFrame, str],
# TODO add missing parameters from the OfflineStore API
project: str,
full_feature_names: bool = False,
):
# Initialize the client connection
self.client = client
self.feature_view_names = feature_view_names
self.feature_refs = feature_refs
self.entity_df = entity_df
self.project = project
self._full_feature_names = full_feature_names

@property
def full_feature_names(self) -> bool:
return self._full_feature_names

# TODO add one specialized implementation for each OfflineStore API
# This can result in a dictionary of functions indexed by api (e.g., "get_historical_features")
Expand Down Expand Up @@ -71,7 +82,10 @@ def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table:
command = {
"command_id": command_id,
"api": "get_historical_features",
"features": self.feature_refs,
"feature_view_names": self.feature_view_names,
"feature_refs": self.feature_refs,
"project": self.project,
"full_feature_names": self._full_feature_names,
}
command_descriptor = fl.FlightDescriptor.for_command(
json.dumps(
Expand All @@ -93,7 +107,6 @@ def on_demand_feature_views(self) -> List[OnDemandFeatureView]:

class RemoteOfflineStore(OfflineStore):
@staticmethod
@log_exceptions_and_usage(offline_store="remote")
def get_historical_features(
config: RepoConfig,
feature_views: List[FeatureView],
Expand All @@ -108,16 +121,21 @@ def get_historical_features(
# TODO: extend RemoteRetrievalJob API with all method parameters

# Initialize the client connection
client = fl.connect(
f"grpc://{config.offline_store.host}:{config.offline_store.port}"
)
location = f"grpc://{config.offline_store.host}:{config.offline_store.port}"
client = fl.connect(location=location)
logger.info(f"Connecting FlightClient at {location}")

feature_view_names = [fv.name for fv in feature_views]
return RemoteRetrievalJob(
client=client, feature_refs=feature_refs, entity_df=entity_df
client=client,
feature_view_names=feature_view_names,
feature_refs=feature_refs,
entity_df=entity_df,
project=project,
full_feature_names=full_feature_names,
)

@staticmethod
@log_exceptions_and_usage(offline_store="remote")
def pull_all_from_table_or_query(
config: RepoConfig,
data_source: DataSource,
Expand All @@ -131,7 +149,6 @@ def pull_all_from_table_or_query(
raise NotImplementedError

@staticmethod
@log_exceptions_and_usage(offline_store="remote")
def pull_latest_from_table_or_query(
config: RepoConfig,
data_source: DataSource,
Expand All @@ -146,7 +163,6 @@ def pull_latest_from_table_or_query(
raise NotImplementedError

@staticmethod
@log_exceptions_and_usage(offline_store="remote")
def write_logged_features(
config: RepoConfig,
data: Union[pyarrow.Table, Path],
Expand All @@ -158,7 +174,6 @@ def write_logged_features(
raise NotImplementedError

@staticmethod
@log_exceptions_and_usage(offline_store="remote")
def offline_write_batch(
config: RepoConfig,
feature_view: FeatureView,
Expand Down
109 changes: 90 additions & 19 deletions sdk/python/feast/offline_server.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import ast
import json
import logging
import traceback
from typing import Any, Dict
from typing import Any, Dict, List

import pyarrow as pa
import pyarrow.flight as fl

from feast import FeatureStore
from feast import FeatureStore, FeatureView
from feast.feature_view import DUMMY_ENTITY_NAME

logger = logging.getLogger(__name__)


class OfflineServer(fl.FlightServerBase):
Expand Down Expand Up @@ -56,44 +60,101 @@ def do_put(self, context, descriptor, reader, writer):
command = json.loads(key[1])
if "api" in command:
data = reader.read_all()
logger.debug(f"do_put: command is{command}, data is {data}")
self.flights[key] = data
else:
print(f"No 'api' field in command: {command}")
logger.warning(f"No 'api' field in command: {command}")

def get_feature_view_by_name(self, fv_name: str, project: str) -> FeatureView:
"""
Retrieves a feature view by name, including all subclasses of FeatureView.
Args:
name: Name of feature view
project: Feast project that this feature view belongs to
Returns:
Returns either the specified feature view, or raises an exception if
none is found
"""
try:
return self.store.registry.get_feature_view(name=fv_name, project=project)
except Exception:
try:
return self.store.registry.get_stream_feature_view(
name=fv_name, project=project
)
except Exception as e:
logger.error(
f"Cannot find any FeatureView by name {fv_name} in project {project}"
)
raise e

def list_feature_views_by_name(
self, feature_view_names: List[str], project: str
) -> List[FeatureView]:
return [
remove_dummies(
self.get_feature_view_by_name(fv_name=fv_name, project=project)
)
for fv_name in feature_view_names
]

# Extracts the API parameters from the flights dictionary, delegates the execution to the FeatureStore instance
# and returns the stream of data
def do_get(self, context, ticket):
key = ast.literal_eval(ticket.ticket.decode())
if key not in self.flights:
print(f"Unknown key {key}")
logger.error(f"Unknown key {key}")
return None

command = json.loads(key[1])
api = command["api"]
# print(f"get command is {command}")
# print(f"requested api is {api}")
logger.debug(f"get command is {command}")
logger.debug(f"requested api is {api}")
if api == "get_historical_features":
# Extract parameters from the internal flights dictionary
entity_df_value = self.flights[key]
entity_df = pa.Table.to_pandas(entity_df_value)
# print(f"entity_df is {entity_df}")

features = command["features"]
# print(f"features is {features}")
logger.debug(f"do_get: entity_df is {entity_df}")

feature_view_names = command["feature_view_names"]
feature_refs = command["feature_refs"]
logger.debug(f"do_get: feature_refs is {feature_refs}")
project = command["project"]
logger.debug(f"do_get: project is {project}")
full_feature_names = command["full_feature_names"]
feature_views = self.list_feature_views_by_name(
feature_view_names=feature_view_names, project=project
)
logger.debug(f"do_get: feature_views is {feature_views}")

print(
logger.info(
f"get_historical_features for: entity_df from {entity_df.index[0]} to {entity_df.index[len(entity_df)-1]}, "
f"features from {features[0]} to {features[len(features)-1]}"
f"feature_views is {[(fv.name, fv.entities) for fv in feature_views]}"
f"feature_refs is {feature_refs}"
)

# TODO define error handling
try:
training_df = self.store.get_historical_features(
entity_df, features
).to_df()
except Exception:
training_df = (
self.store._get_provider()
.get_historical_features(
config=self.store.config,
feature_views=feature_views,
feature_refs=feature_refs,
entity_df=entity_df,
registry=self.store._registry,
project=project,
full_feature_names=full_feature_names,
)
.to_df()
)
logger.debug(f"Len of training_df is {len(training_df)}")
table = pa.Table.from_pandas(training_df)
except Exception as e:
logger.exception(e)
traceback.print_exc()
table = pa.Table.from_pandas(training_df)
raise e

# Get service is consumed, so we clear the corresponding flight and data
del self.flights[key]
Expand All @@ -112,12 +173,22 @@ def do_drop_dataset(self, dataset):
pass


def remove_dummies(fv: FeatureView) -> FeatureView:
"""
Removes dummmy IDs from FeatureView instances created with FeatureView.from_proto
"""
if DUMMY_ENTITY_NAME in fv.entities:
fv.entities = []
fv.entity_columns = []
return fv


def start_server(
store: FeatureStore,
host: str,
port: int,
):
location = "grpc+tcp://{}:{}".format(host, port)
server = OfflineServer(store, location)
print("Serving on", location)
logger.info(f"Offline store server serving on {location}")
server.serve()
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
DuckDBDeltaDataSourceCreator,
DuckDBDeltaS3DataSourceCreator,
FileDataSourceCreator,
RemoteOfflineStoreDataSourceCreator,
)
from tests.integration.feature_repos.universal.data_sources.redshift import (
RedshiftDataSourceCreator,
Expand Down Expand Up @@ -122,6 +123,7 @@
("local", FileDataSourceCreator),
("local", DuckDBDataSourceCreator),
("local", DuckDBDeltaDataSourceCreator),
("local", RemoteOfflineStoreDataSourceCreator),
]

if os.getenv("FEAST_IS_LOCAL_TEST", "False") == "True":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
import pyarrow as pa
import pyarrow.parquet as pq
from minio import Minio
from multiprocess import Process
from testcontainers.core.generic import DockerContainer
from testcontainers.core.waiting_utils import wait_for_logs
from testcontainers.minio import MinioContainer

from feast import FileSource
from feast import FeatureStore, FileSource, RepoConfig
from feast.data_format import DeltaFormat, ParquetFormat
from feast.data_source import DataSource
from feast.feature_logging import LoggingDestination
Expand All @@ -22,10 +23,14 @@
FileLoggingDestination,
SavedDatasetFileStorage,
)
from feast.repo_config import FeastConfigBaseModel
from feast.infra.offline_stores.remote import RemoteOfflineStoreConfig
from feast.offline_server import start_server
from feast.repo_config import FeastConfigBaseModel, RegistryConfig
from feast.wait import wait_retry_backoff # noqa: E402
from tests.integration.feature_repos.universal.data_source_creator import (
DataSourceCreator,
)
from tests.utils.http_server import check_port_open, free_port # noqa: E402


class FileDataSourceCreator(DataSourceCreator):
Expand Down Expand Up @@ -351,3 +356,56 @@ def create_offline_store_config(self):
staging_location_endpoint_override=self.endpoint_url,
)
return self.duckdb_offline_store_config


class RemoteOfflineStoreDataSourceCreator(FileDataSourceCreator):
def __init__(self, project_name: str, *args, **kwargs):
super().__init__(project_name)
self.server_port: int = 0
self.proc: Process = None

def setup(self, registry: RegistryConfig):
parent_offline_config = super().create_offline_store_config()

fs = FeatureStore(
config=RepoConfig(
project=self.project_name,
provider="local",
offline_store=parent_offline_config,
registry=registry.path,
entity_key_serialization_version=2,
)
)
self.server_port = free_port()
host = "0.0.0.0"
self.proc = Process(
target=start_server,
args=(fs, host, self.server_port),
)
self.proc.start()
# Wait for server to start
wait_retry_backoff(
lambda: (None, check_port_open(host, self.server_port)),
timeout_secs=10,
)
return "grpc+tcp://{}:{}".format(host, self.server_port)

def create_offline_store_config(self) -> FeastConfigBaseModel:
self.remote_offline_store_config = RemoteOfflineStoreConfig(
type="remote", host="0.0.0.0", port=self.server_port
)
return self.remote_offline_store_config

def teardown(self):
super().teardown()
if self.proc is None and self.proc.is_alive():
self.proc.kill()

# wait server to free the port
wait_retry_backoff(
lambda: (
None,
not check_port_open("localhost", self.server_port),
),
timeout_secs=30,
)
Loading

0 comments on commit a1f660e

Please sign in to comment.