Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

send traceparent header to ES #1002

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions module/apmelasticsearch/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,23 @@ type roundTripper struct {
func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
ctx := req.Context()
tx := apm.TransactionFromContext(ctx)
traceContext := tx.TraceContext()
if tx == nil || !tx.Sampled() {
stuartnelson3 marked this conversation as resolved.
Show resolved Hide resolved
apmhttp.SetHeaders(req, traceContext, false)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we default to propagating the legacy header?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can skip sending the legacy header. It's only for compatibility with old Elastic APM agents, so not relevant for the Elasticsearch use case where the plan is to extract the traceparent header and use the trace ID in logs.

return r.r.RoundTrip(req)
}

propagateLegacyHeader := tx.ShouldPropagateLegacyHeader()
name := requestName(req)
span := tx.StartSpan(name, "db.elasticsearch", apm.SpanFromContext(ctx))

if span.Dropped() {
span.End()
apmhttp.SetHeaders(req, traceContext, propagateLegacyHeader)
return r.r.RoundTrip(req)
}

traceContext = span.TraceContext()
statement, req := captureSearchStatement(req)
username, _, _ := req.BasicAuth()
ctx = apm.ContextWithSpan(ctx, span)
Expand All @@ -89,6 +95,7 @@ func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
User: username,
})

apmhttp.SetHeaders(req, traceContext, propagateLegacyHeader)
resp, err := r.r.RoundTrip(req)
if err != nil {
span.End()
Expand Down
113 changes: 113 additions & 0 deletions module/apmelasticsearch/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,12 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/net/context/ctxhttp"

"go.elastic.co/apm"
"go.elastic.co/apm/apmtest"
"go.elastic.co/apm/model"
"go.elastic.co/apm/module/apmelasticsearch"
"go.elastic.co/apm/module/apmhttp"
"go.elastic.co/apm/transport/transporttest"
)

func TestWrapRoundTripper(t *testing.T) {
Expand Down Expand Up @@ -303,6 +306,116 @@ func TestDestination(t *testing.T) {
test("http://[2001:db8::1]:80/_search", "2001:db8::1", 80)
}

func TestTraceHeaders(t *testing.T) {
headers := make(map[string]string)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
for k, vs := range req.Header {
headers[k] = strings.Join(vs, " ")
}
}))
defer server.Close()
client := &http.Client{Transport: apmelasticsearch.WrapRoundTripper(http.DefaultTransport)}

req, err := http.NewRequest("GET", server.URL, nil)
require.NoError(t, err)

_, _, _ = apmtest.WithTransaction(func(ctx context.Context) {
_, err := client.Do(req.WithContext(ctx))
assert.NoError(t, err)
})

assert.Contains(t, headers, apmhttp.ElasticTraceparentHeader)
assert.Contains(t, headers, apmhttp.W3CTraceparentHeader)
assert.Contains(t, headers, apmhttp.TracestateHeader)
}

func TestClientSpanDropped(t *testing.T) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@axw this and TestClientTransactionUnsampled are copied from apmhttp/client_test.go. They do verify that we're setting the headers when a span is dropped and when a transaction is not sampled, but what do you think about paring them back to just verifying we have the headers? I'm happy either way since this is essentially doubling up on checking the calculation of the headers, vs. just making sure we're setting them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what you have here is good. I do think it's important to check the trace context of the transaction vs. span is used like you have in this test.

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte(req.Header.Get("Traceparent")))
}))
defer server.Close()

tracer, transport := transporttest.NewRecorderTracer()
defer tracer.Close()

tracer.SetMaxSpans(1)
tx := tracer.StartTransaction("name", "type")
ctx := apm.ContextWithTransaction(context.Background(), tx)

var responseBodies []string
for i := 0; i < 2; i++ {
body, err := doGET(ctx, server.URL)
require.NoError(t, err)
responseBodies = append(responseBodies, body)
}

tx.End()
tracer.Flush(nil)
payloads := transport.Payloads()
require.Len(t, payloads.Spans, 1)
transaction := payloads.Transactions[0]
span := payloads.Spans[0] // for first request

clientTraceContext, err := apmhttp.ParseTraceparentHeader(string(responseBodies[0]))
require.NoError(t, err)
assert.Equal(t, span.TraceID, model.TraceID(clientTraceContext.Trace))
assert.Equal(t, span.ID, model.SpanID(clientTraceContext.Span))

clientTraceContext, err = apmhttp.ParseTraceparentHeader(string(responseBodies[1]))
require.NoError(t, err)
assert.Equal(t, transaction.TraceID, model.TraceID(clientTraceContext.Trace))
assert.Equal(t, transaction.ID, model.SpanID(clientTraceContext.Span))
}

func TestClientTransactionUnsampled(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte(req.Header.Get("Traceparent")))
}))
defer server.Close()

tracer, transport := transporttest.NewRecorderTracer()
defer tracer.Close()
tracer.SetSampler(apm.NewRatioSampler(0)) // sample nothing

tx := tracer.StartTransaction("name", "type")
ctx := apm.ContextWithTransaction(context.Background(), tx)
body, err := doGET(ctx, server.URL)
require.NoError(t, err)

tx.End()
tracer.Flush(nil)

payloads := transport.Payloads()
require.Len(t, payloads.Transactions, 1)
require.Len(t, payloads.Spans, 0)
transaction := payloads.Transactions[0]

clientTraceContext, err := apmhttp.ParseTraceparentHeader(string(body))
require.NoError(t, err)
assert.Equal(t, transaction.TraceID, model.TraceID(clientTraceContext.Trace))
assert.Equal(t, transaction.ID, model.SpanID(clientTraceContext.Span))
}

func doGET(ctx context.Context, url string) (string, error) {
client := &http.Client{Transport: apmelasticsearch.WrapRoundTripper(http.DefaultTransport)}
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return "", err
}
resp, err := client.Do(req.WithContext(ctx))
if err != nil {
return "", err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", err
}
defer resp.Body.Close()

return string(body), nil
}

type readCloser struct {
io.Reader
closed bool
Expand Down
7 changes: 4 additions & 3 deletions module/apmhttp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
propagateLegacyHeader := tx.ShouldPropagateLegacyHeader()
traceContext := tx.TraceContext()
if !traceContext.Options.Recorded() {
r.setHeaders(req, traceContext, propagateLegacyHeader)
SetHeaders(req, traceContext, propagateLegacyHeader)
return r.r.RoundTrip(req)
}

Expand All @@ -117,7 +117,7 @@ func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
span = nil
}

r.setHeaders(req, traceContext, propagateLegacyHeader)
SetHeaders(req, traceContext, propagateLegacyHeader)
resp, err := r.r.RoundTrip(req)
if span != nil {
if err != nil {
Expand All @@ -133,7 +133,8 @@ func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return resp, err
}

func (r *roundTripper) setHeaders(req *http.Request, traceContext apm.TraceContext, propagateLegacyHeader bool) {
// SetHeaders sets traceparent and tracestate headers on an http request.
stuartnelson3 marked this conversation as resolved.
Show resolved Hide resolved
func SetHeaders(req *http.Request, traceContext apm.TraceContext, propagateLegacyHeader bool) {
headerValue := FormatTraceparentHeader(traceContext)
if propagateLegacyHeader {
req.Header.Set(ElasticTraceparentHeader, headerValue)
Expand Down