diff --git a/pkg/kv/kvclient/kvcoord/txn_coord_sender_savepoints.go b/pkg/kv/kvclient/kvcoord/txn_coord_sender_savepoints.go index f3efde257738..fbc72f7bc8e4 100644 --- a/pkg/kv/kvclient/kvcoord/txn_coord_sender_savepoints.go +++ b/pkg/kv/kvclient/kvcoord/txn_coord_sender_savepoints.go @@ -43,6 +43,9 @@ type savepoint struct { seqNum enginepb.TxnSeq // txnSpanRefresher fields. + // TODO(mira): after we remove + // kv.transaction.keep_refresh_spans_on_savepoint_rollback.enabled, we won't + // need these two fields anymore. refreshSpans []roachpb.Span refreshInvalid bool } diff --git a/pkg/kv/kvclient/kvcoord/txn_interceptor_span_refresher.go b/pkg/kv/kvclient/kvcoord/txn_interceptor_span_refresher.go index 41c525312e69..ab54324cc52b 100644 --- a/pkg/kv/kvclient/kvcoord/txn_interceptor_span_refresher.go +++ b/pkg/kv/kvclient/kvcoord/txn_interceptor_span_refresher.go @@ -44,6 +44,22 @@ var MaxTxnRefreshSpansBytes = settings.RegisterIntSetting( 1<<22, /* 4 MB */ settings.WithPublic) +// keepRefreshSpansOnSavepointRollback is a boolean flag that, when enabled, +// ensures that all refresh spans accumulated since a savepoint was created are +// kept even after the savepoint is rolled back. This ensures that the reads +// corresponding to the refresh spans are serialized correctly. See #111228 for +// more details. +// The default value of this setting corresponds to the correct new behavior, +// which also matches the Postgres behavior. We don't expect this new behavior +// to impact customers because they should already be able to handle +// serialization errors; in case any unforeseen customer issues arise, the +// setting here allows us to revert to the old behavior. +var keepRefreshSpansOnSavepointRollback = settings.RegisterBoolSetting( + settings.TenantWritable, + "kv.transaction.keep_refresh_spans_on_savepoint_rollback.enabled", + "if enabled, all refresh spans accumulated since a savepoint was created are kept after the savepoint is rolled back", + true) + // txnSpanRefresher is a txnInterceptor that collects the read spans of a // serializable transaction in the event it gets a serializable retry error. It // can then use the set of read spans to avoid retrying the transaction if all @@ -790,6 +806,9 @@ func (sr *txnSpanRefresher) createSavepointLocked(ctx context.Context, s *savepo // TODO(nvanbenschoten): make sure this works correctly with ReadCommitted. // The refresh spans should either be empty when captured into a savepoint or // should be cleared when the savepoint is rolled back to. + // TODO(mira): after we remove + // kv.transaction.keep_refresh_spans_on_savepoint_rollback.enabled, we won't + // need to keep refresh spans in the savepoint anymore. s.refreshSpans = make([]roachpb.Span, len(sr.refreshFootprint.asSlice())) copy(s.refreshSpans, sr.refreshFootprint.asSlice()) s.refreshInvalid = sr.refreshInvalid @@ -797,9 +816,11 @@ func (sr *txnSpanRefresher) createSavepointLocked(ctx context.Context, s *savepo // rollbackToSavepointLocked is part of the txnInterceptor interface. func (sr *txnSpanRefresher) rollbackToSavepointLocked(ctx context.Context, s savepoint) { - sr.refreshFootprint.clear() - sr.refreshFootprint.insert(s.refreshSpans...) - sr.refreshInvalid = s.refreshInvalid + if !keepRefreshSpansOnSavepointRollback.Get(&sr.st.SV) { + sr.refreshFootprint.clear() + sr.refreshFootprint.insert(s.refreshSpans...) + sr.refreshInvalid = s.refreshInvalid + } } // closeLocked implements the txnInterceptor interface. diff --git a/pkg/kv/kvclient/kvcoord/txn_interceptor_span_refresher_test.go b/pkg/kv/kvclient/kvcoord/txn_interceptor_span_refresher_test.go index 849d39d15035..5e386982eeb0 100644 --- a/pkg/kv/kvclient/kvcoord/txn_interceptor_span_refresher_test.go +++ b/pkg/kv/kvclient/kvcoord/txn_interceptor_span_refresher_test.go @@ -16,11 +16,13 @@ import ( "strconv" "testing" + "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/kv/kvpb" "github.com/cockroachdb/cockroach/pkg/kv/kvserver/concurrency/isolation" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/util/hlc" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" @@ -1436,57 +1438,137 @@ func TestTxnSpanRefresherEpochIncrement(t *testing.T) { func TestTxnSpanRefresherSavepoint(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) - ctx := context.Background() - tsr, mockSender := makeMockTxnSpanRefresher() - keyA, keyB := roachpb.Key("a"), roachpb.Key("b") - txn := makeTxnProto() + testutils.RunTrueAndFalse(t, "keep-refresh-spans", func(t *testing.T, keepRefreshSpans bool) { + ctx := context.Background() + tsr, mockSender := makeMockTxnSpanRefresher() - read := func(key roachpb.Key) { - ba := &kvpb.BatchRequest{} - ba.Header = kvpb.Header{Txn: &txn} - getArgs := kvpb.GetRequest{RequestHeader: kvpb.RequestHeader{Key: key}} - ba.Add(&getArgs) - mockSender.MockSend(func(ba *kvpb.BatchRequest) (*kvpb.BatchResponse, *kvpb.Error) { - require.Len(t, ba.Requests, 1) - require.IsType(t, &kvpb.GetRequest{}, ba.Requests[0].GetInner()) - - br := ba.CreateReply() - br.Txn = ba.Txn - return br, nil - }) - br, pErr := tsr.SendLocked(ctx, ba) - require.Nil(t, pErr) - require.NotNil(t, br) - } - read(keyA) - require.Equal(t, []roachpb.Span{{Key: keyA}}, tsr.refreshFootprint.asSlice()) + if keepRefreshSpans { + keepRefreshSpansOnSavepointRollback.Override(ctx, &tsr.st.SV, true) + } else { + keepRefreshSpansOnSavepointRollback.Override(ctx, &tsr.st.SV, false) + } + + keyA, keyB := roachpb.Key("a"), roachpb.Key("b") + txn := makeTxnProto() + + read := func(key roachpb.Key) { + ba := &kvpb.BatchRequest{} + ba.Header = kvpb.Header{Txn: &txn} + getArgs := kvpb.GetRequest{RequestHeader: kvpb.RequestHeader{Key: key}} + ba.Add(&getArgs) + mockSender.MockSend(func(ba *kvpb.BatchRequest) (*kvpb.BatchResponse, *kvpb.Error) { + require.Len(t, ba.Requests, 1) + require.IsType(t, &kvpb.GetRequest{}, ba.Requests[0].GetInner()) - s := savepoint{} - tsr.createSavepointLocked(ctx, &s) + br := ba.CreateReply() + br.Txn = ba.Txn + return br, nil + }) + br, pErr := tsr.SendLocked(ctx, ba) + require.Nil(t, pErr) + require.NotNil(t, br) + } + read(keyA) + require.Equal(t, []roachpb.Span{{Key: keyA}}, tsr.refreshFootprint.asSlice()) - // Another read after the savepoint was created. - read(keyB) - require.Equal(t, []roachpb.Span{{Key: keyA}, {Key: keyB}}, tsr.refreshFootprint.asSlice()) + s := savepoint{} + tsr.createSavepointLocked(ctx, &s) - require.Equal(t, []roachpb.Span{{Key: keyA}}, s.refreshSpans) - require.False(t, s.refreshInvalid) + // Another read after the savepoint was created. + read(keyB) + require.Equal(t, []roachpb.Span{{Key: keyA}, {Key: keyB}}, tsr.refreshFootprint.asSlice()) - // Rollback the savepoint and check that refresh spans were overwritten. - tsr.rollbackToSavepointLocked(ctx, s) - require.Equal(t, []roachpb.Span{{Key: keyA}}, tsr.refreshFootprint.asSlice()) + require.Equal(t, []roachpb.Span{{Key: keyA}}, s.refreshSpans) + require.False(t, s.refreshInvalid) - // Check that rolling back to the savepoint resets refreshInvalid. - tsr.refreshInvalid = true - tsr.rollbackToSavepointLocked(ctx, s) - require.False(t, tsr.refreshInvalid) + // Rollback the savepoint. + tsr.rollbackToSavepointLocked(ctx, s) + if keepRefreshSpans { + // Check that refresh spans were kept as such. + require.Equal(t, []roachpb.Span{{Key: keyA}, {Key: keyB}}, tsr.refreshFootprint.asSlice()) + } else { + // Check that refresh spans were overwritten. + require.Equal(t, []roachpb.Span{{Key: keyA}}, tsr.refreshFootprint.asSlice()) + } - // Set refreshInvalid and then create a savepoint. - tsr.refreshInvalid = true - s = savepoint{} - tsr.createSavepointLocked(ctx, &s) - require.True(t, s.refreshInvalid) - // Rollback to the savepoint check that refreshes are still invalid. - tsr.rollbackToSavepointLocked(ctx, s) - require.True(t, tsr.refreshInvalid) + tsr.refreshInvalid = true + tsr.rollbackToSavepointLocked(ctx, s) + if keepRefreshSpans { + // Check that rolling back to the savepoint keeps refreshInvalid as such. + require.True(t, tsr.refreshInvalid) + } else { + // Check that rolling back to the savepoint resets refreshInvalid. + require.False(t, tsr.refreshInvalid) + } + + // Set refreshInvalid and then create a savepoint. + tsr.refreshInvalid = true + s = savepoint{} + tsr.createSavepointLocked(ctx, &s) + require.True(t, s.refreshInvalid) + // Rollback to the savepoint check that refreshes are still invalid. + tsr.rollbackToSavepointLocked(ctx, s) + require.True(t, tsr.refreshInvalid) + }) +} + +// TestRefreshWithSavepoint is an integration test that ensures the correct +// behavior of refreshes under savepoint rollback. The test sets up a write-skew +// example where txn1 reads keyA and writes to keyB, while concurrently txn2 +// reads keyB and writes to keyA. The two txns can't be serialized so one is +// expected to get a serialization error upon commit. +// +// However, with the old behavior of discarding refresh spans upon savepoint +// rollback, the read corresponding to the discarded refresh span is not +// refreshed, so the conflict goes unnoticed and both txns commit successfully. +// See #111228 for more details. +func TestRefreshWithSavepoint(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + testutils.RunTrueAndFalse(t, "keep-refresh-spans", func(t *testing.T, keepRefreshSpans bool) { + s, _, kvDB := serverutils.StartServer(t, base.TestServerArgs{}) + ctx := context.Background() + defer s.Stopper().Stop(context.Background()) + + if keepRefreshSpans { + keepRefreshSpansOnSavepointRollback.Override(ctx, &s.ClusterSettings().SV, true) + } else { + keepRefreshSpansOnSavepointRollback.Override(ctx, &s.ClusterSettings().SV, false) + } + + keyA := roachpb.Key("a") + keyB := roachpb.Key("b") + txn1 := kvDB.NewTxn(ctx, "txn1") + txn2 := kvDB.NewTxn(ctx, "txn2") + + spt1, err := txn1.CreateSavepoint(ctx) + require.NoError(t, err) + + _, err = txn1.Get(ctx, keyA) + require.NoError(t, err) + + err = txn1.RollbackToSavepoint(ctx, spt1) + require.NoError(t, err) + + _, err = txn2.Get(ctx, keyB) + require.NoError(t, err) + + err = txn1.Put(ctx, keyB, "bb") + require.NoError(t, err) + + err = txn2.Put(ctx, keyA, "aa") + require.NoError(t, err) + + err = txn1.Commit(ctx) + if keepRefreshSpans { + require.Regexp(t, ".*RETRY_SERIALIZABLE - failed preemptive refresh due to conflicting locks on \"a\"*", err) + } else { + require.NoError(t, err) + } + + err = txn2.Commit(ctx) + require.NoError(t, err) + }) }