Skip to content

Commit

Permalink
[SVLS-5034] Create trace context from Step Function execution details (
Browse files Browse the repository at this point in the history
  • Loading branch information
agocs authored Sep 19, 2024
1 parent 87763a2 commit a4534bc
Show file tree
Hide file tree
Showing 18 changed files with 516 additions and 20 deletions.
3 changes: 3 additions & 0 deletions pkg/serverless/daemon/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ func (s *StartInvocation) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Debug("a context has been found, sending the context to the tracer")
w.Header().Set(invocationlifecycle.TraceIDHeader, fmt.Sprintf("%v", s.daemon.InvocationProcessor.GetExecutionInfo().TraceID))
w.Header().Set(invocationlifecycle.SamplingPriorityHeader, fmt.Sprintf("%v", s.daemon.InvocationProcessor.GetExecutionInfo().SamplingPriority))
if s.daemon.InvocationProcessor.GetExecutionInfo().TraceIDUpper64Hex != "" {
w.Header().Set(invocationlifecycle.TraceTagsHeader, fmt.Sprintf("%s=%s", invocationlifecycle.Upper64BitsTag, s.daemon.InvocationProcessor.GetExecutionInfo().TraceIDUpper64Hex))
}
}
}

Expand Down
7 changes: 7 additions & 0 deletions pkg/serverless/invocationlifecycle/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ const (
// SamplingPriorityHeader is the header containing the sampling priority for execution and/or inferred spans
SamplingPriorityHeader = "x-datadog-sampling-priority"

// TraceTagsHeader is the header containing trace tags, e.g. the upper 64 bits tag
TraceTagsHeader = "x-datadog-tags"

// Upper64BitsTag is the tag for the upper 64 bits of the trace ID, if it exists
Upper64BitsTag = "_dd.p.tid"

// Lambda function trigger span tag values
apiGateway = "api-gateway"
applicationLoadBalancer = "application-load-balancer"
Expand All @@ -47,4 +53,5 @@ const (
sns = "sns"
sqs = "sqs"
functionURL = "lambda-function-url"
stepFunction = "step-function"
)
4 changes: 4 additions & 0 deletions pkg/serverless/invocationlifecycle/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,7 @@ func (lp *LifecycleProcessor) initFromLambdaFunctionURLEvent(event events.Lambda
lp.addTag(tagFunctionTriggerEventSourceArn, fmt.Sprintf("arn:aws:lambda:%v:%v:url:%v", region, accountID, functionName))
lp.addTags(trigger.GetTagsFromLambdaFunctionURLRequest(event))
}

func (lp *LifecycleProcessor) initFromStepFunctionPayload(event events.StepFunctionPayload) {
lp.requestHandler.event = event
}
17 changes: 16 additions & 1 deletion pkg/serverless/invocationlifecycle/lifecycle.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ func (lp *LifecycleProcessor) OnInvokeStart(startDetails *InvocationStartDetails
if err != nil {
log.Debugf("[lifecycle] Failed to parse event payload: %v", err)
}

eventType := trigger.GetEventType(lowercaseEventPayload)
if eventType == trigger.Unknown {
log.Debugf("[lifecycle] Failed to extract event type")
Expand Down Expand Up @@ -230,6 +229,22 @@ func (lp *LifecycleProcessor) OnInvokeStart(startDetails *InvocationStartDetails
}
ev = event
lp.initFromLambdaFunctionURLEvent(event, region, account, resource)
case trigger.LegacyStepFunctionEvent:
var event events.StepFunctionEvent
if err := json.Unmarshal(payloadBytes, &event); err != nil {
log.Debugf("Failed to unmarshal %s event: %s", stepFunction, err)
break
}
ev = event.Payload
lp.initFromStepFunctionPayload(event.Payload)
case trigger.StepFunctionEvent:
var eventPayload events.StepFunctionPayload
if err := json.Unmarshal(payloadBytes, &eventPayload); err != nil {
log.Debugf("Failed to unmarshal %s event: %s", stepFunction, err)
break
}
ev = eventPayload
lp.initFromStepFunctionPayload(eventPayload)
default:
log.Debug("Skipping adding trigger types and inferred spans as a non-supported payload was received.")
}
Expand Down
72 changes: 72 additions & 0 deletions pkg/serverless/invocationlifecycle/lifecycle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,78 @@ func TestStartExecutionSpanWithLambdaLibrary(t *testing.T) {
assert.Equal(t, startInvocationTime, testProcessor.GetExecutionInfo().startTime)
}

func TestStartExecutionSpanStepFunctionEvent(t *testing.T) {
extraTags := &logs.Tags{
Tags: []string{"functionname:test-function"},
}
demux := createDemultiplexer(t)
mockProcessTrace := func(*api.Payload) {}
mockDetectLambdaLibrary := func() bool { return false }

eventPayload := `{"Execution":{"Id":"arn:aws:states:us-east-1:425362996713:execution:agocsTestSF:bc9f281c-3daa-4e5a-9a60-471a3810bf44","Input":{},"StartTime":"2024-07-30T19:55:52.976Z","Name":"bc9f281c-3daa-4e5a-9a60-471a3810bf44","RoleArn":"arn:aws:iam::425362996713:role/test-serverless-stepfunctions-dev-AgocsTestSFRole-tRkeFXScjyk4","RedriveCount":0},"StateMachine":{"Id":"arn:aws:states:us-east-1:425362996713:stateMachine:agocsTestSF","Name":"agocsTestSF"},"State":{"Name":"agocsTest1","EnteredTime":"2024-07-30T19:55:53.018Z","RetryCount":0}}`
startInvocationTime := time.Now()
startDetails := InvocationStartDetails{
StartTime: startInvocationTime,
InvokeEventRawPayload: []byte(eventPayload),
InvokedFunctionARN: "arn:aws:lambda:us-east-1:123456789012:function:my-function",
}

testProcessor := LifecycleProcessor{
ExtraTags: extraTags,
ProcessTrace: mockProcessTrace,
DetectLambdaLibrary: mockDetectLambdaLibrary,
Demux: demux,
}

testProcessor.OnInvokeStart(&startDetails)

assert.NotNil(t, testProcessor.GetExecutionInfo())

assert.Equal(t, uint64(0), testProcessor.GetExecutionInfo().SpanID)
assert.Equal(t, uint64(5744042798732701615), testProcessor.GetExecutionInfo().TraceID)
assert.Equal(t, uint64(2902498116043018663), testProcessor.GetExecutionInfo().parentID)
assert.Equal(t, sampler.SamplingPriority(1), testProcessor.GetExecutionInfo().SamplingPriority)
upper64 := testProcessor.GetExecutionInfo().TraceIDUpper64Hex
assert.Equal(t, "1914fe7789eb32be", upper64)
assert.Equal(t, startInvocationTime, testProcessor.GetExecutionInfo().startTime)
}

func TestLegacyLambdaStartExecutionSpanStepFunctionEvent(t *testing.T) {
extraTags := &logs.Tags{
Tags: []string{"functionname:test-function"},
}
demux := createDemultiplexer(t)
mockProcessTrace := func(*api.Payload) {}
mockDetectLambdaLibrary := func() bool { return false }

eventPayload := `{"Payload":{"Execution":{"Id":"arn:aws:states:us-east-1:425362996713:execution:agocsTestSF:bc9f281c-3daa-4e5a-9a60-471a3810bf44","Input":{},"StartTime":"2024-07-30T19:55:52.976Z","Name":"bc9f281c-3daa-4e5a-9a60-471a3810bf44","RoleArn":"arn:aws:iam::425362996713:role/test-serverless-stepfunctions-dev-AgocsTestSFRole-tRkeFXScjyk4","RedriveCount":0},"StateMachine":{"Id":"arn:aws:states:us-east-1:425362996713:stateMachine:agocsTestSF","Name":"agocsTestSF"},"State":{"Name":"agocsTest1","EnteredTime":"2024-07-30T19:55:53.018Z","RetryCount":0}}}`
startInvocationTime := time.Now()
startDetails := InvocationStartDetails{
StartTime: startInvocationTime,
InvokeEventRawPayload: []byte(eventPayload),
InvokedFunctionARN: "arn:aws:lambda:us-east-1:123456789012:function:my-function",
}

testProcessor := LifecycleProcessor{
ExtraTags: extraTags,
ProcessTrace: mockProcessTrace,
DetectLambdaLibrary: mockDetectLambdaLibrary,
Demux: demux,
}

testProcessor.OnInvokeStart(&startDetails)

assert.NotNil(t, testProcessor.GetExecutionInfo())

assert.Equal(t, uint64(0), testProcessor.GetExecutionInfo().SpanID)
assert.Equal(t, uint64(5744042798732701615), testProcessor.GetExecutionInfo().TraceID)
assert.Equal(t, uint64(2902498116043018663), testProcessor.GetExecutionInfo().parentID)
assert.Equal(t, sampler.SamplingPriority(1), testProcessor.GetExecutionInfo().SamplingPriority)
upper64 := testProcessor.GetExecutionInfo().TraceIDUpper64Hex
assert.Equal(t, "1914fe7789eb32be", upper64)
assert.Equal(t, startInvocationTime, testProcessor.GetExecutionInfo().startTime)
}

func TestEndExecutionSpanNoLambdaLibrary(t *testing.T) {
t.Setenv(functionNameEnvVar, "TestFunction")

Expand Down
19 changes: 13 additions & 6 deletions pkg/serverless/invocationlifecycle/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ var /* const */ runtimeRegex = regexp.MustCompile(`^(dotnet|go|java|ruby)(\d+(\.

// ExecutionStartInfo is saved information from when an execution span was started
type ExecutionStartInfo struct {
startTime time.Time
TraceID uint64
SpanID uint64
parentID uint64
requestPayload []byte
SamplingPriority sampler.SamplingPriority
startTime time.Time
TraceID uint64
TraceIDUpper64Hex string
SpanID uint64
parentID uint64
requestPayload []byte
SamplingPriority sampler.SamplingPriority
}

// startExecutionSpan records information from the start of the invocation.
Expand All @@ -63,6 +64,12 @@ func (lp *LifecycleProcessor) startExecutionSpan(event interface{}, rawPayload [
inferredSpan.Span.TraceID = traceContext.TraceID
inferredSpan.Span.ParentID = traceContext.ParentID
}
if traceContext.TraceIDUpper64Hex != "" {
executionContext.TraceIDUpper64Hex = traceContext.TraceIDUpper64Hex
lp.requestHandler.SetMetaTag(Upper64BitsTag, traceContext.TraceIDUpper64Hex)
} else {
delete(lp.requestHandler.triggerTags, Upper64BitsTag)
}
} else {
executionContext.TraceID = 0
executionContext.parentID = 0
Expand Down
95 changes: 95 additions & 0 deletions pkg/serverless/invocationlifecycle/trace_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,21 @@ func TestStartExecutionSpan(t *testing.T) {
reqHeadersWithCtx.Set("x-datadog-sampling-priority", "3")
reqHeadersWithCtx.Set("traceparent", "00-00000000000000000000000000000006-0000000000000006-01")

stepFunctionEvent := events.StepFunctionPayload{
Execution: struct {
ID string
}{
ID: "arn:aws:states:us-east-1:425362996713:execution:agocsTestSF:aa6c9316-713a-41d4-9c30-61131716744f",
},
State: struct {
Name string
EnteredTime string
}{
Name: "agocsTest1",
EnteredTime: "2024-07-30T20:46:20.824Z",
},
}

testcases := []struct {
name string
event interface{}
Expand Down Expand Up @@ -315,6 +330,20 @@ func TestStartExecutionSpan(t *testing.T) {
SamplingPriority: sampler.SamplingPriority(1),
},
},
{
name: "step function event",
event: stepFunctionEvent,
payload: payloadWithoutCtx,
reqHeaders: reqHeadersWithoutCtx,
infSpanEnabled: false,
propStyle: "datadog",
expectCtx: &ExecutionStartInfo{
TraceID: 5377636026938777059,
TraceIDUpper64Hex: "6fb5c3a05c73dbfe",
parentID: 8947638978974359093,
SamplingPriority: 1,
},
},
}

for _, tc := range testcases {
Expand All @@ -333,6 +362,7 @@ func TestStartExecutionSpan(t *testing.T) {
requestHandler: &RequestHandler{
executionInfo: actualCtx,
inferredSpans: [2]*inferredspan.InferredSpan{inferredSpan},
triggerTags: make(map[string]string),
},
}
startDetails := &InvocationStartDetails{
Expand Down Expand Up @@ -697,6 +727,71 @@ func TestEndExecutionSpanWithTimeout(t *testing.T) {
assert.Equal(t, "Datadog detected an Impending Timeout", executionSpan.Meta["error.msg"])
}

func TestEndExecutionSpanWithStepFunctions(t *testing.T) {
t.Setenv(functionNameEnvVar, "TestFunction")
currentExecutionInfo := &ExecutionStartInfo{}
lp := &LifecycleProcessor{
requestHandler: &RequestHandler{
executionInfo: currentExecutionInfo,
triggerTags: make(map[string]string),
},
}

lp.requestHandler.triggerTags["_dd.p.tid"] = "6fb5c3a05c73dbfe"

startTime := time.Now()
startDetails := &InvocationStartDetails{
StartTime: startTime,
InvokeEventHeaders: http.Header{},
}

stepFunctionEvent := events.StepFunctionPayload{
Execution: struct{ ID string }(struct {
ID string `json:"id"`
}{
ID: "arn:aws:states:us-east-1:425362996713:execution:agocsTestSF:aa6c9316-713a-41d4-9c30-61131716744f",
}),
State: struct {
Name string
EnteredTime string
}{
Name: "agocsTest1",
EnteredTime: "2024-07-30T20:46:20.824Z",
},
}

lp.startExecutionSpan(stepFunctionEvent, []byte("[]"), startDetails)

assert.Equal(t, uint64(5377636026938777059), currentExecutionInfo.TraceID)
assert.Equal(t, uint64(8947638978974359093), currentExecutionInfo.parentID)
assert.Equal(t, "6fb5c3a05c73dbfe", lp.requestHandler.triggerTags["_dd.p.tid"])

duration := 1 * time.Second
endTime := startTime.Add(duration)

endDetails := &InvocationEndDetails{
EndTime: endTime,
IsError: false,
RequestID: "test-request-id",
ResponseRawPayload: []byte(`{"response":"test response payload"}`),
ColdStart: true,
ProactiveInit: false,
Runtime: "dotnet6",
}
executionSpan := lp.endExecutionSpan(endDetails)

assert.Equal(t, "aws.lambda", executionSpan.Name)
assert.Equal(t, "aws.lambda", executionSpan.Service)
assert.Equal(t, "TestFunction", executionSpan.Resource)
assert.Equal(t, "serverless", executionSpan.Type)
assert.Equal(t, currentExecutionInfo.TraceID, executionSpan.TraceID)
assert.Equal(t, currentExecutionInfo.SpanID, executionSpan.SpanID)
assert.Equal(t, startTime.UnixNano(), executionSpan.Start)
assert.Equal(t, duration.Nanoseconds(), executionSpan.Duration)
assert.Equal(t, "6fb5c3a05c73dbfe", executionSpan.Meta["_dd.p.tid"])

}

func TestParseLambdaPayload(t *testing.T) {
assert.Equal(t, []byte(""), ParseLambdaPayload([]byte("")))
assert.Equal(t, []byte("{}"), ParseLambdaPayload([]byte("{}")))
Expand Down
79 changes: 69 additions & 10 deletions pkg/serverless/trace/propagation/carriers.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package propagation

import (
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
Expand Down Expand Up @@ -36,16 +37,17 @@ const (
var rootRegex = regexp.MustCompile("Root=1-[0-9a-fA-F]{8}-00000000[0-9a-fA-F]{16}")

var (
errorAWSTraceHeaderMismatch = errors.New("AWSTraceHeader does not match expected regex")
errorAWSTraceHeaderEmpty = errors.New("AWSTraceHeader does not contain trace ID and parent ID")
errorStringNotFound = errors.New("String value not found in _datadog payload")
errorUnsupportedDataType = errors.New("Unsupported DataType in _datadog payload")
errorNoDDContextFound = errors.New("No Datadog trace context found")
errorUnsupportedPayloadType = errors.New("Unsupported type for _datadog payload")
errorUnsupportedTypeType = errors.New("Unsupported type in _datadog payload")
errorUnsupportedValueType = errors.New("Unsupported value type in _datadog payload")
errorUnsupportedTypeValue = errors.New("Unsupported Type in _datadog payload")
errorCouldNotUnmarshal = errors.New("Could not unmarshal the invocation event payload")
errorAWSTraceHeaderMismatch = errors.New("AWSTraceHeader does not match expected regex")
errorAWSTraceHeaderEmpty = errors.New("AWSTraceHeader does not contain trace ID and parent ID")
errorStringNotFound = errors.New("String value not found in _datadog payload")
errorUnsupportedDataType = errors.New("Unsupported DataType in _datadog payload")
errorNoDDContextFound = errors.New("No Datadog trace context found")
errorUnsupportedPayloadType = errors.New("Unsupported type for _datadog payload")
errorUnsupportedTypeType = errors.New("Unsupported type in _datadog payload")
errorUnsupportedValueType = errors.New("Unsupported value type in _datadog payload")
errorUnsupportedTypeValue = errors.New("Unsupported Type in _datadog payload")
errorCouldNotUnmarshal = errors.New("Could not unmarshal the invocation event payload")
errorNoStepFunctionContextFound = errors.New("no Step Function context found in Step Function event")
)

// extractTraceContextfromAWSTraceHeader extracts trace context from the
Expand Down Expand Up @@ -220,3 +222,60 @@ func rawPayloadCarrier(rawPayload []byte) (tracer.TextMapReader, error) {
func headersCarrier(hdrs map[string]string) (tracer.TextMapReader, error) {
return tracer.TextMapCarrier(hdrs), nil
}

// extractTraceContextFromStepFunctionContext extracts the execution ARN, state name, and state entered time and uses them to generate Trace ID and Parent ID
// The logic is based on the trace context conversion in Logs To Traces, dd-trace-py, dd-trace-js, etc.
func extractTraceContextFromStepFunctionContext(event events.StepFunctionPayload) (*TraceContext, error) {
tc := new(TraceContext)

execArn := event.Execution.ID
stateName := event.State.Name
stateEnteredTime := event.State.EnteredTime

if execArn == "" || stateName == "" || stateEnteredTime == "" {
return nil, errorNoStepFunctionContextFound
}

lowerTraceID, upperTraceID := stringToDdTraceIDs(execArn)
parentID := stringToDdSpanID(execArn, stateName, stateEnteredTime)

tc.TraceID = lowerTraceID
tc.TraceIDUpper64Hex = upperTraceID
tc.ParentID = parentID
tc.SamplingPriority = sampler.PriorityAutoKeep
return tc, nil
}

// stringToDdSpanID hashes the Execution ARN, state name, and state entered time to generate a 64-bit span ID
func stringToDdSpanID(execArn string, stateName string, stateEnteredTime string) uint64 {
uniqueSpanString := fmt.Sprintf("%s#%s#%s", execArn, stateName, stateEnteredTime)
spanHash := sha256.Sum256([]byte(uniqueSpanString))
parentID := getPositiveUInt64(spanHash[0:8])
return parentID
}

// stringToDdTraceIDs hashes an Execution ARN to generate the lower and upper 64 bits of a 128-bit trace ID
func stringToDdTraceIDs(toHash string) (uint64, string) {
hash := sha256.Sum256([]byte(toHash))
lower64 := getPositiveUInt64(hash[8:16])
upper64 := getHexEncodedString(getPositiveUInt64(hash[0:8]))
return lower64, upper64
}

// getPositiveUInt64 converts the first 8 bytes of a byte array to a positive uint64
func getPositiveUInt64(hashBytes []byte) uint64 {
var result uint64
for i := 0; i < 8; i++ {
result = (result << 8) + uint64(hashBytes[i])
}
result &= ^uint64(1 << 63) // Ensure the highest bit is always 0
if result == 0 {
return 1
}
return result
}

func getHexEncodedString(toEncode uint64) string {
//return hex.EncodeToString(hashBytes[:8])
return fmt.Sprintf("%x", toEncode) //maybe?
}
Loading

0 comments on commit a4534bc

Please sign in to comment.