Skip to content

Commit

Permalink
Merge branch 'master' into prepare_priviledge_check
Browse files Browse the repository at this point in the history
  • Loading branch information
lcwangchao authored Aug 10, 2022
2 parents bb1865c + b150be8 commit c64c554
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 18 deletions.
7 changes: 4 additions & 3 deletions sessiontxn/internal/txn.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/table/temptable"
"github.com/pingcap/tidb/util/logutil"
"go.uber.org/zap"
)
Expand Down Expand Up @@ -62,9 +61,11 @@ func CommitBeforeEnterNewTxn(ctx context.Context, sctx sessionctx.Context) error
}

// GetSnapshotWithTS returns a snapshot with ts.
func GetSnapshotWithTS(s sessionctx.Context, ts uint64) kv.Snapshot {
func GetSnapshotWithTS(s sessionctx.Context, ts uint64, interceptor kv.SnapshotInterceptor) kv.Snapshot {
snap := s.GetStore().GetSnapshot(kv.Version{Ver: ts})
snap.SetOption(kv.SnapInterceptor, temptable.SessionSnapshotInterceptor(s))
if interceptor != nil {
snap.SetOption(kv.SnapInterceptor, interceptor)
}
if s.GetSessionVars().InRestrictedSQL {
snap.SetOption(kv.RequestSourceInternal, true)
}
Expand Down
8 changes: 6 additions & 2 deletions sessiontxn/isolation/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ func (p *baseTxnContextProvider) ActivateTxn() (kv.Transaction, error) {
if readReplicaType.IsFollowerRead() {
txn.SetOption(kv.ReplicaRead, readReplicaType)
}
txn.SetOption(kv.SnapInterceptor, temptable.SessionSnapshotInterceptor(p.sctx))
txn.SetOption(kv.SnapInterceptor, temptable.SessionSnapshotInterceptor(p.sctx, p.infoSchema))

if sessVars.StmtCtx.WeakConsistency {
txn.SetOption(kv.IsolationLevel, kv.RC)
Expand Down Expand Up @@ -393,7 +393,11 @@ func (p *baseTxnContextProvider) getSnapshotByTS(snapshotTS uint64) (kv.Snapshot
}

sessVars := p.sctx.GetSessionVars()
snapshot := internal.GetSnapshotWithTS(p.sctx, snapshotTS)
snapshot := internal.GetSnapshotWithTS(
p.sctx,
snapshotTS,
temptable.SessionSnapshotInterceptor(p.sctx, p.infoSchema),
)

replicaReadType := sessVars.GetReplicaRead()
if replicaReadType.IsFollowerRead() && !sessVars.StmtCtx.RCCheckTS {
Expand Down
8 changes: 6 additions & 2 deletions sessiontxn/staleread/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func (p *StalenessTxnContextProvider) activateStaleTxn() error {
TxnScope: txnScope,
},
}
txn.SetOption(kv.SnapInterceptor, temptable.SessionSnapshotInterceptor(p.sctx))
txn.SetOption(kv.SnapInterceptor, temptable.SessionSnapshotInterceptor(p.sctx, is))

p.is = is
err = p.sctx.GetSessionVars().SetSystemVar(variable.TiDBSnapshot, "")
Expand Down Expand Up @@ -209,7 +209,11 @@ func (p *StalenessTxnContextProvider) GetSnapshotWithStmtReadTS() (kv.Snapshot,
}

sessVars := p.sctx.GetSessionVars()
snapshot := internal.GetSnapshotWithTS(p.sctx, p.ts)
snapshot := internal.GetSnapshotWithTS(
p.sctx,
p.ts,
temptable.SessionSnapshotInterceptor(p.sctx, p.is),
)

replicaReadType := sessVars.GetReplicaRead()
if replicaReadType.IsFollowerRead() {
Expand Down
19 changes: 10 additions & 9 deletions sessiontxn/txn_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/pingcap/tidb/sessiontxn"
"github.com/pingcap/tidb/sessiontxn/internal"
"github.com/pingcap/tidb/sessiontxn/staleread"
"github.com/pingcap/tidb/table/temptable"
"github.com/pingcap/tidb/tablecodec"
"github.com/pingcap/tidb/testkit"
"github.com/pingcap/tidb/tests/realtikvtest"
Expand Down Expand Up @@ -308,7 +309,7 @@ func TestGetSnapshot(t *testing.T) {
check: func(t *testing.T, sctx sessionctx.Context) {
ts, err := mgr.GetStmtReadTS()
require.NoError(t, err)
compareSnap := internal.GetSnapshotWithTS(sctx, ts)
compareSnap := internal.GetSnapshotWithTS(sctx, ts, nil)
snap, err := mgr.GetSnapshotWithStmtReadTS()
require.NoError(t, err)
require.True(t, isSnapshotEqual(t, compareSnap, snap))
Expand All @@ -318,7 +319,7 @@ func TestGetSnapshot(t *testing.T) {
tk.MustQuery("select * from t for update").Check(testkit.Rows("1", "3", "10"))
ts, err = mgr.GetStmtForUpdateTS()
require.NoError(t, err)
compareSnap2 := internal.GetSnapshotWithTS(sctx, ts)
compareSnap2 := internal.GetSnapshotWithTS(sctx, ts, nil)
snap, err = mgr.GetSnapshotWithStmtReadTS()
require.NoError(t, err)
require.False(t, isSnapshotEqual(t, compareSnap2, snap))
Expand All @@ -338,7 +339,7 @@ func TestGetSnapshot(t *testing.T) {
check: func(t *testing.T, sctx sessionctx.Context) {
ts, err := mgr.GetStmtReadTS()
require.NoError(t, err)
compareSnap := internal.GetSnapshotWithTS(sctx, ts)
compareSnap := internal.GetSnapshotWithTS(sctx, ts, nil)
snap, err := mgr.GetSnapshotWithStmtReadTS()
require.NoError(t, err)
require.True(t, isSnapshotEqual(t, compareSnap, snap))
Expand All @@ -348,7 +349,7 @@ func TestGetSnapshot(t *testing.T) {
tk.MustQuery("select * from t").Check(testkit.Rows("1", "3", "10"))
ts, err = mgr.GetStmtForUpdateTS()
require.NoError(t, err)
compareSnap2 := internal.GetSnapshotWithTS(sctx, ts)
compareSnap2 := internal.GetSnapshotWithTS(sctx, ts, nil)
snap, err = mgr.GetSnapshotWithStmtReadTS()
require.NoError(t, err)
require.True(t, isSnapshotEqual(t, compareSnap2, snap))
Expand All @@ -367,7 +368,7 @@ func TestGetSnapshot(t *testing.T) {
check: func(t *testing.T, sctx sessionctx.Context) {
ts, err := mgr.GetStmtReadTS()
require.NoError(t, err)
compareSnap := internal.GetSnapshotWithTS(sctx, ts)
compareSnap := internal.GetSnapshotWithTS(sctx, ts, nil)
snap, err := mgr.GetSnapshotWithStmtReadTS()
require.NoError(t, err)
require.True(t, isSnapshotEqual(t, compareSnap, snap))
Expand All @@ -377,7 +378,7 @@ func TestGetSnapshot(t *testing.T) {
tk.MustQuery("select * from t for update").Check(testkit.Rows("1", "3"))
ts, err = mgr.GetStmtForUpdateTS()
require.NoError(t, err)
compareSnap2 := internal.GetSnapshotWithTS(sctx, ts)
compareSnap2 := internal.GetSnapshotWithTS(sctx, ts, nil)
snap, err = mgr.GetSnapshotWithStmtReadTS()
require.NoError(t, err)
require.True(t, isSnapshotEqual(t, compareSnap2, snap))
Expand All @@ -398,7 +399,7 @@ func TestGetSnapshot(t *testing.T) {
check: func(t *testing.T, sctx sessionctx.Context) {
ts, err := mgr.GetStmtReadTS()
require.NoError(t, err)
compareSnap := internal.GetSnapshotWithTS(sctx, ts)
compareSnap := internal.GetSnapshotWithTS(sctx, ts, nil)
snap, err := mgr.GetSnapshotWithStmtReadTS()
require.NoError(t, err)
require.True(t, isSnapshotEqual(t, compareSnap, snap))
Expand All @@ -408,7 +409,7 @@ func TestGetSnapshot(t *testing.T) {
tk.MustQuery("select * from t for update").Check(testkit.Rows("1", "3"))
ts, err = mgr.GetStmtForUpdateTS()
require.NoError(t, err)
compareSnap2 := internal.GetSnapshotWithTS(sctx, ts)
compareSnap2 := internal.GetSnapshotWithTS(sctx, ts, nil)
snap, err = mgr.GetSnapshotWithStmtReadTS()
require.NoError(t, err)
require.True(t, isSnapshotEqual(t, compareSnap2, snap))
Expand Down Expand Up @@ -494,7 +495,7 @@ func TestSnapshotInterceptor(t *testing.T) {
}

// Also check GetSnapshotWithTS
snap := internal.GetSnapshotWithTS(tk.Session(), 0)
snap := internal.GetSnapshotWithTS(tk.Session(), 0, temptable.SessionSnapshotInterceptor(tk.Session(), sessiontxn.GetTxnManager(tk.Session()).GetTxnInfoSchema()))
val, err := snap.Get(context.Background(), k)
require.NoError(t, err)
require.Equal(t, []byte("v1"), val)
Expand Down
4 changes: 2 additions & 2 deletions table/temptable/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ type TemporaryTableSnapshotInterceptor struct {
}

// SessionSnapshotInterceptor creates a new snapshot interceptor for temporary table data fetch
func SessionSnapshotInterceptor(sctx sessionctx.Context) kv.SnapshotInterceptor {
func SessionSnapshotInterceptor(sctx sessionctx.Context, is infoschema.InfoSchema) kv.SnapshotInterceptor {
return NewTemporaryTableSnapshotInterceptor(
sctx.GetInfoSchema().(infoschema.InfoSchema),
is,
getSessionData(sctx),
)
}
Expand Down

0 comments on commit c64c554

Please sign in to comment.