diff --git a/executor/api/rest/client.go b/executor/api/rest/client.go index a70a867316..ed20df2864 100644 --- a/executor/api/rest/client.go +++ b/executor/api/rest/client.go @@ -4,6 +4,15 @@ import ( "bytes" "context" "encoding/json" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "time" + http2 "github.com/cloudevents/sdk-go/pkg/bindings/http" "github.com/go-logr/logr" "github.com/golang/protobuf/jsonpb" @@ -19,15 +28,7 @@ import ( "github.com/seldonio/seldon-core/executor/api/util" "github.com/seldonio/seldon-core/executor/k8s" v1 "github.com/seldonio/seldon-core/operator/apis/machinelearning.seldon.io/v1" - "io" - "io/ioutil" - "net" - "net/http" - "net/url" logf "sigs.k8s.io/controller-runtime/pkg/log" - "strconv" - "strings" - "time" ) const ( @@ -50,7 +51,13 @@ func (smc *JSONRestClient) IsGrpc() bool { } func (smc *JSONRestClient) CreateErrorPayload(err error) payload.SeldonPayload { - respFailed := proto.SeldonMessage{Status: &proto.Status{Code: http.StatusInternalServerError, Info: err.Error()}} + respFailed := proto.SeldonMessage{ + Status: &proto.Status{ + Code: http.StatusInternalServerError, + Info: err.Error(), + Status: proto.Status_FAILURE, + }, + } m := jsonpb.Marshaler{} jStr, _ := m.MarshalToString(&respFailed) res := payload.BytesPayload{Msg: []byte(jStr)} @@ -291,7 +298,16 @@ func (smc *JSONRestClient) call(ctx context.Context, modelName string, method st contentType = req.GetContentType() contentEncoding = req.GetContentEncoding() } + sm, contentType, contentEncoding, err := smc.doHttp(ctx, modelName, method, &url, bytes, meta, contentType, contentEncoding) + + // Check if a httpStatusError was returned. + if err != nil { + if _, ok := err.(*httpStatusError); !ok { + return smc.CreateErrorPayload(err), err + } + } + res := payload.BytesPayload{Msg: sm, ContentType: contentType, ContentEncoding: contentEncoding} return &res, err } diff --git a/executor/api/rest/client_test.go b/executor/api/rest/client_test.go index 372a64858f..3ad7616c4c 100644 --- a/executor/api/rest/client_test.go +++ b/executor/api/rest/client_test.go @@ -329,8 +329,16 @@ func TestTimeout(t *testing.T) { seldonRestClient, err := NewJSONRestClient(api.ProtocolSeldon, "test", &predictor, annotations) g.Expect(err).To(BeNil()) - _, err = seldonRestClient.Status(createTestContext(), "model", host, int32(port), nil, map[string][]string{}) + r, err := seldonRestClient.Status(createTestContext(), "model", host, int32(port), nil, map[string][]string{}) g.Expect(err).ToNot(BeNil()) + g.Expect(r).ToNot((BeNil())) + + data := r.GetPayload().([]byte) + var smRes proto.SeldonMessage + err = jsonpb.UnmarshalString(string(data), &smRes) + g.Expect(err).Should(BeNil()) + g.Expect(smRes.GetStatus().GetCode()).Should(BeEquivalentTo(int32(500))) + g.Expect(smRes.GetStatus().GetInfo()).Should(ContainSubstring("Client.Timeout exceeded while awaiting headers")) } func TestMarshall(t *testing.T) {