From cb9c69933b45795133a15703364d8c535957173b Mon Sep 17 00:00:00 2001 From: Benjamin Gustin Date: Tue, 10 Oct 2023 11:55:05 +0200 Subject: [PATCH 1/3] also stip password from url if any --- xray/client.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/xray/client.go b/xray/client.go index 020d45a2..7679d64d 100644 --- a/xray/client.go +++ b/xray/client.go @@ -14,6 +14,7 @@ import ( "net/http/httptrace" "net/url" "strconv" + "strings" "github.com/aws/aws-xray-sdk-go/internal/logger" ) @@ -87,7 +88,7 @@ func (rt *roundtripper) RoundTrip(r *http.Request) (*http.Response, error) { } seg.GetHTTP().GetRequest().Method = r.Method - seg.GetHTTP().GetRequest().URL = stripQueryFromURL(*r.URL) + seg.GetHTTP().GetRequest().URL = stripURL(*r.URL) r.Header.Set(TraceIDHeaderKey, seg.DownstreamHeader().String()) seg.Unlock() @@ -119,7 +120,11 @@ func (rt *roundtripper) RoundTrip(r *http.Request) (*http.Response, error) { return resp, err } -func stripQueryFromURL(u url.URL) string { +func stripURL(u url.URL) string { u.RawQuery = "" + _, passSet := u.User.Password() + if passSet { + return strings.Replace(u.String(), u.User.String()+"@", u.User.Username()+":***@", 1) + } return u.String() } From 96adec0e37245dfca8569c0f18509c8bd69258fc Mon Sep 17 00:00:00 2001 From: Benjamin Gustin Date: Thu, 19 Oct 2023 10:47:35 +0200 Subject: [PATCH 2/3] add TestRoundTripWithBasicAuth unit test --- xray/client_test.go | 57 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/xray/client_test.go b/xray/client_test.go index 827b3bed..59442b83 100644 --- a/xray/client_test.go +++ b/xray/client_test.go @@ -16,6 +16,7 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "net/url" "sync" "testing" @@ -177,6 +178,62 @@ func TestRoundTripWithQueryParameter(t *testing.T) { assert.Equal(t, headers.RootTraceID, seg.TraceID) } +func TestRoundTripWithBasicAuth(t *testing.T) { + ctx, td := NewTestDaemon() + defer td.Close() + + const content = `200 - Nothing to see` + const responseContentLength = len(content) + + var userInfo = url.UserPassword("user", "pass") + + ch := make(chan XRayHeaders, 1) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + username, password, ok := r.BasicAuth() + pass, _ := userInfo.Password() + assert.Equal(t, ok, true) + assert.Equal(t, username, userInfo.Username()) + assert.Equal(t, password, pass) + ch <- ParseHeadersForTest(r.Header) + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte(content)); err != nil { + panic(err) + } + })) + defer ts.Close() + + client := Client(nil) + + u, err := url.Parse(ts.URL) + if !assert.NoError(t, err) { + return + } + u.User = userInfo + + err = httpDoTest(ctx, client, http.MethodGet, u.String(), nil) + if !assert.NoError(t, err) { + return + } + + seg, err := td.Recv() + if !assert.NoError(t, err) { + return + } + var subseg *Segment + if assert.NoError(t, json.Unmarshal(seg.Subsegments[0], &subseg)) { + assert.Equal(t, "remote", subseg.Namespace) + assert.Equal(t, http.MethodGet, subseg.HTTP.Request.Method) + assert.Equal(t, stripURL(*u), subseg.HTTP.Request.URL) + assert.Equal(t, http.StatusOK, subseg.HTTP.Response.Status) + assert.Equal(t, responseContentLength, subseg.HTTP.Response.ContentLength) + assert.False(t, subseg.Throttle) + assert.False(t, subseg.Error) + assert.False(t, subseg.Fault) + } + headers := <-ch + assert.Equal(t, headers.RootTraceID, seg.TraceID) +} + func TestRoundTripWithError(t *testing.T) { ctx, td := NewTestDaemon() defer td.Close() From 26f2918a6bd6de689d8ff9268caa1d19560ad524 Mon Sep 17 00:00:00 2001 From: Benjamin Gustin Date: Tue, 24 Oct 2023 17:45:44 +0200 Subject: [PATCH 3/3] do not test function against itself --- xray/client_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xray/client_test.go b/xray/client_test.go index 59442b83..3bc1fbd3 100644 --- a/xray/client_test.go +++ b/xray/client_test.go @@ -223,7 +223,7 @@ func TestRoundTripWithBasicAuth(t *testing.T) { if assert.NoError(t, json.Unmarshal(seg.Subsegments[0], &subseg)) { assert.Equal(t, "remote", subseg.Namespace) assert.Equal(t, http.MethodGet, subseg.HTTP.Request.Method) - assert.Equal(t, stripURL(*u), subseg.HTTP.Request.URL) + assert.Equal(t, "http://user:***@127.0.0.1:"+u.Port(), subseg.HTTP.Request.URL) assert.Equal(t, http.StatusOK, subseg.HTTP.Response.Status) assert.Equal(t, responseContentLength, subseg.HTTP.Response.ContentLength) assert.False(t, subseg.Throttle)