diff --git a/headerforwarder/forwarder.go b/headerforwarder/forwarder.go index 75b7ae50..13576564 100644 --- a/headerforwarder/forwarder.go +++ b/headerforwarder/forwarder.go @@ -19,6 +19,7 @@ import ( "fmt" "net/http" "strings" + "sync" "google.golang.org/grpc" "google.golang.org/grpc/metadata" @@ -36,7 +37,9 @@ import ( // TODO: this should expire entries after a certain amount of time type HeaderForwarder struct { incomingHeaders map[string]http.Header + incomingHeaderLock sync.RWMutex outgoingHeaders map[string]http.Header + outgoingHeaderLock sync.RWMutex interestingHeaders []string actualTransport http.RoundTripper } @@ -66,14 +69,20 @@ func (hf *HeaderForwarder) captureOutgoingHeaders(req *http.Request) { ctx := req.Context() rosettaRequestID := RosettaIDFromContext(ctx) - hf.outgoingHeaders[rosettaRequestID] = make(http.Header) + // We don't worry about overwriting headers here because we only handle "outgoing" headers + // once: when the rosetta request is made + outgoingRequestHeaders := make(http.Header) // Only capture interesting headers for _, interestingHeader := range hf.interestingHeaders { if _, requestHasHeader := req.Header[http.CanonicalHeaderKey(interestingHeader)]; requestHasHeader { - hf.outgoingHeaders[rosettaRequestID].Set(interestingHeader, req.Header.Get(interestingHeader)) + outgoingRequestHeaders.Set(interestingHeader, req.Header.Get(interestingHeader)) } } + + hf.outgoingHeaderLock.Lock() + hf.outgoingHeaders[rosettaRequestID] = outgoingRequestHeaders + hf.outgoingHeaderLock.Unlock() } // shouldRememberHeaders reports whether response headers should be remembered for a @@ -113,7 +122,10 @@ func (hf *HeaderForwarder) rememberHeaders(req *http.Request, resp *http.Respons // For multiple requests with the same rosetta ID, we want to remember all of the headers // For repeated response headers, later values will overwrite earlier ones + hf.incomingHeaderLock.RLock() headersToRemember, exists := hf.incomingHeaders[rosettaRequestID] + hf.incomingHeaderLock.RUnlock() + if !exists { headersToRemember = make(http.Header) } @@ -123,7 +135,9 @@ func (hf *HeaderForwarder) rememberHeaders(req *http.Request, resp *http.Respons headersToRemember.Set(interestingHeader, resp.Header.Get(interestingHeader)) } + hf.incomingHeaderLock.Lock() hf.incomingHeaders[rosettaRequestID] = headersToRemember + hf.incomingHeaderLock.Unlock() } // shouldRememberMetadata reports whether response metadata should be remembered for a grpc unary @@ -152,7 +166,10 @@ func (hf *HeaderForwarder) rememberMetadata(ctx context.Context, resp metadata.M // For multiple requests with the same rosetta ID, we want to remember all of the headers // For repeated response headers, later values will overwrite earlier ones + hf.incomingHeaderLock.RLock() headersToRemember, exists := hf.incomingHeaders[rosettaID] + hf.incomingHeaderLock.RUnlock() + if !exists { headersToRemember = make(http.Header) } @@ -163,20 +180,28 @@ func (hf *HeaderForwarder) rememberMetadata(ctx context.Context, resp metadata.M } } + hf.incomingHeaderLock.Lock() hf.incomingHeaders[rosettaID] = headersToRemember + hf.incomingHeaderLock.Unlock() } // GetResponseHeaders returns any headers that should be returned to a rosetta response. These // consist of native node response headers/metadata that were remembered for a request ID. func (hf *HeaderForwarder) getResponseHeaders(rosettaRequestID string) (http.Header, bool) { + hf.incomingHeaderLock.RLock() headers, ok := hf.incomingHeaders[rosettaRequestID] + hf.incomingHeaderLock.RUnlock() // Delete the headers from the map after they are retrieved // This is safe to call even if the key doesn't exist + hf.incomingHeaderLock.Lock() delete(hf.incomingHeaders, rosettaRequestID) + hf.incomingHeaderLock.Unlock() // Also delete the outgoing headers from the map since we are done with them + hf.outgoingHeaderLock.Lock() delete(hf.outgoingHeaders, rosettaRequestID) + hf.outgoingHeaderLock.Unlock() return headers, ok } @@ -209,8 +234,12 @@ func (hf *HeaderForwarder) HeaderForwarderHandler(next http.Handler) http.Handle // RoundTrip implements http.RoundTripper and will be used to construct an http Client which // saves the native node response headers if necessary. func (hf *HeaderForwarder) RoundTrip(req *http.Request) (*http.Response, error) { + hf.outgoingHeaderLock.RLock() + outgoingHeaders, hasOutgoingHeaders := hf.outgoingHeaders[RosettaIDFromRequest(req)] + hf.outgoingHeaderLock.RUnlock() + // add outgoing headers to the request - if outgoingHeaders, ok := hf.outgoingHeaders[RosettaIDFromRequest(req)]; ok { + if hasOutgoingHeaders { for header, values := range outgoingHeaders { for _, value := range values { req.Header.Add(header, value) @@ -227,13 +256,18 @@ func (hf *HeaderForwarder) RoundTrip(req *http.Request) (*http.Response, error) return resp, err } -func (hf *HeaderForwarder) UnaryClientInterceptor(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { +func (hf *HeaderForwarder) UnaryClientInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { // Capture incoming headers from the grpc call var header metadata.MD opts = append(opts, grpc.Header(&header)) + // Get outgoing headers from the request ID in context + hf.outgoingHeaderLock.RLock() + outgoingHeaders, hasOutgoingHeaders := hf.outgoingHeaders[RosettaIDFromContext(ctx)] + hf.outgoingHeaderLock.RUnlock() + // Add outgoing headers to the context - if outgoingHeaders, ok := hf.outgoingHeaders[RosettaIDFromContext(ctx)]; ok { + if hasOutgoingHeaders { for header, values := range outgoingHeaders { for _, value := range values { ctx = metadata.AppendToOutgoingContext(ctx, strings.ToLower(header), value)