diff --git a/contrib/aws/aws-sdk-go-v2/aws/aws.go b/contrib/aws/aws-sdk-go-v2/aws/aws.go index 503311c15c..478195b81e 100644 --- a/contrib/aws/aws-sdk-go-v2/aws/aws.go +++ b/contrib/aws/aws-sdk-go-v2/aws/aws.go @@ -104,8 +104,10 @@ func (mw *traceMiddleware) deserializeTraceMiddleware(stack *middleware.Stack) e // Get values out of the request. if req, ok := in.Request.(*smithyhttp.Request); ok { + url := *req.URL + url.User = nil span.SetTag(ext.HTTPMethod, req.Method) - span.SetTag(ext.HTTPURL, req.URL.String()) + span.SetTag(ext.HTTPURL, url.String()) span.SetTag(tagAWSAgent, req.Header.Get("User-Agent")) } diff --git a/contrib/aws/aws-sdk-go-v2/aws/aws_test.go b/contrib/aws/aws-sdk-go-v2/aws/aws_test.go index c654b94998..be562d2cb2 100644 --- a/contrib/aws/aws-sdk-go-v2/aws/aws_test.go +++ b/contrib/aws/aws-sdk-go-v2/aws/aws_test.go @@ -7,8 +7,11 @@ package aws import ( "context" + "encoding/base64" "net/http" "net/http/httptest" + "net/url" + "strings" "testing" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" @@ -17,6 +20,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/sqs" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestAppendMiddleware(t *testing.T) { @@ -199,3 +203,54 @@ func TestAppendMiddleware_WithOpts(t *testing.T) { }) } } + +func TestHTTPCredentials(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + var auth string + + server := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if enc, ok := r.Header["Authorization"]; ok { + encoded := strings.TrimPrefix(enc[0], "Basic ") + if b64, err := base64.StdEncoding.DecodeString(encoded); err == nil { + auth = string(b64) + } + } + + w.Header().Set("X-Amz-RequestId", "test_req") + w.WriteHeader(200) + w.Write([]byte(`{}`)) + })) + defer server.Close() + + u, err := url.Parse(server.URL) + require.NoError(t, err) + u.User = url.UserPassword("myuser", "mypassword") + + resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) { + return aws.Endpoint{ + PartitionID: "aws", + URL: u.String(), + SigningRegion: "eu-west-1", + }, nil + }) + + awsCfg := aws.Config{ + Region: "eu-west-1", + Credentials: aws.AnonymousCredentials{}, + EndpointResolver: resolver, + } + + AppendMiddleware(&awsCfg) + + sqsClient := sqs.NewFromConfig(awsCfg) + sqsClient.ListQueues(context.Background(), &sqs.ListQueuesInput{}) + + spans := mt.FinishedSpans() + + s := spans[0] + assert.Equal(t, server.URL+"/", s.Tag(ext.HTTPURL)) + assert.Equal(t, auth, "myuser:mypassword") +} diff --git a/contrib/aws/aws-sdk-go/aws/aws.go b/contrib/aws/aws-sdk-go/aws/aws.go index d4c6bb4105..515ed78084 100644 --- a/contrib/aws/aws-sdk-go/aws/aws.go +++ b/contrib/aws/aws-sdk-go/aws/aws.go @@ -60,6 +60,8 @@ func (h *handlers) Send(req *request.Request) { if req.RetryCount != 0 { return } + url := *req.HTTPRequest.URL + url.User = nil opts := []ddtrace.StartSpanOption{ tracer.SpanType(ext.SpanTypeHTTP), tracer.ServiceName(h.serviceName(req)), @@ -68,7 +70,7 @@ func (h *handlers) Send(req *request.Request) { tracer.Tag(tagAWSOperation, h.awsOperation(req)), tracer.Tag(tagAWSRegion, h.awsRegion(req)), tracer.Tag(ext.HTTPMethod, req.Operation.HTTPMethod), - tracer.Tag(ext.HTTPURL, req.HTTPRequest.URL.String()), + tracer.Tag(ext.HTTPURL, url.String()), tracer.Tag(ext.Component, "aws/aws-sdk-go/aws"), tracer.Tag(ext.SpanKind, ext.SpanKindClient), } diff --git a/contrib/aws/aws-sdk-go/aws/aws_test.go b/contrib/aws/aws-sdk-go/aws/aws_test.go index 2db340da9e..7767ff6889 100644 --- a/contrib/aws/aws-sdk-go/aws/aws_test.go +++ b/contrib/aws/aws-sdk-go/aws/aws_test.go @@ -7,16 +7,23 @@ package aws import ( "context" + "encoding/base64" "errors" + "net/http" + "net/http/httptest" + "net/url" + "strings" "testing" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/s3" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer" @@ -182,3 +189,61 @@ func TestRetries(t *testing.T) { assert.Len(t, mt.FinishedSpans(), 1) assert.Equal(t, mt.FinishedSpans()[0].Tag(tagAWSRetryCount), 3) } + +func TestHTTPCredentials(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + var auth string + + server := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if enc, ok := r.Header["Authorization"]; ok { + encoded := strings.TrimPrefix(enc[0], "Basic ") + if b64, err := base64.StdEncoding.DecodeString(encoded); err == nil { + auth = string(b64) + } + } + + w.Header().Set("X-Amz-RequestId", "test_req") + w.WriteHeader(200) + w.Write([]byte(`{}`)) + })) + defer server.Close() + + u, err := url.Parse(server.URL) + require.NoError(t, err) + u.User = url.UserPassword("myuser", "mypassword") + + resolver := endpoints.ResolverFunc(func(service, region string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { + return endpoints.ResolvedEndpoint{ + PartitionID: "aws", + URL: u.String(), + SigningRegion: "eu-west-1", + }, nil + }) + + region := "eu-west-1" + awsCfg := aws.Config{ + Region: ®ion, + Credentials: credentials.AnonymousCredentials, + EndpointResolver: resolver, + } + session := WrapSession(session.Must(session.NewSession(&awsCfg))) + + ctx := context.Background() + s3api := s3.New(session) + req, _ := s3api.GetObjectRequest(&s3.GetObjectInput{ + Bucket: aws.String("BUCKET"), + Key: aws.String("KEY"), + }) + req.SetContext(ctx) + err = req.Send() + require.NoError(t, err) + + spans := mt.FinishedSpans() + + s := spans[0] + assert.Equal(t, server.URL+"/BUCKET/KEY", s.Tag(ext.HTTPURL)) + assert.Equal(t, auth, "myuser:mypassword") +} diff --git a/contrib/net/http/roundtripper.go b/contrib/net/http/roundtripper.go index 13bff8c5dd..b348784e6a 100644 --- a/contrib/net/http/roundtripper.go +++ b/contrib/net/http/roundtripper.go @@ -27,11 +27,13 @@ func (rt *roundTripper) RoundTrip(req *http.Request) (res *http.Response, err er return rt.base.RoundTrip(req) } resourceName := rt.cfg.resourceNamer(req) + url := *req.URL + url.User = nil opts := []ddtrace.StartSpanOption{ tracer.SpanType(ext.SpanTypeHTTP), tracer.ResourceName(resourceName), tracer.Tag(ext.HTTPMethod, req.Method), - tracer.Tag(ext.HTTPURL, req.URL.String()), + tracer.Tag(ext.HTTPURL, url.String()), tracer.Tag(ext.Component, "net/http"), tracer.Tag(ext.SpanKind, ext.SpanKindClient), } diff --git a/contrib/net/http/roundtripper_test.go b/contrib/net/http/roundtripper_test.go index 1043aa0d3f..6b6b9243df 100644 --- a/contrib/net/http/roundtripper_test.go +++ b/contrib/net/http/roundtripper_test.go @@ -6,13 +6,17 @@ package http import ( + "encoding/base64" "fmt" "net/http" "net/http/httptest" + "net/url" + "strings" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" @@ -179,6 +183,49 @@ func TestRoundTripperNetworkError(t *testing.T) { assert.Equal(t, "net/http", s0.Tag(ext.Component)) } +func TestRoundTripperCredentials(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + var auth string + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if enc, ok := r.Header["Authorization"]; ok { + encoded := strings.TrimPrefix(enc[0], "Basic ") + if b64, err := base64.StdEncoding.DecodeString(encoded); err == nil { + auth = string(b64) + } + } + + })) + defer s.Close() + + rt := WrapRoundTripper(http.DefaultTransport, + WithBefore(func(req *http.Request, span ddtrace.Span) { + span.SetTag("CalledBefore", true) + }), + WithAfter(func(res *http.Response, span ddtrace.Span) { + span.SetTag("CalledAfter", true) + })) + + client := &http.Client{ + Transport: rt, + } + + u, err := url.Parse(s.URL) + require.NoError(t, err) + u.User = url.UserPassword("myuser", "mypassword") + + client.Get(u.String() + "/hello/world") + + spans := mt.FinishedSpans() + require.Len(t, spans, 1) + + s1 := spans[0] + + assert.Equal(t, s.URL+"/hello/world", s1.Tag(ext.HTTPURL)) + assert.Equal(t, auth, "myuser:mypassword") +} + func TestWrapClient(t *testing.T) { c := WrapClient(http.DefaultClient) assert.Equal(t, c, http.DefaultClient)