diff --git a/pkg/httputils/http_util.go b/pkg/httputils/http_util.go index 583ba677f..64c2b04c0 100644 --- a/pkg/httputils/http_util.go +++ b/pkg/httputils/http_util.go @@ -18,10 +18,12 @@ package httputils import ( "bytes" + "context" "crypto/tls" "crypto/x509" "encoding/json" "fmt" + "io" "io/ioutil" "net" "net/http" @@ -52,6 +54,14 @@ const ( DefaultTimeout = 500 * time.Millisecond ) +var ( + // DefaultBuiltInTransport is the transport for HTTPWithHeaders. + DefaultBuiltInTransport *http.Transport + + // DefaultBuiltInHTTPClient is the http client for HTTPWithHeaders. + DefaultBuiltInHTTPClient *http.Client +) + // DefaultHTTPClient is the default implementation of SimpleHTTPClient. var DefaultHTTPClient SimpleHTTPClient = &defaultHTTPClient{} @@ -84,6 +94,25 @@ func init() { TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, } + + DefaultBuiltInTransport = &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + + DefaultBuiltInHTTPClient = &http.Client{ + Transport: DefaultBuiltInTransport, + } + + RegisterProtocolOnTransport(DefaultBuiltInTransport) } // ---------------------------------------------------------------------------- @@ -255,6 +284,10 @@ func HTTPGetWithTLS(url string, headers map[string]string, timeout time.Duration // HTTPWithHeaders sends an HTTP request with headers and specified method. func HTTPWithHeaders(method, url string, headers map[string]string, timeout time.Duration, tlsConfig *tls.Config) (*http.Response, error) { + var ( + cancel func() + ) + req, err := http.NewRequest(method, url, nil) if err != nil { return nil, err @@ -264,33 +297,49 @@ func HTTPWithHeaders(method, url string, headers map[string]string, timeout time req.Header.Add(k, v) } - // copy from http.DefaultTransport - transport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - DualStack: true, - }).DialContext, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, + if timeout > 0 { + timeoutCtx, cancelFunc := context.WithTimeout(context.Background(), timeout) + req = req.WithContext(timeoutCtx) + cancel = cancelFunc } - RegisterProtocolOnTransport(transport) + + var c = DefaultBuiltInHTTPClient if tlsConfig != nil { + // copy from http.DefaultTransport + transport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + + RegisterProtocolOnTransport(transport) transport.TLSClientConfig = tlsConfig + + c = &http.Client{ + Transport: transport, + } } - c := &http.Client{ - Transport: transport, + res, err := c.Do(req) + if err != nil { + return nil, err } - if timeout > 0 { - c.Timeout = timeout + + if cancel == nil { + return res, nil } - return c.Do(req) + // do cancel() when close the body. + res.Body = newWithFuncReadCloser(res.Body, cancel) + return res, nil } // HTTPStatusOk reports whether the http response code is 200. @@ -502,3 +551,22 @@ func RegisterProtocolOnTransport(tr *http.Transport) { func GetValidURLSchemas() string { return validURLSchemas } + +func newWithFuncReadCloser(rc io.ReadCloser, f func()) io.ReadCloser { + return &withFuncReadCloser{ + f: f, + ReadCloser: rc, + } +} + +type withFuncReadCloser struct { + f func() + io.ReadCloser +} + +func (wrc *withFuncReadCloser) Close() error { + if wrc.f != nil { + wrc.f() + } + return wrc.ReadCloser.Close() +} diff --git a/pkg/httputils/http_util_test.go b/pkg/httputils/http_util_test.go index 21731f1b1..6235c9b16 100644 --- a/pkg/httputils/http_util_test.go +++ b/pkg/httputils/http_util_test.go @@ -17,12 +17,15 @@ package httputils import ( + "context" "crypto/tls" "encoding/json" "fmt" + "io/ioutil" "math/rand" "net" "net/http" + "strings" "sync" "testing" "time" @@ -104,6 +107,30 @@ func (s *HTTPUtilTestSuite) TestHTTPStatusOk(c *check.C) { } } +func (s *HTTPUtilTestSuite) TestHttpGet(c *check.C) { + res, e := HTTPGetTimeout("http://"+s.host, nil, 0) + c.Assert(e, check.IsNil) + code := res.StatusCode + body, e := ioutil.ReadAll(res.Body) + c.Assert(e, check.IsNil) + res.Body.Close() + + checkOk(c, code, body, e, 0) + + res, e = HTTPGetTimeout("http://"+s.host, nil, 60*time.Millisecond) + c.Assert(e, check.IsNil) + code = res.StatusCode + body, e = ioutil.ReadAll(res.Body) + c.Assert(e, check.IsNil) + res.Body.Close() + + checkOk(c, code, body, e, 0) + + _, e = HTTPGetTimeout("http://"+s.host, nil, 20*time.Millisecond) + c.Assert(e, check.NotNil) + c.Assert(strings.Contains(e.Error(), context.DeadlineExceeded.Error()), check.Equals, true) +} + func (s *HTTPUtilTestSuite) TestParseQuery(c *check.C) { type req struct { A int `request:"a"`