Skip to content

Commit

Permalink
martian: export roundTripper dial and proxyURL and remove setters
Browse files Browse the repository at this point in the history
Proxy can be now configured using exported variables only.
This is copied from http.Server, and we use initOnce pattern instead of constructor.

This allows to avoid multiple overwrites of transport fields and side effects caused by invalid setter call order.
  • Loading branch information
mmatczuk committed Feb 13, 2024
1 parent ad48766 commit 9cc9854
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 131 deletions.
16 changes: 3 additions & 13 deletions http_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ func (hp *HTTPProxy) configureHTTPS() error {
}

func (hp *HTTPProxy) configureProxy() error {
hp.proxy = martian.NewProxy()
hp.proxy = new(martian.Proxy)

if hp.config.MITM != nil {
mc, err := newMartianMITMConfig(hp.config.MITM)
Expand Down Expand Up @@ -264,17 +264,7 @@ func (hp *HTTPProxy) configureProxy() error {
hp.proxy.ReadTimeout = hp.config.ReadTimeout
hp.proxy.ReadHeaderTimeout = hp.config.ReadHeaderTimeout
hp.proxy.WriteTimeout = hp.config.WriteTimeout
// Martian has an intertwined logic for setting http.Transport and the dialer.
// The dialer is wrapped, so that additional syscalls are made to the dialed connections.
// As a result the dialer needs to be reset.
if tr, ok := hp.transport.(*http.Transport); ok {
// Note: The order matters. DialContext needs to be set first.
// SetRoundTripper overwrites tr.DialContext with hp.proxy.dial.
hp.proxy.SetDialContext(tr.DialContext)
hp.proxy.SetRoundTripper(tr)
} else {
hp.proxy.SetRoundTripper(hp.transport)
}
hp.proxy.RoundTripper = hp.transport

switch {
case hp.config.UpstreamProxyFunc != nil:
Expand All @@ -299,7 +289,7 @@ func (hp *HTTPProxy) configureProxy() error {
if hp.config.ProxyLocalhost == DirectProxyLocalhost {
hp.proxyFunc = hp.directLocalhost(hp.proxyFunc)
}
hp.proxy.SetUpstreamProxyFunc(hp.proxyFunc)
hp.proxy.ProxyURL = hp.proxyFunc

mw := hp.middlewareStack()
hp.proxy.RequestModifier = mw
Expand Down
2 changes: 1 addition & 1 deletion http_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func TestNopDialer(t *testing.T) {
},
Host: "foobar",
}
_, err = p.proxy.GetRoundTripper().RoundTrip(req)
_, err = p.proxy.RoundTripper.RoundTrip(req)
if !errors.Is(err, nopDialerErr) {
t.Fatalf("expected %v, got %v", nopDialerErr, err)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/martian/h2/testing/fixture.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func (f *Fixture) Close() error {
}

func newProxy(spf []h2.StreamProcessorFactory) (*martian.Proxy, error) {
p := martian.NewProxy()
p := new(martian.Proxy)
mc, err := mitm.NewConfig(CA, CAKey)
if err != nil {
return nil, fmt.Errorf("creating mitm config: %w", err)
Expand All @@ -167,7 +167,7 @@ func newProxy(spf []h2.StreamProcessorFactory) (*martian.Proxy, error) {
RootCAs: RootCAs,
},
}
p.SetRoundTripper(tr)
p.RoundTripper = tr

return p, nil
}
Expand Down
111 changes: 52 additions & 59 deletions internal/martian/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ type Proxy struct {
RequestModifier
ResponseModifier

// RoundTripper specifies the round tripper to use for requests.
RoundTripper http.RoundTripper

// DialContext specifies the dial function for creating unencrypted TCP connections.
// If not set and the RoundTripper is an *http.Transport, the Transport's DialContext is used.
DialContext func(context.Context, string, string) (net.Conn, error)

// ProxyURL specifies the upstream proxy to use for requests.
// If not set and the RoundTripper is an *http.Transport, the Transport's ProxyURL is used.
ProxyURL func(*http.Request) (*url.URL, error)

// AllowHTTP disables automatic HTTP to HTTPS upgrades when the listener is TLS.
AllowHTTP bool

Expand Down Expand Up @@ -107,81 +118,61 @@ type Proxy struct {
// TestingSkipRoundTrip skips the round trip for requests and returns a 200 OK response.
TestingSkipRoundTrip bool

roundTripper http.RoundTripper
dial func(context.Context, string, string) (net.Conn, error)
initOnce sync.Once

proxyURL func(*http.Request) (*url.URL, error)
conns sync.WaitGroup
connsMu sync.Mutex // protects conns.Add/Wait from concurrent access
closing chan bool
closeOnce sync.Once
}

// NewProxy returns a new HTTP proxy.
func NewProxy() *Proxy {
proxy := &Proxy{
roundTripper: &http.Transport{
// TODO(adamtanner): This forces the http.Transport to not upgrade requests
// to HTTP/2 in Go 1.6+. Remove this once Martian can support HTTP/2.
TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper),
Proxy: http.ProxyFromEnvironment,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: time.Second,
},
closing: make(chan bool),

BaseContex: context.Background(),
}
proxy.SetDialContext((&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext)
return proxy
}

// GetRoundTripper gets the http.RoundTripper of the proxy.
func (p *Proxy) GetRoundTripper() http.RoundTripper {
return p.roundTripper
}

// SetRoundTripper sets the http.RoundTripper of the proxy.
func (p *Proxy) SetRoundTripper(rt http.RoundTripper) {
p.roundTripper = rt

if tr, ok := p.roundTripper.(*http.Transport); ok {
tr.TLSNextProto = make(map[string]func(string, *tls.Conn) http.RoundTripper)
tr.Proxy = p.proxyURL
tr.DialContext = p.dial
}
}

// SetUpstreamProxy sets the proxy that receives requests from this proxy.
func (p *Proxy) SetUpstreamProxy(proxyURL *url.URL) {
p.SetUpstreamProxyFunc(http.ProxyURL(proxyURL))
}
func (p *Proxy) init() {
p.initOnce.Do(func() {
if p.RoundTripper == nil {
p.RoundTripper = &http.Transport{
// TODO(adamtanner): This forces the http.Transport to not upgrade requests
// to HTTP/2 in Go 1.6+. Remove this once Martian can support HTTP/2.
TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper),
Proxy: http.ProxyFromEnvironment,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: time.Second,
}
}

// SetUpstreamProxyFunc sets proxy function as in http.Transport.Proxy.
func (p *Proxy) SetUpstreamProxyFunc(f func(*http.Request) (*url.URL, error)) {
p.proxyURL = f
if t, ok := p.RoundTripper.(*http.Transport); ok {
if p.DialContext == nil {
p.DialContext = t.DialContext
} else {
t.DialContext = p.DialContext
}
if p.ProxyURL == nil {
p.ProxyURL = t.Proxy
} else {
t.Proxy = p.ProxyURL
}
}

if tr, ok := p.roundTripper.(*http.Transport); ok {
tr.Proxy = f
}
}
if p.DialContext == nil {
p.DialContext = (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext
}

// SetDialContext sets the dial func used to establish a connection.
func (p *Proxy) SetDialContext(dial func(context.Context, string, string) (net.Conn, error)) {
p.dial = dial
if p.BaseContex == nil {
p.BaseContex = context.Background()
}

if tr, ok := p.roundTripper.(*http.Transport); ok {
tr.DialContext = p.dial
}
p.closing = make(chan bool)
})
}

// Close sets the proxy to the closing state so it stops receiving new connections,
// finishes processing any inflight requests, and closes existing connections without
// reading anymore requests from them.
func (p *Proxy) Close() {
p.init()

p.closeOnce.Do(func() {
log.Infof(context.TODO(), "closing down proxy")

Expand Down Expand Up @@ -209,6 +200,8 @@ func (p *Proxy) Closing() bool {
func (p *Proxy) Serve(l net.Listener) error {
defer l.Close()

p.init()

var delay time.Duration
for {
if p.Closing() {
Expand Down Expand Up @@ -335,7 +328,7 @@ func (p *Proxy) roundTrip(req *http.Request) (*http.Response, error) {
return proxyutil.NewResponse(200, http.NoBody, req), nil
}

return p.roundTripper.RoundTrip(req)
return p.RoundTripper.RoundTrip(req)
}

func (p *Proxy) errorResponse(req *http.Request, err error) *http.Response {
Expand Down
14 changes: 7 additions & 7 deletions internal/martian/proxy_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ func (p *Proxy) connect(req *http.Request) (*http.Response, net.Conn, error) {
ctx := req.Context()

var proxyURL *url.URL
if p.proxyURL != nil {
u, err := p.proxyURL(req)
if p.ProxyURL != nil {
u, err := p.ProxyURL(req)
if err != nil {
return nil, nil, err
}
Expand All @@ -54,7 +54,7 @@ func (p *Proxy) connect(req *http.Request) (*http.Response, net.Conn, error) {
if proxyURL == nil {
log.Debugf(ctx, "CONNECT to host directly: %s", req.URL.Host)

conn, err := p.dial(ctx, "tcp", req.URL.Host)
conn, err := p.DialContext(ctx, "tcp", req.URL.Host)
if err != nil {
return nil, nil, err
}
Expand All @@ -79,9 +79,9 @@ func (p *Proxy) connectHTTP(req *http.Request, proxyURL *url.URL) (res *http.Res

var d *dialvia.HTTPProxyDialer
if proxyURL.Scheme == "https" {
d = dialvia.HTTPSProxy(p.dial, proxyURL, p.clientTLSConfig())
d = dialvia.HTTPSProxy(p.DialContext, proxyURL, p.clientTLSConfig())
} else {
d = dialvia.HTTPProxy(p.dial, proxyURL)
d = dialvia.HTTPProxy(p.DialContext, proxyURL)
}
d.ConnectRequestModifier = p.ConnectRequestModifier

Expand All @@ -107,7 +107,7 @@ func (p *Proxy) connectHTTP(req *http.Request, proxyURL *url.URL) (res *http.Res
}

func (p *Proxy) clientTLSConfig() *tls.Config {
if tr, ok := p.roundTripper.(*http.Transport); ok && tr.TLSClientConfig != nil {
if tr, ok := p.RoundTripper.(*http.Transport); ok && tr.TLSClientConfig != nil {
return tr.TLSClientConfig.Clone()
}

Expand All @@ -119,7 +119,7 @@ func (p *Proxy) connectSOCKS5(req *http.Request, proxyURL *url.URL) (*http.Respo

log.Debugf(ctx, "CONNECT with upstream SOCKS5 proxy: %s", proxyURL.Host)

d := dialvia.SOCKS5Proxy(p.dial, proxyURL)
d := dialvia.SOCKS5Proxy(p.DialContext, proxyURL)

conn, err := d.DialContext(ctx, "tcp", req.URL.Host)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions internal/martian/proxy_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ type proxyHandler struct {

// Handler returns proxy as http.Handler, see [proxyHandler] for details.
func (p *Proxy) Handler() http.Handler {
p.init()
return proxyHandler{p}
}

Expand Down
Loading

0 comments on commit 9cc9854

Please sign in to comment.