-
Notifications
You must be signed in to change notification settings - Fork 117
/
client.go
130 lines (111 loc) · 3.29 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
// Copyright 2017-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
package xray
import (
"context"
"net/http"
"net/http/httptrace"
"net/url"
"strconv"
"strings"
"github.com/aws/aws-xray-sdk-go/internal/logger"
)
const emptyHostRename = "empty_host_error"
// Client creates a shallow copy of the provided http client,
// defaulting to http.DefaultClient, with roundtripper wrapped
// with xray.RoundTripper.
func Client(c *http.Client) *http.Client {
if c == nil {
c = http.DefaultClient
}
transport := c.Transport
if transport == nil {
transport = http.DefaultTransport
}
return &http.Client{
Transport: RoundTripper(transport),
CheckRedirect: c.CheckRedirect,
Jar: c.Jar,
Timeout: c.Timeout,
}
}
// RoundTripper wraps the provided http roundtripper with xray.Capture,
// sets HTTP-specific xray fields, and adds the trace header to the outbound request.
func RoundTripper(rt http.RoundTripper) http.RoundTripper {
return &roundtripper{rt}
}
type roundtripper struct {
Base http.RoundTripper
}
// RoundTrip wraps a single HTTP transaction and add corresponding information into a subsegment.
func (rt *roundtripper) RoundTrip(r *http.Request) (*http.Response, error) {
var isEmptyHost bool
var resp *http.Response
host := r.Host
if host == "" {
if h := r.URL.Host; h != "" {
host = h
} else {
host = emptyHostRename
isEmptyHost = true
}
}
err := Capture(r.Context(), host, func(ctx context.Context) error {
var err error
seg := GetSegment(ctx)
if seg == nil {
resp, err = rt.Base.RoundTrip(r)
logger.Warnf("failed to record HTTP transaction: segment cannot be found.")
return err
}
ct, e := NewClientTrace(ctx)
if e != nil {
return e
}
r = r.WithContext(httptrace.WithClientTrace(ctx, ct.httpTrace))
seg.Lock()
if isEmptyHost {
seg.Namespace = ""
} else {
seg.Namespace = "remote"
}
seg.GetHTTP().GetRequest().Method = r.Method
seg.GetHTTP().GetRequest().URL = stripURL(*r.URL)
r.Header.Set(TraceIDHeaderKey, seg.DownstreamHeader().String())
seg.Unlock()
resp, err = rt.Base.RoundTrip(r)
if resp != nil {
seg.Lock()
seg.GetHTTP().GetResponse().Status = resp.StatusCode
seg.GetHTTP().GetResponse().ContentLength, _ = strconv.Atoi(resp.Header.Get("Content-Length"))
if resp.StatusCode >= 400 && resp.StatusCode < 500 {
seg.Error = true
}
if resp.StatusCode == 429 {
seg.Throttle = true
}
if resp.StatusCode >= 500 && resp.StatusCode < 600 {
seg.Fault = true
}
seg.Unlock()
}
if err != nil {
ct.subsegments.GotConn(nil, err)
}
return err
})
return resp, err
}
func stripURL(u url.URL) string {
u.RawQuery = ""
_, passSet := u.User.Password()
if passSet {
return strings.Replace(u.String(), u.User.String()+"@", u.User.Username()+":***@", 1)
}
return u.String()
}