From 32ea1a14780fe7e58d773fea9f2da182ba9ad5ae Mon Sep 17 00:00:00 2001 From: Joel Kenny Date: Tue, 2 Jan 2024 18:45:31 +0000 Subject: [PATCH] sqlproxyccl: add proxy protocol listener To support GCP Private Service Connect, we need to have a listener in SQLProxy which expects packets to contain proxy protocol headers. This listener will be used for all traffic inbound from PSC. At the same time, SQLProxy must continue to accept connections through the public Internet which will not contain proxy protocol headers, and for which any proxy protocol headers we receive cannot be trusted. This commit introduces an optional second listener in SQLProxy, controlled by `--proxy-protocol-listen-addr`, which requires proxy protocol even as the primary listener doesn't. Private Service Connect will direct traffic to this second listener. Resolves #117240 Release note: None --- pkg/ccl/cliccl/flags.go | 1 + pkg/ccl/cliccl/mt_proxy.go | 30 ++- pkg/ccl/sqlproxyccl/proxy_handler.go | 10 +- pkg/ccl/sqlproxyccl/proxy_handler_test.go | 239 ++++++++++++++-------- pkg/ccl/sqlproxyccl/server.go | 57 +++++- pkg/cli/cliflags/flags_mt.go | 8 + 6 files changed, 234 insertions(+), 111 deletions(-) diff --git a/pkg/ccl/cliccl/flags.go b/pkg/ccl/cliccl/flags.go index 0c19bd81e0b5..efa5e3653eff 100644 --- a/pkg/ccl/cliccl/flags.go +++ b/pkg/ccl/cliccl/flags.go @@ -69,6 +69,7 @@ func init() { cliflagcfg.StringFlag(f, &proxyContext.Denylist, cliflags.DenyList) cliflagcfg.StringFlag(f, &proxyContext.Allowlist, cliflags.AllowList) cliflagcfg.StringFlag(f, &proxyContext.ListenAddr, cliflags.ProxyListenAddr) + cliflagcfg.StringFlag(f, &proxyContext.ProxyProtocolListenAddr, cliflags.ProxyProtocolListenAddr) cliflagcfg.StringFlag(f, &proxyContext.ListenCert, cliflags.ListenCert) cliflagcfg.StringFlag(f, &proxyContext.ListenKey, cliflags.ListenKey) cliflagcfg.StringFlag(f, &proxyContext.MetricsAddress, cliflags.ListenMetrics) diff --git a/pkg/ccl/cliccl/mt_proxy.go b/pkg/ccl/cliccl/mt_proxy.go index c6d905b77589..8da06689c578 100644 --- a/pkg/ccl/cliccl/mt_proxy.go +++ b/pkg/ccl/cliccl/mt_proxy.go @@ -56,9 +56,20 @@ func runStartSQLProxy(cmd *cobra.Command, args []string) (returnErr error) { log.Infof(ctx, "New proxy with opts: %+v", proxyContext) - proxyLn, err := net.Listen("tcp", proxyContext.ListenAddr) - if err != nil { - return err + var proxyLn net.Listener + if proxyContext.ListenAddr != "" { + proxyLn, err = net.Listen("tcp", proxyContext.ListenAddr) + if err != nil { + return err + } + } + + var proxyProtocolLn net.Listener + if proxyContext.ProxyProtocolListenAddr != "" { + proxyProtocolLn, err = net.Listen("tcp", proxyContext.ProxyProtocolListenAddr) + if err != nil { + return err + } } metricsLn, err := net.Listen("tcp", proxyContext.MetricsAddress) @@ -84,15 +95,14 @@ func runStartSQLProxy(cmd *cobra.Command, args []string) (returnErr error) { } if err := stopper.RunAsyncTask(ctx, "serve-proxy", func(ctx context.Context) { - log.Infof(ctx, "proxy server listening at %s", proxyLn.Addr()) - if err := server.Serve(ctx, proxyLn); err != nil { + if err := server.ServeSQL(ctx, proxyLn, proxyProtocolLn); err != nil { errChan <- err } }); err != nil { return err } - return waitForSignals(ctx, server, stopper, proxyLn, errChan) + return waitForSignals(ctx, server, stopper, proxyLn, proxyProtocolLn, errChan) } func initLogging(cmd *cobra.Command) (ctx context.Context, stopper *stop.Stopper, err error) { @@ -110,6 +120,7 @@ func waitForSignals( server *sqlproxyccl.Server, stopper *stop.Stopper, proxyLn net.Listener, + proxyProtocolLn net.Listener, errChan chan error, ) (returnErr error) { // Need to alias the signals if this has to run on non-unix OSes too. @@ -139,7 +150,12 @@ func waitForSignals( // waiting for "shutdownConnectionTimeout" to elapse after which // open TCP connections will be forcefully closed so the server can stop log.Infof(ctx, "stopping tcp listener") - _ = proxyLn.Close() + if proxyLn != nil { + _ = proxyLn.Close() + } + if proxyProtocolLn != nil { + _ = proxyProtocolLn.Close() + } select { case <-server.AwaitNoConnections(ctx): case <-time.After(shutdownConnectionTimeout): diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go index 423e9404f9dc..30d0aedabd5e 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -73,6 +73,10 @@ type ProxyOptions struct { Denylist string // ListenAddr is the listen address for incoming connections. ListenAddr string + // ProxyProtocolListenAddr is the optional listen address for incoming + // connections for which it will be enforced that the connections have proxy + // headers set. + ProxyProtocolListenAddr string // ListenCert is the file containing PEM-encoded x509 certificate for listen // address. Set to "*" to auto-generate self-signed cert. ListenCert string @@ -113,8 +117,10 @@ type ProxyOptions struct { DisableConnectionRebalancing bool // RequireProxyProtocol changes the server's behavior to support the PROXY // protocol (SQL=required, HTTP=best-effort). With this set to true, the - // PROXY info from upstream will be trusted on both HTTP and SQL, if the - // headers are allowed. + // PROXY info from upstream will be trusted on both HTTP and SQL (on the + // ListenAddr port), if the headers are allowed. The ProxyProtocolListenAddr + // port, if specified, will require the proxy protocol regardless of + // RequireProxyProtocol. RequireProxyProtocol bool // testingKnobs are knobs used for testing. diff --git a/pkg/ccl/sqlproxyccl/proxy_handler_test.go b/pkg/ccl/sqlproxyccl/proxy_handler_test.go index 205eff0afebe..7ff93a383a4b 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler_test.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler_test.go @@ -65,6 +65,12 @@ const backendError = "Backend error!" // the test directory server. const notFoundTenantID = 99 +type serverAddresses struct { + listenAddr string + proxyProtocolListenAddr string + httpAddr string +} + func TestProxyHandler_ValidateConnection(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) @@ -89,7 +95,7 @@ func TestProxyHandler_ValidateConnection(t *testing.T) { options := &ProxyOptions{} options.testingKnobs.directoryServer = tds - s, _, _ := newSecureProxyServer(ctx, t, stop, options) + s, _ := newSecureProxyServer(ctx, t, stop, options) t.Run("not found/no cluster name", func(t *testing.T) { err := s.handler.validateConnection(ctx, invalidTenantID, "") @@ -150,7 +156,7 @@ func TestProxyProtocol(t *testing.T) { sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) var validateFn func(h *proxyproto.Header) error - withProxyProtocol := func(p bool) (server *Server, addr, httpAddr string) { + withProxyProtocol := func(p bool) (server *Server, addrs *serverAddresses) { options := &ProxyOptions{ RoutingRule: ts.AdvSQLAddr(), SkipVerify: true, @@ -208,8 +214,6 @@ func TestProxyProtocol(t *testing.T) { } } - s, sqlAddr, httpAddr := withProxyProtocol(true) - defer testutils.TestingHook(&validateFn, func(h *proxyproto.Header) error { if h.SourceAddr.String() != "10.20.30.40:4242" { return errors.Newf("got source addr %s, expected 10.20.30.40:4242", h.SourceAddr) @@ -217,25 +221,71 @@ func TestProxyProtocol(t *testing.T) { return nil })() - // Test SQL. Only request with PROXY should go through. - url := fmt.Sprintf("postgres://bob:builder@%s/tenant-cluster-42.defaultdb?sslmode=require", sqlAddr) - te.TestConnectWithPGConfig( - ctx, t, url, - func(c *pgx.ConnConfig) { - c.DialFunc = proxyDialer - }, - func(conn *pgx.Conn) { - require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) - require.NoError(t, runTestQuery(ctx, conn)) - }, - ) - _ = te.TestConnectErr(ctx, t, url, codeClientReadFailed, "tls error") + testSQLNoRequiredProxyProtocol := func(s *Server, addr string) { + url := fmt.Sprintf("postgres://bob:builder@%s/tenant-cluster-42.defaultdb?sslmode=require", addr) + // No proxy protocol. + te.TestConnect(ctx, t, url, + func(conn *pgx.Conn) { + t.Log("B") + require.NotZero(t, s.metrics.CurConnCount.Value()) + require.NoError(t, runTestQuery(ctx, conn)) + }, + ) + // Proxy protocol. + _ = te.TestConnectErrWithPGConfig( + ctx, t, url, + func(c *pgx.ConnConfig) { + c.DialFunc = proxyDialer + }, codeClientReadFailed, "tls error", + ) + } + + testSQLRequiredProxyProtocol := func(s *Server, addr string) { + url := fmt.Sprintf("postgres://bob:builder@%s/tenant-cluster-42.defaultdb?sslmode=require", addr) + // No proxy protocol. + _ = te.TestConnectErr(ctx, t, url, codeClientReadFailed, "tls error") + // Proxy protocol. + te.TestConnectWithPGConfig( + ctx, t, url, + func(c *pgx.ConnConfig) { + c.DialFunc = proxyDialer + }, + func(conn *pgx.Conn) { + require.NotZero(t, s.metrics.CurConnCount.Value()) + require.NoError(t, runTestQuery(ctx, conn)) + }, + ) + } - // Test HTTP. Should support with or without PROXY. - client := http.Client{Timeout: timeout} - makeHttpReq(t, &client, httpAddr, true) - proxyClient := http.Client{Transport: &http.Transport{DialContext: proxyDialer}} - makeHttpReq(t, &proxyClient, httpAddr, true) + t.Run("server doesn't require proxy protocol", func(t *testing.T) { + s, addrs := withProxyProtocol(false) + // Test SQL on the default listener. Both should go through. + testSQLNoRequiredProxyProtocol(s, addrs.listenAddr) + // Test SQL on the proxy protocol listener. Only request with PROXY should go + // through. + testSQLRequiredProxyProtocol(s, addrs.proxyProtocolListenAddr) + + // Test HTTP. Shouldn't support PROXY. + client := http.Client{Timeout: timeout} + makeHttpReq(t, &client, addrs.httpAddr, true) + proxyClient := http.Client{Transport: &http.Transport{DialContext: proxyDialer}} + makeHttpReq(t, &proxyClient, addrs.httpAddr, false) + }) + + t.Run("server requires proxy protocol", func(t *testing.T) { + s, addrs := withProxyProtocol(true) + // Test SQL on the default listener. Both should go through. + testSQLRequiredProxyProtocol(s, addrs.listenAddr) + // Test SQL on the proxy protocol listener. Only request with PROXY should go + // through. + testSQLRequiredProxyProtocol(s, addrs.proxyProtocolListenAddr) + + // Test HTTP. Should support with or without PROXY. + client := http.Client{Timeout: timeout} + makeHttpReq(t, &client, addrs.httpAddr, true) + proxyClient := http.Client{Transport: &http.Transport{DialContext: proxyDialer}} + makeHttpReq(t, &proxyClient, addrs.httpAddr, true) + }) } func TestPrivateEndpointsACL(t *testing.T) { @@ -297,7 +347,7 @@ func TestPrivateEndpointsACL(t *testing.T) { PollConfigInterval: 10 * time.Millisecond, } options.testingKnobs.directoryServer = tds - s, sqlAddr, _ := newSecureProxyServer(ctx, t, sql.Stopper(), options) + s, addrs := newSecureProxyServer(ctx, t, sql.Stopper(), options) timeout := 3 * time.Second proxyDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { @@ -334,7 +384,7 @@ func TestPrivateEndpointsACL(t *testing.T) { } t.Run("private connection allowed", func(t *testing.T) { - url := fmt.Sprintf("postgres://bob:builder@%s/my-tenant-10.defaultdb?sslmode=require", sqlAddr) + url := fmt.Sprintf("postgres://bob:builder@%s/my-tenant-10.defaultdb?sslmode=require", addrs.listenAddr) te.TestConnectWithPGConfig( ctx, t, url, func(c *pgx.ConnConfig) { @@ -384,7 +434,7 @@ func TestPrivateEndpointsACL(t *testing.T) { }) t.Run("private connection disallowed on another tenant", func(t *testing.T) { - url := fmt.Sprintf("postgres://bob:builder@%s/other-tenant-20.defaultdb?sslmode=require", sqlAddr) + url := fmt.Sprintf("postgres://bob:builder@%s/other-tenant-20.defaultdb?sslmode=require", addrs.listenAddr) _ = te.TestConnectErrWithPGConfig( ctx, t, url, func(c *pgx.ConnConfig) { @@ -396,7 +446,7 @@ func TestPrivateEndpointsACL(t *testing.T) { }) t.Run("private connection disallowed on public tenant", func(t *testing.T) { - url := fmt.Sprintf("postgres://bob:builder@%s/public-tenant-30.defaultdb?sslmode=require", sqlAddr) + url := fmt.Sprintf("postgres://bob:builder@%s/public-tenant-30.defaultdb?sslmode=require", addrs.listenAddr) _ = te.TestConnectErrWithPGConfig( ctx, t, url, func(c *pgx.ConnConfig) { @@ -466,10 +516,10 @@ func TestAllowedCIDRRangesACL(t *testing.T) { PollConfigInterval: 10 * time.Millisecond, } options.testingKnobs.directoryServer = tds - s, sqlAddr, _ := newSecureProxyServer(ctx, t, sql.Stopper(), options) + s, addrs := newSecureProxyServer(ctx, t, sql.Stopper(), options) t.Run("public connection allowed", func(t *testing.T) { - url := fmt.Sprintf("postgres://bob:builder@%s/my-tenant-10.defaultdb?sslmode=require", sqlAddr) + url := fmt.Sprintf("postgres://bob:builder@%s/my-tenant-10.defaultdb?sslmode=require", addrs.listenAddr) te.TestConnect(ctx, t, url, func(conn *pgx.Conn) { // Initial connection. require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) @@ -512,12 +562,12 @@ func TestAllowedCIDRRangesACL(t *testing.T) { }) t.Run("public connection disallowed on another tenant", func(t *testing.T) { - url := fmt.Sprintf("postgres://bob:builder@%s/other-tenant-20.defaultdb?sslmode=require", sqlAddr) + url := fmt.Sprintf("postgres://bob:builder@%s/other-tenant-20.defaultdb?sslmode=require", addrs.listenAddr) _ = te.TestConnectErr(ctx, t, url, codeProxyRefusedConnection, "connection refused") }) t.Run("public connection disallowed on private tenant", func(t *testing.T) { - url := fmt.Sprintf("postgres://bob:builder@%s/private-tenant-30.defaultdb?sslmode=require", sqlAddr) + url := fmt.Sprintf("postgres://bob:builder@%s/private-tenant-30.defaultdb?sslmode=require", addrs.listenAddr) _ = te.TestConnectErr(ctx, t, url, codeProxyRefusedConnection, "connection refused") }) } @@ -539,11 +589,11 @@ func TestLongDBName(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - s, addr, _ := newSecureProxyServer( + s, addrs := newSecureProxyServer( ctx, t, stopper, &ProxyOptions{RoutingRule: "127.0.0.1:26257"}) longDB := strings.Repeat("x", 70) // 63 is limit - pgurl := fmt.Sprintf("postgres://unused:unused@%s/%s?options=--cluster=tenant-cluster-28&sslmode=require", addr, longDB) + pgurl := fmt.Sprintf("postgres://unused:unused@%s/%s?options=--cluster=tenant-cluster-28&sslmode=require", addrs.listenAddr, longDB) _ = te.TestConnectErr(ctx, t, pgurl, codeParamsRoutingFailed, "boom") require.Equal(t, int64(1), s.metrics.RoutingErrCount.Count()) } @@ -565,7 +615,7 @@ func TestBackendDownRetry(t *testing.T) { // Set RefreshDelay to -1 so that we could simulate a ListPod call under // the hood, which then triggers an EnsurePod again. opts.testingKnobs.dirOpts = []tenant.DirOption{tenant.RefreshDelay(-1)} - server, addr, _ := newSecureProxyServer(ctx, t, stopper, opts) + server, addrs := newSecureProxyServer(ctx, t, stopper, opts) directoryServer := mustGetTestSimpleDirectoryServer(t, server.handler) callCount := 0 @@ -581,7 +631,7 @@ func TestBackendDownRetry(t *testing.T) { })() // Valid connection, but no backend server running. - pgurl := fmt.Sprintf("postgres://unused:unused@%s/db?options=--cluster=tenant-cluster-28&sslmode=require", addr) + pgurl := fmt.Sprintf("postgres://unused:unused@%s/db?options=--cluster=tenant-cluster-28&sslmode=require", addrs.listenAddr) _ = te.TestConnectErr(ctx, t, pgurl, codeParamsRoutingFailed, "cluster tenant-cluster-28 not found") require.Equal(t, 3, callCount) } @@ -596,12 +646,12 @@ func TestFailedConnection(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - s, proxyAddr, _ := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{RoutingRule: "undialable%$!@$"}) + s, addrs := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{RoutingRule: "undialable%$!@$"}) // TODO(asubiotto): consider using datadriven for these, especially if the // proxy becomes more complex. - _, p, err := addr.SplitHostPort(proxyAddr, "") + _, p, err := addr.SplitHostPort(addrs.listenAddr, "") require.NoError(t, err) u := fmt.Sprintf("postgres://unused:unused@localhost:%s/", p) @@ -656,9 +706,9 @@ func TestUnexpectedError(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - _, addr, _ := newProxyServer(ctx, t, stopper, &ProxyOptions{}) + _, addrs := newProxyServer(ctx, t, stopper, &ProxyOptions{}) - u := fmt.Sprintf("postgres://root:admin@%s/?sslmode=disable&connect_timeout=5", addr) + u := fmt.Sprintf("postgres://root:admin@%s/?sslmode=disable&connect_timeout=5", addrs.listenAddr) // Time how long it takes for pgx.Connect to return. If the proxy handles // errors appropriately, pgx.Connect should return near immediately @@ -695,10 +745,10 @@ func TestProxyAgainstSecureCRDB(t *testing.T) { sqlDB := sqlutils.MakeSQLRunner(db) sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) - s, addr, _ := newSecureProxyServer( + s, addrs := newSecureProxyServer( ctx, t, sql.Stopper(), &ProxyOptions{RoutingRule: ts.AdvSQLAddr(), SkipVerify: true}, ) - _, port, err := net.SplitHostPort(addr) + _, port, err := net.SplitHostPort(addrs.listenAddr) require.NoError(t, err) for _, tc := range []struct { @@ -711,7 +761,7 @@ func TestProxyAgainstSecureCRDB(t *testing.T) { name: "failed_SASL_auth_1", url: fmt.Sprintf( "postgres://bob:wrong@%s/tenant-cluster-28.defaultdb?sslmode=require", - addr, + addrs.listenAddr, ), expErr: "failed SASL auth", }, @@ -719,7 +769,7 @@ func TestProxyAgainstSecureCRDB(t *testing.T) { name: "failed_SASL_auth_2", url: fmt.Sprintf( "postgres://bob@%s/tenant-cluster-28.defaultdb?sslmode=require", - addr, + addrs.listenAddr, ), expErr: "failed SASL auth", }, @@ -744,7 +794,7 @@ func TestProxyAgainstSecureCRDB(t *testing.T) { name: "database_provides_tenant_ID", url: fmt.Sprintf( "postgres://bob:builder@%s/tenant-cluster-28.defaultdb?sslmode=require", - addr, + addrs.listenAddr, ), }, { @@ -807,12 +857,12 @@ func TestProxyTLSConf(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - _, addr, _ := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{ + _, addrs := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{ Insecure: true, RoutingRule: "127.0.0.1:26257", }) - pgurl := fmt.Sprintf("postgres://unused:unused@%s/%s?options=--cluster=tenant-cluster-28&sslmode=require", addr, "defaultdb") + pgurl := fmt.Sprintf("postgres://unused:unused@%s/%s?options=--cluster=tenant-cluster-28&sslmode=require", addrs.listenAddr, "defaultdb") _ = te.TestConnectErr(ctx, t, pgurl, codeParamsRoutingFailed, "boom") }) @@ -830,13 +880,13 @@ func TestProxyTLSConf(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - _, addr, _ := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{ + _, addrs := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{ Insecure: false, SkipVerify: true, RoutingRule: "127.0.0.1:26257", }) - pgurl := fmt.Sprintf("postgres://unused:unused@%s/%s?options=--cluster=tenant-cluster-28&sslmode=require", addr, "defaultdb") + pgurl := fmt.Sprintf("postgres://unused:unused@%s/%s?options=--cluster=tenant-cluster-28&sslmode=require", addrs.listenAddr, "defaultdb") _ = te.TestConnectErr(ctx, t, pgurl, codeParamsRoutingFailed, "boom") }) @@ -858,13 +908,13 @@ func TestProxyTLSConf(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - _, addr, _ := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{ + _, addrs := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{ Insecure: false, SkipVerify: false, RoutingRule: "127.0.0.1:26257", }) - pgurl := fmt.Sprintf("postgres://unused:unused@%s/%s?options=--cluster=tenant-cluster-28&sslmode=require", addr, "defaultdb") + pgurl := fmt.Sprintf("postgres://unused:unused@%s/%s?options=--cluster=tenant-cluster-28&sslmode=require", addrs.listenAddr, "defaultdb") _ = te.TestConnectErr(ctx, t, pgurl, codeParamsRoutingFailed, "boom") }) @@ -903,11 +953,11 @@ func TestProxyTLSClose(t *testing.T) { return originalFrontendAdmit(conn, incomingTLSConfig) })() - s, addr, _ := newSecureProxyServer( + s, addrs := newSecureProxyServer( ctx, t, sql.Stopper(), &ProxyOptions{RoutingRule: ts.AdvSQLAddr(), SkipVerify: true}, ) - url := fmt.Sprintf("postgres://bob:builder@%s/tenant-cluster-28.defaultdb?sslmode=require", addr) + url := fmt.Sprintf("postgres://bob:builder@%s/tenant-cluster-28.defaultdb?sslmode=require", addrs.listenAddr) conn, err := pgx.Connect(ctx, url) require.NoError(t, err) @@ -977,9 +1027,9 @@ func TestProxyModifyRequestParams(t *testing.T) { return originalBackendDial(ctx, msg, ts.AdvSQLAddr(), proxyOutgoingTLSConfig) })() - s, proxyAddr, _ := newSecureProxyServer(ctx, t, sql.Stopper(), &ProxyOptions{}) + s, addrs := newSecureProxyServer(ctx, t, sql.Stopper(), &ProxyOptions{}) - u := fmt.Sprintf("postgres://bogususer:foo123@%s/?sslmode=require&authToken=abc123&options=--cluster=tenant-cluster-28&sslmode=require", proxyAddr) + u := fmt.Sprintf("postgres://bogususer:foo123@%s/?sslmode=require&authToken=abc123&options=--cluster=tenant-cluster-28&sslmode=require", addrs.listenAddr) te.TestConnect(ctx, t, u, func(conn *pgx.Conn) { require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) require.NoError(t, runTestQuery(ctx, conn)) @@ -1007,14 +1057,14 @@ func TestInsecureProxy(t *testing.T) { sqlDB := sqlutils.MakeSQLRunner(db) sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) - s, addr, _ := newProxyServer( + s, addrs := newProxyServer( ctx, t, sql.Stopper(), &ProxyOptions{RoutingRule: ts.AdvSQLAddr(), SkipVerify: true}, ) - url := fmt.Sprintf("postgres://bob:wrong@%s?sslmode=disable&options=--cluster=tenant-cluster-28&sslmode=require", addr) + url := fmt.Sprintf("postgres://bob:wrong@%s?sslmode=disable&options=--cluster=tenant-cluster-28&sslmode=require", addrs.listenAddr) _ = te.TestConnectErr(ctx, t, url, 0, "failed SASL auth") - url = fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=tenant-cluster-28&sslmode=require", addr) + url = fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=tenant-cluster-28&sslmode=require", addrs.listenAddr) te.TestConnect(ctx, t, url, func(conn *pgx.Conn) { require.NoError(t, runTestQuery(ctx, conn)) }) @@ -1048,9 +1098,9 @@ func TestErroneousFrontend(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - _, addr, _ := newProxyServer(ctx, t, stopper, &ProxyOptions{}) + _, addrs := newProxyServer(ctx, t, stopper, &ProxyOptions{}) - url := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=tenant-cluster-28&sslmode=require", addr) + url := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=tenant-cluster-28&sslmode=require", addrs.listenAddr) // Generic message here as the Frontend's error is not codeError and // by default we don't pass back error's text. The startup message doesn't @@ -1080,9 +1130,9 @@ func TestErrorHint(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - _, addr, _ := newProxyServer(ctx, t, stopper, &ProxyOptions{}) + _, addrs := newProxyServer(ctx, t, stopper, &ProxyOptions{}) - url := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=tenant-cluster-28&sslmode=require", addr) + url := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=tenant-cluster-28&sslmode=require", addrs.listenAddr) err := te.TestConnectErr(ctx, t, url, 0, "codeParamsRoutingFailed: Frontend error") pgErr := (*pgconn.PgError)(nil) @@ -1106,9 +1156,9 @@ func TestErroneousBackend(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - _, addr, _ := newProxyServer(ctx, t, stopper, &ProxyOptions{}) + _, addrs := newProxyServer(ctx, t, stopper, &ProxyOptions{}) - url := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=tenant-cluster-28&sslmode=require", addr) + url := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=tenant-cluster-28&sslmode=require", addrs.listenAddr) // Generic message here as the Backend's error is not codeError and // by default we don't pass back error's text. The startup message has @@ -1132,9 +1182,9 @@ func TestProxyRefuseConn(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - s, addr, _ := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{}) + s, addrs := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{}) - url := fmt.Sprintf("postgres://root:admin@%s?sslmode=require&options=--cluster=tenant-cluster-28&sslmode=require", addr) + url := fmt.Sprintf("postgres://root:admin@%s?sslmode=require&options=--cluster=tenant-cluster-28&sslmode=require", addrs.listenAddr) _ = te.TestConnectErr(ctx, t, url, codeProxyRefusedConnection, "too many attempts") require.Equal(t, int64(1), s.metrics.RefusedConnCount.Count()) require.Equal(t, int64(0), s.metrics.SuccessfulConnCount.Count()) @@ -1153,7 +1203,7 @@ func TestProxyHandler_handle(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - proxy, _, _ := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{}) + proxy, _ := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{}) p1, p2 := net.Pipe() require.NoError(t, p1.Close()) @@ -1230,10 +1280,10 @@ func TestDenylistUpdate(t *testing.T) { PollConfigInterval: 10 * time.Millisecond, } opts.testingKnobs.directoryServer = tds - s, addr, _ := newSecureProxyServer(ctx, t, sql.Stopper(), opts) + s, addrs := newSecureProxyServer(ctx, t, sql.Stopper(), opts) // Establish a connection. - url := fmt.Sprintf("postgres://testuser:foo123@%s/defaultdb?sslmode=require&options=--cluster=tenant-cluster-%s&sslmode=require", addr, tenantID) + url := fmt.Sprintf("postgres://testuser:foo123@%s/defaultdb?sslmode=require&options=--cluster=tenant-cluster-%s&sslmode=require", addrs.listenAddr, tenantID) db, err := gosql.Open("postgres", url) db.SetMaxOpenConns(1) defer db.Close() @@ -1313,11 +1363,11 @@ func TestDirectoryConnect(t *testing.T) { // Start the proxy server using the static directory server. opts := &ProxyOptions{SkipVerify: true} opts.testingKnobs.directoryServer = tds - _, addr, _ := newSecureProxyServer(ctx, t, s.Stopper(), opts) - connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addr, tenantID) + _, addrs := newSecureProxyServer(ctx, t, s.Stopper(), opts) + connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addrs.listenAddr, tenantID) t.Run("tenant not found", func(t *testing.T) { - url := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%d", addr, notFoundTenantID) + url := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%d", addrs.listenAddr, notFoundTenantID) _ = te.TestConnectErr(ctx, t, url, codeParamsRoutingFailed, "cluster tenant-cluster-99 not found") }) @@ -1430,8 +1480,8 @@ func TestConnectionRebalancingDisabled(t *testing.T) { opts := &ProxyOptions{SkipVerify: true, DisableConnectionRebalancing: true} opts.testingKnobs.directoryServer = tds - proxy, addr, _ := newSecureProxyServer(ctx, t, s.Stopper(), opts) - connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addr, tenantID) + proxy, addrs := newSecureProxyServer(ctx, t, s.Stopper(), opts) + connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addrs.listenAddr, tenantID) // Open 12 connections to the first pod. dist := map[string]int{} @@ -1540,10 +1590,10 @@ func TestCancelQuery(t *testing.T) { balancer.RebalanceRate(1), balancer.RebalanceDelay(-1), } - proxy, addr, httpAddr := newSecureProxyServer(ctx, t, s.Stopper(), opts) + proxy, addrs := newSecureProxyServer(ctx, t, s.Stopper(), opts) connectionString := fmt.Sprintf( "postgres://testuser:hunter2@%s/defaultdb?sslmode=require&sslrootcert=%s&options=--cluster=tenant-cluster-%s", - addr, datapathutils.TestDataPath(t, "testserver.crt"), tenantID, + addrs.listenAddr, datapathutils.TestDataPath(t, "testserver.crt"), tenantID, ) // Open a connection to the first pod. @@ -1614,7 +1664,7 @@ func TestCancelQuery(t *testing.T) { SecretKey: conn.PgConn().SecretKey(), ClientIP: net.IP{127, 0, 0, 1}, } - u := "http://" + httpAddr + "/_status/cancel/" + u := "http://" + addrs.httpAddr + "/_status/cancel/" reqBody := bytes.NewReader(cancelRequest.Encode()) client := http.Client{ Timeout: 10 * time.Second, @@ -1689,7 +1739,7 @@ func TestCancelQuery(t *testing.T) { SecretKey: conn.PgConn().SecretKey(), ClientIP: net.IP{210, 1, 2, 3}, } - u := "http://" + httpAddr + "/_status/cancel/" + u := "http://" + addrs.httpAddr + "/_status/cancel/" reqBody := bytes.NewReader(cancelRequest.Encode()) client := http.Client{ Timeout: 10 * time.Second, @@ -1753,7 +1803,7 @@ func TestCancelQuery(t *testing.T) { SecretKey: conn.PgConn().SecretKey() + 1, ClientIP: net.IP{127, 0, 0, 1}, } - u := "http://" + httpAddr + "/_status/cancel/" + u := "http://" + addrs.httpAddr + "/_status/cancel/" reqBody := bytes.NewReader(cancelRequest.Encode()) client := http.Client{ Timeout: 10 * time.Second, @@ -1811,8 +1861,8 @@ func TestPodWatcher(t *testing.T) { balancer.NoRebalanceLoop(), balancer.RebalanceRate(1.0), } - proxy, addr, _ := newSecureProxyServer(ctx, t, s.Stopper(), opts) - connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addr, tenantID) + proxy, addrs := newSecureProxyServer(ctx, t, s.Stopper(), opts) + connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addrs.listenAddr, tenantID) // Open 12 connections to it. The balancer should distribute the connections // evenly across 3 SQL pods (i.e. 4 connections each). @@ -1907,9 +1957,9 @@ func TestConnectionMigration(t *testing.T) { // loads. For this test, we will stub out lookupAddr in the connector. We // will alternate between tenant1 and tenant2, starting with tenant1. opts := &ProxyOptions{SkipVerify: true, RoutingRule: tenant1.SQLAddr()} - proxy, addr, _ := newSecureProxyServer(ctx, t, s.Stopper(), opts) + proxy, addrs := newSecureProxyServer(ctx, t, s.Stopper(), opts) - connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addr, tenantID) + connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addrs.listenAddr, tenantID) // validateMiscMetrics ensures that our invariant of // attempts = success + error_recoverable + error_fatal is valid, and all @@ -2296,10 +2346,10 @@ func TestAcceptedConnCountMetric(t *testing.T) { opts := &ProxyOptions{SkipVerify: true, DisableConnectionRebalancing: true} opts.testingKnobs.directoryServer = tds - proxy, addr, _ := newSecureProxyServer(ctx, t, s.Stopper(), opts) + proxy, addrs := newSecureProxyServer(ctx, t, s.Stopper(), opts) - goodConnStr := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addr, tenantID) - badConnStr := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=nocluster", addr) + goodConnStr := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addrs.listenAddr, tenantID) + badConnStr := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=nocluster", addrs.listenAddr) const ( numGood = 5 @@ -2334,7 +2384,7 @@ func TestAcceptedConnCountMetric(t *testing.T) { go func() { defer wg.Done() - conn, err := net.DialTimeout("tcp", addr, 3*time.Second) + conn, err := net.DialTimeout("tcp", addrs.listenAddr, 3*time.Second) defer func() { if conn != nil { _ = conn.Close() @@ -2387,8 +2437,8 @@ func TestCurConnCountMetric(t *testing.T) { opts := &ProxyOptions{SkipVerify: true, DisableConnectionRebalancing: true} opts.testingKnobs.directoryServer = tds - proxy, addr, _ := newSecureProxyServer(ctx, t, s.Stopper(), opts) - connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addr, tenantID) + proxy, addrs := newSecureProxyServer(ctx, t, s.Stopper(), opts) + connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addrs.listenAddr, tenantID) // Open 500 connections to the SQL pod. const numConns = 500 @@ -2829,7 +2879,7 @@ func (te *tester) TestConnectErrWithPGConfig( func newSecureProxyServer( ctx context.Context, t *testing.T, stopper *stop.Stopper, opts *ProxyOptions, -) (server *Server, addr, httpAddr string) { +) (server *Server, addrs *serverAddresses) { // Created via: const _ = ` openssl genrsa -out testdata/testserver.key 2048 @@ -2844,12 +2894,15 @@ openssl req -new -x509 -sha256 -key testdata/testserver.key -out testdata/testse func newProxyServer( ctx context.Context, t *testing.T, stopper *stop.Stopper, opts *ProxyOptions, -) (server *Server, addr, httpAddr string) { +) (server *Server, addrs *serverAddresses) { const listenAddress = "127.0.0.1:0" ctx, _ = stopper.WithCancelOnQuiesce(ctx) ln, err := net.Listen("tcp", listenAddress) require.NoError(t, err) stopper.AddCloser(stop.CloserFn(func() { _ = ln.Close() })) + proxyProtocolLn, err := net.Listen("tcp", listenAddress) + require.NoError(t, err) + stopper.AddCloser(stop.CloserFn(func() { _ = proxyProtocolLn.Close() })) httpLn, err := net.Listen("tcp", listenAddress) require.NoError(t, err) stopper.AddCloser(stop.CloserFn(func() { _ = httpLn.Close() })) @@ -2858,7 +2911,7 @@ func newProxyServer( require.NoError(t, err) err = server.Stopper.RunAsyncTask(ctx, "proxy-server-serve", func(ctx context.Context) { - _ = server.Serve(ctx, ln) + _ = server.ServeSQL(ctx, ln, proxyProtocolLn) }) require.NoError(t, err) err = server.Stopper.RunAsyncTask(ctx, "proxy-http-server-serve", func(ctx context.Context) { @@ -2866,7 +2919,11 @@ func newProxyServer( }) require.NoError(t, err) - return server, ln.Addr().String(), httpLn.Addr().String() + return server, &serverAddresses{ + listenAddr: ln.Addr().String(), + proxyProtocolListenAddr: proxyProtocolLn.Addr().String(), + httpAddr: httpLn.Addr().String(), + } } func runTestQuery(ctx context.Context, conn *pgx.Conn) error { diff --git a/pkg/ccl/sqlproxyccl/server.go b/pkg/ccl/sqlproxyccl/server.go index 279635754e83..08b5e9e5e941 100644 --- a/pkg/ccl/sqlproxyccl/server.go +++ b/pkg/ccl/sqlproxyccl/server.go @@ -213,22 +213,45 @@ func (s *Server) ServeHTTP(ctx context.Context, ln net.Listener) error { return nil } -// Serve serves a listener according to the Options given in NewServer(). +// Serve serves up to two listeners according to the Options given in +// NewServer(). +// +// If ln is not nil, a listener is served which does not require +// proxy protocol headers, unless RequireProxyProtocol is true. +// +// If proxyProtocolLn is not nil, a listener is served which requires proxy +// protocol headers. +// // Incoming client connections are taken through the Postgres handshake and // relayed to the configured backend server. -func (s *Server) Serve(ctx context.Context, ln net.Listener) error { - if s.handler.RequireProxyProtocol { - ln = &proxyproto.Listener{ - Listener: ln, - Policy: func(upstream net.Addr) (proxyproto.Policy, error) { - // REQUIRE enforces the connection to send a PROXY header. - // The connection will be rejected if one was not present. - return proxyproto.REQUIRE, nil - }, - ValidateHeader: s.handler.testingKnobs.validateProxyHeader, +func (s *Server) ServeSQL( + ctx context.Context, ln net.Listener, proxyProtocolLn net.Listener, +) error { + if ln != nil { + if s.handler.RequireProxyProtocol { + ln = s.requireProxyProtocolOnListener(ln) + } + log.Infof(ctx, "proxy server listening at %s", ln.Addr()) + if err := s.Stopper.RunAsyncTask(ctx, "listener-serve", func(ctx context.Context) { + _ = s.serve(ctx, ln) + }); err != nil { + return err } } + if proxyProtocolLn != nil { + proxyProtocolLn = s.requireProxyProtocolOnListener(proxyProtocolLn) + log.Infof(ctx, "proxy with required proxy headers server listening at %s", proxyProtocolLn.Addr()) + if err := s.Stopper.RunAsyncTask(ctx, "proxy-protocol-listener-serve", func(ctx context.Context) { + _ = s.serve(ctx, proxyProtocolLn) + }); err != nil { + return err + } + } + return nil +} +// serve is called by ServeSQL to serve a single listener. +func (s *Server) serve(ctx context.Context, ln net.Listener) error { err := s.Stopper.RunAsyncTask(ctx, "listen-quiesce", func(ctx context.Context) { <-s.Stopper.ShouldQuiesce() if err := ln.Close(); err != nil && !grpcutil.IsClosedConnection(err) { @@ -262,6 +285,18 @@ func (s *Server) Serve(ctx context.Context, ln net.Listener) error { } } +func (s *Server) requireProxyProtocolOnListener(ln net.Listener) net.Listener { + return &proxyproto.Listener{ + Listener: ln, + Policy: func(upstream net.Addr) (proxyproto.Policy, error) { + // REQUIRE enforces the connection to send a PROXY header. + // The connection will be rejected if one was not present. + return proxyproto.REQUIRE, nil + }, + ValidateHeader: s.handler.testingKnobs.validateProxyHeader, + } +} + // AwaitNoConnections returns a channel that is closed once the server has no open connections. // This is meant to be used after the server has stopped accepting new connections and we are // waiting to shutdown the server without inturrupting existing connections diff --git a/pkg/cli/cliflags/flags_mt.go b/pkg/cli/cliflags/flags_mt.go index 2590109a4e49..b9d52e6d2218 100644 --- a/pkg/cli/cliflags/flags_mt.go +++ b/pkg/cli/cliflags/flags_mt.go @@ -47,6 +47,11 @@ wait for the tenant id to be fully written to the file (with a newline character Description: "Listen address for incoming connections.", } + ProxyProtocolListenAddr = FlagInfo{ + Name: "proxy-protocol-listen-addr", + Description: "Listen address for incoming connections which require proxy protocol headers.", + } + ThrottleBaseDelay = FlagInfo{ Name: "throttle-base-delay", Description: "Initial value for the exponential backoff used to throttle connection attempts.", @@ -94,6 +99,9 @@ wait for the tenant id to be fully written to the file (with a newline character Description: "If true, proxy will not attempt to rebalance connections.", } + // TODO(joel): Remove this flag, and use --listen-addr for a non-proxy + // protocol listener, and use --proxy-protocol-listen-addr for a proxy + // protocol listener. RequireProxyProtocol = FlagInfo{ Name: "require-proxy-protocol", Description: `Requires PROXY protocol on the SQL listener. The HTTP