diff --git a/http_proxy.go b/http_proxy.go index 7564844a..a8641859 100644 --- a/http_proxy.go +++ b/http_proxy.go @@ -250,6 +250,7 @@ func (hp *HTTPProxy) configureProxy() error { hp.proxy.WithoutWarning = true hp.proxy.ErrorResponse = hp.errorResponse hp.proxy.IdleTimeout = hp.config.IdleTimeout + hp.proxy.TLSHandshakeTimeout = hp.config.TLSServerConfig.HandshakeTimeout hp.proxy.ReadTimeout = hp.config.ReadTimeout hp.proxy.ReadHeaderTimeout = hp.config.ReadHeaderTimeout hp.proxy.WriteTimeout = hp.config.WriteTimeout diff --git a/internal/martian/proxy.go b/internal/martian/proxy.go index 86a4c9a0..5682380a 100644 --- a/internal/martian/proxy.go +++ b/internal/martian/proxy.go @@ -83,6 +83,11 @@ type Proxy struct { // If both are zero, there is no timeout. IdleTimeout time.Duration + // TLSHandshakeTimeout is the maximum amount of time to wait for a TLS handshake. + // The proxy will try to cast accepted connections to tls.Conn and perform a handshake. + // If TLSHandshakeTimeout is zero, no timeout is set. + TLSHandshakeTimeout time.Duration + // ReadTimeout is the maximum duration for reading the entire // request, including the body. A zero or negative value means // there will be no timeout. @@ -257,6 +262,11 @@ func (p *Proxy) handleLoop(conn net.Conn) { pc := newProxyConn(p, conn) + if err := pc.maybeHandshakeTLS(); err != nil { + log.Errorf(context.TODO(), "failed to do TLS handshake: %v", err) + return + } + const maxConsecutiveErrors = 5 errorsN := 0 for { diff --git a/internal/martian/proxy_conn.go b/internal/martian/proxy_conn.go index 232909df..2a603152 100644 --- a/internal/martian/proxy_conn.go +++ b/internal/martian/proxy_conn.go @@ -41,18 +41,33 @@ type proxyConn struct { } func newProxyConn(p *Proxy, conn net.Conn) *proxyConn { - v := &proxyConn{ + return &proxyConn{ Proxy: p, brw: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), conn: conn, } +} + +func (p *proxyConn) maybeHandshakeTLS() error { + tconn, ok := p.conn.(*tls.Conn) + if !ok { + return nil + } - if tconn, ok := conn.(*tls.Conn); ok { - v.secure = true - v.cs = tconn.ConnectionState() + ctx := context.Background() + if p.TLSHandshakeTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), p.TLSHandshakeTimeout) + defer cancel() + } + if err := tconn.HandshakeContext(ctx); err != nil { + return err } - return v + p.secure = true + p.cs = tconn.ConnectionState() + + return nil } func (p *proxyConn) readRequest() (*http.Request, error) { diff --git a/internal/martian/proxy_test.go b/internal/martian/proxy_test.go index d51d5c86..ba3bc7da 100644 --- a/internal/martian/proxy_test.go +++ b/internal/martian/proxy_test.go @@ -1838,6 +1838,37 @@ func TestIdleTimeout(t *testing.T) { } } +func TestTLSHandshakeTimeout(t *testing.T) { + t.Parallel() + + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("net.Listen(): got %v, want no error", err) + } + _, mc := certs(t) + l = tls.NewListener(l, mc.TLS(context.Background())) + + h := testHelper{ + Listener: l, + Proxy: func(p *Proxy) { + p.TLSHandshakeTimeout = 100 * time.Millisecond + }, + } + + c, cancel := h.proxyClient(t) + defer cancel() + + conn, err := net.Dial("tcp", c.Addr) + if err != nil { + t.Fatalf("net.Dial(): got %v, want no error", err) + } + + time.Sleep(200 * time.Millisecond) + if _, err := conn.Read(make([]byte, 1)); !errors.Is(err, io.EOF) { + t.Fatalf("conn.Read(): got %v, want io.EOF", err) + } +} + func TestReadHeaderTimeout(t *testing.T) { t.Parallel()