diff --git a/azure/defaults.go b/azure/defaults.go index 29efbbad360..fb2d439e8c7 100644 --- a/azure/defaults.go +++ b/azure/defaults.go @@ -18,6 +18,7 @@ package azure import ( "fmt" + "net/http" "github.com/Azure/go-autorest/autorest/azure" @@ -26,6 +27,7 @@ import ( "github.com/pkg/errors" infrav1 "sigs.k8s.io/cluster-api-provider-azure/api/v1alpha4" + "sigs.k8s.io/cluster-api-provider-azure/util/tele" "sigs.k8s.io/cluster-api-provider-azure/version" ) @@ -322,6 +324,10 @@ func UserAgent() string { // SetAutoRestClientDefaults set authorizer and user agent for autorest client. func SetAutoRestClientDefaults(c *autorest.Client, auth autorest.Authorizer) { c.Authorizer = auth + // Wrap the original Sender on the autorest.Client c. + // The wrapped Sender should set the x-ms-correlation-request-id on the given + // request, then pass the new request to the underlying Sender. + c.Sender = autorest.DecorateSender(c.Sender, msCorrelationIDSendDecorator) AutoRestClientAppendUserAgent(c, UserAgent()) } @@ -329,3 +335,14 @@ func SetAutoRestClientDefaults(c *autorest.Client, auth autorest.Authorizer) { func AutoRestClientAppendUserAgent(c *autorest.Client, extension string) { _ = c.AddToUserAgent(extension) // intentionally ignore error as it doesn't matter } + +func msCorrelationIDSendDecorator(snd autorest.Sender) autorest.Sender { + return autorest.SenderFunc(func(r *http.Request) (*http.Response, error) { + // if the correlation ID was found in the request context, set + // it in the header + if corrID, ok := tele.CorrIDFromCtx(r.Context()); ok { + r.Header.Set(string(tele.CorrIDKeyVal), string(corrID)) + } + return snd.Do(r) + }) +} diff --git a/azure/defaults_test.go b/azure/defaults_test.go index c18f1335bee..f098fa3095c 100644 --- a/azure/defaults_test.go +++ b/azure/defaults_test.go @@ -17,11 +17,16 @@ limitations under the License. package azure import ( + "context" "fmt" + "net/http" + "net/http/httptest" + "sync" "testing" "github.com/Azure/go-autorest/autorest" . "github.com/onsi/gomega" + "sigs.k8s.io/cluster-api-provider-azure/util/tele" ) func TestGetDefaultImageSKUID(t *testing.T) { @@ -235,3 +240,51 @@ func TestGetDefaultUbuntuImage(t *testing.T) { }) } } + +func TestMSCorrelationIDSendDecorator(t *testing.T) { + g := NewWithT(t) + const corrID tele.CorrID = "TestMSCorrelationIDSendDecoratorCorrID" + ctx := context.WithValue(context.Background(), tele.CorrIDKeyVal, corrID) + + // create a fake server so that the sender can send to + // somewhere + var wg sync.WaitGroup + receivedReqs := []*http.Request{} + wg.Add(1) + originHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedReqs = append(receivedReqs, r) + wg.Done() + }) + + testSrv := httptest.NewServer(originHandler) + defer testSrv.Close() + + // create a sender that sends to the fake server, then + // decorate the sender with the msCorrelationIDSendDecorator + origSender := autorest.SenderFunc(func(r *http.Request) (*http.Response, error) { + // preserve the incoming headers to the fake server, so that + // we can test that the fake server received the right + // correlation ID header. + req, err := http.NewRequest("GET", testSrv.URL, nil) + if err != nil { + return nil, err + } + req.Header = r.Header + return testSrv.Client().Do(req) + }) + newSender := autorest.DecorateSender(origSender, msCorrelationIDSendDecorator) + + // create a new HTTP request and send it via the new decorated sender + req, err := http.NewRequest("GET", "/abc", nil) + g.Expect(err).NotTo(HaveOccurred()) + + req = req.WithContext(ctx) + _, err = newSender.Do(req) + g.Expect(err).NotTo(HaveOccurred()) + wg.Wait() + g.Expect(len(receivedReqs)).To(Equal(1)) + receivedReq := receivedReqs[0] + g.Expect( + receivedReq.Header.Get(string(tele.CorrIDKeyVal)), + ).To(Equal(string(corrID))) +} diff --git a/util/tele/corr_id.go b/util/tele/corr_id.go index fa93ea4c5d2..b3254da7b06 100644 --- a/util/tele/corr_id.go +++ b/util/tele/corr_id.go @@ -22,7 +22,13 @@ import ( "github.com/google/uuid" ) -type corrIDKey string +// CorrIDKey is the type of the key used to store correlation +// IDs in context.Contexts. +type CorrIDKey string + +// CorrIDKeyVal is the key used to store the correlation ID in +// context.Contexts, HTTP headers, and other similar locations. +const CorrIDKeyVal CorrIDKey = "x-ms-correlation-request-id" // CorrID is a correlation ID that the cluster API provider // sends with all API requests to Azure. Do not create one @@ -30,8 +36,6 @@ type corrIDKey string // to create one of these within a context.Context. type CorrID string -const corrIDKeyVal corrIDKey = "x-ms-correlation-id" - // ctxWithCorrID creates a CorrID and creates a new context.Context // with the new CorrID in it. It returns the _new_ context and the // newly created CorrID. If there was a problem creating the correlation @@ -41,11 +45,11 @@ const corrIDKeyVal corrIDKey = "x-ms-correlation-id" // below: // // ctx := context.Background() -// ctx, newCorrID := CtxWithCorrID(ctx) +// ctx, newCorrID := ctxWithCorrID(ctx) // fmt.Println("new corr ID: ", newCorrID) // doSomething(ctx) func ctxWithCorrID(ctx context.Context) (context.Context, CorrID) { - currentCorrIDIface := ctx.Value(corrIDKeyVal) + currentCorrIDIface := ctx.Value(CorrIDKeyVal) if currentCorrIDIface != nil { currentCorrID, ok := currentCorrIDIface.(CorrID) if ok { @@ -57,7 +61,7 @@ func ctxWithCorrID(ctx context.Context) (context.Context, CorrID) { return nil, CorrID("") } newCorrID := CorrID(uid.String()) - ctx = context.WithValue(ctx, corrIDKeyVal, newCorrID) + ctx = context.WithValue(ctx, CorrIDKeyVal, newCorrID) return ctx, newCorrID } @@ -65,7 +69,7 @@ func ctxWithCorrID(ctx context.Context) (context.Context, CorrID) { // context.Context. If none exists, returns an empty CorrID and false. // Otherwise returns the CorrID value and true. func CorrIDFromCtx(ctx context.Context) (CorrID, bool) { - currentCorrIDIface := ctx.Value(corrIDKeyVal) + currentCorrIDIface := ctx.Value(CorrIDKeyVal) if currentCorrIDIface == nil { return CorrID(""), false } diff --git a/util/tele/tele.go b/util/tele/tele.go index ec183efe6d9..a6679e18411 100644 --- a/util/tele/tele.go +++ b/util/tele/tele.go @@ -20,6 +20,7 @@ import ( "context" "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" ) @@ -35,7 +36,15 @@ func (t tracer) Start( op string, opts ...trace.SpanOption, ) (context.Context, trace.Span) { - ctx, _ = ctxWithCorrID(ctx) + ctx, corrID := ctxWithCorrID(ctx) + opts = append( + opts, + trace.WithSpanKind(trace.SpanKindClient), + trace.WithAttributes(attribute.String( + string(CorrIDKeyVal), + string(corrID), + )), + ) return t.Tracer.Start(ctx, op, opts...) }