From 8dc9960ef8fc8af114db7d3c013b643476689dd1 Mon Sep 17 00:00:00 2001 From: Greg Weber Date: Wed, 4 Jul 2018 01:02:33 -0500 Subject: [PATCH] send back HTTP 400 where there is a client error (#1131) * server: use HTTP 400 where there is a client error We often send a 500 error where it should be 400 or 404. It is important to get these right. A client will assume it did nothing wrong when a 500 is received and perhaps automatically retry the request. This is a conservative change: there are many more places that should be changed from 500. A helper function jsonRespondError is added --- pkg/apiutil/apiutil.go | 41 ++++++++++++++++++++--- server/api/admin.go | 2 +- server/api/config.go | 23 +++++-------- server/api/log.go | 2 +- server/api/member.go | 13 ++++---- server/api/operator.go | 9 +++-- server/api/region.go | 6 ++-- server/api/scheduler.go | 3 +- server/api/store.go | 16 ++++----- server/api/util.go | 16 +++++++++ server/api/util_test.go | 67 ++++++++++++++++++++++++++++++++++++++ server/util_test.go | 2 +- table/namespace_handler.go | 26 ++++++++++----- 13 files changed, 169 insertions(+), 57 deletions(-) create mode 100644 server/api/util_test.go diff --git a/pkg/apiutil/apiutil.go b/pkg/apiutil/apiutil.go index 82072790d51..795ddc9a007 100644 --- a/pkg/apiutil/apiutil.go +++ b/pkg/apiutil/apiutil.go @@ -21,10 +21,41 @@ import ( "github.com/juju/errors" ) -// ReadJSON reads a JSON data from r and then close it. -func ReadJSON(r io.ReadCloser, data interface{}) error { - defer r.Close() +// DeferClose captures the error returned from closing (if an error occurs). +// This is designed to be used in a defer statement. +func DeferClose(c io.Closer, err *error) { + if cerr := c.Close(); cerr != nil && *err == nil { + *err = errors.Trace(cerr) + } +} + +// JSONError lets callers check for just one error type +type JSONError struct { + err error +} + +func (e JSONError) Error() string { + return e.err.Error() +} +// Cause for compatibility with the errors package +func (e JSONError) Cause() error { + return e.err +} + +func tagJSONError(err error) error { + switch err.(type) { + case *json.SyntaxError, *json.UnmarshalTypeError: + return JSONError{err} + } + return err +} + +// ReadJSON reads a JSON data from r and then closes it. +// An error due to invalid json will be returned as a JSONError +func ReadJSON(r io.ReadCloser, data interface{}) error { + var err error + defer DeferClose(r, &err) b, err := ioutil.ReadAll(r) if err != nil { return errors.Trace(err) @@ -32,8 +63,8 @@ func ReadJSON(r io.ReadCloser, data interface{}) error { err = json.Unmarshal(b, data) if err != nil { - return errors.Trace(err) + return tagJSONError(err) } - return nil + return err } diff --git a/server/api/admin.go b/server/api/admin.go index a5f820b5eab..3e3df32eeda 100644 --- a/server/api/admin.go +++ b/server/api/admin.go @@ -45,7 +45,7 @@ func (h *adminHandler) HandleDropCacheRegion(w http.ResponseWriter, r *http.Requ regionIDStr := vars["id"] regionID, err := strconv.ParseUint(regionIDStr, 10, 64) if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } cluster.DropCacheRegion(regionID) diff --git a/server/api/config.go b/server/api/config.go index a66b66a3d4b..f81a7920d0c 100644 --- a/server/api/config.go +++ b/server/api/config.go @@ -74,9 +74,7 @@ func (h *confHandler) GetSchedule(w http.ResponseWriter, r *http.Request) { func (h *confHandler) SetSchedule(w http.ResponseWriter, r *http.Request) { config := h.svr.GetScheduleConfig() - err := readJSON(r.Body, config) - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + if err := readJSONRespondError(h.rd, w, r.Body, &config); err != nil { return } @@ -93,9 +91,7 @@ func (h *confHandler) GetReplication(w http.ResponseWriter, r *http.Request) { func (h *confHandler) SetReplication(w http.ResponseWriter, r *http.Request) { config := h.svr.GetReplicationConfig() - err := readJSON(r.Body, config) - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + if err := readJSONRespondError(h.rd, w, r.Body, &config); err != nil { return } @@ -108,7 +104,7 @@ func (h *confHandler) GetNamespace(w http.ResponseWriter, r *http.Request) { name := vars["name"] if !h.svr.IsNamespaceExist(name) { - h.rd.JSON(w, http.StatusInternalServerError, fmt.Sprintf("invalid namespace Name %s, not found", name)) + h.rd.JSON(w, http.StatusNotFound, fmt.Sprintf("invalid namespace Name %s, not found", name)) return } @@ -122,14 +118,12 @@ func (h *confHandler) SetNamespace(w http.ResponseWriter, r *http.Request) { name := vars["name"] if !h.svr.IsNamespaceExist(name) { - h.rd.JSON(w, http.StatusInternalServerError, fmt.Sprintf("invalid namespace Name %s, not found", name)) + h.rd.JSON(w, http.StatusNotFound, fmt.Sprintf("invalid namespace Name %s, not found", name)) return } config := h.svr.GetNamespaceConfig(name) - err := readJSON(r.Body, config) - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + if err := readJSONRespondError(h.rd, w, r.Body, &config); err != nil { return } @@ -142,7 +136,7 @@ func (h *confHandler) DeleteNamespace(w http.ResponseWriter, r *http.Request) { name := vars["name"] if !h.svr.IsNamespaceExist(name) { - h.rd.JSON(w, http.StatusInternalServerError, fmt.Sprintf("invalid namespace Name %s, not found", name)) + h.rd.JSON(w, http.StatusNotFound, fmt.Sprintf("invalid namespace Name %s, not found", name)) return } h.svr.DeleteNamespaceConfig(name) @@ -156,11 +150,10 @@ func (h *confHandler) GetLabelProperty(w http.ResponseWriter, r *http.Request) { func (h *confHandler) SetLabelProperty(w http.ResponseWriter, r *http.Request) { input := make(map[string]string) - err := readJSON(r.Body, &input) - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + if err := readJSONRespondError(h.rd, w, r.Body, &input); err != nil { return } + var err error switch input["action"] { case "set": err = h.svr.SetLabelProperty(input["type"], input["label-key"], input["label-value"]) diff --git a/server/api/log.go b/server/api/log.go index cab2e8abce4..714909f9f18 100644 --- a/server/api/log.go +++ b/server/api/log.go @@ -46,7 +46,7 @@ func (h *logHandler) Handle(w http.ResponseWriter, r *http.Request) { } err = json.Unmarshal(data, &level) if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } diff --git a/server/api/member.go b/server/api/member.go index de31c163595..6fa331b6c47 100644 --- a/server/api/member.go +++ b/server/api/member.go @@ -96,7 +96,7 @@ func (h *memberHandler) DeleteByID(w http.ResponseWriter, r *http.Request) { idStr := mux.Vars(r)["id"] id, err := strconv.ParseUint(idStr, 10, 64) if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } @@ -117,9 +117,9 @@ func (h *memberHandler) DeleteByID(w http.ResponseWriter, r *http.Request) { } func (h *memberHandler) SetMemberPropertyByName(w http.ResponseWriter, r *http.Request) { - members, err := h.listMembers() - if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + members, membersErr := h.listMembers() + if membersErr != nil { + h.rd.JSON(w, http.StatusInternalServerError, membersErr.Error()) return } @@ -137,8 +137,7 @@ func (h *memberHandler) SetMemberPropertyByName(w http.ResponseWriter, r *http.R } var input map[string]interface{} - if err = readJSON(r.Body, &input); err != nil { - h.rd.JSON(w, http.StatusBadRequest, err.Error()) + if err := readJSONRespondError(h.rd, w, r.Body, &input); err != nil { return } for k, v := range input { @@ -149,7 +148,7 @@ func (h *memberHandler) SetMemberPropertyByName(w http.ResponseWriter, r *http.R h.rd.JSON(w, http.StatusBadRequest, "bad format leader priority") return } - err = h.svr.SetMemberLeaderPriority(memberID, int(priority)) + err := h.svr.SetMemberLeaderPriority(memberID, int(priority)) if err != nil { h.rd.JSON(w, http.StatusInternalServerError, err.Error()) return diff --git a/server/api/operator.go b/server/api/operator.go index 2be9c52e7a8..72daa91b812 100644 --- a/server/api/operator.go +++ b/server/api/operator.go @@ -40,7 +40,7 @@ func (h *operatorHandler) Get(w http.ResponseWriter, r *http.Request) { regionID, err := strconv.ParseUint(id, 10, 64) if err != nil { - h.r.JSON(w, http.StatusInternalServerError, err.Error()) + h.r.JSON(w, http.StatusBadRequest, err.Error()) return } @@ -90,14 +90,13 @@ func (h *operatorHandler) List(w http.ResponseWriter, r *http.Request) { func (h *operatorHandler) Post(w http.ResponseWriter, r *http.Request) { var input map[string]interface{} - if err := readJSON(r.Body, &input); err != nil { - h.r.JSON(w, http.StatusInternalServerError, err.Error()) + if err := readJSONRespondError(h.r, w, r.Body, &input); err != nil { return } name, ok := input["name"].(string) if !ok { - h.r.JSON(w, http.StatusInternalServerError, "missing operator name") + h.r.JSON(w, http.StatusBadRequest, "missing operator name") return } @@ -234,7 +233,7 @@ func (h *operatorHandler) Delete(w http.ResponseWriter, r *http.Request) { regionID, err := strconv.ParseUint(id, 10, 64) if err != nil { - h.r.JSON(w, http.StatusInternalServerError, err.Error()) + h.r.JSON(w, http.StatusBadRequest, err.Error()) return } diff --git a/server/api/region.go b/server/api/region.go index ffa8fb98f3f..7a936001dd4 100644 --- a/server/api/region.go +++ b/server/api/region.go @@ -92,7 +92,7 @@ func (h *regionHandler) GetRegionByID(w http.ResponseWriter, r *http.Request) { regionIDStr := vars["id"] regionID, err := strconv.ParseUint(regionIDStr, 10, 64) if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } @@ -203,12 +203,12 @@ func (h *regionsHandler) GetRegionSiblings(w http.ResponseWriter, r *http.Reques vars := mux.Vars(r) id, err := strconv.Atoi(vars["id"]) if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } region := cluster.GetRegionInfoByID(uint64(id)) if region == nil { - h.rd.JSON(w, http.StatusInternalServerError, server.ErrRegionNotFound(uint64(id)).Error()) + h.rd.JSON(w, http.StatusNotFound, server.ErrRegionNotFound(uint64(id)).Error()) return } diff --git a/server/api/scheduler.go b/server/api/scheduler.go index 4a0422c9c8d..f7e77dcf385 100644 --- a/server/api/scheduler.go +++ b/server/api/scheduler.go @@ -44,8 +44,7 @@ func (h *schedulerHandler) List(w http.ResponseWriter, r *http.Request) { func (h *schedulerHandler) Post(w http.ResponseWriter, r *http.Request) { var input map[string]interface{} - if err := readJSON(r.Body, &input); err != nil { - h.r.JSON(w, http.StatusInternalServerError, err.Error()) + if err := readJSONRespondError(h.r, w, r.Body, &input); err != nil { return } diff --git a/server/api/store.go b/server/api/store.go index 48db0d3f2af..6f19d3ee382 100644 --- a/server/api/store.go +++ b/server/api/store.go @@ -141,7 +141,7 @@ func (h *storeHandler) Get(w http.ResponseWriter, r *http.Request) { storeIDStr := vars["id"] storeID, err := strconv.ParseUint(storeIDStr, 10, 64) if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } @@ -166,7 +166,7 @@ func (h *storeHandler) Delete(w http.ResponseWriter, r *http.Request) { storeIDStr := vars["id"] storeID, err := strconv.ParseUint(storeIDStr, 10, 64) if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } @@ -196,7 +196,7 @@ func (h *storeHandler) SetState(w http.ResponseWriter, r *http.Request) { storeIDStr := vars["id"] storeID, err := strconv.ParseUint(storeIDStr, 10, 64) if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } @@ -227,13 +227,12 @@ func (h *storeHandler) SetLabels(w http.ResponseWriter, r *http.Request) { storeIDStr := vars["id"] storeID, err := strconv.ParseUint(storeIDStr, 10, 64) if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } var input map[string]string - if err := readJSON(r.Body, &input); err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + if err := readJSONRespondError(h.rd, w, r.Body, &input); err != nil { return } @@ -264,13 +263,12 @@ func (h *storeHandler) SetWeight(w http.ResponseWriter, r *http.Request) { storeIDStr := vars["id"] storeID, err := strconv.ParseUint(storeIDStr, 10, 64) if err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } var input map[string]interface{} - if err := readJSON(r.Body, &input); err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + if err := readJSONRespondError(h.rd, w, r.Body, &input); err != nil { return } diff --git a/server/api/util.go b/server/api/util.go index e97fcc0526b..ac081107a45 100644 --- a/server/api/util.go +++ b/server/api/util.go @@ -21,9 +21,25 @@ import ( "net/http" "github.com/juju/errors" + "github.com/pingcap/pd/pkg/apiutil" "github.com/pingcap/pd/server" + "github.com/unrolled/render" ) +func readJSONRespondError(r *render.Render, w http.ResponseWriter, body io.ReadCloser, data interface{}) error { + err := apiutil.ReadJSON(body, data) + if err == nil { + return nil + } + switch err.(type) { + case apiutil.JSONError: + r.JSON(w, http.StatusBadRequest, err.Error()) + default: + r.JSON(w, http.StatusInternalServerError, err.Error()) + } + return err +} + func readJSON(r io.ReadCloser, data interface{}) error { defer r.Close() diff --git a/server/api/util_test.go b/server/api/util_test.go new file mode 100644 index 00000000000..1a5dde4a4d6 --- /dev/null +++ b/server/api/util_test.go @@ -0,0 +1,67 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "bytes" + "io/ioutil" + "net/http/httptest" + + . "github.com/pingcap/check" + "github.com/unrolled/render" +) + +var _ = Suite(&testUtilSuite{}) + +type testUtilSuite struct{} + +func (s *testUtilSuite) TestJsonRespondErrorOk(c *C) { + rd := render.New(render.Options{ + IndentJSON: true, + }) + response := httptest.NewRecorder() + body := ioutil.NopCloser(bytes.NewBufferString("{\"zone\":\"cn\", \"host\":\"local\"}")) + var input map[string]string + output := map[string]string{"zone": "cn", "host": "local"} + err := readJSONRespondError(rd, response, body, &input) + c.Assert(err, IsNil) + c.Assert(input["zone"], Equals, output["zone"]) + c.Assert(input["host"], Equals, output["host"]) + result := response.Result() + c.Assert(result.StatusCode, Equals, 200) +} + +func (s *testUtilSuite) TestJsonRespondErrorBadInput(c *C) { + rd := render.New(render.Options{ + IndentJSON: true, + }) + response := httptest.NewRecorder() + body := ioutil.NopCloser(bytes.NewBufferString("{\"zone\":\"cn\", \"host\":\"local\"}")) + var input []string + err := readJSONRespondError(rd, response, body, &input) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "json: cannot unmarshal object into Go value of type []string") + result := response.Result() + c.Assert(result.StatusCode, Equals, 400) + + { + body := ioutil.NopCloser(bytes.NewBufferString("{\"zone\":\"cn\",")) + var input []string + err := readJSONRespondError(rd, response, body, &input) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "unexpected end of JSON input") + result := response.Result() + c.Assert(result.StatusCode, Equals, 400) + } +} diff --git a/server/util_test.go b/server/util_test.go index 59261448b84..159157c0f2b 100644 --- a/server/util_test.go +++ b/server/util_test.go @@ -1,4 +1,4 @@ -// Copyright 2016 PingCAP, Inc. +// Copyright 2018 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/table/namespace_handler.go b/table/namespace_handler.go index b355c975f72..e924fb9ddf8 100644 --- a/table/namespace_handler.go +++ b/table/namespace_handler.go @@ -14,6 +14,7 @@ package table import ( + "io" "net/http" "strconv" @@ -56,11 +57,23 @@ func (h *tableNamespaceHandler) Get(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, nsInfo) } +func readJSONRespondError(r *render.Render, w http.ResponseWriter, body io.ReadCloser, data interface{}) (err error) { + if err = apiutil.ReadJSON(body, &data); err != nil { + switch err.(type) { + case apiutil.JSONError: + r.JSON(w, http.StatusBadRequest, err.Error()) + default: + r.JSON(w, http.StatusInternalServerError, err.Error()) + + } + } + return +} + // Post creates a namespace. func (h *tableNamespaceHandler) Post(w http.ResponseWriter, r *http.Request) { var input map[string]string - if err := apiutil.ReadJSON(r.Body, &input); err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + if err := readJSONRespondError(h.rd, w, r.Body, &input); err != nil { return } ns := input["namespace"] @@ -76,8 +89,7 @@ func (h *tableNamespaceHandler) Post(w http.ResponseWriter, r *http.Request) { func (h *tableNamespaceHandler) Update(w http.ResponseWriter, r *http.Request) { var input map[string]string - if err := apiutil.ReadJSON(r.Body, &input); err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + if err := readJSONRespondError(h.rd, w, r.Body, &input); err != nil { return } tableIDStr := input["table_id"] @@ -116,8 +128,7 @@ func (h *tableNamespaceHandler) Update(w http.ResponseWriter, r *http.Request) { func (h *tableNamespaceHandler) SetMetaNamespace(w http.ResponseWriter, r *http.Request) { var input map[string]string - if err := apiutil.ReadJSON(r.Body, &input); err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + if err := readJSONRespondError(h.rd, w, r.Body, &input); err != nil { return } ns := input["namespace"] @@ -149,8 +160,7 @@ func (h *tableNamespaceHandler) SetNamespace(w http.ResponseWriter, r *http.Requ } var input map[string]string - if err := apiutil.ReadJSON(r.Body, &input); err != nil { - h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + if err := readJSONRespondError(h.rd, w, r.Body, &input); err != nil { return } ns := input["namespace"]