diff --git a/sdk/trace/provider.go b/sdk/trace/provider.go index 324b686f4b8..0a018c14ded 100644 --- a/sdk/trace/provider.go +++ b/sdk/trace/provider.go @@ -120,7 +120,7 @@ func NewTracerProvider(opts ...TracerProviderOption) *TracerProvider { } global.Info("TracerProvider created", "config", o) - spss := spanProcessorStates{} + spss := make(spanProcessorStates, 0, len(o.processors)) for _, sp := range o.processors { spss = append(spss, newSpanProcessorState(sp)) } @@ -192,8 +192,10 @@ func (p *TracerProvider) RegisterSpanProcessor(sp SpanProcessor) { if p.isShutdown.Load() { return } - newSPS := spanProcessorStates{} - newSPS = append(newSPS, *(p.spanProcessors.Load())...) + + current := p.getSpanProcessors() + newSPS := make(spanProcessorStates, 0, len(current)+1) + newSPS = append(newSPS, current...) newSPS = append(newSPS, newSpanProcessorState(sp)) p.spanProcessors.Store(&newSPS) } @@ -210,12 +212,12 @@ func (p *TracerProvider) UnregisterSpanProcessor(sp SpanProcessor) { if p.isShutdown.Load() { return } - old := *(p.spanProcessors.Load()) + old := p.getSpanProcessors() if len(old) == 0 { return } - spss := spanProcessorStates{} - spss = append(spss, old...) + spss := make(spanProcessorStates, len(old)) + copy(spss, old) // stop the span processor if it is started and remove it from the list var stopOnce *spanProcessorState @@ -245,7 +247,7 @@ func (p *TracerProvider) UnregisterSpanProcessor(sp SpanProcessor) { // ForceFlush immediately exports all spans that have not yet been exported for // all the registered span processors. func (p *TracerProvider) ForceFlush(ctx context.Context) error { - spss := *(p.spanProcessors.Load()) + spss := p.getSpanProcessors() if len(spss) == 0 { return nil } @@ -278,10 +280,9 @@ func (p *TracerProvider) Shutdown(ctx context.Context) error { if !p.isShutdown.CompareAndSwap(false, true) { // did toggle? return nil } - spss := *(p.spanProcessors.Load()) var retErr error - for _, sps := range spss { + for _, sps := range p.getSpanProcessors() { select { case <-ctx.Done(): return ctx.Err() @@ -305,6 +306,10 @@ func (p *TracerProvider) Shutdown(ctx context.Context) error { return retErr } +func (p *TracerProvider) getSpanProcessors() spanProcessorStates { + return *(p.spanProcessors.Load()) +} + // TracerProviderOption configures a TracerProvider. type TracerProviderOption interface { apply(tracerProviderConfig) tracerProviderConfig diff --git a/sdk/trace/provider_test.go b/sdk/trace/provider_test.go index 282cd16ee67..8df3f1a4bd7 100644 --- a/sdk/trace/provider_test.go +++ b/sdk/trace/provider_test.go @@ -80,6 +80,57 @@ func TestForceFlushAndShutdownTraceProviderWithoutProcessor(t *testing.T) { assert.True(t, stp.isShutdown.Load()) } +func TestUnregisterFirst(t *testing.T) { + stp := NewTracerProvider() + sp1 := &basicSpanProcessor{} + sp2 := &basicSpanProcessor{} + sp3 := &basicSpanProcessor{} + stp.RegisterSpanProcessor(sp1) + stp.RegisterSpanProcessor(sp2) + stp.RegisterSpanProcessor(sp3) + + stp.UnregisterSpanProcessor(sp1) + + sps := stp.getSpanProcessors() + require.Len(t, sps, 2) + assert.Same(t, sp2, sps[0].sp) + assert.Same(t, sp3, sps[1].sp) +} + +func TestUnregisterMiddle(t *testing.T) { + stp := NewTracerProvider() + sp1 := &basicSpanProcessor{} + sp2 := &basicSpanProcessor{} + sp3 := &basicSpanProcessor{} + stp.RegisterSpanProcessor(sp1) + stp.RegisterSpanProcessor(sp2) + stp.RegisterSpanProcessor(sp3) + + stp.UnregisterSpanProcessor(sp2) + + sps := stp.getSpanProcessors() + require.Len(t, sps, 2) + assert.Same(t, sp1, sps[0].sp) + assert.Same(t, sp3, sps[1].sp) +} + +func TestUnregisterLast(t *testing.T) { + stp := NewTracerProvider() + sp1 := &basicSpanProcessor{} + sp2 := &basicSpanProcessor{} + sp3 := &basicSpanProcessor{} + stp.RegisterSpanProcessor(sp1) + stp.RegisterSpanProcessor(sp2) + stp.RegisterSpanProcessor(sp3) + + stp.UnregisterSpanProcessor(sp3) + + sps := stp.getSpanProcessors() + require.Len(t, sps, 2) + assert.Same(t, sp1, sps[0].sp) + assert.Same(t, sp2, sps[1].sp) +} + func TestShutdownTraceProvider(t *testing.T) { stp := NewTracerProvider() sp := &basicSpanProcessor{} @@ -162,7 +213,7 @@ func TestRegisterAfterShutdownWithoutProcessors(t *testing.T) { sp := &basicSpanProcessor{} stp.RegisterSpanProcessor(sp) // no-op - assert.Empty(t, stp.spanProcessors.Load()) + assert.Empty(t, stp.getSpanProcessors()) } func TestRegisterAfterShutdownWithProcessors(t *testing.T) { @@ -173,11 +224,11 @@ func TestRegisterAfterShutdownWithProcessors(t *testing.T) { err := stp.Shutdown(context.Background()) assert.NoError(t, err) assert.True(t, stp.isShutdown.Load()) - assert.Empty(t, stp.spanProcessors.Load()) + assert.Empty(t, stp.getSpanProcessors()) sp2 := &basicSpanProcessor{} stp.RegisterSpanProcessor(sp2) // no-op - assert.Empty(t, stp.spanProcessors.Load()) + assert.Empty(t, stp.getSpanProcessors()) } func TestTracerProviderSamplerConfigFromEnv(t *testing.T) { diff --git a/sdk/trace/span.go b/sdk/trace/span.go index 81e89122fa2..8ec7da7f744 100644 --- a/sdk/trace/span.go +++ b/sdk/trace/span.go @@ -410,7 +410,7 @@ func (s *recordingSpan) End(options ...trace.SpanEndOption) { } s.mu.Unlock() - sps := *(s.tracer.provider.spanProcessors.Load()) + sps := s.tracer.provider.getSpanProcessors() if len(sps) == 0 { return } diff --git a/sdk/trace/span_processor.go b/sdk/trace/span_processor.go index e6ae1935219..9c53657a719 100644 --- a/sdk/trace/span_processor.go +++ b/sdk/trace/span_processor.go @@ -62,11 +62,11 @@ type SpanProcessor interface { type spanProcessorState struct { sp SpanProcessor - state *sync.Once + state sync.Once } func newSpanProcessorState(sp SpanProcessor) *spanProcessorState { - return &spanProcessorState{sp: sp, state: &sync.Once{}} + return &spanProcessorState{sp: sp} } type spanProcessorStates []*spanProcessorState diff --git a/sdk/trace/tracer.go b/sdk/trace/tracer.go index 16e0108f085..85a71227f3f 100644 --- a/sdk/trace/tracer.go +++ b/sdk/trace/tracer.go @@ -51,7 +51,7 @@ func (tr *tracer) Start(ctx context.Context, name string, options ...trace.SpanS s := tr.newSpan(ctx, name, &config) if rw, ok := s.(ReadWriteSpan); ok && s.IsRecording() { - sps := *(tr.provider.spanProcessors.Load()) + sps := tr.provider.getSpanProcessors() for _, sp := range sps { sp.sp.OnStart(ctx, rw) }