diff --git a/pkg/ccl/cliccl/flags.go b/pkg/ccl/cliccl/flags.go index 58bd3e8d99ba..fa1ebd92313b 100644 --- a/pkg/ccl/cliccl/flags.go +++ b/pkg/ccl/cliccl/flags.go @@ -21,6 +21,7 @@ func init() { cliflagcfg.StringFlag(f, &proxyContext.Denylist, cliflags.DenyList) cliflagcfg.StringFlag(f, &proxyContext.Allowlist, cliflags.AllowList) cliflagcfg.StringFlag(f, &proxyContext.ListenAddr, cliflags.ProxyListenAddr) + cliflagcfg.StringFlag(f, &proxyContext.ProxyProtocolListenAddr, cliflags.ProxyProtocolListenAddr) cliflagcfg.StringFlag(f, &proxyContext.ListenCert, cliflags.ListenCert) cliflagcfg.StringFlag(f, &proxyContext.ListenKey, cliflags.ListenKey) cliflagcfg.StringFlag(f, &proxyContext.MetricsAddress, cliflags.ListenMetrics) diff --git a/pkg/ccl/cliccl/mt_proxy.go b/pkg/ccl/cliccl/mt_proxy.go index c6d905b77589..8da06689c578 100644 --- a/pkg/ccl/cliccl/mt_proxy.go +++ b/pkg/ccl/cliccl/mt_proxy.go @@ -56,9 +56,20 @@ func runStartSQLProxy(cmd *cobra.Command, args []string) (returnErr error) { log.Infof(ctx, "New proxy with opts: %+v", proxyContext) - proxyLn, err := net.Listen("tcp", proxyContext.ListenAddr) - if err != nil { - return err + var proxyLn net.Listener + if proxyContext.ListenAddr != "" { + proxyLn, err = net.Listen("tcp", proxyContext.ListenAddr) + if err != nil { + return err + } + } + + var proxyProtocolLn net.Listener + if proxyContext.ProxyProtocolListenAddr != "" { + proxyProtocolLn, err = net.Listen("tcp", proxyContext.ProxyProtocolListenAddr) + if err != nil { + return err + } } metricsLn, err := net.Listen("tcp", proxyContext.MetricsAddress) @@ -84,15 +95,14 @@ func runStartSQLProxy(cmd *cobra.Command, args []string) (returnErr error) { } if err := stopper.RunAsyncTask(ctx, "serve-proxy", func(ctx context.Context) { - log.Infof(ctx, "proxy server listening at %s", proxyLn.Addr()) - if err := server.Serve(ctx, proxyLn); err != nil { + if err := server.ServeSQL(ctx, proxyLn, proxyProtocolLn); err != nil { errChan <- err } }); err != nil { return err } - return waitForSignals(ctx, server, stopper, proxyLn, errChan) + return waitForSignals(ctx, server, stopper, proxyLn, proxyProtocolLn, errChan) } func initLogging(cmd *cobra.Command) (ctx context.Context, stopper *stop.Stopper, err error) { @@ -110,6 +120,7 @@ func waitForSignals( server *sqlproxyccl.Server, stopper *stop.Stopper, proxyLn net.Listener, + proxyProtocolLn net.Listener, errChan chan error, ) (returnErr error) { // Need to alias the signals if this has to run on non-unix OSes too. @@ -139,7 +150,12 @@ func waitForSignals( // waiting for "shutdownConnectionTimeout" to elapse after which // open TCP connections will be forcefully closed so the server can stop log.Infof(ctx, "stopping tcp listener") - _ = proxyLn.Close() + if proxyLn != nil { + _ = proxyLn.Close() + } + if proxyProtocolLn != nil { + _ = proxyProtocolLn.Close() + } select { case <-server.AwaitNoConnections(ctx): case <-time.After(shutdownConnectionTimeout): diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go index 423e9404f9dc..30d0aedabd5e 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -73,6 +73,10 @@ type ProxyOptions struct { Denylist string // ListenAddr is the listen address for incoming connections. ListenAddr string + // ProxyProtocolListenAddr is the optional listen address for incoming + // connections for which it will be enforced that the connections have proxy + // headers set. + ProxyProtocolListenAddr string // ListenCert is the file containing PEM-encoded x509 certificate for listen // address. Set to "*" to auto-generate self-signed cert. ListenCert string @@ -113,8 +117,10 @@ type ProxyOptions struct { DisableConnectionRebalancing bool // RequireProxyProtocol changes the server's behavior to support the PROXY // protocol (SQL=required, HTTP=best-effort). With this set to true, the - // PROXY info from upstream will be trusted on both HTTP and SQL, if the - // headers are allowed. + // PROXY info from upstream will be trusted on both HTTP and SQL (on the + // ListenAddr port), if the headers are allowed. The ProxyProtocolListenAddr + // port, if specified, will require the proxy protocol regardless of + // RequireProxyProtocol. RequireProxyProtocol bool // testingKnobs are knobs used for testing. diff --git a/pkg/ccl/sqlproxyccl/proxy_handler_test.go b/pkg/ccl/sqlproxyccl/proxy_handler_test.go index 96416e89e35b..ac3a5ba58a79 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler_test.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler_test.go @@ -67,6 +67,12 @@ const backendError = "Backend error!" // the test directory server. const notFoundTenantID = 99 +type serverAddresses struct { + listenAddr string + proxyProtocolListenAddr string + httpAddr string +} + func TestProxyHandler_ValidateConnection(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) @@ -91,7 +97,7 @@ func TestProxyHandler_ValidateConnection(t *testing.T) { options := &ProxyOptions{} options.testingKnobs.directoryServer = tds - s, _, _ := newSecureProxyServer(ctx, t, stop, options) + s, _ := newSecureProxyServer(ctx, t, stop, options) t.Run("not found/no cluster name", func(t *testing.T) { err := s.handler.validateConnection(ctx, invalidTenantID, "") @@ -154,7 +160,7 @@ func TestProxyProtocol(t *testing.T) { sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) var validateFn func(h *proxyproto.Header) error - withProxyProtocol := func(p bool) (server *Server, addr, httpAddr string) { + withProxyProtocol := func(p bool) (server *Server, addrs *serverAddresses) { options := &ProxyOptions{ RoutingRule: sql.ServingSQLAddr(), SkipVerify: true, @@ -212,8 +218,6 @@ func TestProxyProtocol(t *testing.T) { } } - s, sqlAddr, httpAddr := withProxyProtocol(true) - defer testutils.TestingHook(&validateFn, func(h *proxyproto.Header) error { if h.SourceAddr.String() != "10.20.30.40:4242" { return errors.Newf("got source addr %s, expected 10.20.30.40:4242", h.SourceAddr) @@ -221,25 +225,71 @@ func TestProxyProtocol(t *testing.T) { return nil })() - // Test SQL. Only request with PROXY should go through. - url := fmt.Sprintf("postgres://bob:builder@%s/tenant-cluster-42.defaultdb?sslmode=require", sqlAddr) - te.TestConnectWithPGConfig( - ctx, t, url, - func(c *pgx.ConnConfig) { - c.DialFunc = proxyDialer - }, - func(conn *pgx.Conn) { - require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) - require.NoError(t, runTestQuery(ctx, conn)) - }, - ) - _ = te.TestConnectErr(ctx, t, url, codeClientReadFailed, "tls error") + testSQLNoRequiredProxyProtocol := func(s *Server, addr string) { + url := fmt.Sprintf("postgres://bob:builder@%s/tenant-cluster-42.defaultdb?sslmode=require", addr) + // No proxy protocol. + te.TestConnect(ctx, t, url, + func(conn *pgx.Conn) { + t.Log("B") + require.NotZero(t, s.metrics.CurConnCount.Value()) + require.NoError(t, runTestQuery(ctx, conn)) + }, + ) + // Proxy protocol. + _ = te.TestConnectErrWithPGConfig( + ctx, t, url, + func(c *pgx.ConnConfig) { + c.DialFunc = proxyDialer + }, codeClientReadFailed, "tls error", + ) + } + + testSQLRequiredProxyProtocol := func(s *Server, addr string) { + url := fmt.Sprintf("postgres://bob:builder@%s/tenant-cluster-42.defaultdb?sslmode=require", addr) + // No proxy protocol. + _ = te.TestConnectErr(ctx, t, url, codeClientReadFailed, "tls error") + // Proxy protocol. + te.TestConnectWithPGConfig( + ctx, t, url, + func(c *pgx.ConnConfig) { + c.DialFunc = proxyDialer + }, + func(conn *pgx.Conn) { + require.NotZero(t, s.metrics.CurConnCount.Value()) + require.NoError(t, runTestQuery(ctx, conn)) + }, + ) + } - // Test HTTP. Should support with or without PROXY. - client := http.Client{Timeout: timeout} - makeHttpReq(t, &client, httpAddr, true) - proxyClient := http.Client{Transport: &http.Transport{DialContext: proxyDialer}} - makeHttpReq(t, &proxyClient, httpAddr, true) + t.Run("server doesn't require proxy protocol", func(t *testing.T) { + s, addrs := withProxyProtocol(false) + // Test SQL on the default listener. Both should go through. + testSQLNoRequiredProxyProtocol(s, addrs.listenAddr) + // Test SQL on the proxy protocol listener. Only request with PROXY should go + // through. + testSQLRequiredProxyProtocol(s, addrs.proxyProtocolListenAddr) + + // Test HTTP. Shouldn't support PROXY. + client := http.Client{Timeout: timeout} + makeHttpReq(t, &client, addrs.httpAddr, true) + proxyClient := http.Client{Transport: &http.Transport{DialContext: proxyDialer}} + makeHttpReq(t, &proxyClient, addrs.httpAddr, false) + }) + + t.Run("server requires proxy protocol", func(t *testing.T) { + s, addrs := withProxyProtocol(true) + // Test SQL on the default listener. Both should go through. + testSQLRequiredProxyProtocol(s, addrs.listenAddr) + // Test SQL on the proxy protocol listener. Only request with PROXY should go + // through. + testSQLRequiredProxyProtocol(s, addrs.proxyProtocolListenAddr) + + // Test HTTP. Should support with or without PROXY. + client := http.Client{Timeout: timeout} + makeHttpReq(t, &client, addrs.httpAddr, true) + proxyClient := http.Client{Transport: &http.Transport{DialContext: proxyDialer}} + makeHttpReq(t, &proxyClient, addrs.httpAddr, true) + }) } func TestPrivateEndpointsACL(t *testing.T) { @@ -307,7 +357,7 @@ func TestPrivateEndpointsACL(t *testing.T) { PollConfigInterval: 10 * time.Millisecond, } options.testingKnobs.directoryServer = tds - s, sqlAddr, _ := newSecureProxyServer(ctx, t, sql.Stopper(), options) + s, addrs := newSecureProxyServer(ctx, t, sql.Stopper(), options) timeout := 3 * time.Second proxyDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { @@ -344,7 +394,7 @@ func TestPrivateEndpointsACL(t *testing.T) { } t.Run("private connection allowed", func(t *testing.T) { - url := fmt.Sprintf("postgres://bob:builder@%s/my-tenant-10.defaultdb?sslmode=require", sqlAddr) + url := fmt.Sprintf("postgres://bob:builder@%s/my-tenant-10.defaultdb?sslmode=require", addrs.listenAddr) te.TestConnectWithPGConfig( ctx, t, url, func(c *pgx.ConnConfig) { @@ -395,7 +445,7 @@ func TestPrivateEndpointsACL(t *testing.T) { }) t.Run("private connection disallowed on another tenant", func(t *testing.T) { - url := fmt.Sprintf("postgres://bob:builder@%s/other-tenant-20.defaultdb?sslmode=require", sqlAddr) + url := fmt.Sprintf("postgres://bob:builder@%s/other-tenant-20.defaultdb?sslmode=require", addrs.listenAddr) _ = te.TestConnectErrWithPGConfig( ctx, t, url, func(c *pgx.ConnConfig) { @@ -407,7 +457,7 @@ func TestPrivateEndpointsACL(t *testing.T) { }) t.Run("private connection disallowed on public tenant", func(t *testing.T) { - url := fmt.Sprintf("postgres://bob:builder@%s/public-tenant-30.defaultdb?sslmode=require", sqlAddr) + url := fmt.Sprintf("postgres://bob:builder@%s/public-tenant-30.defaultdb?sslmode=require", addrs.listenAddr) _ = te.TestConnectErrWithPGConfig( ctx, t, url, func(c *pgx.ConnConfig) { @@ -483,10 +533,10 @@ func TestAllowedCIDRRangesACL(t *testing.T) { PollConfigInterval: 10 * time.Millisecond, } options.testingKnobs.directoryServer = tds - s, sqlAddr, _ := newSecureProxyServer(ctx, t, sql.Stopper(), options) + s, addrs := newSecureProxyServer(ctx, t, sql.Stopper(), options) t.Run("public connection allowed", func(t *testing.T) { - url := fmt.Sprintf("postgres://bob:builder@%s/my-tenant-10.defaultdb?sslmode=require", sqlAddr) + url := fmt.Sprintf("postgres://bob:builder@%s/my-tenant-10.defaultdb?sslmode=require", addrs.listenAddr) te.TestConnect(ctx, t, url, func(conn *pgx.Conn) { // Initial connection. require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) @@ -530,12 +580,12 @@ func TestAllowedCIDRRangesACL(t *testing.T) { }) t.Run("public connection disallowed on another tenant", func(t *testing.T) { - url := fmt.Sprintf("postgres://bob:builder@%s/other-tenant-20.defaultdb?sslmode=require", sqlAddr) + url := fmt.Sprintf("postgres://bob:builder@%s/other-tenant-20.defaultdb?sslmode=require", addrs.listenAddr) _ = te.TestConnectErr(ctx, t, url, codeProxyRefusedConnection, "connection refused") }) t.Run("public connection disallowed on private tenant", func(t *testing.T) { - url := fmt.Sprintf("postgres://bob:builder@%s/private-tenant-30.defaultdb?sslmode=require", sqlAddr) + url := fmt.Sprintf("postgres://bob:builder@%s/private-tenant-30.defaultdb?sslmode=require", addrs.listenAddr) _ = te.TestConnectErr(ctx, t, url, codeProxyRefusedConnection, "connection refused") }) } @@ -557,11 +607,11 @@ func TestLongDBName(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - s, addr, _ := newSecureProxyServer( + s, addrs := newSecureProxyServer( ctx, t, stopper, &ProxyOptions{RoutingRule: "127.0.0.1:26257"}) longDB := strings.Repeat("x", 70) // 63 is limit - pgurl := fmt.Sprintf("postgres://unused:unused@%s/%s?options=--cluster=tenant-cluster-28&sslmode=require", addr, longDB) + pgurl := fmt.Sprintf("postgres://unused:unused@%s/%s?options=--cluster=tenant-cluster-28&sslmode=require", addrs.listenAddr, longDB) _ = te.TestConnectErr(ctx, t, pgurl, codeParamsRoutingFailed, "boom") require.Equal(t, int64(1), s.metrics.RoutingErrCount.Count()) } @@ -583,7 +633,7 @@ func TestBackendDownRetry(t *testing.T) { // Set RefreshDelay to -1 so that we could simulate a ListPod call under // the hood, which then triggers an EnsurePod again. opts.testingKnobs.dirOpts = []tenant.DirOption{tenant.RefreshDelay(-1)} - server, addr, _ := newSecureProxyServer(ctx, t, stopper, opts) + server, addrs := newSecureProxyServer(ctx, t, stopper, opts) directoryServer := mustGetTestSimpleDirectoryServer(t, server.handler) callCount := 0 @@ -599,7 +649,7 @@ func TestBackendDownRetry(t *testing.T) { })() // Valid connection, but no backend server running. - pgurl := fmt.Sprintf("postgres://unused:unused@%s/db?options=--cluster=tenant-cluster-28&sslmode=require", addr) + pgurl := fmt.Sprintf("postgres://unused:unused@%s/db?options=--cluster=tenant-cluster-28&sslmode=require", addrs.listenAddr) _ = te.TestConnectErr(ctx, t, pgurl, codeParamsRoutingFailed, "cluster tenant-cluster-28 not found") require.Equal(t, 3, callCount) } @@ -614,12 +664,12 @@ func TestFailedConnection(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - s, proxyAddr, _ := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{RoutingRule: "undialable%$!@$"}) + s, addrs := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{RoutingRule: "undialable%$!@$"}) // TODO(asubiotto): consider using datadriven for these, especially if the // proxy becomes more complex. - _, p, err := addr.SplitHostPort(proxyAddr, "") + _, p, err := addr.SplitHostPort(addrs.listenAddr, "") require.NoError(t, err) u := fmt.Sprintf("postgres://unused:unused@localhost:%s/", p) @@ -674,9 +724,9 @@ func TestUnexpectedError(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - _, addr, _ := newProxyServer(ctx, t, stopper, &ProxyOptions{}) + _, addrs := newProxyServer(ctx, t, stopper, &ProxyOptions{}) - u := fmt.Sprintf("postgres://root:admin@%s/?sslmode=disable&connect_timeout=5", addr) + u := fmt.Sprintf("postgres://root:admin@%s/?sslmode=disable&connect_timeout=5", addrs.listenAddr) // Time how long it takes for pgx.Connect to return. If the proxy handles // errors appropriately, pgx.Connect should return near immediately @@ -716,10 +766,10 @@ func TestProxyAgainstSecureCRDB(t *testing.T) { sqlDB := sqlutils.MakeSQLRunner(db) sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) - s, addr, _ := newSecureProxyServer( + s, addrs := newSecureProxyServer( ctx, t, sql.Stopper(), &ProxyOptions{RoutingRule: sql.ServingSQLAddr(), SkipVerify: true}, ) - _, port, err := net.SplitHostPort(addr) + _, port, err := net.SplitHostPort(addrs.listenAddr) require.NoError(t, err) for _, tc := range []struct { @@ -732,7 +782,7 @@ func TestProxyAgainstSecureCRDB(t *testing.T) { name: "failed_SASL_auth_1", url: fmt.Sprintf( "postgres://bob:wrong@%s/tenant-cluster-28.defaultdb?sslmode=require", - addr, + addrs.listenAddr, ), expErr: "failed SASL auth", }, @@ -740,7 +790,7 @@ func TestProxyAgainstSecureCRDB(t *testing.T) { name: "failed_SASL_auth_2", url: fmt.Sprintf( "postgres://bob@%s/tenant-cluster-28.defaultdb?sslmode=require", - addr, + addrs.listenAddr, ), expErr: "failed SASL auth", }, @@ -765,7 +815,7 @@ func TestProxyAgainstSecureCRDB(t *testing.T) { name: "database_provides_tenant_ID", url: fmt.Sprintf( "postgres://bob:builder@%s/tenant-cluster-28.defaultdb?sslmode=require", - addr, + addrs.listenAddr, ), }, { @@ -828,12 +878,12 @@ func TestProxyTLSConf(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - _, addr, _ := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{ + _, addrs := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{ Insecure: true, RoutingRule: "127.0.0.1:26257", }) - pgurl := fmt.Sprintf("postgres://unused:unused@%s/%s?options=--cluster=tenant-cluster-28&sslmode=require", addr, "defaultdb") + pgurl := fmt.Sprintf("postgres://unused:unused@%s/%s?options=--cluster=tenant-cluster-28&sslmode=require", addrs.listenAddr, "defaultdb") _ = te.TestConnectErr(ctx, t, pgurl, codeParamsRoutingFailed, "boom") }) @@ -851,13 +901,13 @@ func TestProxyTLSConf(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - _, addr, _ := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{ + _, addrs := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{ Insecure: false, SkipVerify: true, RoutingRule: "127.0.0.1:26257", }) - pgurl := fmt.Sprintf("postgres://unused:unused@%s/%s?options=--cluster=tenant-cluster-28&sslmode=require", addr, "defaultdb") + pgurl := fmt.Sprintf("postgres://unused:unused@%s/%s?options=--cluster=tenant-cluster-28&sslmode=require", addrs.listenAddr, "defaultdb") _ = te.TestConnectErr(ctx, t, pgurl, codeParamsRoutingFailed, "boom") }) @@ -879,13 +929,13 @@ func TestProxyTLSConf(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - _, addr, _ := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{ + _, addrs := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{ Insecure: false, SkipVerify: false, RoutingRule: "127.0.0.1:26257", }) - pgurl := fmt.Sprintf("postgres://unused:unused@%s/%s?options=--cluster=tenant-cluster-28&sslmode=require", addr, "defaultdb") + pgurl := fmt.Sprintf("postgres://unused:unused@%s/%s?options=--cluster=tenant-cluster-28&sslmode=require", addrs.listenAddr, "defaultdb") _ = te.TestConnectErr(ctx, t, pgurl, codeParamsRoutingFailed, "boom") }) @@ -928,11 +978,11 @@ func TestProxyTLSClose(t *testing.T) { return originalFrontendAdmit(conn, incomingTLSConfig) })() - s, addr, _ := newSecureProxyServer( + s, addrs := newSecureProxyServer( ctx, t, sql.Stopper(), &ProxyOptions{RoutingRule: sql.ServingSQLAddr(), SkipVerify: true}, ) - url := fmt.Sprintf("postgres://bob:builder@%s/tenant-cluster-28.defaultdb?sslmode=require", addr) + url := fmt.Sprintf("postgres://bob:builder@%s/tenant-cluster-28.defaultdb?sslmode=require", addrs.listenAddr) conn, err := pgx.Connect(ctx, url) require.NoError(t, err) @@ -1006,9 +1056,9 @@ func TestProxyModifyRequestParams(t *testing.T) { return originalBackendDial(msg, sql.ServingSQLAddr(), proxyOutgoingTLSConfig) })() - s, proxyAddr, _ := newSecureProxyServer(ctx, t, sql.Stopper(), &ProxyOptions{}) + s, addrs := newSecureProxyServer(ctx, t, sql.Stopper(), &ProxyOptions{}) - u := fmt.Sprintf("postgres://bogususer:foo123@%s/?sslmode=require&authToken=abc123&options=--cluster=tenant-cluster-28&sslmode=require", proxyAddr) + u := fmt.Sprintf("postgres://bogususer:foo123@%s/?sslmode=require&authToken=abc123&options=--cluster=tenant-cluster-28&sslmode=require", addrs.listenAddr) te.TestConnect(ctx, t, u, func(conn *pgx.Conn) { require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) require.NoError(t, runTestQuery(ctx, conn)) @@ -1041,14 +1091,14 @@ func TestInsecureProxy(t *testing.T) { sqlDB := sqlutils.MakeSQLRunner(db) sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) - s, addr, _ := newProxyServer( + s, addrs := newProxyServer( ctx, t, sql.Stopper(), &ProxyOptions{RoutingRule: sql.ServingSQLAddr(), SkipVerify: true}, ) - url := fmt.Sprintf("postgres://bob:wrong@%s?sslmode=disable&options=--cluster=tenant-cluster-28&sslmode=require", addr) + url := fmt.Sprintf("postgres://bob:wrong@%s?sslmode=disable&options=--cluster=tenant-cluster-28&sslmode=require", addrs.listenAddr) _ = te.TestConnectErr(ctx, t, url, 0, "failed SASL auth") - url = fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=tenant-cluster-28&sslmode=require", addr) + url = fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=tenant-cluster-28&sslmode=require", addrs.listenAddr) te.TestConnect(ctx, t, url, func(conn *pgx.Conn) { require.NoError(t, runTestQuery(ctx, conn)) }) @@ -1082,9 +1132,9 @@ func TestErroneousFrontend(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - _, addr, _ := newProxyServer(ctx, t, stopper, &ProxyOptions{}) + _, addrs := newProxyServer(ctx, t, stopper, &ProxyOptions{}) - url := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=tenant-cluster-28&sslmode=require", addr) + url := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=tenant-cluster-28&sslmode=require", addrs.listenAddr) // Generic message here as the Frontend's error is not codeError and // by default we don't pass back error's text. The startup message doesn't @@ -1114,9 +1164,9 @@ func TestErrorHint(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - _, addr, _ := newProxyServer(ctx, t, stopper, &ProxyOptions{}) + _, addrs := newProxyServer(ctx, t, stopper, &ProxyOptions{}) - url := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=tenant-cluster-28&sslmode=require", addr) + url := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=tenant-cluster-28&sslmode=require", addrs.listenAddr) err := te.TestConnectErr(ctx, t, url, 0, "codeParamsRoutingFailed: Frontend error") pgErr := (*pgconn.PgError)(nil) @@ -1140,9 +1190,9 @@ func TestErroneousBackend(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - _, addr, _ := newProxyServer(ctx, t, stopper, &ProxyOptions{}) + _, addrs := newProxyServer(ctx, t, stopper, &ProxyOptions{}) - url := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=tenant-cluster-28&sslmode=require", addr) + url := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=tenant-cluster-28&sslmode=require", addrs.listenAddr) // Generic message here as the Backend's error is not codeError and // by default we don't pass back error's text. The startup message has @@ -1166,9 +1216,9 @@ func TestProxyRefuseConn(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - s, addr, _ := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{}) + s, addrs := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{}) - url := fmt.Sprintf("postgres://root:admin@%s?sslmode=require&options=--cluster=tenant-cluster-28&sslmode=require", addr) + url := fmt.Sprintf("postgres://root:admin@%s?sslmode=require&options=--cluster=tenant-cluster-28&sslmode=require", addrs.listenAddr) _ = te.TestConnectErr(ctx, t, url, codeProxyRefusedConnection, "too many attempts") require.Equal(t, int64(1), s.metrics.RefusedConnCount.Count()) require.Equal(t, int64(0), s.metrics.SuccessfulConnCount.Count()) @@ -1187,7 +1237,7 @@ func TestProxyHandler_handle(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - proxy, _, _ := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{}) + proxy, _ := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{}) p1, p2 := net.Pipe() require.NoError(t, p1.Close()) @@ -1266,10 +1316,10 @@ func TestDenylistUpdate(t *testing.T) { PollConfigInterval: 10 * time.Millisecond, } opts.testingKnobs.directoryServer = tds - s, addr, _ := newSecureProxyServer(ctx, t, sql.Stopper(), opts) + s, addrs := newSecureProxyServer(ctx, t, sql.Stopper(), opts) // Establish a connection. - url := fmt.Sprintf("postgres://testuser:foo123@%s/defaultdb?sslmode=require&options=--cluster=tenant-cluster-%s&sslmode=require", addr, tenantID) + url := fmt.Sprintf("postgres://testuser:foo123@%s/defaultdb?sslmode=require&options=--cluster=tenant-cluster-%s&sslmode=require", addrs.listenAddr, tenantID) db, err := gosql.Open("postgres", url) db.SetMaxOpenConns(1) defer db.Close() @@ -1348,11 +1398,11 @@ func TestDirectoryConnect(t *testing.T) { // Start the proxy server using the static directory server. opts := &ProxyOptions{SkipVerify: true} opts.testingKnobs.directoryServer = tds - _, addr, _ := newSecureProxyServer(ctx, t, s.Stopper(), opts) - connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addr, tenantID) + _, addrs := newSecureProxyServer(ctx, t, s.Stopper(), opts) + connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addrs.listenAddr, tenantID) t.Run("tenant not found", func(t *testing.T) { - url := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%d", addr, notFoundTenantID) + url := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%d", addrs.listenAddr, notFoundTenantID) _ = te.TestConnectErr(ctx, t, url, codeParamsRoutingFailed, "cluster tenant-cluster-99 not found") }) @@ -1461,8 +1511,8 @@ func TestConnectionRebalancingDisabled(t *testing.T) { opts := &ProxyOptions{SkipVerify: true, DisableConnectionRebalancing: true} opts.testingKnobs.directoryServer = tds - proxy, addr, _ := newSecureProxyServer(ctx, t, s.Stopper(), opts) - connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addr, tenantID) + proxy, addrs := newSecureProxyServer(ctx, t, s.Stopper(), opts) + connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addrs.listenAddr, tenantID) // Open 12 connections to the first pod. dist := map[string]int{} @@ -1570,10 +1620,10 @@ func TestCancelQuery(t *testing.T) { balancer.RebalanceRate(1), balancer.RebalanceDelay(-1), } - proxy, addr, httpAddr := newSecureProxyServer(ctx, t, s.Stopper(), opts) + proxy, addrs := newSecureProxyServer(ctx, t, s.Stopper(), opts) connectionString := fmt.Sprintf( "postgres://testuser:hunter2@%s/defaultdb?sslmode=require&sslrootcert=%s&options=--cluster=tenant-cluster-%s", - addr, datapathutils.TestDataPath(t, "testserver.crt"), tenantID, + addrs.listenAddr, datapathutils.TestDataPath(t, "testserver.crt"), tenantID, ) // Open a connection to the first pod. @@ -1658,7 +1708,7 @@ func TestCancelQuery(t *testing.T) { SecretKey: conn.PgConn().SecretKey(), ClientIP: net.IP{127, 0, 0, 1}, } - u := "http://" + httpAddr + "/_status/cancel/" + u := "http://" + addrs.httpAddr + "/_status/cancel/" reqBody := bytes.NewReader(cancelRequest.Encode()) client := http.Client{ Timeout: 10 * time.Second, @@ -1745,7 +1795,7 @@ func TestCancelQuery(t *testing.T) { SecretKey: conn.PgConn().SecretKey(), ClientIP: net.IP{210, 1, 2, 3}, } - u := "http://" + httpAddr + "/_status/cancel/" + u := "http://" + addrs.httpAddr + "/_status/cancel/" reqBody := bytes.NewReader(cancelRequest.Encode()) client := http.Client{ Timeout: 10 * time.Second, @@ -1832,7 +1882,7 @@ func TestCancelQuery(t *testing.T) { SecretKey: conn.PgConn().SecretKey() + 1, ClientIP: net.IP{127, 0, 0, 1}, } - u := "http://" + httpAddr + "/_status/cancel/" + u := "http://" + addrs.httpAddr + "/_status/cancel/" reqBody := bytes.NewReader(cancelRequest.Encode()) client := http.Client{ Timeout: 10 * time.Second, @@ -1901,8 +1951,8 @@ func TestPodWatcher(t *testing.T) { balancer.NoRebalanceLoop(), balancer.RebalanceRate(1.0), } - proxy, addr, _ := newSecureProxyServer(ctx, t, s.Stopper(), opts) - connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addr, tenantID) + proxy, addrs := newSecureProxyServer(ctx, t, s.Stopper(), opts) + connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addrs.listenAddr, tenantID) // Open 12 connections to it. The balancer should distribute the connections // evenly across 3 SQL pods (i.e. 4 connections each). @@ -1997,9 +2047,9 @@ func TestConnectionMigration(t *testing.T) { // loads. For this test, we will stub out lookupAddr in the connector. We // will alternate between tenant1 and tenant2, starting with tenant1. opts := &ProxyOptions{SkipVerify: true, RoutingRule: tenant1.SQLAddr()} - proxy, addr, _ := newSecureProxyServer(ctx, t, s.Stopper(), opts) + proxy, addrs := newSecureProxyServer(ctx, t, s.Stopper(), opts) - connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addr, tenantID) + connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addrs.listenAddr, tenantID) // validateMiscMetrics ensures that our invariant of // attempts = success + error_recoverable + error_fatal is valid, and all @@ -2382,10 +2432,10 @@ func TestAcceptedConnCountMetric(t *testing.T) { opts := &ProxyOptions{SkipVerify: true, DisableConnectionRebalancing: true} opts.testingKnobs.directoryServer = tds - proxy, addr, _ := newSecureProxyServer(ctx, t, s.Stopper(), opts) + proxy, addrs := newSecureProxyServer(ctx, t, s.Stopper(), opts) - goodConnStr := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addr, tenantID) - badConnStr := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=nocluster", addr) + goodConnStr := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addrs.listenAddr, tenantID) + badConnStr := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=nocluster", addrs.listenAddr) const ( numGood = 5 @@ -2420,7 +2470,7 @@ func TestAcceptedConnCountMetric(t *testing.T) { go func() { defer wg.Done() - conn, err := net.DialTimeout("tcp", addr, 3*time.Second) + conn, err := net.DialTimeout("tcp", addrs.listenAddr, 3*time.Second) defer func() { if conn != nil { _ = conn.Close() @@ -2472,8 +2522,8 @@ func TestCurConnCountMetric(t *testing.T) { opts := &ProxyOptions{SkipVerify: true, DisableConnectionRebalancing: true} opts.testingKnobs.directoryServer = tds - proxy, addr, _ := newSecureProxyServer(ctx, t, s.Stopper(), opts) - connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addr, tenantID) + proxy, addrs := newSecureProxyServer(ctx, t, s.Stopper(), opts) + connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addrs.listenAddr, tenantID) // Open 500 connections to the SQL pod. const numConns = 500 @@ -2914,7 +2964,7 @@ func (te *tester) TestConnectErrWithPGConfig( func newSecureProxyServer( ctx context.Context, t *testing.T, stopper *stop.Stopper, opts *ProxyOptions, -) (server *Server, addr, httpAddr string) { +) (server *Server, addrs *serverAddresses) { // Created via: const _ = ` openssl genrsa -out testdata/testserver.key 2048 @@ -2929,12 +2979,15 @@ openssl req -new -x509 -sha256 -key testdata/testserver.key -out testdata/testse func newProxyServer( ctx context.Context, t *testing.T, stopper *stop.Stopper, opts *ProxyOptions, -) (server *Server, addr, httpAddr string) { +) (server *Server, addrs *serverAddresses) { const listenAddress = "127.0.0.1:0" ctx, _ = stopper.WithCancelOnQuiesce(ctx) ln, err := net.Listen("tcp", listenAddress) require.NoError(t, err) stopper.AddCloser(stop.CloserFn(func() { _ = ln.Close() })) + proxyProtocolLn, err := net.Listen("tcp", listenAddress) + require.NoError(t, err) + stopper.AddCloser(stop.CloserFn(func() { _ = proxyProtocolLn.Close() })) httpLn, err := net.Listen("tcp", listenAddress) require.NoError(t, err) stopper.AddCloser(stop.CloserFn(func() { _ = httpLn.Close() })) @@ -2943,7 +2996,7 @@ func newProxyServer( require.NoError(t, err) err = server.Stopper.RunAsyncTask(ctx, "proxy-server-serve", func(ctx context.Context) { - _ = server.Serve(ctx, ln) + _ = server.ServeSQL(ctx, ln, proxyProtocolLn) }) require.NoError(t, err) err = server.Stopper.RunAsyncTask(ctx, "proxy-http-server-serve", func(ctx context.Context) { @@ -2951,7 +3004,11 @@ func newProxyServer( }) require.NoError(t, err) - return server, ln.Addr().String(), httpLn.Addr().String() + return server, &serverAddresses{ + listenAddr: ln.Addr().String(), + proxyProtocolListenAddr: proxyProtocolLn.Addr().String(), + httpAddr: httpLn.Addr().String(), + } } func runTestQuery(ctx context.Context, conn *pgx.Conn) error { diff --git a/pkg/ccl/sqlproxyccl/server.go b/pkg/ccl/sqlproxyccl/server.go index dbc4a196fbf8..9cacaf1c7b2d 100644 --- a/pkg/ccl/sqlproxyccl/server.go +++ b/pkg/ccl/sqlproxyccl/server.go @@ -209,22 +209,45 @@ func (s *Server) ServeHTTP(ctx context.Context, ln net.Listener) error { return nil } -// Serve serves a listener according to the Options given in NewServer(). +// Serve serves up to two listeners according to the Options given in +// NewServer(). +// +// If ln is not nil, a listener is served which does not require +// proxy protocol headers, unless RequireProxyProtocol is true. +// +// If proxyProtocolLn is not nil, a listener is served which requires proxy +// protocol headers. +// // Incoming client connections are taken through the Postgres handshake and // relayed to the configured backend server. -func (s *Server) Serve(ctx context.Context, ln net.Listener) error { - if s.handler.RequireProxyProtocol { - ln = &proxyproto.Listener{ - Listener: ln, - Policy: func(upstream net.Addr) (proxyproto.Policy, error) { - // REQUIRE enforces the connection to send a PROXY header. - // The connection will be rejected if one was not present. - return proxyproto.REQUIRE, nil - }, - ValidateHeader: s.handler.testingKnobs.validateProxyHeader, +func (s *Server) ServeSQL( + ctx context.Context, ln net.Listener, proxyProtocolLn net.Listener, +) error { + if ln != nil { + if s.handler.RequireProxyProtocol { + ln = s.requireProxyProtocolOnListener(ln) + } + log.Infof(ctx, "proxy server listening at %s", ln.Addr()) + if err := s.Stopper.RunAsyncTask(ctx, "listener-serve", func(ctx context.Context) { + _ = s.serve(ctx, ln) + }); err != nil { + return err } } + if proxyProtocolLn != nil { + proxyProtocolLn = s.requireProxyProtocolOnListener(proxyProtocolLn) + log.Infof(ctx, "proxy with required proxy headers server listening at %s", proxyProtocolLn.Addr()) + if err := s.Stopper.RunAsyncTask(ctx, "proxy-protocol-listener-serve", func(ctx context.Context) { + _ = s.serve(ctx, proxyProtocolLn) + }); err != nil { + return err + } + } + return nil +} +// serve is called by ServeSQL to serve a single listener. +func (s *Server) serve(ctx context.Context, ln net.Listener) error { err := s.Stopper.RunAsyncTask(ctx, "listen-quiesce", func(ctx context.Context) { <-s.Stopper.ShouldQuiesce() if err := ln.Close(); err != nil && !grpcutil.IsClosedConnection(err) { @@ -258,6 +281,18 @@ func (s *Server) Serve(ctx context.Context, ln net.Listener) error { } } +func (s *Server) requireProxyProtocolOnListener(ln net.Listener) net.Listener { + return &proxyproto.Listener{ + Listener: ln, + Policy: func(upstream net.Addr) (proxyproto.Policy, error) { + // REQUIRE enforces the connection to send a PROXY header. + // The connection will be rejected if one was not present. + return proxyproto.REQUIRE, nil + }, + ValidateHeader: s.handler.testingKnobs.validateProxyHeader, + } +} + // AwaitNoConnections returns a channel that is closed once the server has no open connections. // This is meant to be used after the server has stopped accepting new connections and we are // waiting to shutdown the server without inturrupting existing connections diff --git a/pkg/cli/cliflags/flags_mt.go b/pkg/cli/cliflags/flags_mt.go index 2590109a4e49..b9d52e6d2218 100644 --- a/pkg/cli/cliflags/flags_mt.go +++ b/pkg/cli/cliflags/flags_mt.go @@ -47,6 +47,11 @@ wait for the tenant id to be fully written to the file (with a newline character Description: "Listen address for incoming connections.", } + ProxyProtocolListenAddr = FlagInfo{ + Name: "proxy-protocol-listen-addr", + Description: "Listen address for incoming connections which require proxy protocol headers.", + } + ThrottleBaseDelay = FlagInfo{ Name: "throttle-base-delay", Description: "Initial value for the exponential backoff used to throttle connection attempts.", @@ -94,6 +99,9 @@ wait for the tenant id to be fully written to the file (with a newline character Description: "If true, proxy will not attempt to rebalance connections.", } + // TODO(joel): Remove this flag, and use --listen-addr for a non-proxy + // protocol listener, and use --proxy-protocol-listen-addr for a proxy + // protocol listener. RequireProxyProtocol = FlagInfo{ Name: "require-proxy-protocol", Description: `Requires PROXY protocol on the SQL listener. The HTTP