diff --git a/http_proxy.go b/http_proxy.go index 32636cd1..c0640439 100644 --- a/http_proxy.go +++ b/http_proxy.go @@ -229,7 +229,7 @@ func (hp *HTTPProxy) configureHTTPS() error { } func (hp *HTTPProxy) configureProxy() error { - hp.proxy = martian.NewProxy() + hp.proxy = new(martian.Proxy) if hp.config.MITM != nil { mc, err := newMartianMITMConfig(hp.config.MITM) @@ -264,17 +264,7 @@ func (hp *HTTPProxy) configureProxy() error { hp.proxy.ReadTimeout = hp.config.ReadTimeout hp.proxy.ReadHeaderTimeout = hp.config.ReadHeaderTimeout hp.proxy.WriteTimeout = hp.config.WriteTimeout - // Martian has an intertwined logic for setting http.Transport and the dialer. - // The dialer is wrapped, so that additional syscalls are made to the dialed connections. - // As a result the dialer needs to be reset. - if tr, ok := hp.transport.(*http.Transport); ok { - // Note: The order matters. DialContext needs to be set first. - // SetRoundTripper overwrites tr.DialContext with hp.proxy.dial. - hp.proxy.SetDialContext(tr.DialContext) - hp.proxy.SetRoundTripper(tr) - } else { - hp.proxy.SetRoundTripper(hp.transport) - } + hp.proxy.RoundTripper = hp.transport switch { case hp.config.UpstreamProxyFunc != nil: @@ -299,7 +289,7 @@ func (hp *HTTPProxy) configureProxy() error { if hp.config.ProxyLocalhost == DirectProxyLocalhost { hp.proxyFunc = hp.directLocalhost(hp.proxyFunc) } - hp.proxy.SetUpstreamProxyFunc(hp.proxyFunc) + hp.proxy.ProxyURL = hp.proxyFunc mw := hp.middlewareStack() hp.proxy.RequestModifier = mw diff --git a/http_proxy_test.go b/http_proxy_test.go index d1805fa7..a5112d38 100644 --- a/http_proxy_test.go +++ b/http_proxy_test.go @@ -104,7 +104,7 @@ func TestNopDialer(t *testing.T) { }, Host: "foobar", } - _, err = p.proxy.GetRoundTripper().RoundTrip(req) + _, err = p.proxy.RoundTripper.RoundTrip(req) if !errors.Is(err, nopDialerErr) { t.Fatalf("expected %v, got %v", nopDialerErr, err) } diff --git a/internal/martian/h2/testing/fixture.go b/internal/martian/h2/testing/fixture.go index 9ff6f4e8..464a4925 100644 --- a/internal/martian/h2/testing/fixture.go +++ b/internal/martian/h2/testing/fixture.go @@ -147,7 +147,7 @@ func (f *Fixture) Close() error { } func newProxy(spf []h2.StreamProcessorFactory) (*martian.Proxy, error) { - p := martian.NewProxy() + p := new(martian.Proxy) mc, err := mitm.NewConfig(CA, CAKey) if err != nil { return nil, fmt.Errorf("creating mitm config: %w", err) @@ -167,7 +167,7 @@ func newProxy(spf []h2.StreamProcessorFactory) (*martian.Proxy, error) { RootCAs: RootCAs, }, } - p.SetRoundTripper(tr) + p.RoundTripper = tr return p, nil } diff --git a/internal/martian/proxy.go b/internal/martian/proxy.go index 2e2e075d..67394b68 100644 --- a/internal/martian/proxy.go +++ b/internal/martian/proxy.go @@ -36,6 +36,17 @@ type Proxy struct { RequestModifier ResponseModifier + // RoundTripper specifies the round tripper to use for requests. + RoundTripper http.RoundTripper + + // DialContext specifies the dial function for creating unencrypted TCP connections. + // If not set and the RoundTripper is an *http.Transport, the Transport's DialContext is used. + DialContext func(context.Context, string, string) (net.Conn, error) + + // ProxyURL specifies the upstream proxy to use for requests. + // If not set and the RoundTripper is an *http.Transport, the Transport's ProxyURL is used. + ProxyURL func(*http.Request) (*url.URL, error) + // AllowHTTP disables automatic HTTP to HTTPS upgrades when the listener is TLS. AllowHTTP bool @@ -107,81 +118,61 @@ type Proxy struct { // TestingSkipRoundTrip skips the round trip for requests and returns a 200 OK response. TestingSkipRoundTrip bool - roundTripper http.RoundTripper - dial func(context.Context, string, string) (net.Conn, error) + initOnce sync.Once - proxyURL func(*http.Request) (*url.URL, error) conns sync.WaitGroup connsMu sync.Mutex // protects conns.Add/Wait from concurrent access closing chan bool closeOnce sync.Once } -// NewProxy returns a new HTTP proxy. -func NewProxy() *Proxy { - proxy := &Proxy{ - roundTripper: &http.Transport{ - // TODO(adamtanner): This forces the http.Transport to not upgrade requests - // to HTTP/2 in Go 1.6+. Remove this once Martian can support HTTP/2. - TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper), - Proxy: http.ProxyFromEnvironment, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: time.Second, - }, - closing: make(chan bool), - - BaseContex: context.Background(), - } - proxy.SetDialContext((&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext) - return proxy -} - -// GetRoundTripper gets the http.RoundTripper of the proxy. -func (p *Proxy) GetRoundTripper() http.RoundTripper { - return p.roundTripper -} - -// SetRoundTripper sets the http.RoundTripper of the proxy. -func (p *Proxy) SetRoundTripper(rt http.RoundTripper) { - p.roundTripper = rt - - if tr, ok := p.roundTripper.(*http.Transport); ok { - tr.TLSNextProto = make(map[string]func(string, *tls.Conn) http.RoundTripper) - tr.Proxy = p.proxyURL - tr.DialContext = p.dial - } -} - -// SetUpstreamProxy sets the proxy that receives requests from this proxy. -func (p *Proxy) SetUpstreamProxy(proxyURL *url.URL) { - p.SetUpstreamProxyFunc(http.ProxyURL(proxyURL)) -} +func (p *Proxy) init() { + p.initOnce.Do(func() { + if p.RoundTripper == nil { + p.RoundTripper = &http.Transport{ + // TODO(adamtanner): This forces the http.Transport to not upgrade requests + // to HTTP/2 in Go 1.6+. Remove this once Martian can support HTTP/2. + TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper), + Proxy: http.ProxyFromEnvironment, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: time.Second, + } + } -// SetUpstreamProxyFunc sets proxy function as in http.Transport.Proxy. -func (p *Proxy) SetUpstreamProxyFunc(f func(*http.Request) (*url.URL, error)) { - p.proxyURL = f + if t, ok := p.RoundTripper.(*http.Transport); ok { + if p.DialContext == nil { + p.DialContext = t.DialContext + } else { + t.DialContext = p.DialContext + } + if p.ProxyURL == nil { + p.ProxyURL = t.Proxy + } else { + t.Proxy = p.ProxyURL + } + } - if tr, ok := p.roundTripper.(*http.Transport); ok { - tr.Proxy = f - } -} + if p.DialContext == nil { + p.DialContext = (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext + } -// SetDialContext sets the dial func used to establish a connection. -func (p *Proxy) SetDialContext(dial func(context.Context, string, string) (net.Conn, error)) { - p.dial = dial + if p.BaseContex == nil { + p.BaseContex = context.Background() + } - if tr, ok := p.roundTripper.(*http.Transport); ok { - tr.DialContext = p.dial - } + p.closing = make(chan bool) + }) } // Close sets the proxy to the closing state so it stops receiving new connections, // finishes processing any inflight requests, and closes existing connections without // reading anymore requests from them. func (p *Proxy) Close() { + p.init() + p.closeOnce.Do(func() { log.Infof(context.TODO(), "closing down proxy") @@ -209,6 +200,8 @@ func (p *Proxy) Closing() bool { func (p *Proxy) Serve(l net.Listener) error { defer l.Close() + p.init() + var delay time.Duration for { if p.Closing() { @@ -335,7 +328,7 @@ func (p *Proxy) roundTrip(req *http.Request) (*http.Response, error) { return proxyutil.NewResponse(200, http.NoBody, req), nil } - return p.roundTripper.RoundTrip(req) + return p.RoundTripper.RoundTrip(req) } func (p *Proxy) errorResponse(req *http.Request, err error) *http.Response { diff --git a/internal/martian/proxy_connect.go b/internal/martian/proxy_connect.go index fdd05327..65f48a65 100644 --- a/internal/martian/proxy_connect.go +++ b/internal/martian/proxy_connect.go @@ -43,8 +43,8 @@ func (p *Proxy) connect(req *http.Request) (*http.Response, net.Conn, error) { ctx := req.Context() var proxyURL *url.URL - if p.proxyURL != nil { - u, err := p.proxyURL(req) + if p.ProxyURL != nil { + u, err := p.ProxyURL(req) if err != nil { return nil, nil, err } @@ -54,7 +54,7 @@ func (p *Proxy) connect(req *http.Request) (*http.Response, net.Conn, error) { if proxyURL == nil { log.Debugf(ctx, "CONNECT to host directly: %s", req.URL.Host) - conn, err := p.dial(ctx, "tcp", req.URL.Host) + conn, err := p.DialContext(ctx, "tcp", req.URL.Host) if err != nil { return nil, nil, err } @@ -79,9 +79,9 @@ func (p *Proxy) connectHTTP(req *http.Request, proxyURL *url.URL) (res *http.Res var d *dialvia.HTTPProxyDialer if proxyURL.Scheme == "https" { - d = dialvia.HTTPSProxy(p.dial, proxyURL, p.clientTLSConfig()) + d = dialvia.HTTPSProxy(p.DialContext, proxyURL, p.clientTLSConfig()) } else { - d = dialvia.HTTPProxy(p.dial, proxyURL) + d = dialvia.HTTPProxy(p.DialContext, proxyURL) } d.ConnectRequestModifier = p.ConnectRequestModifier @@ -107,7 +107,7 @@ func (p *Proxy) connectHTTP(req *http.Request, proxyURL *url.URL) (res *http.Res } func (p *Proxy) clientTLSConfig() *tls.Config { - if tr, ok := p.roundTripper.(*http.Transport); ok && tr.TLSClientConfig != nil { + if tr, ok := p.RoundTripper.(*http.Transport); ok && tr.TLSClientConfig != nil { return tr.TLSClientConfig.Clone() } @@ -119,7 +119,7 @@ func (p *Proxy) connectSOCKS5(req *http.Request, proxyURL *url.URL) (*http.Respo log.Debugf(ctx, "CONNECT with upstream SOCKS5 proxy: %s", proxyURL.Host) - d := dialvia.SOCKS5Proxy(p.dial, proxyURL) + d := dialvia.SOCKS5Proxy(p.DialContext, proxyURL) conn, err := d.DialContext(ctx, "tcp", req.URL.Host) if err != nil { diff --git a/internal/martian/proxy_handler.go b/internal/martian/proxy_handler.go index cfb7d62d..d0d52592 100644 --- a/internal/martian/proxy_handler.go +++ b/internal/martian/proxy_handler.go @@ -81,6 +81,7 @@ type proxyHandler struct { // Handler returns proxy as http.Handler, see [proxyHandler] for details. func (p *Proxy) Handler() http.Handler { + p.init() return proxyHandler{p} } diff --git a/internal/martian/proxy_test.go b/internal/martian/proxy_test.go index 68fb176b..c5d0a281 100644 --- a/internal/martian/proxy_test.go +++ b/internal/martian/proxy_test.go @@ -140,11 +140,11 @@ func TestIntegrationTemporaryTimeout(t *testing.T) { t.Parallel() l := newListener(t) - p := NewProxy() + p := new(Proxy) defer p.Close() tr := martiantest.NewTransport() - p.SetRoundTripper(tr) + p.RoundTripper = tr p.SetTimeout(200 * time.Millisecond) // Start the proxy with a listener that will return a temporary error on @@ -184,11 +184,11 @@ func TestIntegrationHTTP(t *testing.T) { t.Parallel() l := newListener(t) - p := NewProxy() + p := new(Proxy) defer p.Close() tr := martiantest.NewTransport() - p.SetRoundTripper(tr) + p.RoundTripper = tr p.SetTimeout(200 * time.Millisecond) tm := martiantest.NewModifier() @@ -240,7 +240,7 @@ func TestIntegrationHTTP100Continue(t *testing.T) { } l := newListener(t) - p := NewProxy() + p := new(Proxy) if *withTLS { p.AllowHTTP = true } @@ -349,7 +349,7 @@ func TestIntegrationHTTP101SwitchingProtocols(t *testing.T) { t.Parallel() l := newListener(t) - p := NewProxy() + p := new(Proxy) if *withTLS { p.AllowHTTP = true } @@ -462,7 +462,7 @@ func TestIntegrationUnexpectedUpstreamFailure(t *testing.T) { t.Parallel() l := newListener(t) - p := NewProxy() + p := new(Proxy) if *withTLS { p.AllowHTTP = true } @@ -569,12 +569,12 @@ func TestIntegrationHTTPUpstreamProxy(t *testing.T) { t.Fatalf("net.Listen(): got %v, want no error", err) } - upstream := NewProxy() + upstream := new(Proxy) defer upstream.Close() utr := martiantest.NewTransport() utr.Respond(299) - upstream.SetRoundTripper(utr) + upstream.RoundTripper = utr upstream.SetTimeout(600 * time.Millisecond) go upstream.Serve(ul) @@ -582,14 +582,14 @@ func TestIntegrationHTTPUpstreamProxy(t *testing.T) { // Start second proxy, will write to upstream proxy. pl := newListener(t) - proxy := NewProxy() + proxy := new(Proxy) if *withTLS { proxy.AllowHTTP = true } defer proxy.Close() // Set proxy's upstream proxy to the host:port of the first proxy. - proxy.SetUpstreamProxy(&url.URL{ + proxy.ProxyURL = http.ProxyURL(&url.URL{ Host: ul.Addr().String(), }) proxy.SetTimeout(600 * time.Millisecond) @@ -629,11 +629,11 @@ func TestIntegrationHTTPUpstreamProxyError(t *testing.T) { t.Parallel() l := newListener(t) - p := NewProxy() + p := new(Proxy) defer p.Close() // Set proxy's upstream proxy to invalid host:port to force failure. - p.SetUpstreamProxy(&url.URL{ + p.ProxyURL = http.ProxyURL(&url.URL{ Host: "[::]:0", }) p.SetTimeout(600 * time.Millisecond) @@ -682,7 +682,7 @@ func TestIntegrationTLSHandshakeErrorCallback(t *testing.T) { t.Parallel() l := newListener(t) - p := NewProxy() + p := new(Proxy) defer p.Close() // Test TLS server. @@ -774,7 +774,7 @@ func TestIntegrationConnect(t *testing.T) { //nolint:tparallel // Subtests share t.Parallel() l := newListener(t) - p := NewProxy() + p := new(Proxy) defer p.Close() // Test TLS server. @@ -948,12 +948,12 @@ func TestIntegrationConnectUpstreamProxy(t *testing.T) { t.Fatalf("net.Listen(): got %v, want no error", err) } - upstream := NewProxy() + upstream := new(Proxy) defer upstream.Close() utr := martiantest.NewTransport() utr.Respond(299) - upstream.SetRoundTripper(utr) + upstream.RoundTripper = utr ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour) if err != nil { @@ -971,11 +971,11 @@ func TestIntegrationConnectUpstreamProxy(t *testing.T) { // Start second proxy, will CONNECT to upstream proxy. pl := newListener(t) - proxy := NewProxy() + proxy := new(Proxy) defer proxy.Close() // Set proxy's upstream proxy to the host:port of the first proxy. - proxy.SetUpstreamProxy(&url.URL{ + proxy.ProxyURL = http.ProxyURL(&url.URL{ Scheme: "http", Host: ul.Addr().String(), }) @@ -1066,7 +1066,7 @@ func TestIntegrationConnectFunc(t *testing.T) { t.Parallel() l := newListener(t) - p := NewProxy() + p := new(Proxy) p.ConnectFunc = func(req *http.Request) (*http.Response, io.ReadWriteCloser, error) { pr, pw := io.Pipe() return proxyutil.NewResponse(200, nil, req), pipeConn{pr, pw}, nil @@ -1123,7 +1123,7 @@ func TestIntegrationConnectTerminateTLS(t *testing.T) { t.Parallel() l := newListener(t) - p := NewProxy() + p := new(Proxy) defer p.Close() // Test TLS server. @@ -1145,7 +1145,7 @@ func TestIntegrationConnectTerminateTLS(t *testing.T) { ServerName: "example.com", RootCAs: roots, } - p.SetRoundTripper(rt) + p.RoundTripper = rt tl, err := net.Listen("tcp", "[::]:0") if err != nil { @@ -1240,7 +1240,7 @@ func TestIntegrationMITM(t *testing.T) { } l := newListener(t) - p := NewProxy() + p := new(Proxy) defer p.Close() tr := martiantest.NewTransport() @@ -1251,7 +1251,7 @@ func TestIntegrationMITM(t *testing.T) { return res, nil }) - p.SetRoundTripper(tr) + p.RoundTripper = tr p.SetTimeout(600 * time.Millisecond) ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour) @@ -1336,15 +1336,11 @@ func TestIntegrationTransparentHTTP(t *testing.T) { t.Parallel() l := newListener(t) - p := NewProxy() + p := new(Proxy) defer p.Close() tr := martiantest.NewTransport() - p.SetRoundTripper(tr) - - if got, want := p.GetRoundTripper(), tr; got != want { - t.Errorf("proxy.GetRoundTripper: got %v, want %v", got, want) - } + p.RoundTripper = tr p.SetTimeout(200 * time.Millisecond) @@ -1416,7 +1412,7 @@ func TestIntegrationTransparentMITM(t *testing.T) { } l = tls.NewListener(l, mc.TLS()) - p := NewProxy() + p := new(Proxy) defer p.Close() tr := martiantest.NewTransport() @@ -1427,7 +1423,7 @@ func TestIntegrationTransparentMITM(t *testing.T) { return res, nil }) - p.SetRoundTripper(tr) + p.RoundTripper = tr tm := martiantest.NewModifier() p.RequestModifier = tm @@ -1487,13 +1483,13 @@ func TestIntegrationFailedRoundTrip(t *testing.T) { t.Parallel() l := newListener(t) - p := NewProxy() + p := new(Proxy) defer p.Close() tr := martiantest.NewTransport() trerr := errors.New("round trip error") tr.RespondError(trerr) - p.SetRoundTripper(tr) + p.RoundTripper = tr p.SetTimeout(200 * time.Millisecond) go serve(p, l) @@ -1535,14 +1531,14 @@ func TestIntegrationSkipRoundTrip(t *testing.T) { t.Parallel() l := newListener(t) - p := NewProxy() + p := new(Proxy) p.TestingSkipRoundTrip = true defer p.Close() // Transport will be skipped, no 500. tr := martiantest.NewTransport() tr.Respond(500) - p.SetRoundTripper(tr) + p.RoundTripper = tr p.SetTimeout(200 * time.Millisecond) tm := martiantest.NewModifier() @@ -1583,7 +1579,7 @@ func TestHTTPThroughConnectWithMITM(t *testing.T) { t.Parallel() l := newListener(t) - p := NewProxy() + p := new(Proxy) p.TestingSkipRoundTrip = true defer p.Close() @@ -1685,7 +1681,7 @@ func TestTLSHandshakeTimeoutWithMITM(t *testing.T) { t.Parallel() l := newListener(t) - p := NewProxy() + p := new(Proxy) p.MITMTLSHandshakeTimeout = 200 * time.Millisecond p.TestingSkipRoundTrip = true defer p.Close() @@ -1801,7 +1797,7 @@ func TestServerClosesConnection(t *testing.T) { if err != nil { t.Fatalf("mitm.NewConfig(): got %v, want no error", err) } - p := NewProxy() + p := new(Proxy) p.MITMConfig = mc defer p.Close() @@ -1863,7 +1859,7 @@ func TestRacyClose(t *testing.T) { } defer l.Close() // to make p.Serve exit - p := NewProxy() + p := new(Proxy) go serve(p, l) defer p.Close() @@ -1884,11 +1880,11 @@ func TestIdleTimeout(t *testing.T) { t.Parallel() l := newListener(t) - p := NewProxy() + p := new(Proxy) defer p.Close() tr := martiantest.NewTransport() - p.SetRoundTripper(tr) + p.RoundTripper = tr // Reset read and write timeouts. p.SetTimeout(0) @@ -1912,11 +1908,11 @@ func TestReadHeaderTimeout(t *testing.T) { t.Parallel() l := newListener(t) - p := NewProxy() + p := new(Proxy) defer p.Close() tr := martiantest.NewTransport() - p.SetRoundTripper(tr) + p.RoundTripper = tr // Reset read and write timeouts. p.SetTimeout(0) @@ -1946,11 +1942,11 @@ func TestReadHeaderConnectionReset(t *testing.T) { t.Parallel() l := newListener(t) - p := NewProxy() + p := new(Proxy) defer p.Close() tr := martiantest.NewTransport() - p.SetRoundTripper(tr) + p.RoundTripper = tr // Reset read and write timeouts. p.SetTimeout(0) @@ -1977,11 +1973,11 @@ func TestConnectRequestModifier(t *testing.T) { t.Parallel() l := newListener(t) - p := NewProxy() + p := new(Proxy) defer p.Close() tr := martiantest.NewTransport() - p.SetRoundTripper(tr) + p.RoundTripper = tr headerName, headerValue := "X-Request-ID", "12345" p.ConnectRequestModifier = func(req *http.Request) error { @@ -2003,7 +1999,7 @@ func TestConnectRequestModifier(t *testing.T) { } })) - p.SetUpstreamProxy(&url.URL{ + p.ProxyURL = http.ProxyURL(&url.URL{ Scheme: "http", Host: tl.Addr().String(), })