diff --git a/pkg/ccl/sqlproxyccl/BUILD.bazel b/pkg/ccl/sqlproxyccl/BUILD.bazel index 9c99cddfe9f9..3fc0f12cca71 100644 --- a/pkg/ccl/sqlproxyccl/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/BUILD.bazel @@ -98,6 +98,7 @@ go_test( "//pkg/testutils/skip", "//pkg/testutils/sqlutils", "//pkg/testutils/testcluster", + "//pkg/util/ctxgroup", "//pkg/util/leaktest", "//pkg/util/log", "//pkg/util/metric", diff --git a/pkg/ccl/sqlproxyccl/connector.go b/pkg/ccl/sqlproxyccl/connector.go index 19ead3a58382..c6657f4c50bf 100644 --- a/pkg/ccl/sqlproxyccl/connector.go +++ b/pkg/ccl/sqlproxyccl/connector.go @@ -120,8 +120,6 @@ func (c *connector) OpenTenantConnWithToken( // Since this method is only used during connection migration (i.e. proxy // is connecting to the SQL pod), we'll discard all of the messages, and // only return once we've seen a ReadyForQuery message. - // - // NOTE: This will need to be updated when we implement query cancellation. newBackendKeyData, err := readTokenAuthResult(serverConn) if err != nil { return nil, err diff --git a/pkg/ccl/sqlproxyccl/connector_test.go b/pkg/ccl/sqlproxyccl/connector_test.go index bb8de51a6b77..0d8eb941cdef 100644 --- a/pkg/ccl/sqlproxyccl/connector_test.go +++ b/pkg/ccl/sqlproxyccl/connector_test.go @@ -89,8 +89,8 @@ func TestConnector_OpenTenantConnWithToken(t *testing.T) { defer testutils.TestingHook( &readTokenAuthResult, - func(serverConn net.Conn) error { - return errors.New("bar") + func(serverConn net.Conn) (*pgproto3.BackendKeyData, error) { + return nil, errors.New("bar") }, )() @@ -114,9 +114,18 @@ func TestConnector_OpenTenantConnWithToken(t *testing.T) { StartupMsg: &pgproto3.StartupMessage{ Parameters: make(map[string]string), }, + CancelInfo: makeCancelInfo( + &net.TCPAddr{IP: net.IP{4, 5, 6, 7}}, + &net.TCPAddr{IP: net.IP{11, 22, 33, 44}}, + ), + } + pipeConn, _ := net.Pipe() + defer pipeConn.Close() + conn := &fakeTCPConn{ + Conn: pipeConn, + remoteAddr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}}, + localAddr: &net.TCPAddr{IP: net.IP{4, 5, 6, 7}}, } - conn, _ := net.Pipe() - defer conn.Close() var openCalled bool c.testingKnobs.dialTenantCluster = func( @@ -134,12 +143,16 @@ func TestConnector_OpenTenantConnWithToken(t *testing.T) { } var authCalled bool + crdbBackendKeyData := &pgproto3.BackendKeyData{ + ProcessID: 4, + SecretKey: 5, + } defer testutils.TestingHook( &readTokenAuthResult, - func(serverConn net.Conn) error { + func(serverConn net.Conn) (*pgproto3.BackendKeyData, error) { authCalled = true require.Equal(t, conn, serverConn) - return nil + return crdbBackendKeyData, nil }, )() @@ -148,6 +161,7 @@ func TestConnector_OpenTenantConnWithToken(t *testing.T) { require.True(t, authCalled) require.NoError(t, err) require.Equal(t, conn, crdbConn) + require.Equal(t, crdbBackendKeyData, c.CancelInfo.mu.origBackendKeyData) // Ensure that token is deleted. _, ok := c.StartupMsg.Parameters[sessionRevivalTokenStartupParam] diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go index 550a04be4162..4676ba7e0abf 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -120,6 +120,8 @@ type ProxyOptions struct { // balancerOpts is used to customize the balancer created by the proxy. balancerOpts []balancer.Option + + httpCancelErrHandler func(err error) } } @@ -491,7 +493,6 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn net.Conn) // handleCancelRequest handles a pgwire query cancel request by either // forwarding it to a SQL node or to another proxy. func (handler *proxyHandler) handleCancelRequest(cr *proxyCancelRequest, allowForward bool) error { - const timeout = 2 * time.Second if ci, ok := handler.cancelInfoMap.getCancelInfo(cr.SecretKey); ok { return ci.sendCancelToBackend(cr.ClientIP) } @@ -499,13 +500,18 @@ func (handler *proxyHandler) handleCancelRequest(cr *proxyCancelRequest, allowFo if !allowForward { return nil } - u := "https://" + cr.ProxyIP.String() + ":8080/_status/cancel" + u := "http://" + cr.ProxyIP.String() + ":8080/_status/cancel/" reqBody := bytes.NewReader(cr.Encode()) + return forwardCancelRequest(u, reqBody) +} + +var forwardCancelRequest = func(url string, reqBody *bytes.Reader) error { + const timeout = 2 * time.Second client := http.Client{ Timeout: timeout, } - if _, err := client.Post(u, "application/octet-stream", reqBody); err != nil { + if _, err := client.Post(url, "application/octet-stream", reqBody); err != nil { return err } return nil diff --git a/pkg/ccl/sqlproxyccl/proxy_handler_test.go b/pkg/ccl/sqlproxyccl/proxy_handler_test.go index 1d251a5b7089..c92be81fbe4f 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler_test.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler_test.go @@ -9,12 +9,15 @@ package sqlproxyccl import ( + "bytes" "context" "crypto/tls" gosql "database/sql" "fmt" + "io" "io/ioutil" "net" + "net/http" "os" "sort" "strings" @@ -83,7 +86,7 @@ func TestLongDBName(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - s, addr := newSecureProxyServer( + s, addr, _ := newSecureProxyServer( ctx, t, stopper, &ProxyOptions{RoutingRule: "127.0.0.1:26257"}) longDB := strings.Repeat("x", 70) // 63 is limit @@ -109,7 +112,7 @@ func TestBackendDownRetry(t *testing.T) { // Set RefreshDelay to -1 so that we could simulate a ListPod call under // the hood, which then triggers an EnsurePod again. opts.testingKnobs.dirOpts = []tenant.DirOption{tenant.RefreshDelay(-1)} - server, addr := newSecureProxyServer(ctx, t, stopper, opts) + server, addr, _ := newSecureProxyServer(ctx, t, stopper, opts) directoryServer := mustGetTestSimpleDirectoryServer(t, server.handler) callCount := 0 @@ -140,7 +143,7 @@ func TestFailedConnection(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - s, proxyAddr := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{RoutingRule: "undialable%$!@$"}) + s, proxyAddr, _ := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{RoutingRule: "undialable%$!@$"}) // TODO(asubiotto): consider using datadriven for these, especially if the // proxy becomes more complex. @@ -200,7 +203,7 @@ func TestUnexpectedError(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - _, addr := newProxyServer(ctx, t, stopper, &ProxyOptions{}) + _, addr, _ := newProxyServer(ctx, t, stopper, &ProxyOptions{}) u := fmt.Sprintf("postgres://root:admin@%s/?sslmode=disable&connect_timeout=5", addr) @@ -242,7 +245,7 @@ func TestProxyAgainstSecureCRDB(t *testing.T) { sqlDB := sqlutils.MakeSQLRunner(db) sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) - s, addr := newSecureProxyServer( + s, addr, _ := newSecureProxyServer( ctx, t, sql.Stopper(), &ProxyOptions{RoutingRule: sql.ServingSQLAddr(), SkipVerify: true}, ) _, port, err := net.SplitHostPort(addr) @@ -314,7 +317,7 @@ func TestProxyTLSConf(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - _, addr := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{ + _, addr, _ := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{ Insecure: true, RoutingRule: "127.0.0.1:26257", }) @@ -337,7 +340,7 @@ func TestProxyTLSConf(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - _, addr := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{ + _, addr, _ := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{ Insecure: false, SkipVerify: true, RoutingRule: "127.0.0.1:26257", @@ -365,7 +368,7 @@ func TestProxyTLSConf(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - _, addr := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{ + _, addr, _ := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{ Insecure: false, SkipVerify: false, RoutingRule: "127.0.0.1:26257", @@ -414,7 +417,7 @@ func TestProxyTLSClose(t *testing.T) { return originalFrontendAdmit(conn, incomingTLSConfig) })() - s, addr := newSecureProxyServer( + s, addr, _ := newSecureProxyServer( ctx, t, sql.Stopper(), &ProxyOptions{RoutingRule: sql.ServingSQLAddr(), SkipVerify: true}, ) @@ -491,7 +494,7 @@ func TestProxyModifyRequestParams(t *testing.T) { return originalBackendDial(msg, sql.ServingSQLAddr(), proxyOutgoingTLSConfig) })() - s, proxyAddr := newSecureProxyServer(ctx, t, sql.Stopper(), &ProxyOptions{}) + s, proxyAddr, _ := newSecureProxyServer(ctx, t, sql.Stopper(), &ProxyOptions{}) u := fmt.Sprintf("postgres://bogususer:foo123@%s/?sslmode=require&authToken=abc123&options=--cluster=tenant-cluster-28&sslmode=require", proxyAddr) te.TestConnect(ctx, t, u, func(conn *pgx.Conn) { @@ -526,7 +529,7 @@ func TestInsecureProxy(t *testing.T) { sqlDB := sqlutils.MakeSQLRunner(db) sqlDB.Exec(t, `CREATE USER bob WITH PASSWORD 'builder'`) - s, addr := newProxyServer( + s, addr, _ := newProxyServer( ctx, t, sql.Stopper(), &ProxyOptions{RoutingRule: sql.ServingSQLAddr(), SkipVerify: true}, ) @@ -558,7 +561,7 @@ func TestErroneousFrontend(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - _, addr := newProxyServer(ctx, t, stopper, &ProxyOptions{}) + _, addr, _ := newProxyServer(ctx, t, stopper, &ProxyOptions{}) url := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=tenant-cluster-28&sslmode=require", addr) @@ -584,7 +587,7 @@ func TestErroneousBackend(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - _, addr := newProxyServer(ctx, t, stopper, &ProxyOptions{}) + _, addr, _ := newProxyServer(ctx, t, stopper, &ProxyOptions{}) url := fmt.Sprintf("postgres://bob:builder@%s/?sslmode=disable&options=--cluster=tenant-cluster-28&sslmode=require", addr) @@ -610,7 +613,7 @@ func TestProxyRefuseConn(t *testing.T) { stopper := stop.NewStopper() defer stopper.Stop(ctx) - s, addr := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{}) + s, addr, _ := newSecureProxyServer(ctx, t, stopper, &ProxyOptions{}) url := fmt.Sprintf("postgres://root:admin@%s?sslmode=require&options=--cluster=tenant-cluster-28&sslmode=require", addr) te.TestConnectErr(ctx, t, url, codeProxyRefusedConnection, "too many attempts") @@ -678,7 +681,7 @@ func TestDenylistUpdate(t *testing.T) { return originalBackendDial(msg, sql.ServingSQLAddr(), proxyOutgoingTLSConfig) })() - s, addr := newSecureProxyServer(ctx, t, sql.Stopper(), &ProxyOptions{ + s, addr, _ := newSecureProxyServer(ctx, t, sql.Stopper(), &ProxyOptions{ Denylist: denyList.Name(), PollConfigInterval: 10 * time.Millisecond, }) @@ -739,7 +742,7 @@ func TestDirectoryConnect(t *testing.T) { DirectoryAddr: tdsAddr.String(), Insecure: true, } - _, addr := newProxyServer(ctx, t, srv.Stopper(), opts) + _, addr, _ := newProxyServer(ctx, t, srv.Stopper(), opts) t.Run("fallback when tenant not found", func(t *testing.T) { url := fmt.Sprintf( @@ -833,7 +836,7 @@ func TestConnectionRebalancingDisabled(t *testing.T) { // Start two SQL pods for the test tenant. const podCount = 2 tenantID := serverutils.TestTenantID() - tenants := startTestTenantPods(ctx, t, s, tenantID, podCount) + tenants := startTestTenantPods(ctx, t, s, tenantID, podCount, base.TestingKnobs{}) defer func() { for _, tenant := range tenants { tenant.Stopper().Stop(ctx) @@ -853,7 +856,7 @@ func TestConnectionRebalancingDisabled(t *testing.T) { opts := &ProxyOptions{SkipVerify: true, DisableConnectionRebalancing: true} opts.testingKnobs.directoryServer = tds - proxy, addr := newSecureProxyServer(ctx, t, s.Stopper(), opts) + proxy, addr, _ := newSecureProxyServer(ctx, t, s.Stopper(), opts) connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addr, tenantID) // Open 12 connections to the first pod. @@ -906,6 +909,236 @@ func TestConnectionRebalancingDisabled(t *testing.T) { require.Len(t, dist, 1) } +func TestCancelQuery(t *testing.T) { + defer leaktest.AfterTest(t)() + ctx := context.Background() + defer log.Scope(t).Close(t) + + // Start KV server, and enable session migration. + params, _ := tests.CreateTestServerParams() + s, mainDB, _ := serverutils.StartServer(t, params) + defer s.Stopper().Stop(ctx) + _, err := mainDB.Exec("ALTER TENANT ALL SET CLUSTER SETTING server.user_login.session_revival_token.enabled = true") + require.NoError(t, err) + + // Start two SQL pods for the test tenant. + const podCount = 2 + tenantID := serverutils.TestTenantID() + var cancelFn func(context.Context) error + tenantKnobs := base.TestingKnobs{} + tenantKnobs.SQLExecutor = &sql.ExecutorTestingKnobs{ + BeforeExecute: func(ctx context.Context, stmt string) { + if strings.Contains(stmt, "cancel_me") { + err := cancelFn(ctx) + assert.NoError(t, err) + } + }, + } + tenants := startTestTenantPods(ctx, t, s, tenantID, podCount, tenantKnobs) + defer func() { + for _, tenant := range tenants { + tenant.Stopper().Stop(ctx) + } + }() + + // Use a custom time source for testing. + t0 := time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC) + timeSource := timeutil.NewManualTime(t0) + + // Register one SQL pod in the directory server. + tds := tenantdirsvr.NewTestStaticDirectoryServer(s.Stopper(), timeSource) + tds.CreateTenant(tenantID, "tenant-cluster") + tds.AddPod(tenantID, &tenant.Pod{ + TenantID: tenantID.ToUint64(), + Addr: tenants[0].SQLAddr(), + State: tenant.RUNNING, + StateTimestamp: timeSource.Now(), + }) + require.NoError(t, tds.Start(ctx)) + + opts := &ProxyOptions{SkipVerify: true} + opts.testingKnobs.directoryServer = tds + var httpCancelErr error + opts.testingKnobs.httpCancelErrHandler = func(err error) { + httpCancelErr = err + } + opts.testingKnobs.balancerOpts = []balancer.Option{ + balancer.TimeSource(timeSource), + balancer.RebalanceRate(1), + balancer.RebalanceDelay(-1), + } + proxy, addr, httpAddr := newSecureProxyServer(ctx, t, s.Stopper(), opts) + connectionString := fmt.Sprintf( + "postgres://testuser:hunter2@%s/defaultdb?sslmode=require&sslrootcert=%s&options=--cluster=tenant-cluster-%s", + addr, testutils.TestDataPath(t, "testserver.crt"), tenantID, + ) + + // Open a connection to the first pod. + conn, err := pgx.Connect(ctx, connectionString) + require.NoError(t, err) + defer func() { _ = conn.Close(ctx) }() + + // Add a second SQL pod. + tds.AddPod(tenantID, &tenant.Pod{ + TenantID: tenantID.ToUint64(), + Addr: tenants[1].SQLAddr(), + State: tenant.RUNNING, + StateTimestamp: timeSource.Now(), + }) + + // Wait until the update gets propagated to the directory cache. + testutils.SucceedsSoon(t, func() error { + pods, err := proxy.handler.directoryCache.TryLookupTenantPods(ctx, tenantID) + if err != nil { + return err + } + if len(pods) != 2 { + return errors.Newf("expected 2 pods, but got %d", len(pods)) + } + return nil + }) + + t.Run("cancel over sql", func(t *testing.T) { + cancelFn = conn.PgConn().CancelRequest + var b bool + err = conn.QueryRow(ctx, "SELECT pg_sleep(5) AS cancel_me").Scan(&b) + require.Error(t, err) + require.Regexp(t, "query execution canceled", err.Error()) + }) + + t.Run("cancel over http", func(t *testing.T) { + cancelFn = func(ctx context.Context) error { + cancelRequest := proxyCancelRequest{ + ProxyIP: net.IP{}, + SecretKey: conn.PgConn().SecretKey(), + ClientIP: net.IP{127, 0, 0, 1}, + } + u := "http://" + httpAddr + "/_status/cancel/" + reqBody := bytes.NewReader(cancelRequest.Encode()) + client := http.Client{ + Timeout: 1 * time.Second, + } + resp, err := client.Post(u, "application/octet-stream", reqBody) + if err != nil { + return err + } + respBytes, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + assert.Equal(t, "OK", string(respBytes)) + return nil + } + var b bool + err = conn.QueryRow(ctx, "SELECT pg_sleep(5) AS cancel_me").Scan(&b) + require.Error(t, err) + require.Regexp(t, "query execution canceled", err.Error()) + }) + + t.Run("cancel after migrating a session", func(t *testing.T) { + cancelFn = conn.PgConn().CancelRequest + defer testutils.TestingHook(&defaultTransferTimeout, 3*time.Minute)() + origCancelInfo, found := proxy.handler.cancelInfoMap.getCancelInfo(conn.PgConn().SecretKey()) + require.True(t, found) + b := tds.DrainPod(tenantID, tenants[0].SQLAddr()) + require.True(t, b) + testutils.SucceedsSoon(t, func() error { + pods, err := proxy.handler.directoryCache.TryLookupTenantPods(ctx, tenantID) + if err != nil { + return err + } + for _, pod := range pods { + if pod.State == tenant.DRAINING { + return nil + } + } + return errors.New("expected DRAINING pod") + }) + origCancelInfo.mu.RLock() + origKey := origCancelInfo.mu.origBackendKeyData.SecretKey + origCancelInfo.mu.RUnlock() + // Advance the time so that rebalancing will occur. + timeSource.Advance(2 * time.Minute) + proxy.handler.balancer.RebalanceTenant(ctx, tenantID) + testutils.SucceedsSoon(t, func() error { + newCancelInfo, found := proxy.handler.cancelInfoMap.getCancelInfo(conn.PgConn().SecretKey()) + if !found { + return errors.New("expected to find cancel info") + } + newCancelInfo.mu.RLock() + newKey := newCancelInfo.mu.origBackendKeyData.SecretKey + newCancelInfo.mu.RUnlock() + if origKey == newKey { + return errors.Newf("expected %d to differ", origKey) + } + return nil + }) + + err = conn.QueryRow(ctx, "SELECT pg_sleep(5) AS cancel_me").Scan(&b) + require.Error(t, err) + require.Regexp(t, "query execution canceled", err.Error()) + }) + + t.Run("reject cancel from wrong client IP", func(t *testing.T) { + cancelRequest := proxyCancelRequest{ + ProxyIP: net.IP{}, + SecretKey: conn.PgConn().SecretKey(), + ClientIP: net.IP{127, 1, 2, 3}, + } + u := "http://" + httpAddr + "/_status/cancel/" + reqBody := bytes.NewReader(cancelRequest.Encode()) + client := http.Client{ + Timeout: 10 * time.Second, + } + resp, err := client.Post(u, "application/octet-stream", reqBody) + require.NoError(t, err) + respBytes, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, "OK", string(respBytes)) + require.Error(t, httpCancelErr) + require.Regexp(t, "mismatched client IP for cancel request", httpCancelErr.Error()) + }) + + t.Run("forward over http", func(t *testing.T) { + var forwardedTo string + var forwardedReq proxyCancelRequest + var wg sync.WaitGroup + wg.Add(1) + defer testutils.TestingHook(&forwardCancelRequest, func(url string, reqBody *bytes.Reader) error { + forwardedTo = url + var err error + reqBytes, err := ioutil.ReadAll(reqBody) + assert.NoError(t, err) + err = forwardedReq.Decode(reqBytes) + assert.NoError(t, err) + wg.Done() + return nil + })() + crdbRequest := &pgproto3.CancelRequest{ + ProcessID: 1, + SecretKey: 2, + } + buf := crdbRequest.Encode(nil /* buf */) + proxyAddr := conn.PgConn().Conn().RemoteAddr() + cancelConn, err := net.Dial(proxyAddr.Network(), proxyAddr.String()) + require.NoError(t, err) + defer cancelConn.Close() + + _, err = cancelConn.Write(buf) + require.NoError(t, err) + _, err = cancelConn.Read(buf) + require.ErrorIs(t, io.EOF, err) + wg.Wait() + require.Equal(t, "http://0.0.0.1:8080/_status/cancel/", forwardedTo) + expectedReq := proxyCancelRequest{ + ProxyIP: net.IP{0, 0, 0, 1}, + SecretKey: 2, + ClientIP: net.IP{127, 0, 0, 1}, + } + require.Equal(t, expectedReq, forwardedReq) + }) +} + func TestPodWatcher(t *testing.T) { defer leaktest.AfterTest(t)() ctx := context.Background() @@ -921,7 +1154,7 @@ func TestPodWatcher(t *testing.T) { // Start four SQL pods for the test tenant. const podCount = 4 tenantID := serverutils.TestTenantID() - tenants := startTestTenantPods(ctx, t, s, tenantID, podCount) + tenants := startTestTenantPods(ctx, t, s, tenantID, podCount, base.TestingKnobs{}) defer func() { for _, tenant := range tenants { tenant.Stopper().Stop(ctx) @@ -948,7 +1181,7 @@ func TestPodWatcher(t *testing.T) { balancer.NoRebalanceLoop(), balancer.RebalanceRate(1.0), } - proxy, addr := newSecureProxyServer(ctx, t, s.Stopper(), opts) + proxy, addr, _ := newSecureProxyServer(ctx, t, s.Stopper(), opts) connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addr, tenantID) // Open 12 connections to it. The balancer should distribute the connections @@ -1044,7 +1277,7 @@ func TestConnectionMigration(t *testing.T) { // loads. For this test, we will stub out lookupAddr in the connector. We // will alternate between tenant1 and tenant2, starting with tenant1. opts := &ProxyOptions{SkipVerify: true, RoutingRule: tenant1.SQLAddr()} - proxy, addr := newSecureProxyServer(ctx, t, s.Stopper(), opts) + proxy, addr, _ := newSecureProxyServer(ctx, t, s.Stopper(), opts) connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addr, tenantID) @@ -1388,7 +1621,7 @@ func TestCurConnCountMetric(t *testing.T) { // Start a single SQL pod. tenantID := serverutils.TestTenantID() - tenants := startTestTenantPods(ctx, t, s, tenantID, 1) + tenants := startTestTenantPods(ctx, t, s, tenantID, 1, base.TestingKnobs{}) defer func() { for _, tenant := range tenants { tenant.Stopper().Stop(ctx) @@ -1408,7 +1641,7 @@ func TestCurConnCountMetric(t *testing.T) { opts := &ProxyOptions{SkipVerify: true, DisableConnectionRebalancing: true} opts.testingKnobs.directoryServer = tds - proxy, addr := newSecureProxyServer(ctx, t, s.Stopper(), opts) + proxy, addr, _ := newSecureProxyServer(ctx, t, s.Stopper(), opts) connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-%s", addr, tenantID) // Open 500 connections to the SQL pod. @@ -1819,7 +2052,7 @@ func (te *tester) TestConnectErr( func newSecureProxyServer( ctx context.Context, t *testing.T, stopper *stop.Stopper, opts *ProxyOptions, -) (server *Server, addr string) { +) (server *Server, addr, httpAddr string) { // Created via: const _ = ` openssl genrsa -out testdata/testserver.key 2048 @@ -1834,10 +2067,15 @@ openssl req -new -x509 -sha256 -key testdata/testserver.key -out testdata/testse func newProxyServer( ctx context.Context, t *testing.T, stopper *stop.Stopper, opts *ProxyOptions, -) (server *Server, addr string) { +) (server *Server, addr, httpAddr string) { const listenAddress = "127.0.0.1:0" + ctx, _ = stopper.WithCancelOnQuiesce(ctx) ln, err := net.Listen("tcp", listenAddress) require.NoError(t, err) + stopper.AddCloser(stop.CloserFn(func() { _ = ln.Close() })) + httpLn, err := net.Listen("tcp", listenAddress) + require.NoError(t, err) + stopper.AddCloser(stop.CloserFn(func() { _ = httpLn.Close() })) server, err = NewServer(ctx, stopper, *opts) require.NoError(t, err) @@ -1846,8 +2084,12 @@ func newProxyServer( _ = server.Serve(ctx, ln) }) require.NoError(t, err) + err = server.Stopper.RunAsyncTask(ctx, "proxy-http-server-serve", func(ctx context.Context) { + _ = server.ServeHTTP(ctx, httpLn) + }) + require.NoError(t, err) - return server, ln.Addr().String() + return server, ln.Addr().String(), httpLn.Addr().String() } func runTestQuery(ctx context.Context, conn *pgx.Conn) error { @@ -1943,6 +2185,7 @@ func startTestTenantPods( ts serverutils.TestServerInterface, tenantID roachpb.TenantID, count int, + knobs base.TestingKnobs, ) []serverutils.TestTenantInterface { t.Helper() @@ -1953,6 +2196,7 @@ func startTestTenantPods( if i != 0 { params.Existing = true } + params.TestingKnobs = knobs tenant, tenantDB := serverutils.StartTenant(t, ts, params) tenant.PGServer().(*pgwire.Server).TestingSetTrustClientProvidedRemoteAddr(true) diff --git a/pkg/ccl/sqlproxyccl/server.go b/pkg/ccl/sqlproxyccl/server.go index d5a2d703420f..d42489afeaec 100644 --- a/pkg/ccl/sqlproxyccl/server.go +++ b/pkg/ccl/sqlproxyccl/server.go @@ -115,7 +115,8 @@ func (s *Server) handleVars(w http.ResponseWriter, r *http.Request) { } } -// handleCancel +// handleCancel processes a cancel request that has been forwarded from another +// sqlproxy. func (s *Server) handleCancel(w http.ResponseWriter, r *http.Request) { var retErr error defer func() { @@ -127,6 +128,9 @@ func (s *Server) handleCancel(w http.ResponseWriter, r *http.Request) { r.RemoteAddr, retErr, ) } + if f := s.handler.testingKnobs.httpCancelErrHandler; f != nil { + f(retErr) + } }() buf := make([]byte, proxyCancelRequestLen) n, err := r.Body.Read(buf)