From bc39e4774173c92ccc1772784305379f873a6878 Mon Sep 17 00:00:00 2001 From: Rey Abolofia Date: Wed, 27 Nov 2024 10:15:14 -0800 Subject: [PATCH] Create headersOrMultiheadersCarrier. --- pkg/serverless/trace/propagation/carriers.go | 10 ++++ .../trace/propagation/carriers_test.go | 51 ++++++++++++++++++- pkg/serverless/trace/propagation/extractor.go | 6 +-- 3 files changed, 61 insertions(+), 6 deletions(-) diff --git a/pkg/serverless/trace/propagation/carriers.go b/pkg/serverless/trace/propagation/carriers.go index 9cfa1255baabb..ee2062664780e 100644 --- a/pkg/serverless/trace/propagation/carriers.go +++ b/pkg/serverless/trace/propagation/carriers.go @@ -253,6 +253,16 @@ func headersCarrier(hdrs map[string]string) (tracer.TextMapReader, error) { return tracer.TextMapCarrier(hdrs), nil } +// headersOrMultiheadersCarrier returns the tracer.TextMapReader used to extract +// trace context from a Headers field of form map[string]string or MultiValueHeaders +// field of form map[string][]string. +func headersOrMultiheadersCarrier(hdrs map[string]string, multiHdrs map[string][]string) (tracer.TextMapReader, error) { + if len(hdrs) > 0 { + return headersCarrier(hdrs) + } + return tracer.HTTPHeadersCarrier(multiHdrs), nil +} + // extractTraceContextFromStepFunctionContext extracts the execution ARN, state name, and state entered time and uses them to generate Trace ID and Parent ID // The logic is based on the trace context conversion in Logs To Traces, dd-trace-py, dd-trace-js, etc. func extractTraceContextFromStepFunctionContext(event events.StepFunctionPayload) (*TraceContext, error) { diff --git a/pkg/serverless/trace/propagation/carriers_test.go b/pkg/serverless/trace/propagation/carriers_test.go index c58b294b74e39..16e382343f15f 100644 --- a/pkg/serverless/trace/propagation/carriers_test.go +++ b/pkg/serverless/trace/propagation/carriers_test.go @@ -816,7 +816,7 @@ func TestHeadersCarrier(t *testing.T) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { tm, err := headersCarrier(tc.event) - t.Logf("rawPayloadCarrier returned TextMapReader=%#v error=%#v", tm, err) + t.Logf("headersCarrier returned TextMapReader=%#v error=%#v", tm, err) assert.Equal(t, tc.expErr != nil, err != nil) if tc.expErr != nil && err != nil { assert.Equal(t, tc.expErr.Error(), err.Error()) @@ -826,6 +826,55 @@ func TestHeadersCarrier(t *testing.T) { } } +func TestHeadersOrMultiheadersCarrier(t *testing.T) { + testcases := []struct { + name string + hdrs map[string]string + multiHdrs map[string][]string + expMap map[string]string + }{ + { + name: "nil-map", + hdrs: headersMapNone, + multiHdrs: toMultiValueHeaders(headersMapNone), + expMap: headersMapEmpty, + }, + { + name: "empty-map", + hdrs: headersMapEmpty, + multiHdrs: toMultiValueHeaders(headersMapEmpty), + expMap: headersMapEmpty, + }, + { + name: "headers-and-multiheaders", + hdrs: headersMapDD, + multiHdrs: toMultiValueHeaders(headersMapW3C), + expMap: headersMapDD, + }, + { + name: "just-headers", + hdrs: headersMapDD, + multiHdrs: toMultiValueHeaders(headersMapEmpty), + expMap: headersMapDD, + }, + { + name: "just-multiheaders", + hdrs: headersMapEmpty, + multiHdrs: toMultiValueHeaders(headersMapW3C), + expMap: headersMapW3C, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + tm, err := headersOrMultiheadersCarrier(tc.hdrs, tc.multiHdrs) + t.Logf("headersOrMultiheadersCarrier returned TextMapReader=%#v error=%#v", tm, err) + assert.Nil(t, err) + assert.Equal(t, tc.expMap, getMapFromCarrier(tm)) + }) + } +} + func Test_stringToDdSpanId(t *testing.T) { type args struct { execArn string diff --git a/pkg/serverless/trace/propagation/extractor.go b/pkg/serverless/trace/propagation/extractor.go index d0537af1a3e5b..d9fe9b883275c 100644 --- a/pkg/serverless/trace/propagation/extractor.go +++ b/pkg/serverless/trace/propagation/extractor.go @@ -112,11 +112,7 @@ func (e Extractor) extract(event interface{}) (*TraceContext, error) { case events.APIGatewayCustomAuthorizerRequestTypeRequest: carrier, err = headersCarrier(ev.Headers) case events.ALBTargetGroupRequest: - if len(ev.Headers) > 0 { - carrier, err = headersCarrier(ev.Headers) - } else { - carrier = tracer.HTTPHeadersCarrier(ev.MultiValueHeaders) - } + carrier, err = headersOrMultiheadersCarrier(ev.Headers, ev.MultiValueHeaders) case events.LambdaFunctionURLRequest: carrier, err = headersCarrier(ev.Headers) case events.StepFunctionPayload: