diff --git a/module/apmotel/wrapper.go b/module/apmotel/wrapper.go index ee4cc34c9..ba81a562b 100644 --- a/module/apmotel/wrapper.go +++ b/module/apmotel/wrapper.go @@ -20,6 +20,7 @@ package apmotel // import "go.elastic.co/apm/module/apmotel/v2" import ( "context" + "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/trace" "go.elastic.co/apm/v2" @@ -42,19 +43,42 @@ func init() { } func contextWithSpan(ctx context.Context, apmSpan *apm.Span) context.Context { + var provider *tracerProvider + if p, ok := otel.GetTracerProvider().(*tracerProvider); ok { + provider = p + } + ctx = oldOverrideContextWithSpan(ctx, apmSpan) - return trace.ContextWithSpanContext(ctx, trace.NewSpanContext(trace.SpanContextConfig{ - TraceID: trace.TraceID(apmSpan.TraceContext().Trace), - SpanID: trace.SpanID(apmSpan.TraceContext().Span), - TraceFlags: trace.TraceFlags(0).WithSampled(!apmSpan.Dropped()), - })) + + return trace.ContextWithSpan(ctx, &span{ + provider: provider, + + spanContext: trace.NewSpanContext(trace.SpanContextConfig{ + TraceID: trace.TraceID(apmSpan.TraceContext().Trace), + SpanID: trace.SpanID(apmSpan.TraceContext().Span), + TraceFlags: trace.TraceFlags(0).WithSampled(!apmSpan.Dropped()), + }), + + span: apmSpan, + }) } func contextWithTransaction(ctx context.Context, apmTransaction *apm.Transaction) context.Context { + var provider *tracerProvider + if p, ok := otel.GetTracerProvider().(*tracerProvider); ok { + provider = p + } ctx = oldOverrideContextWithTransaction(ctx, apmTransaction) - return trace.ContextWithSpanContext(ctx, trace.NewSpanContext(trace.SpanContextConfig{ - TraceID: trace.TraceID(apmTransaction.TraceContext().Trace), - SpanID: trace.SpanID(apmTransaction.TraceContext().Span), - TraceFlags: trace.TraceFlags(0).WithSampled(apmTransaction.Sampled()), - })) + + return trace.ContextWithSpan(ctx, &span{ + provider: provider, + + spanContext: trace.NewSpanContext(trace.SpanContextConfig{ + TraceID: trace.TraceID(apmTransaction.TraceContext().Trace), + SpanID: trace.SpanID(apmTransaction.TraceContext().Span), + TraceFlags: trace.TraceFlags(0).WithSampled(apmTransaction.Sampled()), + }), + + tx: apmTransaction, + }) } diff --git a/module/apmotel/wrapper_test.go b/module/apmotel/wrapper_test.go index 4a803164b..63a693bed 100644 --- a/module/apmotel/wrapper_test.go +++ b/module/apmotel/wrapper_test.go @@ -42,6 +42,7 @@ func TestLinkAgentToOtel(t *testing.T) { assert.Equal(t, [16]byte(apmTx.TraceContext().Trace), [16]byte(otelSpan.SpanContext().TraceID())) assert.Equal(t, [8]byte(apmTx.TraceContext().Span), [8]byte(otelSpan.SpanContext().SpanID())) + assert.Equal(t, apmTx.Sampled(), otelSpan.IsRecording()) } func TestLinkOtelToAgent(t *testing.T) { @@ -57,4 +58,5 @@ func TestLinkOtelToAgent(t *testing.T) { assert.Equal(t, [16]byte(apmTx.TraceContext().Trace), [16]byte(otelSpan.SpanContext().TraceID())) assert.Equal(t, [8]byte(apmTx.TraceContext().Span), [8]byte(otelSpan.SpanContext().SpanID())) + assert.Equal(t, apmTx.Sampled(), otelSpan.IsRecording()) }