From c85aa209460aee93d74fb99bb9d47abdd3722bb9 Mon Sep 17 00:00:00 2001 From: Anthony J Mirabella Date: Tue, 7 Jul 2020 23:42:37 -0400 Subject: [PATCH] Avoid replacing existing correlation map data in context when correlation context extractor does not find any valid data --- CHANGELOG.md | 1 + .../correlation_context_propagator.go | 14 +++++-- .../correlation_context_propagator_test.go | 37 +++++++++++++++++-- 3 files changed, 45 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9dcd8db84a7..24537098c4a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm - Ensure span status is not set to `Unknown` when no HTTP status code is provided as it is assumed to be `200 OK`. (#908) - Ensure `httptrace.clientTracer` closes `http.headers` span. (#912) - Prometheus exporter will not apply stale updates or forget inactive metrics. (#903) +- Correlation Context extractor will no longer insert an empty map into the returned context when no valid values are extracted. (#923) ## [0.7.0] - 2020-06-26 diff --git a/api/correlation/correlation_context_propagator.go b/api/correlation/correlation_context_propagator.go index 172a3b902c5..fec83ed1416 100644 --- a/api/correlation/correlation_context_propagator.go +++ b/api/correlation/correlation_context_propagator.go @@ -65,7 +65,7 @@ func (CorrelationContext) Inject(ctx context.Context, supplier propagation.HTTPS func (CorrelationContext) Extract(ctx context.Context, supplier propagation.HTTPSupplier) context.Context { correlationContext := supplier.Get(correlationContextHeader) if correlationContext == "" { - return ContextWithMap(ctx, NewEmptyMap()) + return ctx } contextValues := strings.Split(correlationContext, ",") @@ -101,9 +101,15 @@ func (CorrelationContext) Extract(ctx context.Context, supplier propagation.HTTP keyValues = append(keyValues, kv.Key(trimmedName).String(trimmedValueWithProps.String())) } - return ContextWithMap(ctx, NewMap(MapUpdate{ - MultiKV: keyValues, - })) + + if len(keyValues) > 0 { + // Only update the context if valid values were found + return ContextWithMap(ctx, NewMap(MapUpdate{ + MultiKV: keyValues, + })) + } + + return ctx } // GetAllKeys implements HTTPPropagator. diff --git a/api/correlation/correlation_context_propagator_test.go b/api/correlation/correlation_context_propagator_test.go index 5a88133eb5e..e5569c0ade2 100644 --- a/api/correlation/correlation_context_propagator_test.go +++ b/api/correlation/correlation_context_propagator_test.go @@ -123,11 +123,28 @@ func TestExtractInvalidDistributedContextFromHTTPReq(t *testing.T) { tests := []struct { name string header string + hasKVs []kv.KeyValue }{ { name: "no key values", header: "header1", }, + { + name: "invalid header with existing context", + header: "header2", + hasKVs: []kv.KeyValue{ + kv.Key("key1").String("val1"), + kv.Key("key2").String("val2"), + }, + }, + { + name: "empty header value", + header: "", + hasKVs: []kv.KeyValue{ + kv.Key("key1").String("val1"), + kv.Key("key2").String("val2"), + }, + }, } for _, tt := range tests { @@ -135,12 +152,26 @@ func TestExtractInvalidDistributedContextFromHTTPReq(t *testing.T) { req, _ := http.NewRequest("GET", "http://example.com", nil) req.Header.Set("otcorrelations", tt.header) - ctx := context.Background() + ctx := correlation.NewContext(context.Background(), tt.hasKVs...) + wantCorCtx := correlation.MapFromContext(ctx) ctx = propagation.ExtractHTTP(ctx, props, req.Header) gotCorCtx := correlation.MapFromContext(ctx) - if gotCorCtx.Len() != 0 { - t.Errorf("Got and Want CorCtx are not the same size %d != %d", gotCorCtx.Len(), 0) + if gotCorCtx.Len() != wantCorCtx.Len() { + t.Errorf( + "Got and Want CorCtx are not the same size %d != %d", + gotCorCtx.Len(), + wantCorCtx.Len(), + ) } + totalDiff := "" + wantCorCtx.Foreach(func(keyValue kv.KeyValue) bool { + val, _ := gotCorCtx.Value(keyValue.Key) + diff := cmp.Diff(keyValue, kv.KeyValue{Key: keyValue.Key, Value: val}, cmp.AllowUnexported(value.Value{})) + if diff != "" { + totalDiff += diff + "\n" + } + return true + }) }) } }