Skip to content

Commit

Permalink
fix: Pass native input values to get_online_features from feature s…
Browse files Browse the repository at this point in the history
…erver (feast-dev#4117)

* fix: Pass native input values to get_online_features from feature server

Signed-off-by: tokoko <[email protected]>

* remove unnecessary type ignore hint

Signed-off-by: tokoko <[email protected]>

---------

Signed-off-by: tokoko <[email protected]>
  • Loading branch information
tokoko authored May 2, 2024
1 parent b8087f7 commit 60756cb
Showing 1 changed file with 8 additions and 18 deletions.
26 changes: 8 additions & 18 deletions sdk/python/feast/feature_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@
from fastapi import FastAPI, HTTPException, Request, Response, status
from fastapi.logger import logger
from fastapi.params import Depends
from google.protobuf.json_format import MessageToDict, Parse
from google.protobuf.json_format import MessageToDict
from pydantic import BaseModel

import feast
from feast import proto_json, utils
from feast.constants import DEFAULT_FEATURE_SERVER_REGISTRY_TTL
from feast.data_source import PushMode
from feast.errors import PushSourceNotFoundException
from feast.protos.feast.serving.ServingService_pb2 import GetOnlineFeaturesRequest


# TODO: deprecate this in favor of push features
Expand Down Expand Up @@ -83,34 +82,25 @@ def shutdown_event():
@app.post("/get-online-features")
def get_online_features(body=Depends(get_body)):
try:
# Validate and parse the request data into GetOnlineFeaturesRequest Protobuf object
request_proto = GetOnlineFeaturesRequest()
Parse(body, request_proto)

body = json.loads(body)
# Initialize parameters for FeatureStore.get_online_features(...) call
if request_proto.HasField("feature_service"):
if "feature_service" in body:
features = store.get_feature_service(
request_proto.feature_service, allow_cache=True
body["feature_service"], allow_cache=True
)
else:
features = list(request_proto.features.val)

full_feature_names = request_proto.full_feature_names
features = body["features"]

batch_sizes = [len(v.val) for v in request_proto.entities.values()]
num_entities = batch_sizes[0]
if any(batch_size != num_entities for batch_size in batch_sizes):
raise HTTPException(status_code=500, detail="Uneven number of columns")
full_feature_names = body.get("full_feature_names", False)

response_proto = store._get_online_features(
features=features,
entity_values=request_proto.entities,
entity_values=body["entities"],
full_feature_names=full_feature_names,
native_entity_values=False,
).proto

# Convert the Protobuf object to JSON and return it
return MessageToDict( # type: ignore
return MessageToDict(
response_proto, preserving_proto_field_name=True, float_precision=18
)
except Exception as e:
Expand Down

0 comments on commit 60756cb

Please sign in to comment.