diff --git a/pkg/ccl/sqlproxyccl/BUILD.bazel b/pkg/ccl/sqlproxyccl/BUILD.bazel index 3fc0f12cca71..9c99cddfe9f9 100644 --- a/pkg/ccl/sqlproxyccl/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/BUILD.bazel @@ -98,7 +98,6 @@ go_test( "//pkg/testutils/skip", "//pkg/testutils/sqlutils", "//pkg/testutils/testcluster", - "//pkg/util/ctxgroup", "//pkg/util/leaktest", "//pkg/util/log", "//pkg/util/metric", diff --git a/pkg/ccl/sqlproxyccl/frontend_admitter_test.go b/pkg/ccl/sqlproxyccl/frontend_admitter_test.go index d74ce056fc06..5577d50f8454 100644 --- a/pkg/ccl/sqlproxyccl/frontend_admitter_test.go +++ b/pkg/ccl/sqlproxyccl/frontend_admitter_test.go @@ -16,6 +16,7 @@ import ( "path/filepath" "testing" + "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/jackc/pgconn" @@ -79,7 +80,10 @@ func TestFrontendAdmitWithClientSSLRequire(t *testing.T) { defer cancel() go func() { - cfg, err := pgconn.ParseConfig("postgres://localhost?sslmode=require") + cfg, err := pgconn.ParseConfig(fmt.Sprintf( + "postgres://localhost?sslmode=require&sslrootcert=%s", + testutils.TestDataPath(t, "testserver.crt"), + )) cfg.TLSConfig.ServerName = "test" require.NoError(t, err) require.NotNil(t, cfg) @@ -132,7 +136,12 @@ func TestFrontendAdmitRequireEncryption(t *testing.T) { func TestFrontendAdmitWithCancel(t *testing.T) { defer leaktest.AfterTest(t)() - cli, srv := net.Pipe() + cli, srvPipe := net.Pipe() + srv := &fakeTCPConn{ + Conn: srvPipe, + remoteAddr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}}, + localAddr: &net.TCPAddr{IP: net.IP{4, 5, 6, 7}}, + } require.NoError(t, srv.SetReadDeadline(timeutil.Now().Add(3e9))) require.NoError(t, cli.SetReadDeadline(timeutil.Now().Add(3e9))) @@ -152,7 +161,12 @@ func TestFrontendAdmitWithCancel(t *testing.T) { func TestFrontendAdmitWithSSLAndCancel(t *testing.T) { defer leaktest.AfterTest(t)() - cli, srv := net.Pipe() + cli, srvPipe := net.Pipe() + srv := &fakeTCPConn{ + Conn: srvPipe, + remoteAddr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}}, + localAddr: &net.TCPAddr{IP: net.IP{4, 5, 6, 7}}, + } require.NoError(t, srv.SetReadDeadline(timeutil.Now().Add(3e9))) require.NoError(t, cli.SetReadDeadline(timeutil.Now().Add(3e9))) diff --git a/pkg/ccl/sqlproxyccl/metrics.go b/pkg/ccl/sqlproxyccl/metrics.go index ffe9b61bf4f9..cf7abe31f044 100644 --- a/pkg/ccl/sqlproxyccl/metrics.go +++ b/pkg/ccl/sqlproxyccl/metrics.go @@ -37,6 +37,12 @@ type metrics struct { ConnMigrationAttemptedCount *metric.Counter ConnMigrationAttemptedLatency *metric.Histogram ConnMigrationTransferResponseMessageSize *metric.Histogram + + QueryCancelReceivedPGWire *metric.Counter + QueryCancelReceivedHTTP *metric.Counter + QueryCancelForwarded *metric.Counter + QueryCancelIgnored *metric.Counter + QueryCancelSuccessful *metric.Counter } // MetricStruct implements the metrics.Struct interface. @@ -174,6 +180,36 @@ var ( Measurement: "Bytes", Unit: metric.Unit_BYTES, } + metaQueryCancelReceivedPGWire = metric.Metadata{ + Name: "proxy.query_cancel.received.pgwire", + Help: "Number of query cancel requests this proxy received over pgwire", + Measurement: "Query Cancel Requests", + Unit: metric.Unit_COUNT, + } + metaQueryCancelReceivedHTTP = metric.Metadata{ + Name: "proxy.query_cancel.received.http", + Help: "Number of query cancel requests this proxy received over HTTP", + Measurement: "Query Cancel Requests", + Unit: metric.Unit_COUNT, + } + metaQueryCancelIgnored = metric.Metadata{ + Name: "proxy.query_cancel.ignored", + Help: "Number of query cancel requests this proxy ignored", + Measurement: "Query Cancel Requests", + Unit: metric.Unit_COUNT, + } + metaQueryCancelForwarded = metric.Metadata{ + Name: "proxy.query_cancel.forwarded", + Help: "Number of query cancel requests this proxy forwarded to another proxy", + Measurement: "Query Cancel Requests", + Unit: metric.Unit_COUNT, + } + metaQueryCancelSuccessful = metric.Metadata{ + Name: "proxy.query_cancel.successful", + Help: "Number of query cancel requests this proxy forwarded to the tenant", + Measurement: "Query Cancel Requests", + Unit: metric.Unit_COUNT, + } ) // makeProxyMetrics instantiates the metrics holder for proxy monitoring. @@ -215,6 +251,11 @@ func makeProxyMetrics() metrics { maxExpectedTransferResponseMessageSize, 1, ), + QueryCancelReceivedPGWire: metric.NewCounter(metaQueryCancelReceivedPGWire), + QueryCancelReceivedHTTP: metric.NewCounter(metaQueryCancelReceivedHTTP), + QueryCancelIgnored: metric.NewCounter(metaQueryCancelIgnored), + QueryCancelForwarded: metric.NewCounter(metaQueryCancelForwarded), + QueryCancelSuccessful: metric.NewCounter(metaQueryCancelSuccessful), } } diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go index 4676ba7e0abf..751dd292c6d4 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -492,14 +492,32 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn net.Conn) // handleCancelRequest handles a pgwire query cancel request by either // forwarding it to a SQL node or to another proxy. -func (handler *proxyHandler) handleCancelRequest(cr *proxyCancelRequest, allowForward bool) error { +func (handler *proxyHandler) handleCancelRequest( + cr *proxyCancelRequest, allowForward bool, +) (retErr error) { + if allowForward { + handler.metrics.QueryCancelReceivedPGWire.Inc(1) + } else { + handler.metrics.QueryCancelReceivedHTTP.Inc(1) + } + var triedForward bool + defer func() { + if retErr != nil { + handler.metrics.QueryCancelIgnored.Inc(1) + } else if triedForward { + handler.metrics.QueryCancelForwarded.Inc(1) + } else { + handler.metrics.QueryCancelSuccessful.Inc(1) + } + }() if ci, ok := handler.cancelInfoMap.getCancelInfo(cr.SecretKey); ok { return ci.sendCancelToBackend(cr.ClientIP) } // Only forward the request if it hasn't already been sent to the correct proxy. if !allowForward { - return nil + return errors.Newf("ignoring cancel request with unfamiliar key: %d", cr.SecretKey) } + triedForward = true u := "http://" + cr.ProxyIP.String() + ":8080/_status/cancel/" reqBody := bytes.NewReader(cr.Encode()) return forwardCancelRequest(u, reqBody) diff --git a/pkg/ccl/sqlproxyccl/proxy_handler_test.go b/pkg/ccl/sqlproxyccl/proxy_handler_test.go index c92be81fbe4f..15c499a1354a 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler_test.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler_test.go @@ -924,13 +924,12 @@ func TestCancelQuery(t *testing.T) { // Start two SQL pods for the test tenant. const podCount = 2 tenantID := serverutils.TestTenantID() - var cancelFn func(context.Context) error + var cancelFn func() tenantKnobs := base.TestingKnobs{} tenantKnobs.SQLExecutor = &sql.ExecutorTestingKnobs{ BeforeExecute: func(ctx context.Context, stmt string) { if strings.Contains(stmt, "cancel_me") { - err := cancelFn(ctx) - assert.NoError(t, err) + cancelFn() } }, } @@ -998,16 +997,58 @@ func TestCancelQuery(t *testing.T) { return nil }) + clearMetrics := func(t *testing.T, metrics *metrics) { + metrics.QueryCancelSuccessful.Clear() + metrics.QueryCancelIgnored.Clear() + metrics.QueryCancelForwarded.Clear() + metrics.QueryCancelReceivedPGWire.Clear() + metrics.QueryCancelReceivedHTTP.Clear() + + testutils.SucceedsSoon(t, func() error { + if metrics.QueryCancelSuccessful.Count() != 0 || + metrics.QueryCancelIgnored.Count() != 0 || + metrics.QueryCancelForwarded.Count() != 0 || + metrics.QueryCancelReceivedPGWire.Count() != 0 || + metrics.QueryCancelReceivedHTTP.Count() != 0 { + return errors.Newf("expected metrics to update, got: "+ + "QueryCancelSuccessful=%d, QueryCancelIgnored=%d "+ + "QueryCancelForwarded=%d QueryCancelReceivedPGWire=%d QueryCancelReceivedHTTP=%d", + metrics.QueryCancelSuccessful.Count(), metrics.QueryCancelIgnored.Count(), + metrics.QueryCancelForwarded.Count(), metrics.QueryCancelReceivedPGWire.Count(), + metrics.QueryCancelReceivedHTTP.Count(), + ) + } + return nil + }) + } + t.Run("cancel over sql", func(t *testing.T) { - cancelFn = conn.PgConn().CancelRequest + clearMetrics(t, proxy.metrics) + cancelFn = func() { + _ = conn.PgConn().CancelRequest(ctx) + } var b bool err = conn.QueryRow(ctx, "SELECT pg_sleep(5) AS cancel_me").Scan(&b) require.Error(t, err) require.Regexp(t, "query execution canceled", err.Error()) + testutils.SucceedsSoon(t, func() error { + if proxy.metrics.QueryCancelSuccessful.Count() != 1 || + proxy.metrics.QueryCancelReceivedPGWire.Count() != 1 { + return errors.Newf("expected metrics to update, got: "+ + "QueryCancelSuccessful=%d, QueryCancelIgnored=%d "+ + "QueryCancelForwarded=%d QueryCancelReceivedPGWire=%d QueryCancelReceivedHTTP=%d", + proxy.metrics.QueryCancelSuccessful.Count(), proxy.metrics.QueryCancelIgnored.Count(), + proxy.metrics.QueryCancelForwarded.Count(), proxy.metrics.QueryCancelReceivedPGWire.Count(), + proxy.metrics.QueryCancelReceivedHTTP.Count(), + ) + } + return nil + }) }) t.Run("cancel over http", func(t *testing.T) { - cancelFn = func(ctx context.Context) error { + clearMetrics(t, proxy.metrics) + cancelFn = func() { cancelRequest := proxyCancelRequest{ ProxyIP: net.IP{}, SecretKey: conn.PgConn().SecretKey(), @@ -1016,27 +1057,41 @@ func TestCancelQuery(t *testing.T) { u := "http://" + httpAddr + "/_status/cancel/" reqBody := bytes.NewReader(cancelRequest.Encode()) client := http.Client{ - Timeout: 1 * time.Second, + Timeout: 10 * time.Second, } resp, err := client.Post(u, "application/octet-stream", reqBody) - if err != nil { - return err + if !assert.NoError(t, err) { + return } respBytes, err := ioutil.ReadAll(resp.Body) - if err != nil { - return err + if !assert.NoError(t, err) { + return } assert.Equal(t, "OK", string(respBytes)) - return nil } var b bool err = conn.QueryRow(ctx, "SELECT pg_sleep(5) AS cancel_me").Scan(&b) require.Error(t, err) require.Regexp(t, "query execution canceled", err.Error()) + testutils.SucceedsSoon(t, func() error { + if proxy.metrics.QueryCancelSuccessful.Count() != 1 || + proxy.metrics.QueryCancelReceivedHTTP.Count() != 1 { + return errors.Newf("expected metrics to update, got: "+ + "QueryCancelSuccessful=%d, QueryCancelIgnored=%d "+ + "QueryCancelForwarded=%d QueryCancelReceivedPGWire=%d QueryCancelReceivedHTTP=%d", + proxy.metrics.QueryCancelSuccessful.Count(), proxy.metrics.QueryCancelIgnored.Count(), + proxy.metrics.QueryCancelForwarded.Count(), proxy.metrics.QueryCancelReceivedPGWire.Count(), + proxy.metrics.QueryCancelReceivedHTTP.Count(), + ) + } + return nil + }) }) t.Run("cancel after migrating a session", func(t *testing.T) { - cancelFn = conn.PgConn().CancelRequest + cancelFn = func() { + _ = conn.PgConn().CancelRequest(ctx) + } defer testutils.TestingHook(&defaultTransferTimeout, 3*time.Minute)() origCancelInfo, found := proxy.handler.cancelInfoMap.getCancelInfo(conn.PgConn().SecretKey()) require.True(t, found) @@ -1080,10 +1135,11 @@ func TestCancelQuery(t *testing.T) { }) t.Run("reject cancel from wrong client IP", func(t *testing.T) { + clearMetrics(t, proxy.metrics) cancelRequest := proxyCancelRequest{ ProxyIP: net.IP{}, SecretKey: conn.PgConn().SecretKey(), - ClientIP: net.IP{127, 1, 2, 3}, + ClientIP: net.IP{210, 1, 2, 3}, } u := "http://" + httpAddr + "/_status/cancel/" reqBody := bytes.NewReader(cancelRequest.Encode()) @@ -1097,9 +1153,23 @@ func TestCancelQuery(t *testing.T) { assert.Equal(t, "OK", string(respBytes)) require.Error(t, httpCancelErr) require.Regexp(t, "mismatched client IP for cancel request", httpCancelErr.Error()) + testutils.SucceedsSoon(t, func() error { + if proxy.metrics.QueryCancelIgnored.Count() != 1 || + proxy.metrics.QueryCancelReceivedHTTP.Count() != 1 { + return errors.Newf("expected metrics to update, got: "+ + "QueryCancelSuccessful=%d, QueryCancelIgnored=%d "+ + "QueryCancelForwarded=%d QueryCancelReceivedPGWire=%d QueryCancelReceivedHTTP=%d", + proxy.metrics.QueryCancelSuccessful.Count(), proxy.metrics.QueryCancelIgnored.Count(), + proxy.metrics.QueryCancelForwarded.Count(), proxy.metrics.QueryCancelReceivedPGWire.Count(), + proxy.metrics.QueryCancelReceivedHTTP.Count(), + ) + } + return nil + }) }) t.Run("forward over http", func(t *testing.T) { + clearMetrics(t, proxy.metrics) var forwardedTo string var forwardedReq proxyCancelRequest var wg sync.WaitGroup @@ -1116,7 +1186,7 @@ func TestCancelQuery(t *testing.T) { })() crdbRequest := &pgproto3.CancelRequest{ ProcessID: 1, - SecretKey: 2, + SecretKey: conn.PgConn().SecretKey() + 1, } buf := crdbRequest.Encode(nil /* buf */) proxyAddr := conn.PgConn().Conn().RemoteAddr() @@ -1132,10 +1202,57 @@ func TestCancelQuery(t *testing.T) { require.Equal(t, "http://0.0.0.1:8080/_status/cancel/", forwardedTo) expectedReq := proxyCancelRequest{ ProxyIP: net.IP{0, 0, 0, 1}, - SecretKey: 2, + SecretKey: conn.PgConn().SecretKey() + 1, ClientIP: net.IP{127, 0, 0, 1}, } require.Equal(t, expectedReq, forwardedReq) + testutils.SucceedsSoon(t, func() error { + if proxy.metrics.QueryCancelForwarded.Count() != 1 || + proxy.metrics.QueryCancelReceivedPGWire.Count() != 1 { + return errors.Newf("expected metrics to update, got: "+ + "QueryCancelSuccessful=%d, QueryCancelIgnored=%d "+ + "QueryCancelForwarded=%d QueryCancelReceivedPGWire=%d QueryCancelReceivedHTTP=%d", + proxy.metrics.QueryCancelSuccessful.Count(), proxy.metrics.QueryCancelIgnored.Count(), + proxy.metrics.QueryCancelForwarded.Count(), proxy.metrics.QueryCancelReceivedPGWire.Count(), + proxy.metrics.QueryCancelReceivedHTTP.Count(), + ) + } + return nil + }) + }) + + t.Run("ignore unknown secret key", func(t *testing.T) { + clearMetrics(t, proxy.metrics) + cancelRequest := proxyCancelRequest{ + ProxyIP: net.IP{}, + SecretKey: conn.PgConn().SecretKey() + 1, + ClientIP: net.IP{127, 0, 0, 1}, + } + u := "http://" + httpAddr + "/_status/cancel/" + reqBody := bytes.NewReader(cancelRequest.Encode()) + client := http.Client{ + Timeout: 10 * time.Second, + } + resp, err := client.Post(u, "application/octet-stream", reqBody) + require.NoError(t, err) + respBytes, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, "OK", string(respBytes)) + require.Error(t, httpCancelErr) + require.Regexp(t, "ignoring cancel request with unfamiliar key", httpCancelErr.Error()) + testutils.SucceedsSoon(t, func() error { + if proxy.metrics.QueryCancelIgnored.Count() != 1 || + proxy.metrics.QueryCancelReceivedHTTP.Count() != 1 { + return errors.Newf("expected metrics to update, got: "+ + "QueryCancelSuccessful=%d, QueryCancelIgnored=%d "+ + "QueryCancelForwarded=%d QueryCancelReceivedPGWire=%d QueryCancelReceivedHTTP=%d", + proxy.metrics.QueryCancelSuccessful.Count(), proxy.metrics.QueryCancelIgnored.Count(), + proxy.metrics.QueryCancelForwarded.Count(), proxy.metrics.QueryCancelReceivedPGWire.Count(), + proxy.metrics.QueryCancelReceivedHTTP.Count(), + ) + } + return nil + }) }) }