Skip to content

Commit

Permalink
martian: add support for TLSHandshakeTimeout
Browse files Browse the repository at this point in the history
Martian Proxy will explicitly do the handshake if accepted connection is tls.Conn.
  • Loading branch information
Choraden authored and mmatczuk committed Oct 16, 2024
1 parent 7d61fcc commit af3e220
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 5 deletions.
1 change: 1 addition & 0 deletions http_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ func (hp *HTTPProxy) configureProxy() error {
hp.proxy.WithoutWarning = true
hp.proxy.ErrorResponse = hp.errorResponse
hp.proxy.IdleTimeout = hp.config.IdleTimeout
hp.proxy.TLSHandshakeTimeout = hp.config.TLSServerConfig.HandshakeTimeout
hp.proxy.ReadTimeout = hp.config.ReadTimeout
hp.proxy.ReadHeaderTimeout = hp.config.ReadHeaderTimeout
hp.proxy.WriteTimeout = hp.config.WriteTimeout
Expand Down
10 changes: 10 additions & 0 deletions internal/martian/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ type Proxy struct {
// If both are zero, there is no timeout.
IdleTimeout time.Duration

// TLSHandshakeTimeout is the maximum amount of time to wait for a TLS handshake.
// The proxy will try to cast accepted connections to tls.Conn and perform a handshake.
// If TLSHandshakeTimeout is zero, no timeout is set.
TLSHandshakeTimeout time.Duration

// ReadTimeout is the maximum duration for reading the entire
// request, including the body. A zero or negative value means
// there will be no timeout.
Expand Down Expand Up @@ -257,6 +262,11 @@ func (p *Proxy) handleLoop(conn net.Conn) {

pc := newProxyConn(p, conn)

if err := pc.maybeHandshakeTLS(); err != nil {
log.Errorf(context.TODO(), "failed to do TLS handshake: %v", err)
return
}

const maxConsecutiveErrors = 5
errorsN := 0
for {
Expand Down
25 changes: 20 additions & 5 deletions internal/martian/proxy_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,33 @@ type proxyConn struct {
}

func newProxyConn(p *Proxy, conn net.Conn) *proxyConn {
v := &proxyConn{
return &proxyConn{
Proxy: p,
brw: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
conn: conn,
}
}

func (p *proxyConn) maybeHandshakeTLS() error {
tconn, ok := p.conn.(*tls.Conn)
if !ok {
return nil
}

if tconn, ok := conn.(*tls.Conn); ok {
v.secure = true
v.cs = tconn.ConnectionState()
ctx := context.Background()
if p.TLSHandshakeTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), p.TLSHandshakeTimeout)
defer cancel()
}
if err := tconn.HandshakeContext(ctx); err != nil {
return err
}

return v
p.secure = true
p.cs = tconn.ConnectionState()

return nil
}

func (p *proxyConn) readRequest() (*http.Request, error) {
Expand Down
31 changes: 31 additions & 0 deletions internal/martian/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1838,6 +1838,37 @@ func TestIdleTimeout(t *testing.T) {
}
}

func TestTLSHandshakeTimeout(t *testing.T) {
t.Parallel()

l, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("net.Listen(): got %v, want no error", err)
}
_, mc := certs(t)
l = tls.NewListener(l, mc.TLS(context.Background()))

h := testHelper{
Listener: l,
Proxy: func(p *Proxy) {
p.TLSHandshakeTimeout = 100 * time.Millisecond
},
}

c, cancel := h.proxyClient(t)
defer cancel()

conn, err := net.Dial("tcp", c.Addr)
if err != nil {
t.Fatalf("net.Dial(): got %v, want no error", err)
}

time.Sleep(200 * time.Millisecond)
if _, err := conn.Read(make([]byte, 1)); !errors.Is(err, io.EOF) {
t.Fatalf("conn.Read(): got %v, want io.EOF", err)
}
}

func TestReadHeaderTimeout(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit af3e220

Please sign in to comment.