diff --git a/xrayhttp/client.go b/xrayhttp/client.go index 376ede3..fc249e1 100644 --- a/xrayhttp/client.go +++ b/xrayhttp/client.go @@ -1,8 +1,12 @@ package xrayhttp import ( + "context" + "io" "net/http" + "net/http/httptrace" "strconv" + "sync" "github.com/shogo82148/aws-xray-yasdk-go/xray" "github.com/shogo82148/aws-xray-yasdk-go/xray/schema" @@ -64,12 +68,19 @@ func (rt *roundtripper) RoundTrip(req *http.Request) (*http.Response, error) { } seg.SetHTTPRequest(requestInfo) + // set trace hooks ctx, cancel := WithClientTrace(ctx) defer cancel() + respTracer := &clientResponseTracer{BaseContext: ctx} + ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ + GotFirstResponseByte: respTracer.GotFirstResponseByte, + }) req = req.WithContext(ctx) + resp, err := rt.Base.RoundTrip(req) if err != nil { seg.AddError(err) + respTracer.Close() return nil, err } @@ -89,5 +100,53 @@ func (rt *roundtripper) RoundTrip(req *http.Request) (*http.Response, error) { if resp.StatusCode >= 500 && resp.StatusCode < 600 { seg.SetFault() } + if resp.StatusCode == http.StatusSwitchingProtocols { + respTracer.Close() + } else { + respTracer.body = resp.Body + resp.Body = respTracer + } return resp, err } + +type clientResponseTracer struct { + BaseContext context.Context + mu sync.RWMutex + body io.ReadCloser + ctx context.Context + seg *xray.Segment +} + +func (r *clientResponseTracer) GotFirstResponseByte() { + r.mu.Lock() + defer r.mu.Unlock() + if r.ctx != nil { + return + } + r.ctx, r.seg = xray.BeginSubsegment(r.BaseContext, "response") +} + +func (r *clientResponseTracer) Read(b []byte) (int, error) { + r.mu.RLock() + body := r.body + r.mu.RUnlock() + if body != nil { + return body.Read(b) + } + return 0, io.EOF +} + +func (r *clientResponseTracer) Close() error { + r.mu.Lock() + defer r.mu.Unlock() + + var err error + if r.body != nil { + err = r.body.Close() + } + if r.ctx != nil { + r.seg.Close() + r.ctx, r.seg = nil, nil + } + return err +} diff --git a/xrayhttp/client_test.go b/xrayhttp/client_test.go index a8238c8..5416c6a 100644 --- a/xrayhttp/client_test.go +++ b/xrayhttp/client_test.go @@ -121,6 +121,7 @@ func TestClient(t *testing.T) { }, }, {Name: "request"}, + {Name: "response"}, }, }, }, @@ -223,6 +224,7 @@ func TestClient_StatusTooManyRequests(t *testing.T) { }, }, {Name: "request"}, + {Name: "response"}, }, }, }, @@ -324,6 +326,7 @@ func TestClient_StatusInternalServerError(t *testing.T) { }, }, {Name: "request"}, + {Name: "response"}, }, }, }, @@ -442,6 +445,7 @@ func TestClient_TLS(t *testing.T) { }, }, {Name: "request"}, + {Name: "response"}, }, }, }, @@ -559,6 +563,7 @@ func TestClient_DNS(t *testing.T) { }, }, {Name: "request"}, + {Name: "response"}, }, }, },