diff --git a/pkg/ccl/cliccl/flags.go b/pkg/ccl/cliccl/flags.go index 0c19bd81e0b5..efa5e3653eff 100644 --- a/pkg/ccl/cliccl/flags.go +++ b/pkg/ccl/cliccl/flags.go @@ -69,6 +69,7 @@ func init() { cliflagcfg.StringFlag(f, &proxyContext.Denylist, cliflags.DenyList) cliflagcfg.StringFlag(f, &proxyContext.Allowlist, cliflags.AllowList) cliflagcfg.StringFlag(f, &proxyContext.ListenAddr, cliflags.ProxyListenAddr) + cliflagcfg.StringFlag(f, &proxyContext.ProxyProtocolListenAddr, cliflags.ProxyProtocolListenAddr) cliflagcfg.StringFlag(f, &proxyContext.ListenCert, cliflags.ListenCert) cliflagcfg.StringFlag(f, &proxyContext.ListenKey, cliflags.ListenKey) cliflagcfg.StringFlag(f, &proxyContext.MetricsAddress, cliflags.ListenMetrics) diff --git a/pkg/ccl/cliccl/mt_proxy.go b/pkg/ccl/cliccl/mt_proxy.go index c6d905b77589..d57bd1b94b49 100644 --- a/pkg/ccl/cliccl/mt_proxy.go +++ b/pkg/ccl/cliccl/mt_proxy.go @@ -61,6 +61,14 @@ func runStartSQLProxy(cmd *cobra.Command, args []string) (returnErr error) { return err } + var proxyProtocolLn net.Listener + if proxyContext.ProxyProtocolListenAddr != "" { + proxyProtocolLn, err = net.Listen("tcp", proxyContext.ProxyProtocolListenAddr) + if err != nil { + return err + } + } + metricsLn, err := net.Listen("tcp", proxyContext.MetricsAddress) if err != nil { return err @@ -85,13 +93,24 @@ func runStartSQLProxy(cmd *cobra.Command, args []string) (returnErr error) { if err := stopper.RunAsyncTask(ctx, "serve-proxy", func(ctx context.Context) { log.Infof(ctx, "proxy server listening at %s", proxyLn.Addr()) - if err := server.Serve(ctx, proxyLn); err != nil { + if err := server.Serve(ctx, proxyLn, proxyContext.RequireProxyProtocol); err != nil { errChan <- err } }); err != nil { return err } + if proxyProtocolLn != nil { + if err := stopper.RunAsyncTask(ctx, "serve-proxy-with-required-proxy-headers", func(ctx context.Context) { + log.Infof(ctx, "proxy with required proxy headers server listening at %s", proxyProtocolLn.Addr()) + if err := server.Serve(ctx, proxyProtocolLn, true /* requireProxyProtocol */); err != nil { + errChan <- err + } + }); err != nil { + return err + } + } + return waitForSignals(ctx, server, stopper, proxyLn, errChan) } diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go index 423e9404f9dc..30d0aedabd5e 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -73,6 +73,10 @@ type ProxyOptions struct { Denylist string // ListenAddr is the listen address for incoming connections. ListenAddr string + // ProxyProtocolListenAddr is the optional listen address for incoming + // connections for which it will be enforced that the connections have proxy + // headers set. + ProxyProtocolListenAddr string // ListenCert is the file containing PEM-encoded x509 certificate for listen // address. Set to "*" to auto-generate self-signed cert. ListenCert string @@ -113,8 +117,10 @@ type ProxyOptions struct { DisableConnectionRebalancing bool // RequireProxyProtocol changes the server's behavior to support the PROXY // protocol (SQL=required, HTTP=best-effort). With this set to true, the - // PROXY info from upstream will be trusted on both HTTP and SQL, if the - // headers are allowed. + // PROXY info from upstream will be trusted on both HTTP and SQL (on the + // ListenAddr port), if the headers are allowed. The ProxyProtocolListenAddr + // port, if specified, will require the proxy protocol regardless of + // RequireProxyProtocol. RequireProxyProtocol bool // testingKnobs are knobs used for testing. diff --git a/pkg/ccl/sqlproxyccl/server.go b/pkg/ccl/sqlproxyccl/server.go index 279635754e83..7867c66fbd91 100644 --- a/pkg/ccl/sqlproxyccl/server.go +++ b/pkg/ccl/sqlproxyccl/server.go @@ -216,8 +216,8 @@ func (s *Server) ServeHTTP(ctx context.Context, ln net.Listener) error { // Serve serves a listener according to the Options given in NewServer(). // Incoming client connections are taken through the Postgres handshake and // relayed to the configured backend server. -func (s *Server) Serve(ctx context.Context, ln net.Listener) error { - if s.handler.RequireProxyProtocol { +func (s *Server) Serve(ctx context.Context, ln net.Listener, requireProxyProtocol bool) error { + if requireProxyProtocol { ln = &proxyproto.Listener{ Listener: ln, Policy: func(upstream net.Addr) (proxyproto.Policy, error) { diff --git a/pkg/cli/cliflags/flags_mt.go b/pkg/cli/cliflags/flags_mt.go index 2590109a4e49..75ff46392a5a 100644 --- a/pkg/cli/cliflags/flags_mt.go +++ b/pkg/cli/cliflags/flags_mt.go @@ -47,6 +47,11 @@ wait for the tenant id to be fully written to the file (with a newline character Description: "Listen address for incoming connections.", } + ProxyProtocolListenAddr = FlagInfo{ + Name: "proxy-protocol-listen-addr", + Description: "Listen address for incoming connections which require proxy protocol headers.", + } + ThrottleBaseDelay = FlagInfo{ Name: "throttle-base-delay", Description: "Initial value for the exponential backoff used to throttle connection attempts.",