diff --git a/contrib/aws/aws-sdk-go-v2/aws/aws.go b/contrib/aws/aws-sdk-go-v2/aws/aws.go index f193914cbd..2a2bbf5c38 100644 --- a/contrib/aws/aws-sdk-go-v2/aws/aws.go +++ b/contrib/aws/aws-sdk-go-v2/aws/aws.go @@ -31,6 +31,10 @@ import ( "github.com/aws/aws-sdk-go-v2/service/sqs" "github.com/aws/smithy-go/middleware" smithyhttp "github.com/aws/smithy-go/transport/http" + + eventBridgeTracer "gopkg.in/DataDog/dd-trace-go.v1/contrib/aws/internal/eventbridge" + snsTracer "gopkg.in/DataDog/dd-trace-go.v1/contrib/aws/internal/sns" + sqsTracer "gopkg.in/DataDog/dd-trace-go.v1/contrib/aws/internal/sqs" ) const componentName = "aws/aws-sdk-go-v2/aws" @@ -105,6 +109,16 @@ func (mw *traceMiddleware) startTraceMiddleware(stack *middleware.Stack) error { } span, spanctx := tracer.StartSpanFromContext(ctx, spanName(serviceID, operation), opts...) + // Inject trace context + switch serviceID { + case "SQS": + sqsTracer.EnrichOperation(span, in, operation) + case "SNS": + snsTracer.EnrichOperation(span, in, operation) + case "EventBridge": + eventBridgeTracer.EnrichOperation(span, in, operation) + } + // Handle initialize and continue through the middleware chain. out, metadata, err = next.HandleInitialize(spanctx, in) if err != nil && (mw.cfg.errCheck == nil || mw.cfg.errCheck(err)) { diff --git a/contrib/aws/aws-sdk-go-v2/aws/aws_test.go b/contrib/aws/aws-sdk-go-v2/aws/aws_test.go index 88fba42c5d..09768a3f37 100644 --- a/contrib/aws/aws-sdk-go-v2/aws/aws_test.go +++ b/contrib/aws/aws-sdk-go-v2/aws/aws_test.go @@ -8,6 +8,7 @@ package aws import ( "context" "encoding/base64" + "encoding/json" "net/http" "net/http/httptest" "net/url" @@ -24,12 +25,13 @@ import ( "github.com/aws/aws-sdk-go-v2/service/dynamodb" "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/eventbridge" + eventBridgeTypes "github.com/aws/aws-sdk-go-v2/service/eventbridge/types" "github.com/aws/aws-sdk-go-v2/service/kinesis" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/sfn" "github.com/aws/aws-sdk-go-v2/service/sns" "github.com/aws/aws-sdk-go-v2/service/sqs" - "github.com/aws/aws-sdk-go-v2/service/sqs/types" + sqsTypes "github.com/aws/aws-sdk-go-v2/service/sqs/types" "github.com/aws/smithy-go/middleware" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -281,6 +283,66 @@ func TestAppendMiddlewareSqsReceiveMessage(t *testing.T) { } } +func TestAppendMiddlewareSqsSendMessage(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + expectedStatusCode := 200 + server := mockAWS(expectedStatusCode) + defer server.Close() + + resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) { + return aws.Endpoint{ + PartitionID: "aws", + URL: server.URL, + SigningRegion: "eu-west-1", + }, nil + }) + + awsCfg := aws.Config{ + Region: "eu-west-1", + Credentials: aws.AnonymousCredentials{}, + EndpointResolver: resolver, + } + + AppendMiddleware(&awsCfg) + + sqsClient := sqs.NewFromConfig(awsCfg) + sendMessageInput := &sqs.SendMessageInput{ + MessageBody: aws.String("test message"), + QueueUrl: aws.String("https://sqs.us-west-2.amazonaws.com/123456789012/MyQueueName"), + } + _, err := sqsClient.SendMessage(context.Background(), sendMessageInput) + require.NoError(t, err) + + spans := mt.FinishedSpans() + require.Len(t, spans, 1) + + s := spans[0] + assert.Equal(t, "SQS.request", s.OperationName()) + assert.Equal(t, "SendMessage", s.Tag("aws.operation")) + assert.Equal(t, "SQS", s.Tag("aws.service")) + assert.Equal(t, "MyQueueName", s.Tag("queuename")) + assert.Equal(t, "SQS.SendMessage", s.Tag(ext.ResourceName)) + assert.Equal(t, "aws.SQS", s.Tag(ext.ServiceName)) + + // Check for trace context injection + assert.NotNil(t, sendMessageInput.MessageAttributes) + assert.Contains(t, sendMessageInput.MessageAttributes, "_datadog") + ddAttr := sendMessageInput.MessageAttributes["_datadog"] + assert.Equal(t, "String", *ddAttr.DataType) + assert.NotEmpty(t, *ddAttr.StringValue) + + // Decode and verify the injected trace context + var traceContext map[string]string + err = json.Unmarshal([]byte(*ddAttr.StringValue), &traceContext) + assert.NoError(t, err) + assert.Contains(t, traceContext, "x-datadog-trace-id") + assert.Contains(t, traceContext, "x-datadog-parent-id") + assert.NotEmpty(t, traceContext["x-datadog-trace-id"]) + assert.NotEmpty(t, traceContext["x-datadog-parent-id"]) +} + func TestAppendMiddlewareS3ListObjects(t *testing.T) { tests := []struct { name string @@ -441,6 +503,22 @@ func TestAppendMiddlewareSnsPublish(t *testing.T) { assert.Equal(t, server.URL+"/", s.Tag(ext.HTTPURL)) assert.Equal(t, "aws/aws-sdk-go-v2/aws", s.Tag(ext.Component)) assert.Equal(t, ext.SpanKindClient, s.Tag(ext.SpanKind)) + + // Check for trace context injection + assert.NotNil(t, tt.publishInput.MessageAttributes) + assert.Contains(t, tt.publishInput.MessageAttributes, "_datadog") + ddAttr := tt.publishInput.MessageAttributes["_datadog"] + assert.Equal(t, "Binary", *ddAttr.DataType) + assert.NotEmpty(t, ddAttr.BinaryValue) + + // Decode and verify the injected trace context + var traceContext map[string]string + err := json.Unmarshal(ddAttr.BinaryValue, &traceContext) + assert.NoError(t, err) + assert.Contains(t, traceContext, "x-datadog-trace-id") + assert.Contains(t, traceContext, "x-datadog-parent-id") + assert.NotEmpty(t, traceContext["x-datadog-trace-id"]) + assert.NotEmpty(t, traceContext["x-datadog-parent-id"]) }) } } @@ -657,6 +735,62 @@ func TestAppendMiddlewareEventBridgePutRule(t *testing.T) { } } +func TestAppendMiddlewareEventBridgePutEvents(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + expectedStatusCode := 200 + server := mockAWS(expectedStatusCode) + defer server.Close() + + resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) { + return aws.Endpoint{ + PartitionID: "aws", + URL: server.URL, + SigningRegion: "eu-west-1", + }, nil + }) + + awsCfg := aws.Config{ + Region: "eu-west-1", + Credentials: aws.AnonymousCredentials{}, + EndpointResolver: resolver, + } + + AppendMiddleware(&awsCfg) + + eventbridgeClient := eventbridge.NewFromConfig(awsCfg) + putEventsInput := &eventbridge.PutEventsInput{ + Entries: []eventBridgeTypes.PutEventsRequestEntry{ + { + EventBusName: aws.String("my-event-bus"), + Detail: aws.String(`{"key": "value"}`), + }, + }, + } + eventbridgeClient.PutEvents(context.Background(), putEventsInput) + + spans := mt.FinishedSpans() + require.Len(t, spans, 1) + + s := spans[0] + assert.Equal(t, "PutEvents", s.Tag("aws.operation")) + assert.Equal(t, "EventBridge.PutEvents", s.Tag(ext.ResourceName)) + + // Check for trace context injection + assert.Len(t, putEventsInput.Entries, 1) + entry := putEventsInput.Entries[0] + var detail map[string]interface{} + err := json.Unmarshal([]byte(*entry.Detail), &detail) + assert.NoError(t, err) + assert.Contains(t, detail, "_datadog") + ddData, ok := detail["_datadog"].(map[string]interface{}) + assert.True(t, ok) + assert.Contains(t, ddData, "x-datadog-start-time") + assert.Contains(t, ddData, "x-datadog-resource-name") + assert.Equal(t, "my-event-bus", ddData["x-datadog-resource-name"]) +} + func TestAppendMiddlewareSfnDescribeStateMachine(t *testing.T) { tests := []struct { name string @@ -971,8 +1105,8 @@ func TestMessagingNamingSchema(t *testing.T) { _, err = sqsClient.SendMessage(ctx, msg) require.NoError(t, err) - entry := types.SendMessageBatchRequestEntry{Id: aws.String("1"), MessageBody: aws.String("body")} - batchMsg := &sqs.SendMessageBatchInput{QueueUrl: sqsResp.QueueUrl, Entries: []types.SendMessageBatchRequestEntry{entry}} + entry := sqsTypes.SendMessageBatchRequestEntry{Id: aws.String("1"), MessageBody: aws.String("body")} + batchMsg := &sqs.SendMessageBatchInput{QueueUrl: sqsResp.QueueUrl, Entries: []sqsTypes.SendMessageBatchRequestEntry{entry}} _, err = sqsClient.SendMessageBatch(ctx, batchMsg) require.NoError(t, err) diff --git a/contrib/aws/internal/eventbridge/eventbridge.go b/contrib/aws/internal/eventbridge/eventbridge.go new file mode 100644 index 0000000000..5a2a56068e --- /dev/null +++ b/contrib/aws/internal/eventbridge/eventbridge.go @@ -0,0 +1,112 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016 Datadog, Inc. + +package eventbridge + +import ( + "encoding/json" + "fmt" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/eventbridge" + "github.com/aws/aws-sdk-go-v2/service/eventbridge/types" + "github.com/aws/smithy-go/middleware" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "gopkg.in/DataDog/dd-trace-go.v1/internal/log" + "strconv" + "time" +) + +const ( + datadogKey = "_datadog" + startTimeKey = "x-datadog-start-time" + resourceNameKey = "x-datadog-resource-name" + maxSizeBytes = 256 * 1024 // 256 KB +) + +func EnrichOperation(span tracer.Span, in middleware.InitializeInput, operation string) { + switch operation { + case "PutEvents": + handlePutEvents(span, in) + } +} + +func handlePutEvents(span tracer.Span, in middleware.InitializeInput) { + params, ok := in.Parameters.(*eventbridge.PutEventsInput) + if !ok { + log.Debug("Unable to read PutEvents params") + return + } + + // Create trace context + carrier := tracer.TextMapCarrier{} + err := tracer.Inject(span.Context(), carrier) + if err != nil { + log.Debug("Unable to inject trace context: %s", err) + return + } + + // Add start time + startTimeMillis := time.Now().UnixMilli() + carrier[startTimeKey] = strconv.FormatInt(startTimeMillis, 10) + + carrierJSON, err := json.Marshal(carrier) + if err != nil { + log.Debug("Unable to marshal trace context: %s", err) + return + } + + // Remove last '}' + reusedTraceContext := string(carrierJSON[:len(carrierJSON)-1]) + + for i := range params.Entries { + injectTraceContext(reusedTraceContext, ¶ms.Entries[i]) + } +} + +func injectTraceContext(baseTraceContext string, entryPtr *types.PutEventsRequestEntry) { + if entryPtr == nil { + return + } + + // Build the complete trace context + var traceContext string + if entryPtr.EventBusName != nil { + traceContext = fmt.Sprintf(`%s,"%s":"%s"}`, baseTraceContext, resourceNameKey, *entryPtr.EventBusName) + } else { + traceContext = baseTraceContext + "}" + } + + // Get current detail string + var detail string + if entryPtr.Detail == nil || *entryPtr.Detail == "" { + detail = "{}" + } else { + detail = *entryPtr.Detail + } + + // Basic JSON structure validation + if len(detail) < 2 || detail[len(detail)-1] != '}' { + log.Debug("Unable to parse detail JSON. Not injecting trace context into EventBridge payload.") + return + } + + // Create new detail string + var newDetail string + if len(detail) > 2 { + // Case where detail is not empty + newDetail = fmt.Sprintf(`%s,"%s":%s}`, detail[:len(detail)-1], datadogKey, traceContext) + } else { + // Cae where detail is empty + newDetail = fmt.Sprintf(`{"%s":%s}`, datadogKey, traceContext) + } + + // Check sizes + if len(newDetail) > maxSizeBytes { + log.Debug("Payload size too large to pass context") + return + } + + entryPtr.Detail = aws.String(newDetail) +} diff --git a/contrib/aws/internal/eventbridge/eventbridge_test.go b/contrib/aws/internal/eventbridge/eventbridge_test.go new file mode 100644 index 0000000000..77c9ab1e72 --- /dev/null +++ b/contrib/aws/internal/eventbridge/eventbridge_test.go @@ -0,0 +1,192 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016 Datadog, Inc. + +package eventbridge + +import ( + "context" + "encoding/json" + "fmt" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/eventbridge" + "github.com/aws/aws-sdk-go-v2/service/eventbridge/types" + "github.com/aws/smithy-go/middleware" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "strings" + "testing" +) + +func TestEnrichOperation(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + span := tracer.StartSpan("test-span") + + input := middleware.InitializeInput{ + Parameters: &eventbridge.PutEventsInput{ + Entries: []types.PutEventsRequestEntry{ + { + Detail: aws.String(`{"@123": "value", "_foo": "bar"}`), + EventBusName: aws.String("test-bus"), + }, + { + Detail: aws.String(`{"@123": "data", "_foo": "bar"}`), + EventBusName: aws.String("test-bus-2"), + }, + }, + }, + } + + EnrichOperation(span, input, "PutEvents") + + params, ok := input.Parameters.(*eventbridge.PutEventsInput) + require.True(t, ok) + require.Len(t, params.Entries, 2) + + for _, entry := range params.Entries { + var detail map[string]interface{} + err := json.Unmarshal([]byte(*entry.Detail), &detail) + require.NoError(t, err) + + assert.Contains(t, detail, "@123") // make sure user data still exists + assert.Contains(t, detail, "_foo") + assert.Contains(t, detail, datadogKey) + ddData, ok := detail[datadogKey].(map[string]interface{}) + require.True(t, ok) + + assert.Contains(t, ddData, startTimeKey) + assert.Contains(t, ddData, resourceNameKey) + assert.Equal(t, *entry.EventBusName, ddData[resourceNameKey]) + } +} + +func TestInjectTraceContext(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + ctx := context.Background() + span, _ := tracer.StartSpanFromContext(ctx, "test-span") + baseTraceContext := fmt.Sprintf(`{"x-datadog-trace-id":"%d","x-datadog-parent-id":"%d","x-datadog-start-time":"123456789"`, span.Context().TraceID(), span.Context().SpanID()) + + tests := []struct { + name string + entry types.PutEventsRequestEntry + expected func(*testing.T, *types.PutEventsRequestEntry) + }{ + { + name: "Inject into empty detail", + entry: types.PutEventsRequestEntry{ + EventBusName: aws.String("test-bus"), + }, + expected: func(t *testing.T, entry *types.PutEventsRequestEntry) { + assert.NotNil(t, entry.Detail) + var detail map[string]interface{} + err := json.Unmarshal([]byte(*entry.Detail), &detail) + require.NoError(t, err) + assert.Contains(t, detail, datadogKey) + }, + }, + { + name: "Inject into existing detail", + entry: types.PutEventsRequestEntry{ + Detail: aws.String(`{"existing": "data"}`), + EventBusName: aws.String("test-bus"), + }, + expected: func(t *testing.T, entry *types.PutEventsRequestEntry) { + var detail map[string]interface{} + err := json.Unmarshal([]byte(*entry.Detail), &detail) + require.NoError(t, err) + assert.Contains(t, detail, "existing") + assert.Equal(t, "data", detail["existing"]) + assert.Contains(t, detail, datadogKey) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + injectTraceContext(baseTraceContext, &tt.entry) + tt.expected(t, &tt.entry) + + var detail map[string]interface{} + err := json.Unmarshal([]byte(*tt.entry.Detail), &detail) + require.NoError(t, err) + + ddData := detail[datadogKey].(map[string]interface{}) + assert.Contains(t, ddData, startTimeKey) + assert.Contains(t, ddData, resourceNameKey) + assert.Equal(t, *tt.entry.EventBusName, ddData[resourceNameKey]) + + // Check that start time exists and is not empty + startTime, ok := ddData[startTimeKey] + assert.True(t, ok) + assert.Equal(t, startTime, "123456789") + + carrier := tracer.TextMapCarrier{} + for k, v := range ddData { + if s, ok := v.(string); ok { + carrier[k] = s + } + } + + extractedSpanContext, err := tracer.Extract(&carrier) + assert.NoError(t, err) + assert.Equal(t, span.Context().TraceID(), extractedSpanContext.TraceID()) + assert.Equal(t, span.Context().SpanID(), extractedSpanContext.SpanID()) + }) + } +} + +func TestInjectTraceContextSizeLimit(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + baseTraceContext := `{"x-datadog-trace-id":"12345","x-datadog-parent-id":"67890","x-datadog-start-time":"123456789"` + + tests := []struct { + name string + entry types.PutEventsRequestEntry + expected func(*testing.T, *types.PutEventsRequestEntry) + }{ + { + name: "Do not inject when payload is too large", + entry: types.PutEventsRequestEntry{ + Detail: aws.String(`{"large": "` + strings.Repeat("a", maxSizeBytes-50) + `"}`), + EventBusName: aws.String("test-bus"), + }, + expected: func(t *testing.T, entry *types.PutEventsRequestEntry) { + assert.GreaterOrEqual(t, len(*entry.Detail), maxSizeBytes-50) + assert.NotContains(t, *entry.Detail, datadogKey) + assert.True(t, strings.HasPrefix(*entry.Detail, `{"large": "`)) + assert.True(t, strings.HasSuffix(*entry.Detail, `"}`)) + }, + }, + { + name: "Inject when payload is just under the limit", + entry: types.PutEventsRequestEntry{ + Detail: aws.String(`{"large": "` + strings.Repeat("a", maxSizeBytes-1000) + `"}`), + EventBusName: aws.String("test-bus"), + }, + expected: func(t *testing.T, entry *types.PutEventsRequestEntry) { + assert.Less(t, len(*entry.Detail), maxSizeBytes) + var detail map[string]interface{} + err := json.Unmarshal([]byte(*entry.Detail), &detail) + require.NoError(t, err) + assert.Contains(t, detail, datadogKey) + assert.Contains(t, detail, "large") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + injectTraceContext(baseTraceContext, &tt.entry) + tt.expected(t, &tt.entry) + }) + } +} diff --git a/contrib/aws/internal/sns/sns.go b/contrib/aws/internal/sns/sns.go new file mode 100644 index 0000000000..b40ca5ea85 --- /dev/null +++ b/contrib/aws/internal/sns/sns.go @@ -0,0 +1,105 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016 Datadog, Inc. + +package sns + +import ( + "encoding/json" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sns" + "github.com/aws/aws-sdk-go-v2/service/sns/types" + "github.com/aws/smithy-go/middleware" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "gopkg.in/DataDog/dd-trace-go.v1/internal/log" +) + +const ( + datadogKey = "_datadog" + maxMessageAttributes = 10 +) + +func EnrichOperation(span tracer.Span, in middleware.InitializeInput, operation string) { + switch operation { + case "Publish": + handlePublish(span, in) + case "PublishBatch": + handlePublishBatch(span, in) + } +} + +func handlePublish(span tracer.Span, in middleware.InitializeInput) { + params, ok := in.Parameters.(*sns.PublishInput) + if !ok { + log.Debug("Unable to read PublishInput params") + return + } + + traceContext, err := getTraceContext(span) + if err != nil { + log.Debug("Unable to get trace context: %s", err.Error()) + return + } + + if params.MessageAttributes == nil { + params.MessageAttributes = make(map[string]types.MessageAttributeValue) + } + + injectTraceContext(traceContext, params.MessageAttributes) +} + +func handlePublishBatch(span tracer.Span, in middleware.InitializeInput) { + params, ok := in.Parameters.(*sns.PublishBatchInput) + if !ok { + log.Debug("Unable to read PublishBatch params") + return + } + + traceContext, err := getTraceContext(span) + if err != nil { + log.Debug("Unable to get trace context: %s", err.Error()) + return + } + + for i := range params.PublishBatchRequestEntries { + if params.PublishBatchRequestEntries[i].MessageAttributes == nil { + params.PublishBatchRequestEntries[i].MessageAttributes = make(map[string]types.MessageAttributeValue) + } + injectTraceContext(traceContext, params.PublishBatchRequestEntries[i].MessageAttributes) + } +} + +func getTraceContext(span tracer.Span) (types.MessageAttributeValue, error) { + carrier := tracer.TextMapCarrier{} + err := tracer.Inject(span.Context(), carrier) + if err != nil { + return types.MessageAttributeValue{}, err + } + + jsonBytes, err := json.Marshal(carrier) + if err != nil { + return types.MessageAttributeValue{}, err + } + + // Use Binary since SNS subscription filter policies fail silently with JSON + // strings. https://github.com/DataDog/datadog-lambda-js/pull/269 + attribute := types.MessageAttributeValue{ + DataType: aws.String("Binary"), + BinaryValue: jsonBytes, + } + + return attribute, nil +} + +func injectTraceContext(traceContext types.MessageAttributeValue, messageAttributes map[string]types.MessageAttributeValue) { + // SNS only allows a maximum of 10 message attributes. + // https://docs.aws.amazon.com/sns/latest/dg/sns-message-attributes.html + // Only inject if there's room. + if len(messageAttributes) >= maxMessageAttributes { + log.Info("Cannot inject trace context: message already has maximum allowed attributes") + return + } + + messageAttributes[datadogKey] = traceContext +} diff --git a/contrib/aws/internal/sns/sns_test.go b/contrib/aws/internal/sns/sns_test.go new file mode 100644 index 0000000000..0f955680f0 --- /dev/null +++ b/contrib/aws/internal/sns/sns_test.go @@ -0,0 +1,177 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016 Datadog, Inc. + +package sns + +import ( + "context" + "encoding/json" + "fmt" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sns" + "github.com/aws/aws-sdk-go-v2/service/sns/types" + "github.com/aws/smithy-go/middleware" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" +) + +func TestEnrichOperation(t *testing.T) { + tests := []struct { + name string + operation string + input middleware.InitializeInput + setup func(context.Context) tracer.Span + check func(*testing.T, middleware.InitializeInput) + }{ + { + name: "Publish", + operation: "Publish", + input: middleware.InitializeInput{ + Parameters: &sns.PublishInput{ + Message: aws.String("test message"), + TopicArn: aws.String("arn:aws:sns:us-east-1:123456789012:test-topic"), + }, + }, + setup: func(ctx context.Context) tracer.Span { + span, _ := tracer.StartSpanFromContext(ctx, "test-span") + return span + }, + check: func(t *testing.T, in middleware.InitializeInput) { + params, ok := in.Parameters.(*sns.PublishInput) + require.True(t, ok) + require.NotNil(t, params) + require.NotNil(t, params.MessageAttributes) + assert.Contains(t, params.MessageAttributes, datadogKey) + assert.NotNil(t, params.MessageAttributes[datadogKey].DataType) + assert.Equal(t, "Binary", *params.MessageAttributes[datadogKey].DataType) + assert.NotNil(t, params.MessageAttributes[datadogKey].BinaryValue) + assert.NotEmpty(t, params.MessageAttributes[datadogKey].BinaryValue) + }, + }, + { + name: "PublishBatch", + operation: "PublishBatch", + input: middleware.InitializeInput{ + Parameters: &sns.PublishBatchInput{ + TopicArn: aws.String("arn:aws:sns:us-east-1:123456789012:test-topic"), + PublishBatchRequestEntries: []types.PublishBatchRequestEntry{ + { + Id: aws.String("1"), + Message: aws.String("test message 1"), + }, + { + Id: aws.String("2"), + Message: aws.String("test message 2"), + }, + }, + }, + }, + setup: func(ctx context.Context) tracer.Span { + span, _ := tracer.StartSpanFromContext(ctx, "test-span") + return span + }, + check: func(t *testing.T, in middleware.InitializeInput) { + params, ok := in.Parameters.(*sns.PublishBatchInput) + require.True(t, ok) + require.NotNil(t, params) + require.NotNil(t, params.PublishBatchRequestEntries) + require.Len(t, params.PublishBatchRequestEntries, 2) + + for _, entry := range params.PublishBatchRequestEntries { + require.NotNil(t, entry.MessageAttributes) + assert.Contains(t, entry.MessageAttributes, datadogKey) + assert.NotNil(t, entry.MessageAttributes[datadogKey].DataType) + assert.Equal(t, "Binary", *entry.MessageAttributes[datadogKey].DataType) + assert.NotNil(t, entry.MessageAttributes[datadogKey].BinaryValue) + assert.NotEmpty(t, entry.MessageAttributes[datadogKey].BinaryValue) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + ctx := context.Background() + span := tt.setup(ctx) + + EnrichOperation(span, tt.input, tt.operation) + + if tt.check != nil { + tt.check(t, tt.input) + } + }) + } +} + +func TestInjectTraceContext(t *testing.T) { + tests := []struct { + name string + existingAttributes int + expectInjection bool + }{ + { + name: "Inject with no existing attributes", + existingAttributes: 0, + expectInjection: true, + }, + { + name: "Inject with some existing attributes", + existingAttributes: 5, + expectInjection: true, + }, + { + name: "No injection when at max attributes", + existingAttributes: maxMessageAttributes, + expectInjection: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + span := tracer.StartSpan("test-span") + + messageAttributes := make(map[string]types.MessageAttributeValue) + for i := 0; i < tt.existingAttributes; i++ { + messageAttributes[fmt.Sprintf("attr%d", i)] = types.MessageAttributeValue{ + DataType: aws.String("String"), + StringValue: aws.String("value"), + } + } + + traceContext, err := getTraceContext(span) + assert.NoError(t, err) + injectTraceContext(traceContext, messageAttributes) + + if tt.expectInjection { + assert.Contains(t, messageAttributes, datadogKey) + assert.NotNil(t, messageAttributes[datadogKey].DataType) + assert.Equal(t, "Binary", *messageAttributes[datadogKey].DataType) + assert.NotNil(t, messageAttributes[datadogKey].BinaryValue) + assert.NotEmpty(t, messageAttributes[datadogKey].BinaryValue) + + carrier := tracer.TextMapCarrier{} + err := json.Unmarshal(messageAttributes[datadogKey].BinaryValue, &carrier) + assert.NoError(t, err) + + extractedSpanContext, err := tracer.Extract(carrier) + assert.NoError(t, err) + assert.Equal(t, span.Context().TraceID(), extractedSpanContext.TraceID()) + assert.Equal(t, span.Context().SpanID(), extractedSpanContext.SpanID()) + } else { + assert.NotContains(t, messageAttributes, datadogKey) + } + }) + } +} diff --git a/contrib/aws/internal/sqs/sqs.go b/contrib/aws/internal/sqs/sqs.go new file mode 100644 index 0000000000..9fbd8a9f90 --- /dev/null +++ b/contrib/aws/internal/sqs/sqs.go @@ -0,0 +1,103 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016 Datadog, Inc. + +package sqs + +import ( + "encoding/json" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sqs" + "github.com/aws/aws-sdk-go-v2/service/sqs/types" + "github.com/aws/smithy-go/middleware" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "gopkg.in/DataDog/dd-trace-go.v1/internal/log" +) + +const ( + datadogKey = "_datadog" + maxMessageAttributes = 10 +) + +func EnrichOperation(span tracer.Span, in middleware.InitializeInput, operation string) { + switch operation { + case "SendMessage": + handleSendMessage(span, in) + case "SendMessageBatch": + handleSendMessageBatch(span, in) + } +} + +func handleSendMessage(span tracer.Span, in middleware.InitializeInput) { + params, ok := in.Parameters.(*sqs.SendMessageInput) + if !ok { + log.Debug("Unable to read SendMessage params") + return + } + + traceContext, err := getTraceContext(span) + if err != nil { + log.Debug("Unable to get trace context: %s", err.Error()) + return + } + + if params.MessageAttributes == nil { + params.MessageAttributes = make(map[string]types.MessageAttributeValue) + } + + injectTraceContext(traceContext, params.MessageAttributes) +} + +func handleSendMessageBatch(span tracer.Span, in middleware.InitializeInput) { + params, ok := in.Parameters.(*sqs.SendMessageBatchInput) + if !ok { + log.Debug("Unable to read SendMessageBatch params") + return + } + + traceContext, err := getTraceContext(span) + if err != nil { + log.Debug("Unable to get trace context: %s", err.Error()) + return + } + + for i := range params.Entries { + if params.Entries[i].MessageAttributes == nil { + params.Entries[i].MessageAttributes = make(map[string]types.MessageAttributeValue) + } + injectTraceContext(traceContext, params.Entries[i].MessageAttributes) + } +} + +func getTraceContext(span tracer.Span) (types.MessageAttributeValue, error) { + carrier := tracer.TextMapCarrier{} + err := tracer.Inject(span.Context(), carrier) + if err != nil { + return types.MessageAttributeValue{}, err + } + + jsonBytes, err := json.Marshal(carrier) + if err != nil { + return types.MessageAttributeValue{}, err + } + + attribute := types.MessageAttributeValue{ + DataType: aws.String("String"), + StringValue: aws.String(string(jsonBytes)), + } + + return attribute, nil +} + +func injectTraceContext(traceContext types.MessageAttributeValue, messageAttributes map[string]types.MessageAttributeValue) { + // SQS only allows a maximum of 10 message attributes. + // https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-message-metadata.html#sqs-message-attributes + // Only inject if there's room. + if len(messageAttributes) >= maxMessageAttributes { + log.Info("Cannot inject trace context: message already has maximum allowed attributes") + return + } + + messageAttributes[datadogKey] = traceContext +} diff --git a/contrib/aws/internal/sqs/sqs_test.go b/contrib/aws/internal/sqs/sqs_test.go new file mode 100644 index 0000000000..1a66adab09 --- /dev/null +++ b/contrib/aws/internal/sqs/sqs_test.go @@ -0,0 +1,181 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016 Datadog, Inc. + +package sqs + +import ( + "context" + "encoding/json" + "fmt" + "github.com/aws/aws-sdk-go-v2/service/sqs/types" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sqs" + "github.com/aws/smithy-go/middleware" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" +) + +func TestEnrichOperation(t *testing.T) { + tests := []struct { + name string + operation string + input middleware.InitializeInput + setup func(context.Context) tracer.Span + check func(*testing.T, middleware.InitializeInput) + }{ + { + name: "SendMessage", + operation: "SendMessage", + input: middleware.InitializeInput{ + Parameters: &sqs.SendMessageInput{ + MessageBody: aws.String("test message"), + QueueUrl: aws.String("https://sqs.us-east-1.amazonaws.com/1234567890/test-queue"), + }, + }, + setup: func(ctx context.Context) tracer.Span { + span, _ := tracer.StartSpanFromContext(ctx, "test-span") + return span + }, + check: func(t *testing.T, in middleware.InitializeInput) { + params, ok := in.Parameters.(*sqs.SendMessageInput) + require.True(t, ok) + require.NotNil(t, params) + require.NotNil(t, params.MessageAttributes) + assert.Contains(t, params.MessageAttributes, datadogKey) + assert.NotNil(t, params.MessageAttributes[datadogKey].DataType) + assert.Equal(t, "String", *params.MessageAttributes[datadogKey].DataType) + assert.NotNil(t, params.MessageAttributes[datadogKey].StringValue) + assert.NotEmpty(t, *params.MessageAttributes[datadogKey].StringValue) + }, + }, + { + name: "SendMessageBatch", + operation: "SendMessageBatch", + input: middleware.InitializeInput{ + Parameters: &sqs.SendMessageBatchInput{ + QueueUrl: aws.String("https://sqs.us-east-1.amazonaws.com/1234567890/test-queue"), + Entries: []types.SendMessageBatchRequestEntry{ + { + Id: aws.String("1"), + MessageBody: aws.String("test message 1"), + }, + { + Id: aws.String("2"), + MessageBody: aws.String("test message 2"), + }, + { + Id: aws.String("3"), + MessageBody: aws.String("test message 3"), + }, + }, + }, + }, + setup: func(ctx context.Context) tracer.Span { + span, _ := tracer.StartSpanFromContext(ctx, "test-span") + return span + }, + check: func(t *testing.T, in middleware.InitializeInput) { + params, ok := in.Parameters.(*sqs.SendMessageBatchInput) + require.True(t, ok) + require.NotNil(t, params) + require.NotNil(t, params.Entries) + require.Len(t, params.Entries, 3) + + for _, entry := range params.Entries { + require.NotNil(t, entry.MessageAttributes) + assert.Contains(t, entry.MessageAttributes, datadogKey) + assert.NotNil(t, entry.MessageAttributes[datadogKey].DataType) + assert.Equal(t, "String", *entry.MessageAttributes[datadogKey].DataType) + assert.NotNil(t, entry.MessageAttributes[datadogKey].StringValue) + assert.NotEmpty(t, *entry.MessageAttributes[datadogKey].StringValue) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + ctx := context.Background() + span := tt.setup(ctx) + + EnrichOperation(span, tt.input, tt.operation) + + if tt.check != nil { + tt.check(t, tt.input) + } + }) + } +} + +func TestInjectTraceContext(t *testing.T) { + tests := []struct { + name string + existingAttributes int + expectInjection bool + }{ + { + name: "Inject with no existing attributes", + existingAttributes: 0, + expectInjection: true, + }, + { + name: "Inject with some existing attributes", + existingAttributes: 5, + expectInjection: true, + }, + { + name: "No injection when at max attributes", + existingAttributes: maxMessageAttributes, + expectInjection: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + span := tracer.StartSpan("test-span") + + messageAttributes := make(map[string]types.MessageAttributeValue) + for i := 0; i < tt.existingAttributes; i++ { + messageAttributes[fmt.Sprintf("attr%d", i)] = types.MessageAttributeValue{ + DataType: aws.String("String"), + StringValue: aws.String("value"), + } + } + + traceContext, err := getTraceContext(span) + assert.NoError(t, err) + injectTraceContext(traceContext, messageAttributes) + + if tt.expectInjection { + assert.Contains(t, messageAttributes, datadogKey) + assert.NotNil(t, messageAttributes[datadogKey].DataType) + assert.Equal(t, "String", *messageAttributes[datadogKey].DataType) + assert.NotNil(t, messageAttributes[datadogKey].StringValue) + assert.NotEmpty(t, *messageAttributes[datadogKey].StringValue) + + carrier := tracer.TextMapCarrier{} + err := json.Unmarshal([]byte(*messageAttributes[datadogKey].StringValue), &carrier) + assert.NoError(t, err) + + extractedSpanContext, err := tracer.Extract(carrier) + assert.NoError(t, err) + assert.Equal(t, span.Context().TraceID(), extractedSpanContext.TraceID()) + assert.Equal(t, span.Context().SpanID(), extractedSpanContext.SpanID()) + } else { + assert.NotContains(t, messageAttributes, datadogKey) + } + }) + } +}