Skip to content

Commit

Permalink
keep listener trying to accept connections and don't error on invalid…
Browse files Browse the repository at this point in the history
… upstream
  • Loading branch information
Peter Wilson authored and pires committed Oct 8, 2024
1 parent cd8a402 commit e6823d9
Showing 1 changed file with 59 additions and 43 deletions.
102 changes: 59 additions & 43 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,26 @@ package proxyproto

import (
"bufio"
"errors"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
)

// DefaultReadHeaderTimeout is how long header processing waits for header to
// be read from the wire, if Listener.ReaderHeaderTimeout is not set.
// It's kept as a global variable so to make it easier to find and override,
// e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s"
var DefaultReadHeaderTimeout = 10 * time.Second
var (
// DefaultReadHeaderTimeout is how long header processing waits for header to
// be read from the wire, if Listener.ReaderHeaderTimeout is not set.
// It's kept as a global variable so to make it easier to find and override,
// e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s"
DefaultReadHeaderTimeout = 10 * time.Second

// ErrInvalidUpstream should be returned when an upstream connection address
// is not trusted, and therefore is invalid.
ErrInvalidUpstream = fmt.Errorf("proxyproto: upstream connection address not trusted for PROXY information")
)

// Listener is used to wrap an underlying listener,
// whose connections may be using the HAProxy Proxy Protocol.
Expand Down Expand Up @@ -73,53 +81,61 @@ func SetReadHeaderTimeout(t time.Duration) func(*Conn) {
}
}

// Accept waits for and returns the next connection to the listener.
// Accept waits for and returns the next valid connection to the listener.
func (p *Listener) Accept() (net.Conn, error) {
// Get the underlying connection
conn, err := p.Listener.Accept()
if err != nil {
return nil, err
}

proxyHeaderPolicy := USE
if p.Policy != nil && p.ConnPolicy != nil {
panic("only one of policy or connpolicy must be provided.")
}
if p.Policy != nil || p.ConnPolicy != nil {
if p.Policy != nil {
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
} else {
proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{
Upstream: conn.RemoteAddr(),
Downstream: conn.LocalAddr(),
})
}
for {
// Get the underlying connection
conn, err := p.Listener.Accept()
if err != nil {
// can't decide the policy, we can't accept the connection
conn.Close()
return nil, err
}
// Handle a connection as a regular one
if proxyHeaderPolicy == SKIP {
return conn, nil

proxyHeaderPolicy := USE
if p.Policy != nil && p.ConnPolicy != nil {
panic("only one of policy or connpolicy must be provided.")
}
}
if p.Policy != nil || p.ConnPolicy != nil {
if p.Policy != nil {
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
} else {
proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{
Upstream: conn.RemoteAddr(),
Downstream: conn.LocalAddr(),
})
}
if err != nil {
// can't decide the policy, we can't accept the connection
conn.Close()

newConn := NewConn(
conn,
WithPolicy(proxyHeaderPolicy),
ValidateHeader(p.ValidateHeader),
)
if errors.Is(err, ErrInvalidUpstream) {
// keep listening for other connections
continue
}

// If the ReadHeaderTimeout for the listener is unset, use the default timeout.
if p.ReadHeaderTimeout == 0 {
p.ReadHeaderTimeout = DefaultReadHeaderTimeout
}
return nil, err
}
// Handle a connection as a regular one
if proxyHeaderPolicy == SKIP {
return conn, nil
}
}

// Set the readHeaderTimeout of the new conn to the value of the listener
newConn.readHeaderTimeout = p.ReadHeaderTimeout
newConn := NewConn(
conn,
WithPolicy(proxyHeaderPolicy),
ValidateHeader(p.ValidateHeader),
)

return newConn, nil
// If the ReadHeaderTimeout for the listener is unset, use the default timeout.
if p.ReadHeaderTimeout == 0 {
p.ReadHeaderTimeout = DefaultReadHeaderTimeout
}

// Set the readHeaderTimeout of the new conn to the value of the listener
newConn.readHeaderTimeout = p.ReadHeaderTimeout

return newConn, nil
}
}

// Close closes the underlying listener.
Expand Down

0 comments on commit e6823d9

Please sign in to comment.