diff --git a/go/embedded/online_features.go b/go/embedded/online_features.go index 7cd1e4ed81..b5ebcf8a96 100644 --- a/go/embedded/online_features.go +++ b/go/embedded/online_features.go @@ -31,6 +31,7 @@ import ( type OnlineFeatureService struct { fs *feast.FeatureStore grpcStopCh chan os.Signal + httpStopCh chan os.Signal } type OnlineFeatureServiceConfig struct { @@ -63,11 +64,13 @@ func NewOnlineFeatureService(conf *OnlineFeatureServiceConfig, transformationCal log.Fatalln(err) } - // Notify this channel when receiving interrupt or termination signals from OS - c := make(chan os.Signal, 1) - signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) + // Notify these channels when receiving interrupt or termination signals from OS + httpStopCh := make(chan os.Signal, 1) + grpcStopCh := make(chan os.Signal, 1) + signal.Notify(httpStopCh, syscall.SIGINT, syscall.SIGTERM) + signal.Notify(grpcStopCh, syscall.SIGINT, syscall.SIGTERM) - return &OnlineFeatureService{fs: fs, grpcStopCh: c} + return &OnlineFeatureService{fs: fs, httpStopCh: httpStopCh, grpcStopCh: grpcStopCh} } func (s *OnlineFeatureService) GetEntityTypesMap(featureRefs []string) (map[string]int32, error) { @@ -239,15 +242,12 @@ func (s *OnlineFeatureService) StartGprcServerWithLoggingDefaultOpts(host string return s.StartGprcServerWithLogging(host, port, writeLoggedFeaturesCallback, defaultOpts) } -// StartGprcServerWithLogging starts gRPC server with enabled feature logging -// Caller of this function must provide Python callback to flush buffered logs as well as logging configuration (loggingOpts) -func (s *OnlineFeatureService) StartGprcServerWithLogging(host string, port int, writeLoggedFeaturesCallback logging.OfflineStoreWriteCallback, loggingOpts LoggingOptions) error { +func (s *OnlineFeatureService) constructLoggingService(writeLoggedFeaturesCallback logging.OfflineStoreWriteCallback, loggingOpts LoggingOptions) (*logging.LoggingService, error) { var loggingService *logging.LoggingService = nil - var err error if writeLoggedFeaturesCallback != nil { sink, err := logging.NewOfflineStoreSink(writeLoggedFeaturesCallback) if err != nil { - return err + return nil, err } loggingService, err = logging.NewLoggingService(s.fs, sink, logging.LoggingOptions{ @@ -257,9 +257,19 @@ func (s *OnlineFeatureService) StartGprcServerWithLogging(host string, port int, FlushInterval: loggingOpts.FlushInterval, }) if err != nil { - return err + return nil, err } } + return loggingService, nil +} + +// StartGprcServerWithLogging starts gRPC server with enabled feature logging +// Caller of this function must provide Python callback to flush buffered logs as well as logging configuration (loggingOpts) +func (s *OnlineFeatureService) StartGprcServerWithLogging(host string, port int, writeLoggedFeaturesCallback logging.OfflineStoreWriteCallback, loggingOpts LoggingOptions) error { + loggingService, err := s.constructLoggingService(writeLoggedFeaturesCallback, loggingOpts) + if err != nil { + return err + } ser := server.NewGrpcServingServiceServer(s.fs, loggingService) log.Printf("Starting a gRPC server on host %s port %d\n", host, port) lis, err := net.Listen("tcp", fmt.Sprintf("%s:%d", host, port)) @@ -288,7 +298,51 @@ func (s *OnlineFeatureService) StartGprcServerWithLogging(host string, port int, return nil } -func (s *OnlineFeatureService) Stop() { +// StartHttpServer starts HTTP server with disabled feature logging and blocks the thread +func (s *OnlineFeatureService) StartHttpServer(host string, port int) error { + return s.StartHttpServerWithLogging(host, port, nil, LoggingOptions{}) +} + +// StartHttpServerWithLoggingDefaultOpts starts HTTP server with enabled feature logging but default configuration for logging +// Caller of this function must provide Python callback to flush buffered logs +func (s *OnlineFeatureService) StartHttpServerWithLoggingDefaultOpts(host string, port int, writeLoggedFeaturesCallback logging.OfflineStoreWriteCallback) error { + defaultOpts := LoggingOptions{ + ChannelCapacity: logging.DefaultOptions.ChannelCapacity, + EmitTimeout: logging.DefaultOptions.EmitTimeout, + WriteInterval: logging.DefaultOptions.WriteInterval, + FlushInterval: logging.DefaultOptions.FlushInterval, + } + return s.StartHttpServerWithLogging(host, port, writeLoggedFeaturesCallback, defaultOpts) +} + +// StartHttpServerWithLogging starts HTTP server with enabled feature logging +// Caller of this function must provide Python callback to flush buffered logs as well as logging configuration (loggingOpts) +func (s *OnlineFeatureService) StartHttpServerWithLogging(host string, port int, writeLoggedFeaturesCallback logging.OfflineStoreWriteCallback, loggingOpts LoggingOptions) error { + loggingService, err := s.constructLoggingService(writeLoggedFeaturesCallback, loggingOpts) + if err != nil { + return err + } + ser := server.NewHttpServer(s.fs, loggingService) + log.Printf("Starting a HTTP server on host %s port %d\n", host, port) + + go func() { + // As soon as these signals are received from OS, try to gracefully stop the gRPC server + <-s.httpStopCh + fmt.Println("Stopping the HTTP server...") + err := ser.Stop() + if err != nil { + fmt.Printf("Error when stopping the HTTP server: %v\n", err) + } + }() + + return ser.Serve(host, port) +} + +func (s *OnlineFeatureService) StopHttpServer() { + s.httpStopCh <- syscall.SIGINT +} + +func (s *OnlineFeatureService) StopGrpcServer() { s.grpcStopCh <- syscall.SIGINT } diff --git a/go/internal/feast/server/grpc_server.go b/go/internal/feast/server/grpc_server.go index 6040880959..c47d185d6c 100644 --- a/go/internal/feast/server/grpc_server.go +++ b/go/internal/feast/server/grpc_server.go @@ -86,7 +86,7 @@ func (s *grpcServingServiceServer) GetOnlineFeatures(ctx context.Context, reques fmt.Printf("Couldn't instantiate logger for feature service %s: %+v", featuresOrService.FeatureService.Name, err) } - err = logger.Log(entityValuesMap, resp.Results[len(request.Entities):], resp.Metadata.FeatureNames.Val[len(request.Entities):], request.RequestContext, requestId) + err = logger.Log(request.Entities, resp.Results[len(request.Entities):], resp.Metadata.FeatureNames.Val[len(request.Entities):], request.RequestContext, requestId) if err != nil { fmt.Printf("LoggerImpl error[%s]: %+v", featuresOrService.FeatureService.Name, err) } diff --git a/go/internal/feast/server/http_server.go b/go/internal/feast/server/http_server.go new file mode 100644 index 0000000000..75cdbe9929 --- /dev/null +++ b/go/internal/feast/server/http_server.go @@ -0,0 +1,270 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "github.com/feast-dev/feast/go/internal/feast" + "github.com/feast-dev/feast/go/internal/feast/model" + "github.com/feast-dev/feast/go/internal/feast/server/logging" + "github.com/feast-dev/feast/go/protos/feast/serving" + prototypes "github.com/feast-dev/feast/go/protos/feast/types" + "github.com/feast-dev/feast/go/types" + "net/http" + "time" +) + +type httpServer struct { + fs *feast.FeatureStore + loggingService *logging.LoggingService + server *http.Server +} + +// Some Feast types aren't supported during JSON conversion +type repeatedValue struct { + stringVal []string + int64Val []int64 + doubleVal []float64 + boolVal []bool + stringListVal [][]string + int64ListVal [][]int64 + doubleListVal [][]float64 + boolListVal [][]bool +} + +func (u *repeatedValue) UnmarshalJSON(data []byte) error { + isString := false + isDouble := false + isInt64 := false + isArray := false + openBraketCounter := 0 + for _, b := range data { + if b == '"' { + isString = true + } + if b == '.' { + isDouble = true + } + if b >= '0' && b <= '9' { + isInt64 = true + } + if b == '[' { + openBraketCounter++ + if openBraketCounter > 1 { + isArray = true + } + } + } + var err error + if !isArray { + if isString { + err = json.Unmarshal(data, &u.stringVal) + } else if isDouble { + err = json.Unmarshal(data, &u.doubleVal) + } else if isInt64 { + err = json.Unmarshal(data, &u.int64Val) + } else { + err = json.Unmarshal(data, &u.boolVal) + } + } else { + if isString { + err = json.Unmarshal(data, &u.stringListVal) + } else if isDouble { + err = json.Unmarshal(data, &u.doubleListVal) + } else if isInt64 { + err = json.Unmarshal(data, &u.int64ListVal) + } else { + err = json.Unmarshal(data, &u.boolListVal) + } + } + return err +} + +func (u *repeatedValue) ToProto() *prototypes.RepeatedValue { + proto := new(prototypes.RepeatedValue) + if u.stringVal != nil { + for _, val := range u.stringVal { + proto.Val = append(proto.Val, &prototypes.Value{Val: &prototypes.Value_StringVal{StringVal: val}}) + } + } + if u.int64Val != nil { + for _, val := range u.int64Val { + proto.Val = append(proto.Val, &prototypes.Value{Val: &prototypes.Value_Int64Val{Int64Val: val}}) + } + } + if u.doubleVal != nil { + for _, val := range u.doubleVal { + proto.Val = append(proto.Val, &prototypes.Value{Val: &prototypes.Value_DoubleVal{DoubleVal: val}}) + } + } + if u.boolVal != nil { + for _, val := range u.boolVal { + proto.Val = append(proto.Val, &prototypes.Value{Val: &prototypes.Value_BoolVal{BoolVal: val}}) + } + } + if u.stringListVal != nil { + for _, val := range u.stringListVal { + proto.Val = append(proto.Val, &prototypes.Value{Val: &prototypes.Value_StringListVal{StringListVal: &prototypes.StringList{Val: val}}}) + } + } + if u.int64ListVal != nil { + for _, val := range u.int64ListVal { + proto.Val = append(proto.Val, &prototypes.Value{Val: &prototypes.Value_Int64ListVal{Int64ListVal: &prototypes.Int64List{Val: val}}}) + } + } + if u.doubleListVal != nil { + for _, val := range u.doubleListVal { + proto.Val = append(proto.Val, &prototypes.Value{Val: &prototypes.Value_DoubleListVal{DoubleListVal: &prototypes.DoubleList{Val: val}}}) + } + } + if u.boolListVal != nil { + for _, val := range u.boolListVal { + proto.Val = append(proto.Val, &prototypes.Value{Val: &prototypes.Value_BoolListVal{BoolListVal: &prototypes.BoolList{Val: val}}}) + } + } + return proto +} + +type getOnlineFeaturesRequest struct { + FeatureService *string `json:"feature_service"` + Features []string `json:"features"` + Entities map[string]repeatedValue `json:"entities"` + FullFeatureNames bool `json:"full_feature_names"` + RequestContext map[string]repeatedValue `json:"request_context"` +} + +func NewHttpServer(fs *feast.FeatureStore, loggingService *logging.LoggingService) *httpServer { + return &httpServer{fs: fs, loggingService: loggingService} +} + +func (s *httpServer) getOnlineFeatures(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + http.NotFound(w, r) + return + } + + decoder := json.NewDecoder(r.Body) + var request getOnlineFeaturesRequest + err := decoder.Decode(&request) + if err != nil { + http.Error(w, fmt.Sprintf("Error decoding JSON request data: %+v", err), http.StatusInternalServerError) + return + } + var featureService *model.FeatureService + if request.FeatureService != nil { + featureService, err = s.fs.GetFeatureService(*request.FeatureService) + if err != nil { + http.Error(w, fmt.Sprintf("Error getting feature service from registry: %+v", err), http.StatusInternalServerError) + return + } + } + entitiesProto := make(map[string]*prototypes.RepeatedValue) + for key, value := range request.Entities { + entitiesProto[key] = value.ToProto() + } + requestContextProto := make(map[string]*prototypes.RepeatedValue) + for key, value := range request.RequestContext { + requestContextProto[key] = value.ToProto() + } + + featureVectors, err := s.fs.GetOnlineFeatures( + r.Context(), + request.Features, + featureService, + entitiesProto, + requestContextProto, + request.FullFeatureNames) + + if err != nil { + http.Error(w, fmt.Sprintf("Error getting feature vector: %+v", err), http.StatusInternalServerError) + return + } + + var featureNames []string + var results []map[string]interface{} + for _, vector := range featureVectors { + featureNames = append(featureNames, vector.Name) + result := make(map[string]interface{}) + var statuses []string + for _, status := range vector.Statuses { + statuses = append(statuses, status.String()) + } + var timestamps []string + for _, timestamp := range vector.Timestamps { + timestamps = append(timestamps, timestamp.AsTime().Format(time.RFC3339)) + } + + result["statuses"] = statuses + result["event_timestamps"] = timestamps + // Note, that vector.Values is an Arrow Array, but this type implements JSON Marshaller. + // So, it's not necessary to pre-process it in any way. + result["values"] = vector.Values + + results = append(results, result) + } + + response := map[string]interface{}{ + "metadata": map[string]interface{}{ + "feature_names": featureNames, + }, + "results": results, + } + + err = json.NewEncoder(w).Encode(response) + + if err != nil { + http.Error(w, fmt.Sprintf("Error encoding response: %+v", err), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + + if featureService != nil && featureService.LoggingConfig != nil && s.loggingService != nil { + logger, err := s.loggingService.GetOrCreateLogger(featureService) + if err != nil { + http.Error(w, fmt.Sprintf("Couldn't instantiate logger for feature service %s: %+v", featureService.Name, err), http.StatusInternalServerError) + return + } + + requestId := GenerateRequestId() + + // Note: we're converting arrow to proto for feature logging. In the future we should + // base feature logging on arrow so that we don't have to do this extra conversion. + var featureVectorProtos []*serving.GetOnlineFeaturesResponse_FeatureVector + for _, vector := range featureVectors[len(request.Entities):] { + values, err := types.ArrowValuesToProtoValues(vector.Values) + if err != nil { + http.Error(w, fmt.Sprintf("Couldn't convert arrow values into protobuf: %+v", err), http.StatusInternalServerError) + return + } + featureVectorProtos = append(featureVectorProtos, &serving.GetOnlineFeaturesResponse_FeatureVector{ + Values: values, + Statuses: vector.Statuses, + EventTimestamps: vector.Timestamps, + }) + } + + err = logger.Log(entitiesProto, featureVectorProtos, featureNames[len(request.Entities):], requestContextProto, requestId) + if err != nil { + http.Error(w, fmt.Sprintf("LoggerImpl error[%s]: %+v", featureService.Name, err), http.StatusInternalServerError) + return + } + } +} + +func (s *httpServer) Serve(host string, port int) error { + s.server = &http.Server{Addr: fmt.Sprintf("%s:%d", host, port), Handler: nil} + http.HandleFunc("/get-online-features", s.getOnlineFeatures) + err := s.server.ListenAndServe() + // Don't return the error if it's caused by graceful shutdown using Stop() + if err == http.ErrServerClosed { + return nil + } + return err +} +func (s *httpServer) Stop() error { + if s.server != nil { + return s.server.Shutdown(context.Background()) + } + return nil +} diff --git a/go/internal/feast/server/http_server_test.go b/go/internal/feast/server/http_server_test.go new file mode 100644 index 0000000000..67ba1c60f9 --- /dev/null +++ b/go/internal/feast/server/http_server_test.go @@ -0,0 +1,40 @@ +package server + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestUnmarshalJSON(t *testing.T) { + u := repeatedValue{} + assert.Nil(t, u.UnmarshalJSON([]byte("[1, 2, 3]"))) + assert.Equal(t, []int64{1, 2, 3}, u.int64Val) + + u = repeatedValue{} + assert.Nil(t, u.UnmarshalJSON([]byte("[1.2, 2.3, 3.4]"))) + assert.Equal(t, []float64{1.2, 2.3, 3.4}, u.doubleVal) + + u = repeatedValue{} + assert.Nil(t, u.UnmarshalJSON([]byte("[\"foo\", \"bar\"]"))) + assert.Equal(t, []string{"foo", "bar"}, u.stringVal) + + u = repeatedValue{} + assert.Nil(t, u.UnmarshalJSON([]byte("[true, false, true]"))) + assert.Equal(t, []bool{true, false, true}, u.boolVal) + + u = repeatedValue{} + assert.Nil(t, u.UnmarshalJSON([]byte("[[1, 2, 3], [4, 5, 6]]"))) + assert.Equal(t, [][]int64{{1, 2, 3}, {4, 5, 6}}, u.int64ListVal) + + u = repeatedValue{} + assert.Nil(t, u.UnmarshalJSON([]byte("[[1.2, 2.3, 3.4], [10.2, 20.3, 30.4]]"))) + assert.Equal(t, [][]float64{{1.2, 2.3, 3.4}, {10.2, 20.3, 30.4}}, u.doubleListVal) + + u = repeatedValue{} + assert.Nil(t, u.UnmarshalJSON([]byte("[[\"foo\", \"bar\"], [\"foo2\", \"bar2\"]]"))) + assert.Equal(t, [][]string{{"foo", "bar"}, {"foo2", "bar2"}}, u.stringListVal) + + u = repeatedValue{} + assert.Nil(t, u.UnmarshalJSON([]byte("[[true, false, true], [false, true, false]]"))) + assert.Equal(t, [][]bool{{true, false, true}, {false, true, false}}, u.boolListVal) +} diff --git a/go/internal/feast/server/logging/logger.go b/go/internal/feast/server/logging/logger.go index d7ed1fbe18..cbf1c3439a 100644 --- a/go/internal/feast/server/logging/logger.go +++ b/go/internal/feast/server/logging/logger.go @@ -42,7 +42,7 @@ type LogSink interface { } type Logger interface { - Log(joinKeyToEntityValues map[string][]*types.Value, featureVectors []*serving.GetOnlineFeaturesResponse_FeatureVector, featureNames []string, requestData map[string]*types.RepeatedValue, requestId string) error + Log(joinKeyToEntityValues map[string]*types.RepeatedValue, featureVectors []*serving.GetOnlineFeaturesResponse_FeatureVector, featureNames []string, requestData map[string]*types.RepeatedValue, requestId string) error } type LoggerImpl struct { @@ -207,7 +207,7 @@ func getFullFeatureName(featureViewName string, featureName string) string { return fmt.Sprintf("%s__%s", featureViewName, featureName) } -func (l *LoggerImpl) Log(joinKeyToEntityValues map[string][]*types.Value, featureVectors []*serving.GetOnlineFeaturesResponse_FeatureVector, featureNames []string, requestData map[string]*types.RepeatedValue, requestId string) error { +func (l *LoggerImpl) Log(joinKeyToEntityValues map[string]*types.RepeatedValue, featureVectors []*serving.GetOnlineFeaturesResponse_FeatureVector, featureNames []string, requestData map[string]*types.RepeatedValue, requestId string) error { if len(featureVectors) == 0 { return nil } @@ -250,7 +250,7 @@ func (l *LoggerImpl) Log(joinKeyToEntityValues map[string][]*types.Value, featur if !ok { return errors.Errorf("Missing join key %s in log data", joinKey) } - entityValues[idx] = rows[rowIdx] + entityValues[idx] = rows.Val[rowIdx] } requestDataValues := make([]*types.Value, len(l.schema.RequestData)) @@ -283,6 +283,6 @@ func (l *LoggerImpl) Log(joinKeyToEntityValues map[string][]*types.Value, featur type DummyLoggerImpl struct{} -func (l *DummyLoggerImpl) Log(joinKeyToEntityValues map[string][]*types.Value, featureVectors []*serving.GetOnlineFeaturesResponse_FeatureVector, featureNames []string, requestData map[string]*types.RepeatedValue, requestId string) error { +func (l *DummyLoggerImpl) Log(joinKeyToEntityValues map[string]*types.RepeatedValue, featureVectors []*serving.GetOnlineFeaturesResponse_FeatureVector, featureNames []string, requestData map[string]*types.RepeatedValue, requestId string) error { return nil } diff --git a/go/internal/feast/server/logging/logger_test.go b/go/internal/feast/server/logging/logger_test.go index 5625b05a76..4ce883c75b 100644 --- a/go/internal/feast/server/logging/logger_test.go +++ b/go/internal/feast/server/logging/logger_test.go @@ -90,7 +90,17 @@ func TestLogAndFlushToFile(t *testing.T) { assert.Nil(t, err) assert.Nil(t, logger.Log( - map[string][]*types.Value{"driver_id": {{Val: &types.Value_Int32Val{Int32Val: 111}}}}, + map[string]*types.RepeatedValue{ + "driver_id": { + Val: []*types.Value{ + { + Val: &types.Value_Int32Val{ + Int32Val: 111, + }, + }, + }, + }, + }, []*serving.GetOnlineFeaturesResponse_FeatureVector{ { Values: []*types.Value{{Val: &types.Value_DoubleVal{DoubleVal: 2.0}}}, diff --git a/sdk/python/feast/cli.py b/sdk/python/feast/cli.py index a4407132e4..b1281d297f 100644 --- a/sdk/python/feast/cli.py +++ b/sdk/python/feast/cli.py @@ -610,17 +610,27 @@ def init_command(project_directory, minimal: bool, template: str): default=6566, help="Specify a port for the server [default: 6566]", ) +@click.option( + "--type", + "-t", + "type_", + type=click.STRING, + default="http", + help="Specify a server type: 'http' or 'grpc' [default: http]", +) @click.option( "--no-access-log", is_flag=True, help="Disable the Uvicorn access log.", ) @click.pass_context -def serve_command(ctx: click.Context, host: str, port: int, no_access_log: bool): +def serve_command( + ctx: click.Context, host: str, port: int, type_: str, no_access_log: bool +): """Start a feature server locally on a given port.""" repo = ctx.obj["CHDIR"] cli_check_repo(repo) store = FeatureStore(repo_path=str(repo)) - store.serve(host, port, no_access_log) + store.serve(host, port, type_, no_access_log) @cli.command("serve_transformations") diff --git a/sdk/python/feast/embedded_go/online_features_service.py b/sdk/python/feast/embedded_go/online_features_service.py index 48e31766cb..8ec4410bde 100644 --- a/sdk/python/feast/embedded_go/online_features_service.py +++ b/sdk/python/feast/embedded_go/online_features_service.py @@ -158,8 +158,30 @@ def start_grpc_server( else: self._service.StartGprcServer(host, port) + def start_http_server( + self, + host: str, + port: int, + enable_logging: bool = True, + logging_options: Optional[LoggingOptions] = None, + ): + if enable_logging: + if logging_options: + self._service.StartHttpServerWithLogging( + host, port, self._logging_callback, logging_options + ) + else: + self._service.StartHttpServerWithLoggingDefaultOpts( + host, port, self._logging_callback + ) + else: + self._service.StartHttpServer(host, port) + def stop_grpc_server(self): - self._service.Stop() + self._service.StopGrpcServer() + + def stop_http_server(self): + self._service.StopHttpServer() def _to_arrow(value, type_hint: Optional[ValueType]) -> pa.Array: diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 73af4741ef..0e19de08e0 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -1979,14 +1979,25 @@ def _get_feature_views_to_use( return views_to_use @log_exceptions_and_usage - def serve(self, host: str, port: int, no_access_log: bool) -> None: + def serve(self, host: str, port: int, type_: str, no_access_log: bool) -> None: """Start the feature consumption server locally on a given port.""" + type_ = type_.lower() if self.config.go_feature_retrieval: # Start go server instead of python if the flag is enabled self._lazy_init_go_server() - # TODO(tsotne) add http/grpc flag in CLI and call appropriate method here depending on that - self._go_server.start_grpc_server(host, port) + if type_ == "http": + self._go_server.start_http_server(host, port) + elif type_ == "grpc": + self._go_server.start_grpc_server(host, port) + else: + raise ValueError( + f"Unsupported server type '{type_}'. Must be one of 'http' or 'grpc'." + ) else: + if type_ != "http": + raise ValueError( + f"Python server only supports 'http'. Got '{type_}' instead." + ) # Start the python server if go server isn't enabled feature_server.start_server(self, host, port, no_access_log) diff --git a/sdk/python/tests/integration/e2e/test_go_feature_server.py b/sdk/python/tests/integration/e2e/test_go_feature_server.py index e469c90c11..3a00c68b2c 100644 --- a/sdk/python/tests/integration/e2e/test_go_feature_server.py +++ b/sdk/python/tests/integration/e2e/test_go_feature_server.py @@ -9,6 +9,7 @@ import pandas as pd import pytest import pytz +import requests from feast import FeatureService, ValueType from feast.embedded_go.lib.embedded import LoggingOptions @@ -61,8 +62,7 @@ def initialized_registry(environment, universal_data_sources): fs.materialize(environment.start_date, environment.end_date) -@pytest.fixture -def grpc_server_port(environment, initialized_registry): +def server_port(environment, server_type: str): if not environment.test_repo_config.go_feature_retrieval: pytest.skip("Only for Go path") @@ -72,9 +72,15 @@ def grpc_server_port(environment, initialized_registry): repo_path=str(fs.repo_path.absolute()), repo_config=fs.config, feature_store=fs, ) port = free_port() + if server_type == "grpc": + target = embedded.start_grpc_server + elif server_type == "http": + target = embedded.start_http_server + else: + raise ValueError("Server Type must be either 'http' or 'grpc'") t = threading.Thread( - target=embedded.start_grpc_server, + target=target, args=("127.0.0.1", port), kwargs=dict( enable_logging=True, @@ -93,11 +99,24 @@ def grpc_server_port(environment, initialized_registry): ) yield port - embedded.stop_grpc_server() + if server_type == "grpc": + embedded.stop_grpc_server() + else: + embedded.stop_http_server() # wait for graceful stop time.sleep(2) +@pytest.fixture +def grpc_server_port(environment, initialized_registry): + yield from server_port(environment, "grpc") + + +@pytest.fixture +def http_server_port(environment, initialized_registry): + yield from server_port(environment, "http") + + @pytest.fixture def grpc_client(grpc_server_port): ch = grpc.insecure_channel(f"localhost:{grpc_server_port}") @@ -130,6 +149,44 @@ def test_go_grpc_server(grpc_client): assert all([s == FieldStatus.PRESENT for s in vector.statuses]) +@pytest.mark.integration +@pytest.mark.goserver +def test_go_http_server(http_server_port): + response = requests.post( + f"http://localhost:{http_server_port}/get-online-features", + json={ + "feature_service": "driver_features", + "entities": {"driver_id": [5001, 5002]}, + "full_feature_names": True, + }, + ) + assert response.status_code == 200, response.text + response = response.json() + assert set(response.keys()) == {"metadata", "results"} + metadata = response["metadata"] + results = response["results"] + assert response["metadata"] == { + "feature_names": [ + "driver_id", + "driver_stats__conv_rate", + "driver_stats__acc_rate", + "driver_stats__avg_daily_trips", + ] + }, metadata + assert len(results) == 4, results + assert all( + set(result.keys()) == {"event_timestamps", "statuses", "values"} + for result in results + ), results + assert all( + result["statuses"] == ["PRESENT", "PRESENT"] for result in results + ), results + assert results[0]["values"] == [5001, 5002], results + for result in results[1:]: + assert len(result["values"]) == 2, result + assert all(value is not None for value in result["values"]), result + + @pytest.mark.integration @pytest.mark.goserver @pytest.mark.universal_offline_stores