Skip to content

Commit

Permalink
sqlliveness/slstorage: fix bug due to not using a transaction
Browse files Browse the repository at this point in the history
We had a bug where updating a session was not using the transaction. This
exposed it to a problem whereby a concurrent removal of the session would
not be detected and the session could be resurrected.

Fortunately this code moved to using KV from SQL in the 21.2 cycle and
thus no released major release should experience this issue.

Release note: None
  • Loading branch information
ajwerner committed Oct 1, 2021
1 parent 177ed2b commit 4ee0f2b
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pkg/sql/sqlliveness/slstorage/slstorage.go
Original file line number Diff line number Diff line change
Expand Up @@ -402,15 +402,15 @@ func (s *Storage) Update(
) (sessionExists bool, err error) {
err = s.db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error {
k := s.makeSessionKey(sid)
kv, err := s.db.Get(ctx, k)
kv, err := txn.Get(ctx, k)
if err != nil {
return err
}
if sessionExists = kv.Value != nil; !sessionExists {
return nil
}
v := encodeValue(expiration)
return s.db.Put(ctx, k, &v)
return txn.Put(ctx, k, &v)
})
if err != nil || !sessionExists {
s.metrics.WriteFailures.Inc(1)
Expand Down
89 changes: 89 additions & 0 deletions pkg/sql/sqlliveness/slstorage/slstorage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,95 @@ func TestConcurrentAccessSynchronization(t *testing.T) {
})
}

// TestDeleteMidUpdateFails ensures that a session removed while it attempts to
// update itself fails.
func TestDeleteMidUpdateFails(t *testing.T) {
defer leaktest.AfterTest(t)()
defer log.Scope(t).Close(t)

ctx := context.Background()
type filterFunc = func(context.Context, roachpb.BatchRequest, *roachpb.BatchResponse) *roachpb.Error
var respFilter atomic.Value
respFilter.Store(filterFunc(nil))
s, sqlDB, kvDB := serverutils.StartServer(t, base.TestServerArgs{
Knobs: base.TestingKnobs{
Store: &kvserver.StoreTestingKnobs{
TestingResponseFilter: func(
ctx context.Context, request roachpb.BatchRequest, resp *roachpb.BatchResponse,
) *roachpb.Error {
if f := respFilter.Load().(filterFunc); f != nil {
return f(ctx, request, resp)
}
return nil
},
},
},
})
defer s.Stopper().Stop(ctx)
tdb := sqlutils.MakeSQLRunner(sqlDB)

// Set up a fake storage implementation using a separate table.
dbName := t.Name()
tdb.Exec(t, `CREATE DATABASE "`+dbName+`"`)
schema := strings.Replace(systemschema.SqllivenessTableSchema,
`CREATE TABLE system.sqlliveness`,
`CREATE TABLE "`+dbName+`".sqlliveness`, 1)
tdb.Exec(t, schema)
tableID := getTableID(t, tdb, dbName, "sqlliveness")

storage := slstorage.NewTestingStorage(
s.Stopper(), s.Clock(), kvDB, keys.SystemSQLCodec, s.ClusterSettings(),
tableID, timeutil.DefaultTimeSource{}.NewTimer,
)

// Insert a session.
ID := sqlliveness.SessionID("foo")
require.NoError(t, storage.Insert(ctx, ID, s.Clock().Now()))

// Install a filter which will send on this channel when we attempt
// to perform an update after the get has evaluated.
getChan := make(chan chan struct{})
respFilter.Store(func(
ctx context.Context, request roachpb.BatchRequest, _ *roachpb.BatchResponse,
) *roachpb.Error {
if get, ok := request.GetArg(roachpb.Get); !ok || !bytes.HasPrefix(
get.(*roachpb.GetRequest).Key,
keys.SystemSQLCodec.TablePrefix(uint32(tableID)),
) {
return nil
}
respFilter.Store(filterFunc(nil))
unblock := make(chan struct{})
getChan <- unblock
<-unblock
return nil
})

// Launch the update.
type result struct {
exists bool
err error
}
resCh := make(chan result)
go func() {
var res result
res.exists, res.err = storage.Update(ctx, ID, s.Clock().Now())
resCh <- res
}()

// Wait for the update to block.
unblock := <-getChan

// Delete the session being updated.
tdb.Exec(t, `DELETE FROM "`+dbName+`".sqlliveness WHERE true`)

// Unblock the update and ensure that it saw that its session was deleted.
close(unblock)
res := <-resCh
require.False(t, res.exists)
require.NoError(t, res.err)
}

func getTableID(
t *testing.T, db *sqlutils.SQLRunner, dbName, tableName string,
) (tableID descpb.ID) {
Expand Down

0 comments on commit 4ee0f2b

Please sign in to comment.