diff --git a/pkg/roachpb/span_stats.go b/pkg/roachpb/span_stats.go index d750c23ef1a4..18e5586721e2 100644 --- a/pkg/roachpb/span_stats.go +++ b/pkg/roachpb/span_stats.go @@ -12,6 +12,7 @@ package roachpb import ( "fmt" + "time" "github.com/cockroachdb/cockroach/pkg/settings" ) @@ -30,6 +31,15 @@ var SpanStatsBatchLimit = settings.RegisterIntSetting( settings.PositiveInt, ) +var SpanStatsNodeTimeout = settings.RegisterDurationSetting( + settings.TenantWritable, + "server.span_stats.node.timeout", + "the duration allowed for a single node to return span stats data before"+ + " the request is cancelled; if set to 0, there is no timeout", + time.Minute, + settings.NonNegativeDuration, +) + const defaultRangeStatsBatchLimit = 100 // RangeStatsBatchLimit registers the maximum number of ranges to be batched diff --git a/pkg/roachpb/span_stats.proto b/pkg/roachpb/span_stats.proto index 2bae523f3a64..ef5fc4ef434a 100644 --- a/pkg/roachpb/span_stats.proto +++ b/pkg/roachpb/span_stats.proto @@ -67,5 +67,7 @@ message SpanStatsResponse { map span_to_stats = 4; - // NEXT ID: 5. + repeated string errors = 5; + + // NEXT ID: 6. } diff --git a/pkg/server/admin.go b/pkg/server/admin.go index 3172f3d75200..7b9679990fee 100644 --- a/pkg/server/admin.go +++ b/pkg/server/admin.go @@ -3199,6 +3199,7 @@ func (s *systemAdminServer) EnqueueRange( if err := timeutil.RunWithTimeout(ctx, "enqueue range", time.Minute, func(ctx context.Context) error { return s.server.status.iterateNodes( ctx, fmt.Sprintf("enqueue r%d in queue %s", req.RangeID, req.Queue), + noTimeout, dialFn, nodeFn, responseFn, errorFn, ) }); err != nil { diff --git a/pkg/server/api_v2_ranges.go b/pkg/server/api_v2_ranges.go index af737c591fc2..fc952ed4a06b 100644 --- a/pkg/server/api_v2_ranges.go +++ b/pkg/server/api_v2_ranges.go @@ -247,7 +247,11 @@ func (a *apiV2Server) listRange(w http.ResponseWriter, r *http.Request) { } if err := a.status.iterateNodes( - ctx, fmt.Sprintf("details about range %d", rangeID), dialFn, nodeFn, responseFn, errorFn, + ctx, + fmt.Sprintf("details about range %d", rangeID), + noTimeout, + dialFn, nodeFn, + responseFn, errorFn, ); err != nil { srverrors.APIV2InternalError(ctx, err, w) return diff --git a/pkg/server/index_usage_stats.go b/pkg/server/index_usage_stats.go index 81eeb28dc21c..6bf8bef46d75 100644 --- a/pkg/server/index_usage_stats.go +++ b/pkg/server/index_usage_stats.go @@ -100,6 +100,7 @@ func (s *statusServer) IndexUsageStatistics( // yields an incorrect result. if err := s.iterateNodes(ctx, "requesting index usage stats", + noTimeout, dialFn, fetchIndexUsageStats, aggFn, errFn); err != nil { return nil, err } @@ -196,6 +197,7 @@ func (s *statusServer) ResetIndexUsageStats( if err := s.iterateNodes(ctx, "Resetting index usage stats", + noTimeout, dialFn, resetIndexUsageStats, aggFn, errFn); err != nil { return nil, err } diff --git a/pkg/server/key_visualizer_server.go b/pkg/server/key_visualizer_server.go index 7e8b0144cb8d..9e1c48e035dc 100644 --- a/pkg/server/key_visualizer_server.go +++ b/pkg/server/key_visualizer_server.go @@ -110,8 +110,13 @@ func (s *KeyVisualizerServer) getSamplesFromFanOut( } err := s.status.iterateNodes(ctx, - "iterating nodes for key visualizer samples", dialFn, nodeFn, - responseFn, errorFn) + "iterating nodes for key visualizer samples", + noTimeout, + dialFn, + nodeFn, + responseFn, + errorFn, + ) if err != nil { return nil, err } diff --git a/pkg/server/server.go b/pkg/server/server.go index 7b2ce8aef9e3..31a465e4998e 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -976,6 +976,11 @@ func NewServer(cfg Config, stopper *stop.Stopper) (serverctl.ServerStartupInterf } // Instantiate the status API server. + var serverTestingKnobs *TestingKnobs + if cfg.TestingKnobs.Server != nil { + serverTestingKnobs = cfg.TestingKnobs.Server.(*TestingKnobs) + } + sStatus := newSystemStatusServer( cfg.AmbientCtx, st, @@ -998,6 +1003,7 @@ func NewServer(cfg Config, stopper *stop.Stopper) (serverctl.ServerStartupInterf clock, rangestats.NewFetcher(db), node, + serverTestingKnobs, ) keyVisualizerServer := &KeyVisualizerServer{ diff --git a/pkg/server/span_stats_server.go b/pkg/server/span_stats_server.go index 0d149c9bf437..4be4069771f9 100644 --- a/pkg/server/span_stats_server.go +++ b/pkg/server/span_stats_server.go @@ -12,6 +12,7 @@ package server import ( "context" + "fmt" "strconv" "github.com/cockroachdb/cockroach/pkg/keys" @@ -37,8 +38,12 @@ func (s *systemStatusServer) spanStatsFanOut( res := &roachpb.SpanStatsResponse{ SpanToStats: make(map[string]*roachpb.SpanStats), } - // Response level error - var respErr error + // Populate SpanToStats with empty values for each span, + // so that clients may still access stats for a specific span + // in the extreme case of an error encountered on every node. + for _, sp := range req.Spans { + res.SpanToStats[sp.String()] = &roachpb.SpanStats{} + } spansPerNode, err := s.getSpansPerNode(ctx, req) if err != nil { @@ -51,6 +56,14 @@ func (s *systemStatusServer) spanStatsFanOut( ctx context.Context, nodeID roachpb.NodeID, ) (interface{}, error) { + if s.knobs != nil { + if s.knobs.IterateNodesDialCallback != nil { + if err := s.knobs.IterateNodesDialCallback(nodeID); err != nil { + return nil, err + } + } + } + if _, ok := spansPerNode[nodeID]; ok { return s.dialNode(ctx, nodeID) } @@ -58,6 +71,14 @@ func (s *systemStatusServer) spanStatsFanOut( } nodeFn := func(ctx context.Context, client interface{}, nodeID roachpb.NodeID) (interface{}, error) { + if s.knobs != nil { + if s.knobs.IterateNodesNodeCallback != nil { + if err := s.knobs.IterateNodesNodeCallback(ctx, nodeID); err != nil { + return nil, err + } + } + } + // `smartDial` may skip this node, so check to see if the client is nil. // If it is, return nil response. if client == nil { @@ -81,23 +102,26 @@ func (s *systemStatusServer) spanStatsFanOut( nodeResponse := resp.(*roachpb.SpanStatsResponse) for spanStr, spanStats := range nodeResponse.SpanToStats { - _, exists := res.SpanToStats[spanStr] - if !exists { - res.SpanToStats[spanStr] = spanStats - } else { - res.SpanToStats[spanStr].Add(spanStats) + // We are not counting replicas, so only consider range count + // if it has not been set. + if res.SpanToStats[spanStr].RangeCount == 0 { + res.SpanToStats[spanStr].RangeCount = spanStats.RangeCount } + res.SpanToStats[spanStr].Add(spanStats) } } errorFn := func(nodeID roachpb.NodeID, err error) { log.Errorf(ctx, nodeErrorMsgPlaceholder, nodeID, err) - respErr = err + errorMessage := fmt.Sprintf("%v", err) + res.Errors = append(res.Errors, errorMessage) } + timeout := roachpb.SpanStatsNodeTimeout.Get(&s.st.SV) if err := s.statusServer.iterateNodes( ctx, "iterating nodes for span stats", + timeout, smartDial, nodeFn, responseFn, @@ -106,7 +130,7 @@ func (s *systemStatusServer) spanStatsFanOut( return nil, err } - return res, respErr + return res, nil } func (s *systemStatusServer) getLocalStats( diff --git a/pkg/server/span_stats_test.go b/pkg/server/span_stats_test.go index 33508a3c0871..b1b8d50f14f8 100644 --- a/pkg/server/span_stats_test.go +++ b/pkg/server/span_stats_test.go @@ -14,12 +14,14 @@ import ( "bytes" "context" "fmt" + "strings" "testing" "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/keys" "github.com/cockroachdb/cockroach/pkg/kv/kvserver" "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/server" "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" @@ -191,6 +193,135 @@ func TestSpanStatsFanOut(t *testing.T) { } +func TestSpanStatsFanOutFaultTolerance(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + skip.UnderStressWithIssue(t, 108534) + ctx := context.Background() + const numNodes = 5 + + type testCase struct { + name string + dialCallback func(nodeID roachpb.NodeID) error + nodeCallback func(ctx context.Context, nodeID roachpb.NodeID) error + assertions func(res *roachpb.SpanStatsResponse) + } + + containsError := func(errors []string, testString string) bool { + for _, e := range errors { + if strings.Contains(e, testString) { + return true + } + } + return false + } + + testCases := []testCase{ + { + // In a complete failure, no node is able to service requests successfully. + name: "complete-fanout-failure", + dialCallback: func(nodeID roachpb.NodeID) error { + // On the 1st and 2nd node, simulate a connection error. + if nodeID == 1 || nodeID == 2 { + return errors.Newf("error dialing node %d", nodeID) + } + return nil + }, + nodeCallback: func(ctx context.Context, nodeID roachpb.NodeID) error { + // On the 3rd node, simulate some sort of KV error. + if nodeID == 3 { + return errors.Newf("kv error on node %d", nodeID) + } + + // On the 4th and 5th node, simulate a request that takes a very long time. + // In this case, nodeFn will block until the context is cancelled + // i.e. if iterateNodes respects the timeout cluster setting. + if nodeID == 4 || nodeID == 5 { + <-ctx.Done() + // Return an error that mimics the error returned + // when a rpc's context is cancelled: + return errors.Newf("node %d timed out", nodeID) + } + return nil + }, + assertions: func(res *roachpb.SpanStatsResponse) { + // Expect to still be able to access SpanToStats for keys.EverythingSpan + // without panicking, even though there was a failure on every node. + require.Equal(t, int64(0), res.SpanToStats[keys.EverythingSpan.String()].TotalStats.LiveCount) + require.Equal(t, 5, len(res.Errors)) + + require.Equal(t, true, containsError(res.Errors, "error dialing node 1")) + require.Equal(t, true, containsError(res.Errors, "error dialing node 2")) + require.Equal(t, true, containsError(res.Errors, "kv error on node 3")) + require.Equal(t, true, containsError(res.Errors, "node 4 timed out")) + require.Equal(t, true, containsError(res.Errors, "node 5 timed out")) + }, + }, + { + // In a partial failure, nodes 1, 3, and 4 fail, and nodes 2 and 5 succeed. + name: "partial-fanout-failure", + dialCallback: func(nodeID roachpb.NodeID) error { + if nodeID == 1 { + return errors.Newf("error dialing node %d", nodeID) + } + return nil + }, + nodeCallback: func(ctx context.Context, nodeID roachpb.NodeID) error { + if nodeID == 3 { + return errors.Newf("kv error on node %d", nodeID) + } + + if nodeID == 4 { + <-ctx.Done() + // Return an error that mimics the error returned + // when a rpc's context is cancelled: + return errors.Newf("node %d timed out", nodeID) + } + return nil + }, + assertions: func(res *roachpb.SpanStatsResponse) { + require.Greater(t, res.SpanToStats[keys.EverythingSpan.String()].TotalStats.LiveCount, int64(0)) + // 3 nodes could not service their requests. + require.Equal(t, 3, len(res.Errors)) + + require.Equal(t, true, containsError(res.Errors, "error dialing node 1")) + require.Equal(t, true, containsError(res.Errors, "kv error on node 3")) + require.Equal(t, true, containsError(res.Errors, "node 4 timed out")) + + // There should not be any errors for node 2 or node 5. + require.Equal(t, false, containsError(res.Errors, "error dialing node 2")) + require.Equal(t, false, containsError(res.Errors, "node 5 timed out")) + }, + }, + } + + for _, tCase := range testCases { + tCase := tCase + t.Run(tCase.name, func(t *testing.T) { + serverArgs := base.TestServerArgs{} + serverArgs.Knobs.Server = &server.TestingKnobs{ + IterateNodesDialCallback: tCase.dialCallback, + IterateNodesNodeCallback: tCase.nodeCallback, + } + + tc := testcluster.StartTestCluster(t, numNodes, base.TestClusterArgs{ServerArgs: serverArgs}) + defer tc.Stopper().Stop(ctx) + + sqlDB := tc.Server(0).SQLConn(t, "defaultdb") + _, err := sqlDB.Exec("SET CLUSTER SETTING server.span_stats.node.timeout = '3s'") + require.NoError(t, err) + + res, err := tc.GetStatusClient(t, 0).SpanStats(ctx, &roachpb.SpanStatsRequest{ + NodeID: "0", // Indicates we want a fan-out. + Spans: []roachpb.Span{keys.EverythingSpan}, + }) + + require.NoError(t, err) + tCase.assertions(res) + }) + } +} + // BenchmarkSpanStats measures the cost of collecting span statistics. func BenchmarkSpanStats(b *testing.B) { skip.UnderShort(b) diff --git a/pkg/server/sql_stats.go b/pkg/server/sql_stats.go index ac8404b8eb02..304ea26d3c0c 100644 --- a/pkg/server/sql_stats.go +++ b/pkg/server/sql_stats.go @@ -79,6 +79,7 @@ func (s *statusServer) ResetSQLStats( var fanoutError error if err := s.iterateNodes(ctx, "reset SQL statistics", + noTimeout, dialFn, resetSQLStats, func(nodeID roachpb.NodeID, resp interface{}) { diff --git a/pkg/server/statements.go b/pkg/server/statements.go index 35c0d9b48206..dd07a4b3851f 100644 --- a/pkg/server/statements.go +++ b/pkg/server/statements.go @@ -85,6 +85,7 @@ func (s *statusServer) Statements( } if err := s.iterateNodes(ctx, "statement statistics", + noTimeout, dialFn, nodeStatement, func(nodeID roachpb.NodeID, resp interface{}) { diff --git a/pkg/server/status.go b/pkg/server/status.go index 184e166aca91..058170d3aa8a 100644 --- a/pkg/server/status.go +++ b/pkg/server/status.go @@ -500,6 +500,7 @@ type systemStatusServer struct { spanConfigReporter spanconfig.Reporter rangeStatsFetcher *rangestats.Fetcher node *Node + knobs *TestingKnobs } // StmtDiagnosticsRequester is the interface into *stmtdiagnostics.Registry @@ -618,6 +619,7 @@ func newSystemStatusServer( clock *hlc.Clock, rangeStatsFetcher *rangestats.Fetcher, node *Node, + knobs *TestingKnobs, ) *systemStatusServer { server := newStatusServer( ambient, @@ -645,6 +647,7 @@ func newSystemStatusServer( spanConfigReporter: spanConfigReporter, rangeStatsFetcher: rangeStatsFetcher, node: node, + knobs: knobs, } } @@ -1635,7 +1638,9 @@ func (s *statusServer) fetchProfileFromAllNodes( errorFn := func(nodeID roachpb.NodeID, err error) { response.profDataByNodeID[nodeID] = &profData{err: err} } - if err := s.iterateNodes(ctx, opName, dialFn, nodeFn, responseFn, errorFn); err != nil { + if err := s.iterateNodes( + ctx, opName, noTimeout, dialFn, nodeFn, responseFn, errorFn, + ); err != nil { return nil, srverrors.ServerError(ctx, err) } var data []byte @@ -2053,7 +2058,13 @@ func (s *systemStatusServer) NetworkConnectivity( response.ErrorsByNodeID[nodeID] = err.Error() } - if err := s.iterateNodes(ctx, "network connectivity", dialFn, nodeFn, responseFn, errorFn); err != nil { + if err := s.iterateNodes(ctx, "network connectivity", + noTimeout, + dialFn, + nodeFn, + responseFn, + errorFn, + ); err != nil { return nil, srverrors.ServerError(ctx, err) } @@ -2637,7 +2648,13 @@ func (s *systemStatusServer) HotRanges( } } - if err := s.iterateNodes(ctx, "hot ranges", dialFn, nodeFn, responseFn, errorFn); err != nil { + if err := s.iterateNodes(ctx, "hot ranges", + noTimeout, + dialFn, + nodeFn, + responseFn, + errorFn, + ); err != nil { return nil, srverrors.ServerError(ctx, err) } @@ -2990,7 +3007,9 @@ func (s *statusServer) Range( } if err := s.iterateNodes( - ctx, fmt.Sprintf("details about range %d", req.RangeId), dialFn, nodeFn, responseFn, errorFn, + ctx, fmt.Sprintf("details about range %d", req.RangeId), noTimeout, + dialFn, + nodeFn, responseFn, errorFn, ); err != nil { return nil, srverrors.ServerError(ctx, err) } @@ -3019,6 +3038,7 @@ func (s *statusServer) ListLocalSessions( func (s *statusServer) iterateNodes( ctx context.Context, errorCtx string, + nodeFnTimeout time.Duration, dialFn func(ctx context.Context, nodeID roachpb.NodeID) (interface{}, error), nodeFn func(ctx context.Context, client interface{}, nodeID roachpb.NodeID) (interface{}, error), responseFn func(nodeID roachpb.NodeID, resp interface{}), @@ -3053,7 +3073,18 @@ func (s *statusServer) iterateNodes( return } - res, err := nodeFn(ctx, client, nodeID) + var res interface{} + if nodeFnTimeout == noTimeout { + res, err = nodeFn(ctx, client, nodeID) + } else { + err = timeutil.RunWithTimeout(ctx, "iterate-nodes-fn", + nodeFnTimeout, func(ctx context.Context) error { + var _err error + res, _err = nodeFn(ctx, client, nodeID) + return _err + }) + } + if err != nil { err = errors.Wrapf(err, "error requesting %s from node %d (%s)", errorCtx, nodeID, nodeStatuses[serverID(nodeID)]) @@ -3116,7 +3147,9 @@ func (s *statusServer) paginatedIterateNodes( errorFn func(nodeID roachpb.NodeID, nodeFnError error), ) (next paginationState, err error) { if limit == 0 { - return paginationState{}, s.iterateNodes(ctx, errorCtx, dialFn, nodeFn, responseFn, errorFn) + return paginationState{}, s.iterateNodes(ctx, errorCtx, noTimeout, + dialFn, + nodeFn, responseFn, errorFn) } nodeStatuses, err := s.serverIterator.getAllNodes(ctx) if err != nil { @@ -3462,7 +3495,9 @@ func (s *statusServer) ListContentionEvents( response.Errors = append(response.Errors, errResponse) } - if err := s.iterateNodes(ctx, "contention events list", dialFn, nodeFn, responseFn, errorFn); err != nil { + if err := s.iterateNodes(ctx, "contention events list", noTimeout, + dialFn, nodeFn, + responseFn, errorFn); err != nil { return nil, srverrors.ServerError(ctx, err) } return &response, nil @@ -3509,7 +3544,9 @@ func (s *statusServer) ListDistSQLFlows( response.Errors = append(response.Errors, errResponse) } - if err := s.iterateNodes(ctx, "distsql flows list", dialFn, nodeFn, responseFn, errorFn); err != nil { + if err := s.iterateNodes(ctx, "distsql flows list", noTimeout, dialFn, + nodeFn, + responseFn, errorFn); err != nil { return nil, srverrors.ServerError(ctx, err) } return &response, nil @@ -3569,7 +3606,9 @@ func (s *statusServer) ListExecutionInsights( response.Errors = append(response.Errors, errors.EncodeError(ctx, err)) } - if err := s.iterateNodes(ctx, "execution insights list", dialFn, nodeFn, responseFn, errorFn); err != nil { + if err := s.iterateNodes(ctx, "execution insights list", noTimeout, + dialFn, nodeFn, + responseFn, errorFn); err != nil { return nil, srverrors.ServerError(ctx, err) } return &response, nil @@ -3929,6 +3968,7 @@ func (s *statusServer) TransactionContentionEvents( } if err := s.iterateNodes(ctx, "txn contention events for node", + noTimeout, dialFn, rpcCallFn, func(nodeID roachpb.NodeID, nodeResp interface{}) { diff --git a/pkg/server/testing_knobs.go b/pkg/server/testing_knobs.go index ccb3fbd9a935..b9de4039c15b 100644 --- a/pkg/server/testing_knobs.go +++ b/pkg/server/testing_knobs.go @@ -11,6 +11,7 @@ package server import ( + "context" "net" "time" @@ -147,6 +148,15 @@ type TestingKnobs struct { // system.tenants table. This is useful for tests that want to verify that // the tenant connector can't start when the record doesn't exist. ShutdownTenantConnectorEarlyIfNoRecordPresent bool + + // IterateNodesDialCallback is used to mock dial errors in a cluster + // fan-out. It is invoked by the dialFn argument of server.iterateNodes. + IterateNodesDialCallback func(nodeID roachpb.NodeID) error + + // IterateNodesNodeCallback is used to mock errors of the rpc invoked + // on a remote node in a cluster fan-out. It is invoked by the nodeFn argument + // of server.iterateNodes. + IterateNodesNodeCallback func(ctx context.Context, nodeID roachpb.NodeID) error } // ModuleTestingKnobs is part of the base.ModuleTestingKnobs interface.