diff --git a/pkg/roachpb/span_stats.go b/pkg/roachpb/span_stats.go index fe2a0f48cde4..83b3337907e6 100644 --- a/pkg/roachpb/span_stats.go +++ b/pkg/roachpb/span_stats.go @@ -12,6 +12,7 @@ package roachpb import ( "fmt" + "time" "github.com/cockroachdb/cockroach/pkg/settings" ) @@ -30,6 +31,15 @@ var SpanStatsBatchLimit = settings.RegisterIntSetting( settings.PositiveInt, ) +var SpanStatsNodeTimeout = settings.RegisterDurationSetting( + settings.TenantWritable, + "server.span_stats.node.timeout", + "the duration allowed for a single node to return span stats data before"+ + " the request is cancelled; if set to 0, there is no timeout", + time.Minute, + settings.NonNegativeDuration, +) + const defaultRangeStatsBatchLimit = 100 // RangeStatsBatchLimit registers the maximum number of ranges to be batched diff --git a/pkg/roachpb/span_stats.proto b/pkg/roachpb/span_stats.proto index fa699f53ac3c..d1c93ef998cc 100644 --- a/pkg/roachpb/span_stats.proto +++ b/pkg/roachpb/span_stats.proto @@ -51,4 +51,6 @@ message SpanStatsResponse { int32 range_count = 2; uint64 approximate_disk_bytes = 3; map span_to_stats = 4; + repeated string errors = 5; + // NEXT ID: 6. } diff --git a/pkg/server/admin.go b/pkg/server/admin.go index 9c3eb272a18f..78c8a0702c7b 100644 --- a/pkg/server/admin.go +++ b/pkg/server/admin.go @@ -3297,6 +3297,7 @@ func (s *systemAdminServer) EnqueueRange( if err := contextutil.RunWithTimeout(ctx, "enqueue range", time.Minute, func(ctx context.Context) error { return s.server.status.iterateNodes( ctx, fmt.Sprintf("enqueue r%d in queue %s", req.RangeID, req.Queue), + noTimeout, dialFn, nodeFn, responseFn, errorFn, ) }); err != nil { diff --git a/pkg/server/api_v2_ranges.go b/pkg/server/api_v2_ranges.go index f0db561ec9a4..b4e9f7c9b209 100644 --- a/pkg/server/api_v2_ranges.go +++ b/pkg/server/api_v2_ranges.go @@ -237,7 +237,11 @@ func (a *apiV2Server) listRange(w http.ResponseWriter, r *http.Request) { } if err := a.status.iterateNodes( - ctx, fmt.Sprintf("details about range %d", rangeID), dialFn, nodeFn, responseFn, errorFn, + ctx, + fmt.Sprintf("details about range %d", rangeID), + noTimeout, + dialFn, nodeFn, + responseFn, errorFn, ); err != nil { apiV2InternalError(ctx, err, w) return diff --git a/pkg/server/index_usage_stats.go b/pkg/server/index_usage_stats.go index 6fcd754814af..d342f6f78a77 100644 --- a/pkg/server/index_usage_stats.go +++ b/pkg/server/index_usage_stats.go @@ -98,6 +98,7 @@ func (s *statusServer) IndexUsageStatistics( // yields an incorrect result. if err := s.iterateNodes(ctx, "requesting index usage stats", + noTimeout, dialFn, fetchIndexUsageStats, aggFn, errFn); err != nil { return nil, err } @@ -194,6 +195,7 @@ func (s *statusServer) ResetIndexUsageStats( if err := s.iterateNodes(ctx, "Resetting index usage stats", + noTimeout, dialFn, resetIndexUsageStats, aggFn, errFn); err != nil { return nil, err } diff --git a/pkg/server/key_visualizer_server.go b/pkg/server/key_visualizer_server.go index 7e8b0144cb8d..9e1c48e035dc 100644 --- a/pkg/server/key_visualizer_server.go +++ b/pkg/server/key_visualizer_server.go @@ -110,8 +110,13 @@ func (s *KeyVisualizerServer) getSamplesFromFanOut( } err := s.status.iterateNodes(ctx, - "iterating nodes for key visualizer samples", dialFn, nodeFn, - responseFn, errorFn) + "iterating nodes for key visualizer samples", + noTimeout, + dialFn, + nodeFn, + responseFn, + errorFn, + ) if err != nil { return nil, err } diff --git a/pkg/server/server.go b/pkg/server/server.go index ea886556ea86..4285656139d0 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -881,6 +881,11 @@ func NewServer(cfg Config, stopper *stop.Stopper) (*Server, error) { } // Instantiate the status API server. + var serverTestingKnobs *TestingKnobs + if cfg.TestingKnobs.Server != nil { + serverTestingKnobs = cfg.TestingKnobs.Server.(*TestingKnobs) + } + sStatus := newSystemStatusServer( cfg.AmbientCtx, st, @@ -903,6 +908,7 @@ func NewServer(cfg Config, stopper *stop.Stopper) (*Server, error) { clock, rangestats.NewFetcher(db), node, + serverTestingKnobs, ) keyVisualizerServer := &KeyVisualizerServer{ diff --git a/pkg/server/span_stats_server.go b/pkg/server/span_stats_server.go index 7d7bd42fda24..8f2fe8941cd2 100644 --- a/pkg/server/span_stats_server.go +++ b/pkg/server/span_stats_server.go @@ -12,6 +12,7 @@ package server import ( "context" + "fmt" "strconv" "github.com/cockroachdb/cockroach/pkg/keys" @@ -37,8 +38,12 @@ func (s *systemStatusServer) spanStatsFanOut( res := &roachpb.SpanStatsResponse{ SpanToStats: make(map[string]*roachpb.SpanStats), } - // Response level error - var respErr error + // Populate SpanToStats with empty values for each span, + // so that clients may still access stats for a specific span + // in the extreme case of an error encountered on every node. + for _, sp := range req.Spans { + res.SpanToStats[sp.String()] = &roachpb.SpanStats{} + } spansPerNode, err := s.getSpansPerNode(ctx, req) if err != nil { @@ -51,6 +56,14 @@ func (s *systemStatusServer) spanStatsFanOut( ctx context.Context, nodeID roachpb.NodeID, ) (interface{}, error) { + if s.knobs != nil { + if s.knobs.IterateNodesDialCallback != nil { + if err := s.knobs.IterateNodesDialCallback(nodeID); err != nil { + return nil, err + } + } + } + if _, ok := spansPerNode[nodeID]; ok { return s.dialNode(ctx, nodeID) } @@ -58,6 +71,14 @@ func (s *systemStatusServer) spanStatsFanOut( } nodeFn := func(ctx context.Context, client interface{}, nodeID roachpb.NodeID) (interface{}, error) { + if s.knobs != nil { + if s.knobs.IterateNodesNodeCallback != nil { + if err := s.knobs.IterateNodesNodeCallback(ctx, nodeID); err != nil { + return nil, err + } + } + } + // `smartDial` may skip this node, so check to see if the client is nil. // If it is, return nil response. if client == nil { @@ -81,24 +102,28 @@ func (s *systemStatusServer) spanStatsFanOut( nodeResponse := resp.(*roachpb.SpanStatsResponse) for spanStr, spanStats := range nodeResponse.SpanToStats { - _, exists := res.SpanToStats[spanStr] - if !exists { - res.SpanToStats[spanStr] = spanStats - } else { - res.SpanToStats[spanStr].ApproximateDiskBytes += spanStats.ApproximateDiskBytes - res.SpanToStats[spanStr].TotalStats.Add(spanStats.TotalStats) + // We are not counting replicas, so only consider range count + // if it has not been set. + if res.SpanToStats[spanStr].RangeCount == 0 { + res.SpanToStats[spanStr].RangeCount = spanStats.RangeCount } + + res.SpanToStats[spanStr].TotalStats.Add(spanStats.TotalStats) + res.SpanToStats[spanStr].ApproximateDiskBytes += spanStats.ApproximateDiskBytes } } errorFn := func(nodeID roachpb.NodeID, err error) { log.Errorf(ctx, nodeErrorMsgPlaceholder, nodeID, err) - respErr = err + errorMessage := fmt.Sprintf("%v", err) + res.Errors = append(res.Errors, errorMessage) } + timeout := roachpb.SpanStatsNodeTimeout.Get(&s.st.SV) if err := s.statusServer.iterateNodes( ctx, "iterating nodes for span stats", + timeout, smartDial, nodeFn, responseFn, @@ -107,7 +132,7 @@ func (s *systemStatusServer) spanStatsFanOut( return nil, err } - return res, respErr + return res, nil } func (s *systemStatusServer) getLocalStats( diff --git a/pkg/server/span_stats_test.go b/pkg/server/span_stats_test.go index c43104bb9d72..b9e73df33262 100644 --- a/pkg/server/span_stats_test.go +++ b/pkg/server/span_stats_test.go @@ -14,6 +14,7 @@ import ( "bytes" "context" "fmt" + "strings" "testing" "github.com/cockroachdb/cockroach/pkg/base" @@ -191,6 +192,137 @@ func TestSpanStatsFanOut(t *testing.T) { } +func TestSpanStatsFanOutFaultTolerance(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + skip.UnderStressWithIssue(t, 108534) + ctx := context.Background() + const numNodes = 5 + + type testCase struct { + name string + dialCallback func(nodeID roachpb.NodeID) error + nodeCallback func(ctx context.Context, nodeID roachpb.NodeID) error + assertions func(res *roachpb.SpanStatsResponse) + } + + containsError := func(errors []string, testString string) bool { + for _, e := range errors { + if strings.Contains(e, testString) { + return true + } + } + return false + } + + testCases := []testCase{ + { + // In a complete failure, no node is able to service requests successfully. + name: "complete-fanout-failure", + dialCallback: func(nodeID roachpb.NodeID) error { + // On the 1st and 2nd node, simulate a connection error. + if nodeID == 1 || nodeID == 2 { + return errors.Newf("error dialing node %d", nodeID) + } + return nil + }, + nodeCallback: func(ctx context.Context, nodeID roachpb.NodeID) error { + // On the 3rd node, simulate some sort of KV error. + if nodeID == 3 { + return errors.Newf("kv error on node %d", nodeID) + } + + // On the 4th and 5th node, simulate a request that takes a very long time. + // In this case, nodeFn will block until the context is cancelled + // i.e. if iterateNodes respects the timeout cluster setting. + if nodeID == 4 || nodeID == 5 { + <-ctx.Done() + // Return an error that mimics the error returned + // when a rpc's context is cancelled: + return errors.Newf("node %d timed out", nodeID) + } + return nil + }, + assertions: func(res *roachpb.SpanStatsResponse) { + // Expect to still be able to access SpanToStats for keys.EverythingSpan + // without panicking, even though there was a failure on every node. + require.Equal(t, int64(0), res.SpanToStats[keys.EverythingSpan.String()].TotalStats.LiveCount) + require.Equal(t, 5, len(res.Errors)) + + require.Equal(t, true, containsError(res.Errors, "error dialing node 1")) + require.Equal(t, true, containsError(res.Errors, "error dialing node 2")) + require.Equal(t, true, containsError(res.Errors, "kv error on node 3")) + require.Equal(t, true, containsError(res.Errors, "node 4 timed out")) + require.Equal(t, true, containsError(res.Errors, "node 5 timed out")) + }, + }, + { + // In a partial failure, nodes 1, 3, and 4 fail, and nodes 2 and 5 succeed. + name: "partial-fanout-failure", + dialCallback: func(nodeID roachpb.NodeID) error { + if nodeID == 1 { + return errors.Newf("error dialing node %d", nodeID) + } + return nil + }, + nodeCallback: func(ctx context.Context, nodeID roachpb.NodeID) error { + if nodeID == 3 { + return errors.Newf("kv error on node %d", nodeID) + } + + if nodeID == 4 { + <-ctx.Done() + // Return an error that mimics the error returned + // when a rpc's context is cancelled: + return errors.Newf("node %d timed out", nodeID) + } + return nil + }, + assertions: func(res *roachpb.SpanStatsResponse) { + require.Greater(t, res.SpanToStats[keys.EverythingSpan.String()].TotalStats.LiveCount, int64(0)) + // 3 nodes could not service their requests. + require.Equal(t, 3, len(res.Errors)) + + require.Equal(t, true, containsError(res.Errors, "error dialing node 1")) + require.Equal(t, true, containsError(res.Errors, "kv error on node 3")) + require.Equal(t, true, containsError(res.Errors, "node 4 timed out")) + + // There should not be any errors for node 2 or node 5. + require.Equal(t, false, containsError(res.Errors, "error dialing node 2")) + require.Equal(t, false, containsError(res.Errors, "node 5 timed out")) + }, + }, + } + + for _, tCase := range testCases { + tCase := tCase + t.Run(tCase.name, func(t *testing.T) { + serverArgs := base.TestServerArgs{} + serverArgs.Knobs.Server = &server.TestingKnobs{ + IterateNodesDialCallback: tCase.dialCallback, + IterateNodesNodeCallback: tCase.nodeCallback, + } + + tc := testcluster.StartTestCluster(t, numNodes, base.TestClusterArgs{ServerArgs: serverArgs}) + defer tc.Stopper().Stop(ctx) + + srv := tc.Server(0) + sqlDB := serverutils.OpenDBConn(t, srv.SQLAddr(), "", false, srv.Stopper()) + _, err := sqlDB.Exec("SET CLUSTER SETTING server.span_stats.node.timeout = '3s'") + require.NoError(t, err) + + statusServer := srv.(*server.TestServer).StatusServer().(serverpb.StatusServer) + res, err := statusServer.SpanStats(ctx, &roachpb.SpanStatsRequest{ + NodeID: "0", // Indicates we want a fan-out. + Spans: []roachpb.Span{keys.EverythingSpan}, + }) + + require.NoError(t, err) + tCase.assertions(res) + }) + } +} + // BenchmarkSpanStats measures the cost of collecting span statistics. func BenchmarkSpanStats(b *testing.B) { skip.UnderShort(b) diff --git a/pkg/server/sql_stats.go b/pkg/server/sql_stats.go index 411166d3bf9c..0fd95d1a820c 100644 --- a/pkg/server/sql_stats.go +++ b/pkg/server/sql_stats.go @@ -78,6 +78,7 @@ func (s *statusServer) ResetSQLStats( var fanoutError error if err := s.iterateNodes(ctx, "reset SQL statistics", + noTimeout, dialFn, resetSQLStats, func(nodeID roachpb.NodeID, resp interface{}) { diff --git a/pkg/server/statements.go b/pkg/server/statements.go index 41a8ff5fc7be..0f55070df790 100644 --- a/pkg/server/statements.go +++ b/pkg/server/statements.go @@ -84,6 +84,7 @@ func (s *statusServer) Statements( } if err := s.iterateNodes(ctx, "statement statistics", + noTimeout, dialFn, nodeStatement, func(nodeID roachpb.NodeID, resp interface{}) { diff --git a/pkg/server/status.go b/pkg/server/status.go index 7490995a0eb6..390676f7350a 100644 --- a/pkg/server/status.go +++ b/pkg/server/status.go @@ -516,6 +516,7 @@ type systemStatusServer struct { spanConfigReporter spanconfig.Reporter rangeStatsFetcher *rangestats.Fetcher node *Node + knobs *TestingKnobs } // StmtDiagnosticsRequester is the interface into *stmtdiagnostics.Registry @@ -625,6 +626,7 @@ func newSystemStatusServer( clock *hlc.Clock, rangeStatsFetcher *rangestats.Fetcher, node *Node, + knobs *TestingKnobs, ) *systemStatusServer { server := newStatusServer( ambient, @@ -652,6 +654,7 @@ func newSystemStatusServer( spanConfigReporter: spanConfigReporter, rangeStatsFetcher: rangeStatsFetcher, node: node, + knobs: knobs, } } @@ -2468,7 +2471,7 @@ func (s *systemStatusServer) HotRanges( } } - if err := s.iterateNodes(ctx, "hot ranges", dialFn, nodeFn, responseFn, errorFn); err != nil { + if err := s.iterateNodes(ctx, "hot ranges", noTimeout, dialFn, nodeFn, responseFn, errorFn); err != nil { return nil, serverError(ctx, err) } @@ -2821,7 +2824,9 @@ func (s *statusServer) Range( } if err := s.iterateNodes( - ctx, fmt.Sprintf("details about range %d", req.RangeId), dialFn, nodeFn, responseFn, errorFn, + ctx, fmt.Sprintf("details about range %d", req.RangeId), noTimeout, + dialFn, + nodeFn, responseFn, errorFn, ); err != nil { return nil, serverError(ctx, err) } @@ -2850,6 +2855,7 @@ func (s *statusServer) ListLocalSessions( func (s *statusServer) iterateNodes( ctx context.Context, errorCtx string, + nodeFnTimeout time.Duration, dialFn func(ctx context.Context, nodeID roachpb.NodeID) (interface{}, error), nodeFn func(ctx context.Context, client interface{}, nodeID roachpb.NodeID) (interface{}, error), responseFn func(nodeID roachpb.NodeID, resp interface{}), @@ -2884,7 +2890,18 @@ func (s *statusServer) iterateNodes( return } - res, err := nodeFn(ctx, client, nodeID) + var res interface{} + if nodeFnTimeout == noTimeout { + res, err = nodeFn(ctx, client, nodeID) + } else { + err = contextutil.RunWithTimeout(ctx, "iterate-nodes-fn", + nodeFnTimeout, func(ctx context.Context) error { + var _err error + res, _err = nodeFn(ctx, client, nodeID) + return _err + }) + } + if err != nil { err = errors.Wrapf(err, "error requesting %s from node %d (%s)", errorCtx, nodeID, nodeStatuses[serverID(nodeID)]) @@ -2947,7 +2964,9 @@ func (s *statusServer) paginatedIterateNodes( errorFn func(nodeID roachpb.NodeID, nodeFnError error), ) (next paginationState, err error) { if limit == 0 { - return paginationState{}, s.iterateNodes(ctx, errorCtx, dialFn, nodeFn, responseFn, errorFn) + return paginationState{}, s.iterateNodes(ctx, errorCtx, noTimeout, + dialFn, + nodeFn, responseFn, errorFn) } nodeStatuses, err := s.serverIterator.getAllNodes(ctx) if err != nil { @@ -3293,7 +3312,7 @@ func (s *statusServer) ListContentionEvents( response.Errors = append(response.Errors, errResponse) } - if err := s.iterateNodes(ctx, "contention events list", dialFn, nodeFn, responseFn, errorFn); err != nil { + if err := s.iterateNodes(ctx, "contention events list", noTimeout, dialFn, nodeFn, responseFn, errorFn); err != nil { return nil, serverError(ctx, err) } return &response, nil @@ -3340,7 +3359,7 @@ func (s *statusServer) ListDistSQLFlows( response.Errors = append(response.Errors, errResponse) } - if err := s.iterateNodes(ctx, "distsql flows list", dialFn, nodeFn, responseFn, errorFn); err != nil { + if err := s.iterateNodes(ctx, "distsql flows list", noTimeout, dialFn, nodeFn, responseFn, errorFn); err != nil { return nil, serverError(ctx, err) } return &response, nil @@ -3445,7 +3464,7 @@ func (s *statusServer) ListExecutionInsights( response.Errors = append(response.Errors, errors.EncodeError(ctx, err)) } - if err := s.iterateNodes(ctx, "execution insights list", dialFn, nodeFn, responseFn, errorFn); err != nil { + if err := s.iterateNodes(ctx, "execution insights list", noTimeout, dialFn, nodeFn, responseFn, errorFn); err != nil { return nil, serverError(ctx, err) } return &response, nil @@ -3805,6 +3824,7 @@ func (s *statusServer) TransactionContentionEvents( } if err := s.iterateNodes(ctx, "txn contention events for node", + noTimeout, dialFn, rpcCallFn, func(nodeID roachpb.NodeID, nodeResp interface{}) { diff --git a/pkg/server/testing_knobs.go b/pkg/server/testing_knobs.go index 9b5b094e7fb5..059b2437c0f7 100644 --- a/pkg/server/testing_knobs.go +++ b/pkg/server/testing_knobs.go @@ -11,6 +11,7 @@ package server import ( + "context" "net" "time" @@ -142,6 +143,15 @@ type TestingKnobs struct { // DrainReportCh, if set, is a channel that will be notified when // the SQL service shuts down. DrainReportCh chan struct{} + + // IterateNodesDialCallback is used to mock dial errors in a cluster + // fan-out. It is invoked by the dialFn argument of server.iterateNodes. + IterateNodesDialCallback func(nodeID roachpb.NodeID) error + + // IterateNodesNodeCallback is used to mock errors of the rpc invoked + // on a remote node in a cluster fan-out. It is invoked by the nodeFn argument + // of server.iterateNodes. + IterateNodesNodeCallback func(ctx context.Context, nodeID roachpb.NodeID) error } // ModuleTestingKnobs is part of the base.ModuleTestingKnobs interface.