From 3156ca6b04b05e72794685b353c744016e8732a7 Mon Sep 17 00:00:00 2001 From: William Langford Date: Wed, 4 Mar 2020 12:06:35 -0500 Subject: [PATCH] Use an interface for ContentTypeFromMessage This allows custom Marshaler implementations to vary their Content-Type based on the response --- runtime/errors.go | 4 +-- runtime/errors_test.go | 43 +++++++++++++++++--------- runtime/handler.go | 6 ++-- runtime/handler_test.go | 67 ++++++++++++++++++++++++++++++++++++----- runtime/marshaler.go | 7 +++++ runtime/proto_errors.go | 4 +-- 6 files changed, 103 insertions(+), 28 deletions(-) diff --git a/runtime/errors.go b/runtime/errors.go index 8ec2fc0f1c7..65a5dc690d4 100644 --- a/runtime/errors.go +++ b/runtime/errors.go @@ -123,9 +123,9 @@ func DefaultHTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w // Check marshaler on run time in order to keep backwards compatability // An interface param needs to be added to the ContentType() function on // the Marshal interface to be able to remove this check - if httpBodyMarshaler, ok := marshaler.(*HTTPBodyMarshaler); ok { + if typeMarshaler, ok := marshaler.(contentTypeMarshaler); ok { pb := s.Proto() - contentType = httpBodyMarshaler.ContentTypeFromMessage(pb) + contentType = typeMarshaler.ContentTypeFromMessage(pb) } w.Header().Set("Content-Type", contentType) diff --git a/runtime/errors_test.go b/runtime/errors_test.go index 6d684d4ed89..b4ce93f3f16 100644 --- a/runtime/errors_test.go +++ b/runtime/errors_test.go @@ -23,26 +23,41 @@ func TestDefaultHTTPError(t *testing.T) { ) for _, spec := range []struct { - err error - status int - msg string - details string + err error + status int + msg string + marshaler runtime.Marshaler + contentType string + details string }{ { - err: fmt.Errorf("example error"), - status: http.StatusInternalServerError, - msg: "example error", + err: fmt.Errorf("example error"), + status: http.StatusInternalServerError, + marshaler: &runtime.JSONPb{}, + contentType: "application/json", + msg: "example error", }, { - err: status.Error(codes.NotFound, "no such resource"), - status: http.StatusNotFound, - msg: "no such resource", + err: status.Error(codes.NotFound, "no such resource"), + status: http.StatusNotFound, + marshaler: &runtime.JSONPb{}, + contentType: "application/json", + msg: "no such resource", }, { - err: statusWithDetails.Err(), - status: http.StatusBadRequest, - msg: "failed precondition", - details: "type.googleapis.com/google.rpc.PreconditionFailure", + err: statusWithDetails.Err(), + status: http.StatusBadRequest, + marshaler: &runtime.JSONPb{}, + contentType: "application/json", + msg: "failed precondition", + details: "type.googleapis.com/google.rpc.PreconditionFailure", + }, + { + err: fmt.Errorf("example error"), + status: http.StatusInternalServerError, + marshaler: &CustomMarshaler{&runtime.JSONPb{}}, + contentType: "Custom-Content-Type", + msg: "example error", }, } { w := httptest.NewRecorder() diff --git a/runtime/handler.go b/runtime/handler.go index 2af900650dc..b894da86bf8 100644 --- a/runtime/handler.go +++ b/runtime/handler.go @@ -1,13 +1,13 @@ package runtime import ( + "context" "errors" "fmt" "io" "net/http" "net/textproto" - "context" "github.com/golang/protobuf/proto" "github.com/grpc-ecosystem/grpc-gateway/internal" "google.golang.org/grpc/grpclog" @@ -126,8 +126,8 @@ func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marsha // Check marshaler on run time in order to keep backwards compatability // An interface param needs to be added to the ContentType() function on // the Marshal interface to be able to remove this check - if httpBodyMarshaler, ok := marshaler.(*HTTPBodyMarshaler); ok { - contentType = httpBodyMarshaler.ContentTypeFromMessage(resp) + if typeMarshaler, ok := marshaler.(contentTypeMarshaler); ok { + contentType = typeMarshaler.ContentTypeFromMessage(resp) } w.Header().Set("Content-Type", contentType) diff --git a/runtime/handler_test.go b/runtime/handler_test.go index f9a17916da1..912cb3acfa5 100644 --- a/runtime/handler_test.go +++ b/runtime/handler_test.go @@ -1,14 +1,13 @@ package runtime_test import ( + "context" "io" "io/ioutil" "net/http" "net/http/httptest" "testing" - "context" - "github.com/golang/protobuf/proto" "github.com/grpc-ecosystem/grpc-gateway/internal" "github.com/grpc-ecosystem/grpc-gateway/runtime" @@ -134,11 +133,12 @@ type CustomMarshaler struct { m *runtime.JSONPb } -func (c *CustomMarshaler) Marshal(v interface{}) ([]byte, error) { return c.m.Marshal(v) } -func (c *CustomMarshaler) Unmarshal(data []byte, v interface{}) error { return c.m.Unmarshal(data, v) } -func (c *CustomMarshaler) NewDecoder(r io.Reader) runtime.Decoder { return c.m.NewDecoder(r) } -func (c *CustomMarshaler) NewEncoder(w io.Writer) runtime.Encoder { return c.m.NewEncoder(w) } -func (c *CustomMarshaler) ContentType() string { return c.m.ContentType() } +func (c *CustomMarshaler) Marshal(v interface{}) ([]byte, error) { return c.m.Marshal(v) } +func (c *CustomMarshaler) Unmarshal(data []byte, v interface{}) error { return c.m.Unmarshal(data, v) } +func (c *CustomMarshaler) NewDecoder(r io.Reader) runtime.Decoder { return c.m.NewDecoder(r) } +func (c *CustomMarshaler) NewEncoder(w io.Writer) runtime.Encoder { return c.m.NewEncoder(w) } +func (c *CustomMarshaler) ContentType() string { return c.m.ContentType() } +func (c *CustomMarshaler) ContentTypeFromMessage(v interface{}) string { return "Custom-Content-Type" } func TestForwardResponseStreamCustomMarshaler(t *testing.T) { type msg struct { @@ -227,3 +227,56 @@ func TestForwardResponseStreamCustomMarshaler(t *testing.T) { }) } } + +func TestForwardResponseMessage(t *testing.T) { + msg := &pb.SimpleMessage{Id: "One"} + tests := []struct { + name string + marshaler runtime.Marshaler + contentType string + }{{ + name: "standard marshaler", + marshaler: &runtime.JSONPb{}, + contentType: "application/json", + }, { + name: "httpbody marshaler", + marshaler: &runtime.HTTPBodyMarshaler{&runtime.JSONPb{}}, + contentType: "application/json", + }, { + name: "custom marshaler", + marshaler: &CustomMarshaler{&runtime.JSONPb{}}, + contentType: "Custom-Content-Type", + }} + + ctx := runtime.NewServerMetadataContext(context.Background(), runtime.ServerMetadata{}) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + resp := httptest.NewRecorder() + + runtime.ForwardResponseMessage(ctx, runtime.NewServeMux(), tt.marshaler, resp, req, msg) + + w := resp.Result() + if w.StatusCode != http.StatusOK { + t.Errorf("StatusCode %d want %d", w.StatusCode, http.StatusOK) + } + if h := w.Header.Get("Content-Type"); h != tt.contentType { + t.Errorf("Content-Type %v want %v", h, tt.contentType) + } + body, err := ioutil.ReadAll(w.Body) + if err != nil { + t.Errorf("Failed to read response body with %v", err) + } + w.Body.Close() + + want, err := tt.marshaler.Marshal(msg) + if err != nil { + t.Errorf("marshaler.Marshal() failed %v", err) + } + + if string(body) != string(want) { + t.Errorf("ForwardResponseMessage() = \"%s\" want \"%s\"", body, want) + } + }) + } +} diff --git a/runtime/marshaler.go b/runtime/marshaler.go index 98fe6e88ac5..3fdf9fd8738 100644 --- a/runtime/marshaler.go +++ b/runtime/marshaler.go @@ -19,6 +19,13 @@ type Marshaler interface { ContentType() string } +// Marshalers that implement contentTypeMarshaler will have their ContentTypeFromMessage method called +// to set the Content-Type header on the response +type contentTypeMarshaler interface { + // ContentTypeFromMessage returns the Content-Type this marshaler produces from the provided message + ContentTypeFromMessage(v interface{}) string +} + // Decoder decodes a byte sequence type Decoder interface { Decode(v interface{}) error diff --git a/runtime/proto_errors.go b/runtime/proto_errors.go index ca76324efb1..ea75565868d 100644 --- a/runtime/proto_errors.go +++ b/runtime/proto_errors.go @@ -47,9 +47,9 @@ func DefaultHTTPProtoErrorHandler(ctx context.Context, mux *ServeMux, marshaler // Check marshaler on run time in order to keep backwards compatability // An interface param needs to be added to the ContentType() function on // the Marshal interface to be able to remove this check - if httpBodyMarshaler, ok := marshaler.(*HTTPBodyMarshaler); ok { + if typeMarshaler, ok := marshaler.(contentTypeMarshaler); ok { pb := s.Proto() - contentType = httpBodyMarshaler.ContentTypeFromMessage(pb) + contentType = typeMarshaler.ContentTypeFromMessage(pb) } w.Header().Set("Content-Type", contentType)