diff --git a/pkg/serverless/invocationlifecycle/lifecycle.go b/pkg/serverless/invocationlifecycle/lifecycle.go index 2c210efdc5e08..fec26be6d0de1 100644 --- a/pkg/serverless/invocationlifecycle/lifecycle.go +++ b/pkg/serverless/invocationlifecycle/lifecycle.go @@ -358,6 +358,7 @@ func (lp *LifecycleProcessor) newRequest(lambdaPayloadString []byte, startTime t SpanID: inferredspan.GenerateSpanId(), }, } + lp.requestHandler.inferredSpans[1] = nil lp.requestHandler.triggerTags = make(map[string]string) lp.requestHandler.triggerMetrics = make(map[string]float64) } diff --git a/pkg/serverless/invocationlifecycle/lifecycle_test.go b/pkg/serverless/invocationlifecycle/lifecycle_test.go index d4799305701c7..798df82a29d60 100644 --- a/pkg/serverless/invocationlifecycle/lifecycle_test.go +++ b/pkg/serverless/invocationlifecycle/lifecycle_test.go @@ -1168,6 +1168,58 @@ func TestTriggerTypesLifecycleEventForSNSSQSNoDdContext(t *testing.T) { assert.Equal(t, snsSpan.SpanID, sqsSpan.ParentID) } +func TestTriggerTypesLifecycleEventForSNSSQSThenApiGateway(t *testing.T) { + // SNS-SQS creates two inferred spans. Ensure that then invoking the + // function with an event that should have just one inferred span (API + // Gateway) creates just one inferred span. + var tracePayload *api.Payload + testProcessor := &LifecycleProcessor{ + DetectLambdaLibrary: func() bool { return false }, + ProcessTrace: func(payload *api.Payload) { tracePayload = payload }, + InferredSpansEnabled: true, + } + + // SNS-SQS invocation + startInvocationTime := time.Now() + endInvocationTime := startInvocationTime.Add(time.Second) + + startDetails := &InvocationStartDetails{ + InvokeEventRawPayload: getEventFromFile("snssqs.json"), + InvokedFunctionARN: "arn:aws:lambda:us-east-1:123456789012:function:my-function", + StartTime: startInvocationTime, + } + endDetails := &InvocationEndDetails{ + RequestID: "test-request-id", + EndTime: endInvocationTime, + } + + testProcessor.OnInvokeStart(startDetails) + testProcessor.OnInvokeEnd(endDetails) + + spans := tracePayload.TracerPayload.Chunks[0].Spans + assert.Equal(t, 3, len(spans)) + + // API Gateway invocation + startInvocationTime = endInvocationTime + endInvocationTime = startInvocationTime.Add(time.Second) + + startDetails = &InvocationStartDetails{ + InvokeEventRawPayload: getEventFromFile("api-gateway.json"), + InvokedFunctionARN: "arn:aws:lambda:us-east-1:123456789012:function:my-function", + StartTime: startInvocationTime, + } + endDetails = &InvocationEndDetails{ + RequestID: "test-request-id", + EndTime: endInvocationTime, + } + + testProcessor.OnInvokeStart(startDetails) + testProcessor.OnInvokeEnd(endDetails) + + spans = tracePayload.TracerPayload.Chunks[0].Spans + assert.Equal(t, 2, len(spans)) +} + func TestTriggerTypesLifecycleEventForSQSNoDdContext(t *testing.T) { startInvocationTime := time.Now()