Skip to content

Commit

Permalink
Explore extending Mux.
Browse files Browse the repository at this point in the history
  • Loading branch information
tamalsaha committed Mar 23, 2017
1 parent 15d0787 commit 5ce0d94
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 41 deletions.
39 changes: 2 additions & 37 deletions runtime/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,49 +37,14 @@ var (
DefaultContextTimeout = 0 * time.Second
)

type Matcher func(string) bool

// EqualMatcher performs a case-sensitive equality match for request metadata keys
func EqualMatcher(h string) Matcher {
return func(key string) bool {
return key == h
}
}

// EqualFoldMatcher performs a case-insensitive equality match for request metadata keys
func EqualFoldMatcher(h string) Matcher {
return func(key string) bool {
return strings.EqualFold(key, h)
}
}

// PrefixMatcher performs a case-sensitive prefix match for request metadata keys
func PrefixMatcher(h string) Matcher {
return func(key string) bool {
return strings.HasPrefix(key, h)
}
}

// PrefixFoldMatcher performs a case-insensitive prefix match for request metadata keys
func PrefixFoldMatcher(h string) Matcher {
h = strings.ToLower(h)
return func(key string) bool {
return strings.HasPrefix(strings.ToLower(key), h)
}
}

// HeaderMatchers are checked against context metadata keys to forward or return request headers between
// grpc service and gateway.
var HeaderMatchers []Matcher

/*
AnnotateContext adds context information such as metadata from the request.
At a minimum, the RemoteAddr is included in the fashion of "X-Forwarded-For",
except that the forwarded destination is not another HTTP service but rather
a gRPC service.
*/
func AnnotateContext(ctx context.Context, req *http.Request) (context.Context, error) {
func AnnotateContext(ctx context.Context, mux *ServeMux, req *http.Request) (context.Context, error) {
var pairs []string
timeout := DefaultContextTimeout
if tm := req.Header.Get(metadataGrpcTimeout); tm != "" {
Expand All @@ -97,7 +62,7 @@ func AnnotateContext(ctx context.Context, req *http.Request) (context.Context, e
if strings.ToLower(key) == "authorization" {
pairs = append(pairs, "authorization", val)
}
for _, m := range HeaderMatchers {
for _, m := range mux.headerMatchers {
if m(key) {
pairs = append(pairs, key, val)
continue nextval
Expand Down
4 changes: 2 additions & 2 deletions runtime/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (*errorBody) ProtoMessage() {}
//
// The response body returned by this function is a JSON object,
// which contains a member whose key is "error" and whose value is err.Error().
func DefaultHTTPError(ctx context.Context, marshaler Marshaler, w http.ResponseWriter, _ *http.Request, err error) {
func DefaultHTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, _ *http.Request, err error) {
const fallback = `{"error": "failed to marshal error message"}`

w.Header().Del("Trailer")
Expand All @@ -103,7 +103,7 @@ func DefaultHTTPError(ctx context.Context, marshaler Marshaler, w http.ResponseW
grpclog.Printf("Failed to extract ServerMetadata from context")
}

handleForwardResponseServerMetadata(w, md)
handleForwardResponseServerMetadata(w, mux, md)
handleForwardResponseTrailerHeader(w, md)
st := HTTPStatusFromCode(grpc.Code(err))
w.WriteHeader(st)
Expand Down
4 changes: 2 additions & 2 deletions runtime/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ func ForwardResponseStream(ctx context.Context, marshaler Marshaler, w http.Resp
}
}

func handleForwardResponseServerMetadata(w http.ResponseWriter, md ServerMetadata) {
func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
for k, vs := range md.HeaderMD {
hKey := fmt.Sprintf("%s%s", MetadataHeaderPrefix, k)
for _, m := range HeaderMatchers {
for _, m := range mux.headerMatchers {
if m(k) {
hKey = k
break
Expand Down
40 changes: 40 additions & 0 deletions runtime/matchers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package runtime

import "strings"

// EqualMatcher performs a case-sensitive equality match for request metadata keys
func EqualMatcher(h string) ServeMuxOption {
return func(mux *ServeMux) {
mux.headerMatchers = append(mux.headerMatchers, func(key string) bool {
return key == h
})
}
}

// EqualFoldMatcher performs a case-insensitive equality match for request metadata keys
func EqualFoldMatcher(h string) ServeMuxOption {
return func(mux *ServeMux) {
mux.headerMatchers = append(mux.headerMatchers, func(key string) bool {
return strings.EqualFold(key, h)
})
}
}

// PrefixMatcher performs a case-sensitive prefix match for request metadata keys
func PrefixMatcher(h string) ServeMuxOption {
return func(mux *ServeMux) {
mux.headerMatchers = append(mux.headerMatchers, func(key string) bool {
return strings.HasPrefix(key, h)
})
}
}

// PrefixFoldMatcher performs a case-insensitive prefix match for request metadata keys
func PrefixFoldMatcher(h string) ServeMuxOption {
h = strings.ToLower(h)
return func(mux *ServeMux) {
mux.headerMatchers = append(mux.headerMatchers, func(key string) bool {
return strings.HasPrefix(strings.ToLower(key), h)
})
}
}
5 changes: 5 additions & 0 deletions runtime/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@ import (
// A HandlerFunc handles a specific pair of path pattern and HTTP method.
type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string)

// Matcher matches metadata keys against user provides values to allow forwarding headers.
type Matcher func(string) bool

// ServeMux is a request multiplexer for grpc-gateway.
// It matches http requests to patterns and invokes the corresponding handler.
type ServeMux struct {
// handlers maps HTTP method to a list of handlers.
handlers map[string][]handler
forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
marshalers marshalerRegistry
headerMatchers []Matcher
}

// ServeMuxOption is an option that can be given to a ServeMux on construction.
Expand All @@ -42,6 +46,7 @@ func NewServeMux(opts ...ServeMuxOption) *ServeMux {
handlers: make(map[string][]handler),
forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0),
marshalers: makeMarshalerMIMERegistry(),
headerMatchers: make([]Matcher, 0),
}

for _, opt := range opts {
Expand Down

0 comments on commit 5ce0d94

Please sign in to comment.