Skip to content

Commit

Permalink
add sync.RWMutex to headerforwarder map (#511)
Browse files Browse the repository at this point in the history
  • Loading branch information
potterbm-cb authored Nov 26, 2024
1 parent 945092d commit 07dc44a
Showing 1 changed file with 39 additions and 5 deletions.
44 changes: 39 additions & 5 deletions headerforwarder/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"fmt"
"net/http"
"strings"
"sync"

"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
Expand All @@ -36,7 +37,9 @@ import (
// TODO: this should expire entries after a certain amount of time
type HeaderForwarder struct {
incomingHeaders map[string]http.Header
incomingHeaderLock sync.RWMutex
outgoingHeaders map[string]http.Header
outgoingHeaderLock sync.RWMutex
interestingHeaders []string
actualTransport http.RoundTripper
}
Expand Down Expand Up @@ -66,14 +69,20 @@ func (hf *HeaderForwarder) captureOutgoingHeaders(req *http.Request) {
ctx := req.Context()
rosettaRequestID := RosettaIDFromContext(ctx)

hf.outgoingHeaders[rosettaRequestID] = make(http.Header)
// We don't worry about overwriting headers here because we only handle "outgoing" headers
// once: when the rosetta request is made
outgoingRequestHeaders := make(http.Header)

// Only capture interesting headers
for _, interestingHeader := range hf.interestingHeaders {
if _, requestHasHeader := req.Header[http.CanonicalHeaderKey(interestingHeader)]; requestHasHeader {
hf.outgoingHeaders[rosettaRequestID].Set(interestingHeader, req.Header.Get(interestingHeader))
outgoingRequestHeaders.Set(interestingHeader, req.Header.Get(interestingHeader))
}
}

hf.outgoingHeaderLock.Lock()
hf.outgoingHeaders[rosettaRequestID] = outgoingRequestHeaders
hf.outgoingHeaderLock.Unlock()
}

// shouldRememberHeaders reports whether response headers should be remembered for a
Expand Down Expand Up @@ -113,7 +122,10 @@ func (hf *HeaderForwarder) rememberHeaders(req *http.Request, resp *http.Respons

// For multiple requests with the same rosetta ID, we want to remember all of the headers
// For repeated response headers, later values will overwrite earlier ones
hf.incomingHeaderLock.RLock()
headersToRemember, exists := hf.incomingHeaders[rosettaRequestID]
hf.incomingHeaderLock.RUnlock()

if !exists {
headersToRemember = make(http.Header)
}
Expand All @@ -123,7 +135,9 @@ func (hf *HeaderForwarder) rememberHeaders(req *http.Request, resp *http.Respons
headersToRemember.Set(interestingHeader, resp.Header.Get(interestingHeader))
}

hf.incomingHeaderLock.Lock()
hf.incomingHeaders[rosettaRequestID] = headersToRemember
hf.incomingHeaderLock.Unlock()
}

// shouldRememberMetadata reports whether response metadata should be remembered for a grpc unary
Expand Down Expand Up @@ -152,7 +166,10 @@ func (hf *HeaderForwarder) rememberMetadata(ctx context.Context, resp metadata.M

// For multiple requests with the same rosetta ID, we want to remember all of the headers
// For repeated response headers, later values will overwrite earlier ones
hf.incomingHeaderLock.RLock()
headersToRemember, exists := hf.incomingHeaders[rosettaID]
hf.incomingHeaderLock.RUnlock()

if !exists {
headersToRemember = make(http.Header)
}
Expand All @@ -163,20 +180,28 @@ func (hf *HeaderForwarder) rememberMetadata(ctx context.Context, resp metadata.M
}
}

hf.incomingHeaderLock.Lock()
hf.incomingHeaders[rosettaID] = headersToRemember
hf.incomingHeaderLock.Unlock()
}

// GetResponseHeaders returns any headers that should be returned to a rosetta response. These
// consist of native node response headers/metadata that were remembered for a request ID.
func (hf *HeaderForwarder) getResponseHeaders(rosettaRequestID string) (http.Header, bool) {
hf.incomingHeaderLock.RLock()
headers, ok := hf.incomingHeaders[rosettaRequestID]
hf.incomingHeaderLock.RUnlock()

// Delete the headers from the map after they are retrieved
// This is safe to call even if the key doesn't exist
hf.incomingHeaderLock.Lock()
delete(hf.incomingHeaders, rosettaRequestID)
hf.incomingHeaderLock.Unlock()

// Also delete the outgoing headers from the map since we are done with them
hf.outgoingHeaderLock.Lock()
delete(hf.outgoingHeaders, rosettaRequestID)
hf.outgoingHeaderLock.Unlock()

return headers, ok
}
Expand Down Expand Up @@ -209,8 +234,12 @@ func (hf *HeaderForwarder) HeaderForwarderHandler(next http.Handler) http.Handle
// RoundTrip implements http.RoundTripper and will be used to construct an http Client which
// saves the native node response headers if necessary.
func (hf *HeaderForwarder) RoundTrip(req *http.Request) (*http.Response, error) {
hf.outgoingHeaderLock.RLock()
outgoingHeaders, hasOutgoingHeaders := hf.outgoingHeaders[RosettaIDFromRequest(req)]
hf.outgoingHeaderLock.RUnlock()

// add outgoing headers to the request
if outgoingHeaders, ok := hf.outgoingHeaders[RosettaIDFromRequest(req)]; ok {
if hasOutgoingHeaders {
for header, values := range outgoingHeaders {
for _, value := range values {
req.Header.Add(header, value)
Expand All @@ -227,13 +256,18 @@ func (hf *HeaderForwarder) RoundTrip(req *http.Request) (*http.Response, error)
return resp, err
}

func (hf *HeaderForwarder) UnaryClientInterceptor(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
func (hf *HeaderForwarder) UnaryClientInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
// Capture incoming headers from the grpc call
var header metadata.MD
opts = append(opts, grpc.Header(&header))

// Get outgoing headers from the request ID in context
hf.outgoingHeaderLock.RLock()
outgoingHeaders, hasOutgoingHeaders := hf.outgoingHeaders[RosettaIDFromContext(ctx)]
hf.outgoingHeaderLock.RUnlock()

// Add outgoing headers to the context
if outgoingHeaders, ok := hf.outgoingHeaders[RosettaIDFromContext(ctx)]; ok {
if hasOutgoingHeaders {
for header, values := range outgoingHeaders {
for _, value := range values {
ctx = metadata.AppendToOutgoingContext(ctx, strings.ToLower(header), value)
Expand Down

0 comments on commit 07dc44a

Please sign in to comment.