Skip to content

Commit

Permalink
Create headersOrMultiheadersCarrier.
Browse files Browse the repository at this point in the history
  • Loading branch information
purple4reina committed Nov 27, 2024
1 parent aaaa24f commit bc39e47
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 6 deletions.
10 changes: 10 additions & 0 deletions pkg/serverless/trace/propagation/carriers.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,16 @@ func headersCarrier(hdrs map[string]string) (tracer.TextMapReader, error) {
return tracer.TextMapCarrier(hdrs), nil
}

// headersOrMultiheadersCarrier returns the tracer.TextMapReader used to extract
// trace context from a Headers field of form map[string]string or MultiValueHeaders
// field of form map[string][]string.
func headersOrMultiheadersCarrier(hdrs map[string]string, multiHdrs map[string][]string) (tracer.TextMapReader, error) {
if len(hdrs) > 0 {
return headersCarrier(hdrs)
}
return tracer.HTTPHeadersCarrier(multiHdrs), nil
}

// extractTraceContextFromStepFunctionContext extracts the execution ARN, state name, and state entered time and uses them to generate Trace ID and Parent ID
// The logic is based on the trace context conversion in Logs To Traces, dd-trace-py, dd-trace-js, etc.
func extractTraceContextFromStepFunctionContext(event events.StepFunctionPayload) (*TraceContext, error) {
Expand Down
51 changes: 50 additions & 1 deletion pkg/serverless/trace/propagation/carriers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ func TestHeadersCarrier(t *testing.T) {
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
tm, err := headersCarrier(tc.event)
t.Logf("rawPayloadCarrier returned TextMapReader=%#v error=%#v", tm, err)
t.Logf("headersCarrier returned TextMapReader=%#v error=%#v", tm, err)
assert.Equal(t, tc.expErr != nil, err != nil)
if tc.expErr != nil && err != nil {
assert.Equal(t, tc.expErr.Error(), err.Error())
Expand All @@ -826,6 +826,55 @@ func TestHeadersCarrier(t *testing.T) {
}
}

func TestHeadersOrMultiheadersCarrier(t *testing.T) {
testcases := []struct {
name string
hdrs map[string]string
multiHdrs map[string][]string
expMap map[string]string
}{
{
name: "nil-map",
hdrs: headersMapNone,
multiHdrs: toMultiValueHeaders(headersMapNone),
expMap: headersMapEmpty,
},
{
name: "empty-map",
hdrs: headersMapEmpty,
multiHdrs: toMultiValueHeaders(headersMapEmpty),
expMap: headersMapEmpty,
},
{
name: "headers-and-multiheaders",
hdrs: headersMapDD,
multiHdrs: toMultiValueHeaders(headersMapW3C),
expMap: headersMapDD,
},
{
name: "just-headers",
hdrs: headersMapDD,
multiHdrs: toMultiValueHeaders(headersMapEmpty),
expMap: headersMapDD,
},
{
name: "just-multiheaders",
hdrs: headersMapEmpty,
multiHdrs: toMultiValueHeaders(headersMapW3C),
expMap: headersMapW3C,
},
}

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
tm, err := headersOrMultiheadersCarrier(tc.hdrs, tc.multiHdrs)
t.Logf("headersOrMultiheadersCarrier returned TextMapReader=%#v error=%#v", tm, err)
assert.Nil(t, err)
assert.Equal(t, tc.expMap, getMapFromCarrier(tm))
})
}
}

func Test_stringToDdSpanId(t *testing.T) {
type args struct {
execArn string
Expand Down
6 changes: 1 addition & 5 deletions pkg/serverless/trace/propagation/extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,7 @@ func (e Extractor) extract(event interface{}) (*TraceContext, error) {
case events.APIGatewayCustomAuthorizerRequestTypeRequest:
carrier, err = headersCarrier(ev.Headers)
case events.ALBTargetGroupRequest:
if len(ev.Headers) > 0 {
carrier, err = headersCarrier(ev.Headers)
} else {
carrier = tracer.HTTPHeadersCarrier(ev.MultiValueHeaders)
}
carrier, err = headersOrMultiheadersCarrier(ev.Headers, ev.MultiValueHeaders)
case events.LambdaFunctionURLRequest:
carrier, err = headersCarrier(ev.Headers)
case events.StepFunctionPayload:
Expand Down

0 comments on commit bc39e47

Please sign in to comment.