diff --git a/processor/samplingprocessor/tailsamplingprocessor/processor.go b/processor/samplingprocessor/tailsamplingprocessor/processor.go index f9ded9e1cd5..7bbd1c926b4 100644 --- a/processor/samplingprocessor/tailsamplingprocessor/processor.go +++ b/processor/samplingprocessor/tailsamplingprocessor/processor.go @@ -167,9 +167,16 @@ func (tsp *tailSamplingSpanProcessor) samplingPolicyOnTick() { trace.Unlock() if decision == sampling.Sampled { + + // Combine all individual batches into a single batch so + // consumers may operate on the entire trace + allSpans := pdata.NewTraces() for j := 0; j < len(traceBatches); j++ { - _ = tsp.nextConsumer.ConsumeTraces(policy.ctx, internaldata.OCToTraceData(traceBatches[j])) + batch := internaldata.OCToTraceData(traceBatches[j]) + batch.ResourceSpans().MoveAndAppendTo(allSpans.ResourceSpans()) } + + _ = tsp.nextConsumer.ConsumeTraces(policy.ctx, allSpans) } } diff --git a/processor/samplingprocessor/tailsamplingprocessor/processor_test.go b/processor/samplingprocessor/tailsamplingprocessor/processor_test.go index 75fae492209..b30a3016c08 100644 --- a/processor/samplingprocessor/tailsamplingprocessor/processor_test.go +++ b/processor/samplingprocessor/tailsamplingprocessor/processor_test.go @@ -15,8 +15,10 @@ package tailsamplingprocessor import ( + "bytes" "context" "errors" + "sort" "sync" "testing" "time" @@ -331,8 +333,103 @@ func TestSamplingPolicyDecisionNotSampled(t *testing.T) { require.Equal(t, 2, mpe.LateArrivingSpansCount, "policy was not notified of the late span") } +func TestMultipleBatchesAreCombinedIntoOne(t *testing.T) { + const maxSize = 100 + const decisionWaitSeconds = 1 + // For this test explicitly control the timer calls and batcher, and set a mock + // sampling policy evaluator. + msp := new(exportertest.SinkTraceExporter) + mpe := &mockPolicyEvaluator{} + mtt := &manualTTicker{} + tsp := &tailSamplingSpanProcessor{ + ctx: context.Background(), + nextConsumer: msp, + maxNumTraces: maxSize, + logger: zap.NewNop(), + decisionBatcher: newSyncIDBatcher(decisionWaitSeconds), + policies: []*Policy{{Name: "mock-policy", Evaluator: mpe, ctx: context.TODO()}}, + deleteChan: make(chan traceKey, maxSize), + policyTicker: mtt, + } + + mpe.NextDecision = sampling.Sampled + + traceIds, batches := generateIdsAndBatches(3) + for _, batch := range batches { + require.NoError(t, tsp.ConsumeTraces(context.Background(), batch)) + } + + tsp.samplingPolicyOnTick() + tsp.samplingPolicyOnTick() + + require.EqualValues(t, 3, len(msp.AllTraces()), "There should be three batches, one for each trace") + + expectedSpanIds := make(map[int][]pdata.SpanID) + expectedSpanIds[0] = []pdata.SpanID{ + pdata.NewSpanID(tracetranslator.UInt64ToByteSpanID(uint64(1))), + } + expectedSpanIds[1] = []pdata.SpanID{ + pdata.NewSpanID(tracetranslator.UInt64ToByteSpanID(uint64(2))), + pdata.NewSpanID(tracetranslator.UInt64ToByteSpanID(uint64(3))), + } + expectedSpanIds[2] = []pdata.SpanID{ + pdata.NewSpanID(tracetranslator.UInt64ToByteSpanID(uint64(4))), + pdata.NewSpanID(tracetranslator.UInt64ToByteSpanID(uint64(5))), + pdata.NewSpanID(tracetranslator.UInt64ToByteSpanID(uint64(6))), + } + + receivedTraces := msp.AllTraces() + for i, traceID := range traceIds { + trace := findTrace(receivedTraces, traceID) + require.NotNil(t, trace, "Trace was not received. TraceId %s", traceID.HexString()) + require.EqualValues(t, i+1, trace.SpanCount(), "The trace should have all of its spans in a single batch") + + expected := expectedSpanIds[i] + got := collectSpanIds(trace) + + // might have received out of order, sort for comparison + sort.Slice(got, func(i, j int) bool { + a, _ := tracetranslator.BytesToInt64SpanID(got[i]) + b, _ := tracetranslator.BytesToInt64SpanID(got[j]) + return a < b + }) + + require.EqualValues(t, expected, got) + } +} + +func collectSpanIds(trace *pdata.Traces) []pdata.SpanID { + spanIDs := make([]pdata.SpanID, 0) + + for i := 0; i < trace.ResourceSpans().Len(); i++ { + ilss := trace.ResourceSpans().At(i).InstrumentationLibrarySpans() + + for j := 0; j < ilss.Len(); j++ { + ils := ilss.At(j) + + for k := 0; k < ils.Spans().Len(); k++ { + span := ils.Spans().At(k) + spanIDs = append(spanIDs, span.SpanID()) + } + } + } + + return spanIDs +} + +func findTrace(a []pdata.Traces, traceID pdata.TraceID) *pdata.Traces { + for _, batch := range a { + id := batch.ResourceSpans().At(0).InstrumentationLibrarySpans().At(0).Spans().At(0).TraceID() + if bytes.Equal(traceID.Bytes(), id.Bytes()) { + return &batch + } + } + return nil +} + func generateIdsAndBatches(numIds int) ([]pdata.TraceID, []pdata.Traces) { traceIds := make([]pdata.TraceID, numIds) + spanID := 0 var tds []pdata.Traces for i := 0; i < numIds; i++ { traceIds[i] = tracetranslator.UInt64ToTraceID(1, uint64(i+1)) @@ -341,7 +438,9 @@ func generateIdsAndBatches(numIds int) ([]pdata.TraceID, []pdata.Traces) { td := testdata.GenerateTraceDataOneSpan() span := td.ResourceSpans().At(0).InstrumentationLibrarySpans().At(0).Spans().At(0) span.SetTraceID(traceIds[i]) - span.SetSpanID(tracetranslator.UInt64ToByteSpanID(uint64(i + 1))) + + spanID++ + span.SetSpanID(tracetranslator.UInt64ToByteSpanID(uint64(spanID))) tds = append(tds, td) } }