From 18279e982f86442613844a3d60cc5366bf9f092f Mon Sep 17 00:00:00 2001 From: cliveseldon Date: Sat, 15 Feb 2020 10:05:16 +0000 Subject: [PATCH 1/2] Allow REST error payloads to be returned --- executor/api/rest/client.go | 18 ++++++------ executor/api/rest/client_test.go | 34 +++++++++++++++++++++- executor/api/rest/server.go | 34 +++++++++++----------- executor/api/rest/server_test.go | 49 ++++++++++++++++++++++++++++++++ executor/logger/worker.go | 4 +-- 5 files changed, 110 insertions(+), 29 deletions(-) diff --git a/executor/api/rest/client.go b/executor/api/rest/client.go index 817f080598..2f58bade94 100644 --- a/executor/api/rest/client.go +++ b/executor/api/rest/client.go @@ -149,11 +149,6 @@ func (smc *JSONRestClient) doHttp(ctx context.Context, modelName string, method return nil, "", err } - if response.StatusCode != http.StatusOK { - smc.Log.Info("httpPost failed", "response code", response.StatusCode) - return nil, "", errors.Errorf("Internal service call failed calling %s status code %d", url, response.StatusCode) - } - //Read response b, err := ioutil.ReadAll(response.Body) if err != nil { @@ -163,7 +158,13 @@ func (smc *JSONRestClient) doHttp(ctx context.Context, modelName string, method contentType := response.Header.Get("Content-Type") - return b, contentType, nil + err = nil + if response.StatusCode != http.StatusOK { + smc.Log.Info("httpPost failed", "response code", response.StatusCode) + err = errors.Errorf("Internal service call from executor failed calling %s status code %d", url, response.StatusCode) + } + + return b, contentType, err } func (smc *JSONRestClient) modifyMethod(method string, modelName string) string { @@ -197,11 +198,8 @@ func (smc *JSONRestClient) call(ctx context.Context, modelName string, method st bytes = req.GetPayload().([]byte) } sm, contentType, err := smc.doHttp(ctx, modelName, method, &url, bytes, meta) - if err != nil { - return nil, err - } res := payload.BytesPayload{Msg: sm, ContentType: contentType} - return &res, nil + return &res, err } func (smc *JSONRestClient) Status(ctx context.Context, modelName string, host string, port int32, msg payload.SeldonPayload, meta map[string][]string) (payload.SeldonPayload, error) { diff --git a/executor/api/rest/client_test.go b/executor/api/rest/client_test.go index 878fb7d951..8a70566ec1 100644 --- a/executor/api/rest/client_test.go +++ b/executor/api/rest/client_test.go @@ -3,6 +3,7 @@ package rest import ( "context" "crypto/tls" + "encoding/json" "github.com/golang/protobuf/jsonpb" . "github.com/onsi/gomega" "github.com/prometheus/client_golang/prometheus" @@ -41,6 +42,9 @@ const ( "name":"mymodel" } }` + errorPredictResponse = `{ + "status":"failed" + }` ) func testingHTTPClient(g *GomegaWithT, handler http.Handler) (string, int, *http.Client, func()) { @@ -108,7 +112,6 @@ func TestSimpleMethods(t *testing.T) { g.Expect(smRes.GetData().GetNdarray().Values[0].GetListValue().Values[0].GetNumberValue()).Should(Equal(0.9)) g.Expect(smRes.GetData().GetNdarray().Values[0].GetListValue().Values[1].GetNumberValue()).Should(Equal(0.1)) } - } func TestRouter(t *testing.T) { @@ -266,3 +269,32 @@ func TestClientMetrics(t *testing.T) { } } + +func TestErrorResponse(t *testing.T) { + t.Logf("Started") + g := NewGomegaWithT(t) + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(errorPredictResponse)) + }) + host, port, httpClient, teardown := testingHTTPClient(g, h) + defer teardown() + predictor := v1.PredictorSpec{ + Name: "test", + Annotations: map[string]string{}, + } + seldonRestClient := NewJSONRestClient(api.ProtocolSeldon, "test", &predictor, SetHTTPClient(httpClient)) + + methods := []func(context.Context, string, string, int32, payload.SeldonPayload, map[string][]string) (payload.SeldonPayload, error){seldonRestClient.Predict} + for _, method := range methods { + resPayload, err := method(createTestContext(), "model", host, int32(port), createPayload(g), map[string][]string{}) + g.Expect(err).ToNot(BeNil()) + + data := resPayload.GetPayload().([]byte) + var objmap map[string]interface{} + err = json.Unmarshal(data, &objmap) + g.Expect(err).To(BeNil()) + g.Expect(string(data)).To(Equal(errorPredictResponse)) + } + +} diff --git a/executor/api/rest/server.go b/executor/api/rest/server.go index 37e38da5dc..b3c24d4c75 100644 --- a/executor/api/rest/server.go +++ b/executor/api/rest/server.go @@ -66,14 +66,21 @@ func (r *SeldonRestApi) respondWithSuccess(w http.ResponseWriter, code int, payl } } -func (r *SeldonRestApi) respondWithError(w http.ResponseWriter, err error) { +func (r *SeldonRestApi) respondWithError(w http.ResponseWriter, payload payload.SeldonPayload, err error) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusInternalServerError) - errPayload := r.Client.CreateErrorPayload(err) - err = r.Client.Marshall(w, errPayload) - if err != nil { - r.Log.Error(err, "Failed to write error payload") + if payload != nil && payload.GetPayload() != nil { + err := r.Client.Marshall(w, payload) + if err != nil { + r.Log.Error(err, "Failed to write response") + } + } else { + errPayload := r.Client.CreateErrorPayload(err) + err = r.Client.Marshall(w, errPayload) + if err != nil { + r.Log.Error(err, "Failed to write error payload") + } } } @@ -140,11 +147,6 @@ func (r *SeldonRestApi) alive(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusOK) } -func (r *SeldonRestApi) failWithError(w http.ResponseWriter, err error) { - r.Log.Error(err, "Failed") - r.respondWithError(w, err) -} - func getGraphNodeForModelName(req *http.Request, graph *v1.PredictiveUnit) (*v1.PredictiveUnit, error) { vars := mux.Vars(req) modelName := vars[ModelHttpPathVariable] @@ -179,7 +181,7 @@ func (r *SeldonRestApi) metadata(w http.ResponseWriter, req *http.Request) { seldonPredictorProcess := predictor.NewPredictorProcess(ctx, r.Client, logf.Log.WithName(LoggingRestClientName), r.ServerUrl, r.Namespace, req.Header) resPayload, err := seldonPredictorProcess.Metadata(r.predictor.Graph, modelName, nil) if err != nil { - r.failWithError(w, err) + r.respondWithError(w, resPayload, err) return } r.respondWithSuccess(w, http.StatusOK, resPayload) @@ -201,7 +203,7 @@ func (r *SeldonRestApi) status(w http.ResponseWriter, req *http.Request) { seldonPredictorProcess := predictor.NewPredictorProcess(ctx, r.Client, logf.Log.WithName(LoggingRestClientName), r.ServerUrl, r.Namespace, req.Header) resPayload, err := seldonPredictorProcess.Status(r.predictor.Graph, modelName, nil) if err != nil { - r.failWithError(w, err) + r.respondWithError(w, resPayload, err) return } r.respondWithSuccess(w, http.StatusOK, resPayload) @@ -223,7 +225,7 @@ func (r *SeldonRestApi) predictions(w http.ResponseWriter, req *http.Request) { bodyBytes, err := ioutil.ReadAll(req.Body) if err != nil { - r.failWithError(w, err) + r.respondWithError(w, nil, err) return } @@ -231,7 +233,7 @@ func (r *SeldonRestApi) predictions(w http.ResponseWriter, req *http.Request) { reqPayload, err := seldonPredictorProcess.Client.Unmarshall(bodyBytes) if err != nil { - r.failWithError(w, err) + r.respondWithError(w, nil, err) return } @@ -239,7 +241,7 @@ func (r *SeldonRestApi) predictions(w http.ResponseWriter, req *http.Request) { if r.Protocol == api.ProtocolTensorflow { graphNode, err = getGraphNodeForModelName(req, r.predictor.Graph) if err != nil { - r.failWithError(w, err) + r.respondWithError(w, nil, err) return } } else { @@ -247,7 +249,7 @@ func (r *SeldonRestApi) predictions(w http.ResponseWriter, req *http.Request) { } resPayload, err := seldonPredictorProcess.Predict(graphNode, reqPayload) if err != nil { - r.failWithError(w, err) + r.respondWithError(w, resPayload, err) return } r.respondWithSuccess(w, http.StatusOK, resPayload) diff --git a/executor/api/rest/server_test.go b/executor/api/rest/server_test.go index fa3b23de1f..2a56596ec4 100644 --- a/executor/api/rest/server_test.go +++ b/executor/api/rest/server_test.go @@ -266,3 +266,52 @@ func TestTensorflowMetadata(t *testing.T) { g.Expect(res.Code).To(Equal(200)) g.Expect(res.Body.String()).To(Equal(test.TestClientMetadataResponse)) } + +func TestPredictErrorWithServer(t *testing.T) { + t.Logf("Started") + g := NewGomegaWithT(t) + called := false + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := ioutil.ReadAll(r.Body) + g.Expect(err).To(BeNil()) + g.Expect(r.Header.Get(payload.SeldonPUIDHeader)).To(Equal(TestSeldonPuid)) + called = true + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(errorPredictResponse)) + }) + server := httptest.NewServer(handler) + defer server.Close() + url, err := url.Parse(server.URL) + g.Expect(err).Should(BeNil()) + urlParts := strings.Split(url.Host, ":") + port, err := strconv.Atoi(urlParts[1]) + g.Expect(err).Should(BeNil()) + + model := v1.MODEL + p := v1.PredictorSpec{ + Name: "p", + Graph: &v1.PredictiveUnit{ + Type: &model, + Endpoint: &v1.Endpoint{ + ServiceHost: urlParts[0], + ServicePort: int32(port), + Type: v1.REST, + }, + }, + } + + r := NewServerRestApi(&p, NewJSONRestClient(api.ProtocolSeldon, "dep", &p), false, url, "default", api.ProtocolSeldon, "test", "/metrics") + r.Initialise() + var data = ` {"data":{"ndarray":[1.1,2.0]}}` + + req, _ := http.NewRequest("POST", "/api/v0.1/predictions", strings.NewReader(data)) + req.Header = map[string][]string{"Content-Type": []string{"application/json"}, payload.SeldonPUIDHeader: []string{TestSeldonPuid}} + res := httptest.NewRecorder() + r.Router.ServeHTTP(res, req) + g.Expect(res.Code).To(Equal(http.StatusInternalServerError)) + g.Expect(called).To(Equal(true)) + b, err := ioutil.ReadAll(res.Body) + g.Expect(err).Should(BeNil()) + g.Expect(string(b)).To(Equal(errorPredictResponse)) +} diff --git a/executor/logger/worker.go b/executor/logger/worker.go index 39df50f2a1..839d58068e 100644 --- a/executor/logger/worker.go +++ b/executor/logger/worker.go @@ -11,8 +11,8 @@ import ( ) const ( - CEInferenceRequest = "io.seldon.serving.inference.request" - CEInferenceResponse = "io.seldon.serving.inference.response" + CEInferenceRequest = "io.seldon.serving.inference.request" + CEInferenceResponse = "io.seldon.serving.inference.response" // cloud events extension attributes have to be lowercase alphanumeric RequestIdAttr = "requestid" ModelIdAttr = "modelid" From ad3042c899c434c1e2f0b324ad5bd915fdf7ba85 Mon Sep 17 00:00:00 2001 From: cliveseldon Date: Tue, 3 Mar 2020 09:14:40 +0000 Subject: [PATCH 2/2] fix unnecessary err set to nil --- executor/api/rest/client.go | 1 - 1 file changed, 1 deletion(-) diff --git a/executor/api/rest/client.go b/executor/api/rest/client.go index 6788ddfd67..8d2b394235 100644 --- a/executor/api/rest/client.go +++ b/executor/api/rest/client.go @@ -187,7 +187,6 @@ func (smc *JSONRestClient) doHttp(ctx context.Context, modelName string, method contentType := response.Header.Get("Content-Type") - err = nil if response.StatusCode != http.StatusOK { smc.Log.Info("httpPost failed", "response code", response.StatusCode) err = errors.Errorf("Internal service call from executor failed calling %s status code %d", url, response.StatusCode)