Skip to content

Commit

Permalink
sqlproxyccl: add proxy protocol listener
Browse files Browse the repository at this point in the history
Resolves #117240

Release note: None
  • Loading branch information
DuskEagle committed Jan 2, 2024
1 parent e218b13 commit f8254c5
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 5 deletions.
1 change: 1 addition & 0 deletions pkg/ccl/cliccl/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ func init() {
cliflagcfg.StringFlag(f, &proxyContext.Denylist, cliflags.DenyList)
cliflagcfg.StringFlag(f, &proxyContext.Allowlist, cliflags.AllowList)
cliflagcfg.StringFlag(f, &proxyContext.ListenAddr, cliflags.ProxyListenAddr)
cliflagcfg.StringFlag(f, &proxyContext.ProxyProtocolListenAddr, cliflags.ProxyProtocolListenAddr)
cliflagcfg.StringFlag(f, &proxyContext.ListenCert, cliflags.ListenCert)
cliflagcfg.StringFlag(f, &proxyContext.ListenKey, cliflags.ListenKey)
cliflagcfg.StringFlag(f, &proxyContext.MetricsAddress, cliflags.ListenMetrics)
Expand Down
21 changes: 20 additions & 1 deletion pkg/ccl/cliccl/mt_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ func runStartSQLProxy(cmd *cobra.Command, args []string) (returnErr error) {
return err
}

var proxyProtocolLn net.Listener
if proxyContext.ProxyProtocolListenAddr != "" {
proxyProtocolLn, err = net.Listen("tcp", proxyContext.ProxyProtocolListenAddr)
if err != nil {
return err
}
}

metricsLn, err := net.Listen("tcp", proxyContext.MetricsAddress)
if err != nil {
return err
Expand All @@ -85,13 +93,24 @@ func runStartSQLProxy(cmd *cobra.Command, args []string) (returnErr error) {

if err := stopper.RunAsyncTask(ctx, "serve-proxy", func(ctx context.Context) {
log.Infof(ctx, "proxy server listening at %s", proxyLn.Addr())
if err := server.Serve(ctx, proxyLn); err != nil {
if err := server.Serve(ctx, proxyLn, proxyContext.RequireProxyProtocol); err != nil {
errChan <- err
}
}); err != nil {
return err
}

if proxyProtocolLn != nil {
if err := stopper.RunAsyncTask(ctx, "serve-proxy-with-required-proxy-headers", func(ctx context.Context) {
log.Infof(ctx, "proxy with required proxy headers server listening at %s", proxyProtocolLn.Addr())
if err := server.Serve(ctx, proxyProtocolLn, true /* requireProxyProtocol */); err != nil {
errChan <- err
}
}); err != nil {
return err
}
}

return waitForSignals(ctx, server, stopper, proxyLn, errChan)
}

Expand Down
10 changes: 8 additions & 2 deletions pkg/ccl/sqlproxyccl/proxy_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ type ProxyOptions struct {
Denylist string
// ListenAddr is the listen address for incoming connections.
ListenAddr string
// ProxyProtocolListenAddr is the optional listen address for incoming
// connections for which it will be enforced that the connections have proxy
// headers set.
ProxyProtocolListenAddr string
// ListenCert is the file containing PEM-encoded x509 certificate for listen
// address. Set to "*" to auto-generate self-signed cert.
ListenCert string
Expand Down Expand Up @@ -113,8 +117,10 @@ type ProxyOptions struct {
DisableConnectionRebalancing bool
// RequireProxyProtocol changes the server's behavior to support the PROXY
// protocol (SQL=required, HTTP=best-effort). With this set to true, the
// PROXY info from upstream will be trusted on both HTTP and SQL, if the
// headers are allowed.
// PROXY info from upstream will be trusted on both HTTP and SQL (on the
// ListenAddr port), if the headers are allowed. The ProxyProtocolListenAddr
// port, if specified, will require the proxy protocol regardless of
// RequireProxyProtocol.
RequireProxyProtocol bool

// testingKnobs are knobs used for testing.
Expand Down
4 changes: 2 additions & 2 deletions pkg/ccl/sqlproxyccl/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ func (s *Server) ServeHTTP(ctx context.Context, ln net.Listener) error {
// Serve serves a listener according to the Options given in NewServer().
// Incoming client connections are taken through the Postgres handshake and
// relayed to the configured backend server.
func (s *Server) Serve(ctx context.Context, ln net.Listener) error {
if s.handler.RequireProxyProtocol {
func (s *Server) Serve(ctx context.Context, ln net.Listener, requireProxyProtocol bool) error {
if requireProxyProtocol {
ln = &proxyproto.Listener{
Listener: ln,
Policy: func(upstream net.Addr) (proxyproto.Policy, error) {
Expand Down
5 changes: 5 additions & 0 deletions pkg/cli/cliflags/flags_mt.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ wait for the tenant id to be fully written to the file (with a newline character
Description: "Listen address for incoming connections.",
}

ProxyProtocolListenAddr = FlagInfo{
Name: "proxy-protocol-listen-addr",
Description: "Listen address for incoming connections which require proxy protocol headers.",
}

ThrottleBaseDelay = FlagInfo{
Name: "throttle-base-delay",
Description: "Initial value for the exponential backoff used to throttle connection attempts.",
Expand Down

0 comments on commit f8254c5

Please sign in to comment.