diff --git a/src/cmd/tools/dtest/docker/harness/carbon_test.go b/src/cmd/tools/dtest/docker/harness/carbon_test.go index 6bc35a9ee7..c353ed0b5f 100644 --- a/src/cmd/tools/dtest/docker/harness/carbon_test.go +++ b/src/cmd/tools/dtest/docker/harness/carbon_test.go @@ -32,7 +32,7 @@ import ( ) func findVerifier(expected string) resources.ResponseVerifier { - return func(status int, s string, err error) error { + return func(status int, _ map[string][]string, s string, err error) error { if err != nil { return err } @@ -55,7 +55,7 @@ func renderVerifier(v float64) resources.ResponseVerifier { Datapoints [][]float64 `json:"datapoints"` } - return func(status int, s string, err error) error { + return func(status int, _ map[string][]string, s string, err error) error { if err != nil { return err } diff --git a/src/cmd/tools/dtest/docker/harness/query_api_test.go b/src/cmd/tools/dtest/docker/harness/query_api_test.go index dd389afd66..5ef751b2ec 100644 --- a/src/cmd/tools/dtest/docker/harness/query_api_test.go +++ b/src/cmd/tools/dtest/docker/harness/query_api_test.go @@ -71,7 +71,7 @@ func testInvalidQueryReturns400(t *testing.T, tests []urlTest) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assert.NoError(t, coord.RunQuery(verifyStatus(400), tt.url), "for query '%v'", tt.url) + assert.NoError(t, coord.RunQuery(verifyResponse(400), tt.url), "for query '%v'", tt.url) }) } } @@ -129,8 +129,8 @@ func queryString(params map[string]string) string { return strings.Join(p, "&") } -func verifyStatus(expectedStatus int) resources.ResponseVerifier { - return func(status int, resp string, err error) error { +func verifyResponse(expectedStatus int) resources.ResponseVerifier { + return func(status int, headers map[string][]string, resp string, err error) error { if err != nil { return err } @@ -139,6 +139,12 @@ func verifyStatus(expectedStatus int) resources.ResponseVerifier { return fmt.Errorf("expeceted %v status code, got %v", expectedStatus, status) } + if contentType, ok := headers["Content-Type"]; !ok { + return fmt.Errorf("missing Content-Type header") + } else if len(contentType) != 1 || contentType[0] != "application/json" { + return fmt.Errorf("expected json content type, got %v", contentType) + } + return nil } } diff --git a/src/cmd/tools/dtest/docker/harness/resources/coordinator.go b/src/cmd/tools/dtest/docker/harness/resources/coordinator.go index d2d320a446..2bafb8506f 100644 --- a/src/cmd/tools/dtest/docker/harness/resources/coordinator.go +++ b/src/cmd/tools/dtest/docker/harness/resources/coordinator.go @@ -55,7 +55,7 @@ var ( ) // ResponseVerifier is a function that checks if the query response is valid. -type ResponseVerifier func(int, string, error) error +type ResponseVerifier func(int, map[string][]string, string, error) error // Coordinator is a wrapper for a coordinator. It provides a wrapper on HTTP // endpoints that expose cluster management APIs as well as read and write @@ -363,7 +363,7 @@ func (c *coordinator) query( defer resp.Body.Close() b, err := ioutil.ReadAll(resp.Body) - return verifier(resp.StatusCode, string(b), err) + return verifier(resp.StatusCode, resp.Header, string(b), err) } func (c *coordinator) RunQuery( diff --git a/src/x/net/http/errors.go b/src/x/net/http/errors.go index fb90d73874..a0fbcde08d 100644 --- a/src/x/net/http/errors.go +++ b/src/x/net/http/errors.go @@ -92,25 +92,27 @@ func WriteError(w http.ResponseWriter, err error, opts ...WriteErrorOption) { fn(&o) } + statusCode := getStatusCode(err) + if o.response == nil { + w.Header().Set(HeaderContentType, ContentTypeJSON) + w.WriteHeader(statusCode) + json.NewEncoder(w).Encode(ErrorResponse{Error: err.Error()}) + } else { + w.WriteHeader(statusCode) + w.Write(o.response) + } +} + +func getStatusCode(err error) int { switch v := err.(type) { case Error: - w.WriteHeader(v.Code()) + return v.Code() case error: if xerrors.IsInvalidParams(v) { - w.WriteHeader(http.StatusBadRequest) + return http.StatusBadRequest } else if errors.Is(err, context.DeadlineExceeded) { - w.WriteHeader(http.StatusGatewayTimeout) - } else { - w.WriteHeader(http.StatusInternalServerError) + return http.StatusGatewayTimeout } - default: - w.WriteHeader(http.StatusInternalServerError) } - - if o.response != nil { - w.Write(o.response) - return - } - - json.NewEncoder(w).Encode(ErrorResponse{Error: err.Error()}) + return http.StatusInternalServerError }