Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow REST error payloads to be returned #1446

Merged
merged 6 commits into from
Mar 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions executor/api/rest/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,6 @@ func (smc *JSONRestClient) doHttp(ctx context.Context, modelName string, method
return nil, "", err
}

if response.StatusCode != http.StatusOK {
smc.Log.Info("httpPost failed", "response code", response.StatusCode)
return nil, "", errors.Errorf("Internal service call failed calling %s status code %d", url, response.StatusCode)
}

//Read response
b, err := ioutil.ReadAll(response.Body)
if err != nil {
Expand All @@ -192,7 +187,12 @@ func (smc *JSONRestClient) doHttp(ctx context.Context, modelName string, method

contentType := response.Header.Get("Content-Type")

return b, contentType, nil
if response.StatusCode != http.StatusOK {
smc.Log.Info("httpPost failed", "response code", response.StatusCode)
err = errors.Errorf("Internal service call from executor failed calling %s status code %d", url, response.StatusCode)
}

return b, contentType, err
}

func (smc *JSONRestClient) modifyMethod(method string, modelName string) string {
Expand Down Expand Up @@ -226,11 +226,8 @@ func (smc *JSONRestClient) call(ctx context.Context, modelName string, method st
bytes = req.GetPayload().([]byte)
}
sm, contentType, err := smc.doHttp(ctx, modelName, method, &url, bytes, meta)
if err != nil {
return nil, err
}
res := payload.BytesPayload{Msg: sm, ContentType: contentType}
return &res, nil
return &res, err
}

func (smc *JSONRestClient) Status(ctx context.Context, modelName string, host string, port int32, msg payload.SeldonPayload, meta map[string][]string) (payload.SeldonPayload, error) {
Expand Down
37 changes: 36 additions & 1 deletion executor/api/rest/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package rest
import (
"context"
"crypto/tls"
"encoding/json"
"github.com/golang/protobuf/jsonpb"
. "github.com/onsi/gomega"
"github.com/prometheus/client_golang/prometheus"
Expand Down Expand Up @@ -43,6 +44,9 @@ const (
"name":"mymodel"
}
}`
errorPredictResponse = `{
"status":"failed"
}`
)

func testingHTTPClient(g *GomegaWithT, handler http.Handler) (string, int, *http.Client, func()) {
Expand Down Expand Up @@ -111,7 +115,6 @@ func TestSimpleMethods(t *testing.T) {
g.Expect(smRes.GetData().GetNdarray().Values[0].GetListValue().Values[0].GetNumberValue()).Should(Equal(0.9))
g.Expect(smRes.GetData().GetNdarray().Values[0].GetListValue().Values[1].GetNumberValue()).Should(Equal(0.1))
}

}

func TestRouter(t *testing.T) {
Expand Down Expand Up @@ -275,6 +278,37 @@ func TestClientMetrics(t *testing.T) {

}

func TestErrorResponse(t *testing.T) {
t.Logf("Started")
g := NewGomegaWithT(t)
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(errorPredictResponse))
})
host, port, _, teardown := testingHTTPClient(g, h)

defer teardown()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gomega's BeforeEach and AfterEach are another option for how to manage test setup and teardown

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As its not for all test might leave for now?

predictor := v1.PredictorSpec{
Name: "test",
Annotations: map[string]string{},
}

seldonRestClient, err := NewJSONRestClient(api.ProtocolSeldon, "test", &predictor, nil)
g.Expect(err).To(BeNil())

methods := []func(context.Context, string, string, int32, payload.SeldonPayload, map[string][]string) (payload.SeldonPayload, error){seldonRestClient.Predict}
for _, method := range methods {
resPayload, err := method(createTestContext(), "model", host, int32(port), createPayload(g), map[string][]string{})
g.Expect(err).ToNot(BeNil())

data := resPayload.GetPayload().([]byte)
var objmap map[string]interface{}
err = json.Unmarshal(data, &objmap)
g.Expect(err).To(BeNil())
g.Expect(string(data)).To(Equal(errorPredictResponse))
}
}

func TestTimeout(t *testing.T) {
t.Logf("Started")
g := NewGomegaWithT(t)
Expand All @@ -283,6 +317,7 @@ func TestTimeout(t *testing.T) {
w.Write([]byte(okStatusResponse))
})
host, port, _, teardown := testingHTTPClient(g, h)

defer teardown()
predictor := v1.PredictorSpec{
Name: "test",
Expand Down
34 changes: 18 additions & 16 deletions executor/api/rest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,21 @@ func (r *SeldonRestApi) respondWithSuccess(w http.ResponseWriter, code int, payl
}
}

func (r *SeldonRestApi) respondWithError(w http.ResponseWriter, err error) {
func (r *SeldonRestApi) respondWithError(w http.ResponseWriter, payload payload.SeldonPayload, err error) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)

errPayload := r.Client.CreateErrorPayload(err)
err = r.Client.Marshall(w, errPayload)
if err != nil {
r.Log.Error(err, "Failed to write error payload")
if payload != nil && payload.GetPayload() != nil {
err := r.Client.Marshall(w, payload)
if err != nil {
r.Log.Error(err, "Failed to write response")
}
} else {
errPayload := r.Client.CreateErrorPayload(err)
err = r.Client.Marshall(w, errPayload)
if err != nil {
r.Log.Error(err, "Failed to write error payload")
}
}
}

Expand Down Expand Up @@ -155,11 +162,6 @@ func (r *SeldonRestApi) alive(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(http.StatusOK)
}

func (r *SeldonRestApi) failWithError(w http.ResponseWriter, err error) {
r.Log.Error(err, "Failed")
r.respondWithError(w, err)
}

func getGraphNodeForModelName(req *http.Request, graph *v1.PredictiveUnit) (*v1.PredictiveUnit, error) {
vars := mux.Vars(req)
modelName := vars[ModelHttpPathVariable]
Expand Down Expand Up @@ -194,7 +196,7 @@ func (r *SeldonRestApi) metadata(w http.ResponseWriter, req *http.Request) {
seldonPredictorProcess := predictor.NewPredictorProcess(ctx, r.Client, logf.Log.WithName(LoggingRestClientName), r.ServerUrl, r.Namespace, req.Header)
resPayload, err := seldonPredictorProcess.Metadata(r.predictor.Graph, modelName, nil)
if err != nil {
r.failWithError(w, err)
r.respondWithError(w, resPayload, err)
return
}
r.respondWithSuccess(w, http.StatusOK, resPayload)
Expand All @@ -216,7 +218,7 @@ func (r *SeldonRestApi) status(w http.ResponseWriter, req *http.Request) {
seldonPredictorProcess := predictor.NewPredictorProcess(ctx, r.Client, logf.Log.WithName(LoggingRestClientName), r.ServerUrl, r.Namespace, req.Header)
resPayload, err := seldonPredictorProcess.Status(r.predictor.Graph, modelName, nil)
if err != nil {
r.failWithError(w, err)
r.respondWithError(w, resPayload, err)
return
}
r.respondWithSuccess(w, http.StatusOK, resPayload)
Expand All @@ -238,31 +240,31 @@ func (r *SeldonRestApi) predictions(w http.ResponseWriter, req *http.Request) {

bodyBytes, err := ioutil.ReadAll(req.Body)
if err != nil {
r.failWithError(w, err)
r.respondWithError(w, nil, err)
return
}

seldonPredictorProcess := predictor.NewPredictorProcess(ctx, r.Client, logf.Log.WithName(LoggingRestClientName), r.ServerUrl, r.Namespace, req.Header)

reqPayload, err := seldonPredictorProcess.Client.Unmarshall(bodyBytes)
if err != nil {
r.failWithError(w, err)
r.respondWithError(w, nil, err)
return
}

var graphNode *v1.PredictiveUnit
if r.Protocol == api.ProtocolTensorflow {
graphNode, err = getGraphNodeForModelName(req, r.predictor.Graph)
if err != nil {
r.failWithError(w, err)
r.respondWithError(w, nil, err)
return
}
} else {
graphNode = r.predictor.Graph
}
resPayload, err := seldonPredictorProcess.Predict(graphNode, reqPayload)
if err != nil {
r.failWithError(w, err)
r.respondWithError(w, resPayload, err)
return
}
r.respondWithSuccess(w, http.StatusOK, resPayload)
Expand Down
50 changes: 50 additions & 0 deletions executor/api/rest/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,53 @@ func TestTensorflowMetadata(t *testing.T) {
g.Expect(res.Code).To(Equal(200))
g.Expect(res.Body.String()).To(Equal(test.TestClientMetadataResponse))
}

func TestPredictErrorWithServer(t *testing.T) {
t.Logf("Started")
g := NewGomegaWithT(t)
called := false

handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := ioutil.ReadAll(r.Body)
g.Expect(err).To(BeNil())
g.Expect(r.Header.Get(payload.SeldonPUIDHeader)).To(Equal(TestSeldonPuid))
called = true
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(errorPredictResponse))
})
server := httptest.NewServer(handler)
defer server.Close()
url, err := url.Parse(server.URL)
g.Expect(err).Should(BeNil())
urlParts := strings.Split(url.Host, ":")
port, err := strconv.Atoi(urlParts[1])
g.Expect(err).Should(BeNil())

model := v1.MODEL
p := v1.PredictorSpec{
Name: "p",
Graph: &v1.PredictiveUnit{
Type: &model,
Endpoint: &v1.Endpoint{
ServiceHost: urlParts[0],
ServicePort: int32(port),
Type: v1.REST,
},
},
}
client, err := NewJSONRestClient(api.ProtocolSeldon, "dep", &p, nil)
g.Expect(err).Should(BeNil())
r := NewServerRestApi(&p, client, false, url, "default", api.ProtocolSeldon, "test", "/metrics")
r.Initialise()
var data = ` {"data":{"ndarray":[1.1,2.0]}}`

req, _ := http.NewRequest("POST", "/api/v0.1/predictions", strings.NewReader(data))
req.Header = map[string][]string{"Content-Type": []string{"application/json"}, payload.SeldonPUIDHeader: []string{TestSeldonPuid}}
res := httptest.NewRecorder()
r.Router.ServeHTTP(res, req)
g.Expect(res.Code).To(Equal(http.StatusInternalServerError))
g.Expect(called).To(Equal(true))
b, err := ioutil.ReadAll(res.Body)
g.Expect(err).Should(BeNil())
g.Expect(string(b)).To(Equal(errorPredictResponse))
}