diff --git a/pkg/sql/sqlliveness/slstorage/slstorage.go b/pkg/sql/sqlliveness/slstorage/slstorage.go index 86f004843759..68cb548da36d 100644 --- a/pkg/sql/sqlliveness/slstorage/slstorage.go +++ b/pkg/sql/sqlliveness/slstorage/slstorage.go @@ -402,7 +402,7 @@ func (s *Storage) Update( ) (sessionExists bool, err error) { err = s.db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error { k := s.makeSessionKey(sid) - kv, err := s.db.Get(ctx, k) + kv, err := txn.Get(ctx, k) if err != nil { return err } @@ -410,7 +410,7 @@ func (s *Storage) Update( return nil } v := encodeValue(expiration) - return s.db.Put(ctx, k, &v) + return txn.Put(ctx, k, &v) }) if err != nil || !sessionExists { s.metrics.WriteFailures.Inc(1) diff --git a/pkg/sql/sqlliveness/slstorage/slstorage_test.go b/pkg/sql/sqlliveness/slstorage/slstorage_test.go index 7568eb068761..4930fc5f2472 100644 --- a/pkg/sql/sqlliveness/slstorage/slstorage_test.go +++ b/pkg/sql/sqlliveness/slstorage/slstorage_test.go @@ -649,6 +649,95 @@ func TestConcurrentAccessSynchronization(t *testing.T) { }) } +// TestDeleteMidUpdateFails ensures that a session removed while it attempts to +// update itself fails. +func TestDeleteMidUpdateFails(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + type filterFunc = func(context.Context, roachpb.BatchRequest, *roachpb.BatchResponse) *roachpb.Error + var respFilter atomic.Value + respFilter.Store(filterFunc(nil)) + s, sqlDB, kvDB := serverutils.StartServer(t, base.TestServerArgs{ + Knobs: base.TestingKnobs{ + Store: &kvserver.StoreTestingKnobs{ + TestingResponseFilter: func( + ctx context.Context, request roachpb.BatchRequest, resp *roachpb.BatchResponse, + ) *roachpb.Error { + if f := respFilter.Load().(filterFunc); f != nil { + return f(ctx, request, resp) + } + return nil + }, + }, + }, + }) + defer s.Stopper().Stop(ctx) + tdb := sqlutils.MakeSQLRunner(sqlDB) + + // Set up a fake storage implementation using a separate table. + dbName := t.Name() + tdb.Exec(t, `CREATE DATABASE "`+dbName+`"`) + schema := strings.Replace(systemschema.SqllivenessTableSchema, + `CREATE TABLE system.sqlliveness`, + `CREATE TABLE "`+dbName+`".sqlliveness`, 1) + tdb.Exec(t, schema) + tableID := getTableID(t, tdb, dbName, "sqlliveness") + + storage := slstorage.NewTestingStorage( + s.Stopper(), s.Clock(), kvDB, keys.SystemSQLCodec, s.ClusterSettings(), + tableID, timeutil.DefaultTimeSource{}.NewTimer, + ) + + // Insert a session. + ID := sqlliveness.SessionID("foo") + require.NoError(t, storage.Insert(ctx, ID, s.Clock().Now())) + + // Install a filter which will send on this channel when we attempt + // to perform an update after the get has evaluated. + getChan := make(chan chan struct{}) + respFilter.Store(func( + ctx context.Context, request roachpb.BatchRequest, _ *roachpb.BatchResponse, + ) *roachpb.Error { + if get, ok := request.GetArg(roachpb.Get); !ok || !bytes.HasPrefix( + get.(*roachpb.GetRequest).Key, + keys.SystemSQLCodec.TablePrefix(uint32(tableID)), + ) { + return nil + } + respFilter.Store(filterFunc(nil)) + unblock := make(chan struct{}) + getChan <- unblock + <-unblock + return nil + }) + + // Launch the update. + type result struct { + exists bool + err error + } + resCh := make(chan result) + go func() { + var res result + res.exists, res.err = storage.Update(ctx, ID, s.Clock().Now()) + resCh <- res + }() + + // Wait for the update to block. + unblock := <-getChan + + // Delete the session being updated. + tdb.Exec(t, `DELETE FROM "`+dbName+`".sqlliveness WHERE true`) + + // Unblock the update and ensure that it saw that its session was deleted. + close(unblock) + res := <-resCh + require.False(t, res.exists) + require.NoError(t, res.err) +} + func getTableID( t *testing.T, db *sqlutils.SQLRunner, dbName, tableName string, ) (tableID descpb.ID) {