diff --git a/propagation/http_trace_context_propagator.go b/propagation/http_trace_context_propagator.go index 3f92725ebc72..e00031df5d41 100644 --- a/propagation/http_trace_context_propagator.go +++ b/propagation/http_trace_context_propagator.go @@ -18,6 +18,7 @@ import ( "context" "encoding/hex" "fmt" + "regexp" "strconv" "strings" @@ -36,6 +37,7 @@ const ( type httpTraceContextPropagator struct{} var _ apipropagation.TextFormatPropagator = httpTraceContextPropagator{} +var traceCtxRegExp = regexp.MustCompile("^[0-9a-f]{2}-[a-f0-9]{32}-[a-f0-9]{16}-[a-f0-9]{2}") func (hp httpTraceContextPropagator) Inject(ctx context.Context, supplier apipropagation.Supplier) { sc := trace.CurrentSpan(ctx).SpanContext() @@ -56,6 +58,11 @@ func (hp httpTraceContextPropagator) Extract(ctx context.Context, supplier apipr return core.EmptySpanContext() } + if !traceCtxRegExp.MatchString(h) { + fmt.Printf("header does not match regex %s\n", h) + return core.EmptySpanContext() + } + sections := strings.Split(h, "-") if len(sections) < 4 { return core.EmptySpanContext() @@ -104,11 +111,14 @@ func (hp httpTraceContextPropagator) Extract(ctx context.Context, supplier apipr } sc.SpanID = result + if len(sections[3]) != 2 { + return core.EmptySpanContext() + } opts, err := hex.DecodeString(sections[3]) - if err != nil || len(opts) < 1 { + if err != nil || len(opts) < 1 || (version == 0 && opts[0] > 2) { return core.EmptySpanContext() } - sc.TraceOptions = opts[0] + sc.TraceOptions = opts[0] &^ core.TraceOptionUnused if !sc.IsValid() { return core.EmptySpanContext() diff --git a/propagation/http_trace_context_propagator_test.go b/propagation/http_trace_context_propagator_test.go index fc2fe34311d4..ea48ddca9269 100644 --- a/propagation/http_trace_context_propagator_test.go +++ b/propagation/http_trace_context_propagator_test.go @@ -33,7 +33,7 @@ var ( spanID = uint64(0x00f067aa0ba902b7) ) -func TestExtractTraceContextFromHTTPReq(t *testing.T) { +func TestExtractValidTraceContextFromHTTPReq(t *testing.T) { trace.SetGlobalTracer(&mocktrace.MockTracer{}) propagator := propagation.HttpTraceContextPropagator() tests := []struct { @@ -41,6 +41,23 @@ func TestExtractTraceContextFromHTTPReq(t *testing.T) { header string wantSc core.SpanContext }{ + { + name: "valid header", + header: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-00", + wantSc: core.SpanContext{ + TraceID: traceID, + SpanID: spanID, + }, + }, + { + name: "valid header and sampled", + header: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01", + wantSc: core.SpanContext{ + TraceID: traceID, + SpanID: spanID, + TraceOptions: core.TraceOptionSampled, + }, + }, { name: "future version", header: "02-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01", @@ -51,28 +68,118 @@ func TestExtractTraceContextFromHTTPReq(t *testing.T) { }, }, { - name: "zero trace ID and span ID", - header: "00-00000000000000000000000000000000-0000000000000000-01", - wantSc: core.EmptySpanContext(), + name: "future options with sampled bit set", + header: "02-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-09", + wantSc: core.SpanContext{ + TraceID: traceID, + SpanID: spanID, + TraceOptions: core.TraceOptionSampled, + }, }, { - name: "valid header", - header: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01", + name: "future options with sampled bit cleared", + header: "02-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-08", + wantSc: core.SpanContext{ + TraceID: traceID, + SpanID: spanID, + }, + }, + { + name: "future additional data", + header: "02-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-09-XYZxsf09", wantSc: core.SpanContext{ TraceID: traceID, SpanID: spanID, TraceOptions: core.TraceOptionSampled, }, }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, _ := http.NewRequest("GET", "http://example.com", nil) + req.Header.Set("traceparent", tt.header) + + ctx := context.Background() + gotSc := propagator.Extract(ctx, req.Header) + if diff := cmp.Diff(gotSc, tt.wantSc); diff != "" { + t.Errorf("Extract Tracecontext: %s: -got +want %s", tt.name, diff) + } + }) + } +} + +func TestExtractInvalidTraceContextFromHTTPReq(t *testing.T) { + trace.SetGlobalTracer(&mocktrace.MockTracer{}) + propagator := propagation.HttpTraceContextPropagator() + wantSc := core.EmptySpanContext() + tests := []struct { + name string + header string + }{ + { + name: "wrong version length", + header: "0000-00000000000000000000000000000000-0000000000000000-01", + }, + { + name: "wrong trace ID length", + header: "00-ab00000000000000000000000000000000-cd00000000000000-01", + }, + { + name: "wrong span ID length", + header: "00-ab000000000000000000000000000000-cd0000000000000000-01", + }, + { + name: "wrong trace flag length", + header: "00-ab000000000000000000000000000000-cd00000000000000-0100", + }, + { + name: "bogus version length", + header: "qw-00000000000000000000000000000000-0000000000000000-01", + }, + { + name: "bogus trace ID length", + header: "00-qw000000000000000000000000000000-cd00000000000000-01", + }, + { + name: "bogus span ID length", + header: "00-ab000000000000000000000000000000-qw00000000000000-01", + }, + { + name: "bogus trace flag length", + header: "00-ab000000000000000000000000000000-cd00000000000000-qw", + }, + { + name: "upper case version length", + header: "A0-00000000000000000000000000000000-0000000000000000-01", + }, + { + name: "upper case trace ID length", + header: "00-AB000000000000000000000000000000-cd00000000000000-01", + }, + { + name: "upper case span ID length", + header: "00-ab000000000000000000000000000000-CD00000000000000-01", + }, + { + name: "upper case trace flag length", + header: "00-ab000000000000000000000000000000-cd00000000000000-A1", + }, + { + name: "zero trace ID and span ID", + header: "00-00000000000000000000000000000000-0000000000000000-01", + }, + { + name: "trace-flag unused bits set", + header: "00-ab000000000000000000000000000000-cd00000000000000-09", + }, { name: "missing options", header: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7", - wantSc: core.EmptySpanContext(), }, { name: "empty options", header: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-", - wantSc: core.EmptySpanContext(), }, } @@ -83,7 +190,7 @@ func TestExtractTraceContextFromHTTPReq(t *testing.T) { ctx := context.Background() gotSc := propagator.Extract(ctx, req.Header) - if diff := cmp.Diff(gotSc, tt.wantSc); diff != "" { + if diff := cmp.Diff(gotSc, wantSc); diff != "" { t.Errorf("Extract Tracecontext: %s: -got +want %s", tt.name, diff) } }) @@ -149,4 +256,4 @@ func TestHttpTraceContextPropagator_GetAllKeys(t *testing.T) { if diff := cmp.Diff(got, want); diff != "" { t.Errorf("GetAllKeys: -got +want %s", diff) } -} \ No newline at end of file +}