Skip to content

Commit

Permalink
proper header handling for grpc metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
potterbm-cb committed Nov 7, 2024
1 parent c3eb2ce commit b0d488a
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 21 deletions.
15 changes: 15 additions & 0 deletions headerforwarder/context_headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ type contextKey string

const requestIDKey = contextKey("request_id")

const outgoingHeadersKey = contextKey("outgoing_headers")

func ContextWithRosettaID(ctx context.Context) context.Context {
return context.WithValue(ctx, requestIDKey, uuid.NewString())
}
Expand All @@ -46,3 +48,16 @@ func RosettaIDFromRequest(r *http.Request) string {
return ""
}
}

func ContextWithOutgoingHeaders(ctx context.Context, headers http.Header) context.Context {
return context.WithValue(ctx, outgoingHeadersKey, headers)
}

func OutgoingHeadersFromContext(ctx context.Context) http.Header {
switch val := ctx.Value(outgoingHeadersKey).(type) {
case http.Header:
return val
default:
return nil
}
}
49 changes: 28 additions & 21 deletions headerforwarder/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"fmt"
"net/http"
"strings"

"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
Expand All @@ -39,9 +40,11 @@ type HeaderForwarder struct {
actualTransport http.RoundTripper
}

// TODO: make transport an optional parameter, add "WithTransport" style functions to make it easier
// to add the actual RPC clients to this struct
func NewHeaderForwarder(interestingHeaders []string, transport http.RoundTripper) (*HeaderForwarder, error) {
func NewHeaderForwarder(
interestingHeaders []string,
transport http.RoundTripper,
// outgoingContextFromRequest func(r *http.Request) context.Context,
) (*HeaderForwarder, error) {
if len(interestingHeaders) == 0 {
return nil, fmt.Errorf("must provide at least one interesting header")
}
Expand Down Expand Up @@ -103,8 +106,14 @@ func (hf *HeaderForwarder) rememberHeaders(req *http.Request, resp *http.Respons
ctx := req.Context()
rosettaRequestID := RosettaIDFromContext(ctx)

// For multiple requests with the same rosetta ID, we want to remember all of the headers
// For repeated response headers, later values will overwrite earlier ones
headersToRemember, exists := hf.requestHeaders[rosettaRequestID]
if !exists {
headersToRemember = make(http.Header)
}

// Only remember interesting headers
headersToRemember := make(http.Header)
for _, interestingHeader := range hf.interestingHeaders {
headersToRemember.Set(interestingHeader, resp.Header.Get(interestingHeader))
}
Expand All @@ -121,8 +130,9 @@ func (hf *HeaderForwarder) shouldRememberMetadata(ctx context.Context, req any,
}

// If any of the interesting headers are in the response metadata, remember it
// grpc metadata uses lowercase keys rather than http canonicalized keys
for _, interestingHeader := range hf.interestingHeaders {
if _, responseHasHeader := resp[http.CanonicalHeaderKey(interestingHeader)]; responseHasHeader {
if _, responseHasHeader := resp[strings.ToLower(interestingHeader)]; responseHasHeader {
return true
}
}
Expand All @@ -135,7 +145,13 @@ func (hf *HeaderForwarder) shouldRememberMetadata(ctx context.Context, req any,
func (hf *HeaderForwarder) rememberMetadata(ctx context.Context, req any, resp metadata.MD) {
rosettaID := RosettaIDFromContext(ctx)

headersToRemember := make(http.Header)
// For multiple requests with the same rosetta ID, we want to remember all of the headers
// For repeated response headers, later values will overwrite earlier ones
headersToRemember, exists := hf.requestHeaders[rosettaID]
if !exists {
headersToRemember = make(http.Header)
}

for _, interestingHeader := range hf.interestingHeaders {
for _, value := range resp.Get(interestingHeader) {
headersToRemember.Set(interestingHeader, value)
Expand All @@ -162,7 +178,6 @@ func (hf *HeaderForwarder) getResponseHeaders(rosettaRequestID string) (http.Hea
// those headers on the response
func (hf *HeaderForwarder) HeaderForwarderHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Println("HeaderForwarder Handler")
// add a unique ID to the request context, and make a new request for it
requestWithID := hf.RequestWithRequestID(r)

Expand All @@ -182,10 +197,9 @@ func (hf *HeaderForwarder) HeaderForwarderHandler(next http.Handler) http.Handle
// RoundTrip implements http.RoundTripper and will be used to construct an http Client which
// saves the native node response headers if necessary.
func (hf *HeaderForwarder) RoundTrip(req *http.Request) (*http.Response, error) {
fmt.Println("HeaderForwarder RoundTrip")
resp, err := hf.actualTransport.RoundTrip(req)
// TODO: add outgoing headers to the request

fmt.Println("HeaderForwarder RoundTrip: response headers", resp.Header)
resp, err := hf.actualTransport.RoundTrip(req)

if err == nil && hf.shouldRememberHeaders(req, resp) {
hf.rememberHeaders(req, resp)
Expand All @@ -195,22 +209,15 @@ func (hf *HeaderForwarder) RoundTrip(req *http.Request) (*http.Response, error)
}

func (hf *HeaderForwarder) UnaryClientInterceptor(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
fmt.Println("HeaderForwarder grpc interceptor")

fmt.Println("request id: ", RosettaIDFromContext(ctx))

// append a header DialOption to the request
var responseMD metadata.MD
opts = append(opts, grpc.Header(&responseMD))
var header metadata.MD
opts = append(opts, grpc.Header(&header))

err := invoker(ctx, method, req, reply, cc, opts...)

if hf.shouldRememberMetadata(ctx, req, responseMD) {
hf.rememberMetadata(ctx, req, responseMD)
if hf.shouldRememberMetadata(ctx, req, header) {
hf.rememberMetadata(ctx, req, header)
}

// get headers from response
fmt.Println("HeaderForwarder grpc interceptor: headers from response", responseMD)

return err
}

0 comments on commit b0d488a

Please sign in to comment.