diff --git a/pkg/sql/stmtdiagnostics/statement_diagnostics.go b/pkg/sql/stmtdiagnostics/statement_diagnostics.go index 47adb48dca73..d4a652bc216b 100644 --- a/pkg/sql/stmtdiagnostics/statement_diagnostics.go +++ b/pkg/sql/stmtdiagnostics/statement_diagnostics.go @@ -36,7 +36,9 @@ var pollingInterval = settings.RegisterDurationSetting( settings.TenantReadOnly, "sql.stmt_diagnostics.poll_interval", "rate at which the stmtdiagnostics.Registry polls for requests, set to zero to disable", - 10*time.Second) + 10*time.Second, + settings.NonNegativeDuration, +) var bundleChunkSize = settings.RegisterByteSizeSetting( settings.TenantWritable, @@ -158,19 +160,27 @@ func (r *Registry) Start(ctx context.Context, stopper *stop.Stopper) { func (r *Registry) poll(ctx context.Context) { var ( - timer timeutil.Timer + timer timeutil.Timer + // We need to store timer.C reference separately because timer.Stop() + // (called when polling is disabled) puts timer into the pool and + // prohibits further usage of stored timer.C. + timerC = timer.C lastPoll time.Time deadline time.Time pollIntervalChanged = make(chan struct{}, 1) maybeResetTimer = func() { - if interval := pollingInterval.Get(&r.st.SV); interval <= 0 { - // Setting the interval to a non-positive value stops the polling. + if interval := pollingInterval.Get(&r.st.SV); interval == 0 { + // Setting the interval to zero stops the polling. timer.Stop() + // nil out the channel so that it'd block forever in the loop + // below (until the polling interval is changed). + timerC = nil } else { newDeadline := lastPoll.Add(interval) if deadline.IsZero() || !deadline.Equal(newDeadline) { deadline = newDeadline timer.Reset(timeutil.Until(deadline)) + timerC = timer.C } } } @@ -195,7 +205,7 @@ func (r *Registry) poll(ctx context.Context) { select { case <-pollIntervalChanged: continue // go back around and maybe reset the timer - case <-timer.C: + case <-timerC: timer.Read = true case <-ctx.Done(): return diff --git a/pkg/sql/stmtdiagnostics/statement_diagnostics_test.go b/pkg/sql/stmtdiagnostics/statement_diagnostics_test.go index c345bd6f2c08..4acd4b1bfe88 100644 --- a/pkg/sql/stmtdiagnostics/statement_diagnostics_test.go +++ b/pkg/sql/stmtdiagnostics/statement_diagnostics_test.go @@ -50,6 +50,11 @@ func TestDiagnosticsRequest(t *testing.T) { _, err := db.Exec("CREATE TABLE test (x int PRIMARY KEY)") require.NoError(t, err) + // Disable polling interval since we're inserting requests directly into the + // registry manually and want precise control of updating the registry. + _, err = db.Exec("SET CLUSTER SETTING sql.stmt_diagnostics.poll_interval = '0';") + require.NoError(t, err) + var collectUntilExpirationEnabled bool isCompleted := func(reqID int64) (completed bool, diagnosticsID gosql.NullInt64) { completedQuery := "SELECT completed, statement_diagnostics_id FROM system.statement_diagnostics_requests WHERE ID = $1" @@ -76,28 +81,12 @@ func TestDiagnosticsRequest(t *testing.T) { require.True(t, completed) require.True(t, diagnosticsID.Valid) } - // checkMaybeCompleted returns an error if 'completed' value for the given - // request is different from expectedCompleted. - checkMaybeCompleted := func(reqID int64, expectedCompleted bool) error { - completed, diagnosticsID := isCompleted(reqID) - if completed != expectedCompleted { - return errors.Newf("expected completed to be %t, but found %t", expectedCompleted, completed) - } - // diagnosticsID is NULL when the request hasn't been completed yet. - require.True(t, diagnosticsID.Valid == expectedCompleted) - return nil - } setCollectUntilExpiration := func(v bool) { collectUntilExpirationEnabled = v _, err := db.Exec( fmt.Sprintf("SET CLUSTER SETTING sql.stmt_diagnostics.collect_continuously.enabled = %t", v)) require.NoError(t, err) } - setPollInterval := func(d time.Duration) { - _, err := db.Exec( - fmt.Sprintf("SET CLUSTER SETTING sql.stmt_diagnostics.poll_interval = '%s'", d)) - require.NoError(t, err) - } var minExecutionLatency, expiresAfter time.Duration var samplingProbability float64 @@ -259,26 +248,21 @@ func TestDiagnosticsRequest(t *testing.T) { } for _, expiresAfter := range []time.Duration{0, time.Second} { t.Run(fmt.Sprintf("expiresAfter=%s", expiresAfter), func(t *testing.T) { - // TODO(yuzefovich): for some reason occasionally the - // bundle for the request is collected, so we use - // SucceedsSoon. Figure it out. - testutils.SucceedsSoon(t, func() error { - reqID, err := registry.InsertRequestInternal( - ctx, fprint, samplingProbability, minExecutionLatency, expiresAfter, - ) - require.NoError(t, err) - checkNotCompleted(reqID) - - err = registry.CancelRequest(ctx, reqID) - require.NoError(t, err) - checkNotCompleted(reqID) - - // Run the query that is slow enough to satisfy the - // conditional request. - _, err = db.Exec("SELECT pg_sleep(0.2)") - require.NoError(t, err) - return checkMaybeCompleted(reqID, false /* expectedCompleted */) - }) + reqID, err := registry.InsertRequestInternal( + ctx, fprint, samplingProbability, minExecutionLatency, expiresAfter, + ) + require.NoError(t, err) + checkNotCompleted(reqID) + + err = registry.CancelRequest(ctx, reqID) + require.NoError(t, err) + checkNotCompleted(reqID) + + // Run the query that is slow enough to satisfy the + // conditional request. + _, err = db.Exec("SELECT pg_sleep(0.2)") + require.NoError(t, err) + checkNotCompleted(reqID) }) } }) @@ -324,7 +308,13 @@ func TestDiagnosticsRequest(t *testing.T) { require.NoError(t, err) wg.Wait() - return checkMaybeCompleted(reqID, true /* expectedCompleted */) + + completed, diagnosticsID := isCompleted(reqID) + if !completed { + return errors.New("expected request to be completed") + } + require.True(t, diagnosticsID.Valid) + return nil }) }) } @@ -451,21 +441,16 @@ func TestDiagnosticsRequest(t *testing.T) { // Sleep until expiration (and then some), and then run the query. time.Sleep(expiresAfter + 100*time.Millisecond) - setPollInterval(10 * time.Millisecond) - defer setPollInterval(stmtdiagnostics.PollingInterval.Default()) - - // We should not find the request and a subsequent executions should not - // capture anything. - testutils.SucceedsSoon(t, func() error { - if found := registry.TestingFindRequest(reqID); found { - return errors.New("expected expired request to no longer be tracked") - } - return nil - }) - + // Even though the request has expired, it hasn't been removed from the + // registry yet (because we disabled the polling interval). When we run + // the query that matches the fingerprint, the expired request is + // removed, and the bundle is not collected. _, err = db.Exec("SELECT pg_sleep(0.01)") // run the query require.NoError(t, err) checkNotCompleted(reqID) + + // Sanity check that the request is no longer in the registry. + require.False(t, registry.TestingFindRequest(reqID)) }) }