From af259faacfa25d7516920ac4bd976eaeeaabbd82 Mon Sep 17 00:00:00 2001 From: Chengpeng Yan <41809508+Reminiscent@users.noreply.github.com> Date: Thu, 16 Dec 2021 10:18:35 +0800 Subject: [PATCH] planner: support the plan cache aware of bindings (#30169) --- bindinfo/bind_serial_test.go | 312 ++++++++++++++++++++++++++++++++++ executor/explainfor_test.go | 14 +- executor/prepared.go | 3 +- planner/core/cache.go | 5 +- planner/core/cache_test.go | 2 +- planner/core/common_plans.go | 17 +- planner/core/optimizer.go | 3 + planner/optimize.go | 54 ++++-- server/driver_tidb.go | 5 +- session/session.go | 3 +- sessionctx/stmtctx/stmtctx.go | 3 + testkit/testkit.go | 24 +++ 12 files changed, 415 insertions(+), 30 deletions(-) diff --git a/bindinfo/bind_serial_test.go b/bindinfo/bind_serial_test.go index 7cf49c087278f..53b674a8ec715 100644 --- a/bindinfo/bind_serial_test.go +++ b/bindinfo/bind_serial_test.go @@ -16,7 +16,9 @@ package bindinfo_test import ( "context" + "crypto/tls" "fmt" + "strconv" "testing" "github.com/pingcap/tidb/bindinfo" @@ -26,10 +28,320 @@ import ( "github.com/pingcap/tidb/parser/auth" "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/terror" + plannercore "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/session/txninfo" "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/util" "github.com/stretchr/testify/require" ) +// mockSessionManager is a mocked session manager which is used for test. +type mockSessionManager1 struct { + PS []*util.ProcessInfo +} + +func (msm *mockSessionManager1) ShowTxnList() []*txninfo.TxnInfo { + return nil +} + +// ShowProcessList implements the SessionManager.ShowProcessList interface. +func (msm *mockSessionManager1) ShowProcessList() map[uint64]*util.ProcessInfo { + ret := make(map[uint64]*util.ProcessInfo) + for _, item := range msm.PS { + ret[item.ID] = item + } + return ret +} + +func (msm *mockSessionManager1) GetProcessInfo(id uint64) (*util.ProcessInfo, bool) { + for _, item := range msm.PS { + if item.ID == id { + return item, true + } + } + return &util.ProcessInfo{}, false +} + +// Kill implements the SessionManager.Kill interface. +func (msm *mockSessionManager1) Kill(cid uint64, query bool) { +} + +func (msm *mockSessionManager1) KillAllConnections() { +} + +func (msm *mockSessionManager1) UpdateTLSConfig(cfg *tls.Config) { +} + +func (msm *mockSessionManager1) ServerID() uint64 { + return 1 +} + +func TestPrepareCacheWithBinding(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + orgEnable := plannercore.PreparedPlanCacheEnabled() + defer func() { + plannercore.SetPreparedPlanCache(orgEnable) + }() + plannercore.SetPreparedPlanCache(true) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("create table t1(a int, b int, c int, key idx_b(b), key idx_c(c))") + tk.MustExec("create table t2(a int, b int, c int, key idx_b(b), key idx_c(c))") + + // TestDMLSQLBind + tk.MustExec("prepare stmt1 from 'delete from t1 where b = 1 and c > 1';") + tk.MustExec("execute stmt1;") + require.Equal(t, "t1:idx_b", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + tkProcess := tk.Session().ShowProcess() + ps := []*util.ProcessInfo{tkProcess} + tk.Session().SetSessionManager(&mockSessionManager1{PS: ps}) + res := tk.MustQuery("explain for connection " + strconv.FormatUint(tkProcess.ID, 10)) + require.True(t, tk.MustUseIndex4ExplainFor(res, "idx_b(b)"), res.Rows()) + tk.MustExec("execute stmt1;") + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + + tk.MustExec("create global binding for delete from t1 where b = 1 and c > 1 using delete /*+ use_index(t1,idx_c) */ from t1 where b = 1 and c > 1") + + tk.MustExec("execute stmt1;") + require.Equal(t, "t1:idx_c", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + tkProcess = tk.Session().ShowProcess() + ps = []*util.ProcessInfo{tkProcess} + tk.Session().SetSessionManager(&mockSessionManager1{PS: ps}) + res = tk.MustQuery("explain for connection " + strconv.FormatUint(tkProcess.ID, 10)) + require.True(t, tk.MustUseIndex4ExplainFor(res, "idx_c(c)"), res.Rows()) + + tk.MustExec("prepare stmt2 from 'delete t1, t2 from t1 inner join t2 on t1.b = t2.b';") + tk.MustExec("execute stmt2;") + tkProcess = tk.Session().ShowProcess() + ps = []*util.ProcessInfo{tkProcess} + tk.Session().SetSessionManager(&mockSessionManager1{PS: ps}) + res = tk.MustQuery("explain for connection " + strconv.FormatUint(tkProcess.ID, 10)) + require.True(t, tk.HasPlan4ExplainFor(res, "HashJoin"), res.Rows()) + tk.MustExec("execute stmt2;") + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + + tk.MustExec("create global binding for delete t1, t2 from t1 inner join t2 on t1.b = t2.b using delete /*+ inl_join(t1) */ t1, t2 from t1 inner join t2 on t1.b = t2.b") + + tk.MustExec("execute stmt2;") + tkProcess = tk.Session().ShowProcess() + ps = []*util.ProcessInfo{tkProcess} + tk.Session().SetSessionManager(&mockSessionManager1{PS: ps}) + res = tk.MustQuery("explain for connection " + strconv.FormatUint(tkProcess.ID, 10)) + require.True(t, tk.HasPlan4ExplainFor(res, "IndexJoin"), res.Rows()) + + tk.MustExec("prepare stmt3 from 'update t1 set a = 1 where b = 1 and c > 1';") + tk.MustExec("execute stmt3;") + require.Equal(t, "t1:idx_b", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + tkProcess = tk.Session().ShowProcess() + ps = []*util.ProcessInfo{tkProcess} + tk.Session().SetSessionManager(&mockSessionManager1{PS: ps}) + res = tk.MustQuery("explain for connection " + strconv.FormatUint(tkProcess.ID, 10)) + require.True(t, tk.MustUseIndex4ExplainFor(res, "idx_b(b)"), res.Rows()) + tk.MustExec("execute stmt3;") + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + + tk.MustExec("create global binding for update t1 set a = 1 where b = 1 and c > 1 using update /*+ use_index(t1,idx_c) */ t1 set a = 1 where b = 1 and c > 1") + + tk.MustExec("execute stmt3;") + require.Equal(t, "t1:idx_c", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + tkProcess = tk.Session().ShowProcess() + ps = []*util.ProcessInfo{tkProcess} + tk.Session().SetSessionManager(&mockSessionManager1{PS: ps}) + res = tk.MustQuery("explain for connection " + strconv.FormatUint(tkProcess.ID, 10)) + require.True(t, tk.MustUseIndex4ExplainFor(res, "idx_c(c)"), res.Rows()) + + tk.MustExec("prepare stmt4 from 'update t1, t2 set t1.a = 1 where t1.b = t2.b';") + tk.MustExec("execute stmt4;") + tkProcess = tk.Session().ShowProcess() + ps = []*util.ProcessInfo{tkProcess} + tk.Session().SetSessionManager(&mockSessionManager1{PS: ps}) + res = tk.MustQuery("explain for connection " + strconv.FormatUint(tkProcess.ID, 10)) + require.True(t, tk.HasPlan4ExplainFor(res, "HashJoin"), res.Rows()) + tk.MustExec("execute stmt4;") + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + + tk.MustExec("create global binding for update t1, t2 set t1.a = 1 where t1.b = t2.b using update /*+ inl_join(t1) */ t1, t2 set t1.a = 1 where t1.b = t2.b") + + tk.MustExec("execute stmt4;") + tkProcess = tk.Session().ShowProcess() + ps = []*util.ProcessInfo{tkProcess} + tk.Session().SetSessionManager(&mockSessionManager1{PS: ps}) + res = tk.MustQuery("explain for connection " + strconv.FormatUint(tkProcess.ID, 10)) + require.True(t, tk.HasPlan4ExplainFor(res, "IndexJoin"), res.Rows()) + + tk.MustExec("prepare stmt5 from 'insert into t1 select * from t2 where t2.b = 2 and t2.c > 2';") + tk.MustExec("execute stmt5;") + require.Equal(t, "t2:idx_b", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + tkProcess = tk.Session().ShowProcess() + ps = []*util.ProcessInfo{tkProcess} + tk.Session().SetSessionManager(&mockSessionManager1{PS: ps}) + res = tk.MustQuery("explain for connection " + strconv.FormatUint(tkProcess.ID, 10)) + require.True(t, tk.MustUseIndex4ExplainFor(res, "idx_b(b)"), res.Rows()) + tk.MustExec("execute stmt5;") + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + + tk.MustExec("create global binding for insert into t1 select * from t2 where t2.b = 1 and t2.c > 1 using insert /*+ use_index(t2,idx_c) */ into t1 select * from t2 where t2.b = 1 and t2.c > 1") + + tk.MustExec("execute stmt5;") + require.Equal(t, "t2:idx_b", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + tkProcess = tk.Session().ShowProcess() + ps = []*util.ProcessInfo{tkProcess} + tk.Session().SetSessionManager(&mockSessionManager1{PS: ps}) + res = tk.MustQuery("explain for connection " + strconv.FormatUint(tkProcess.ID, 10)) + require.True(t, tk.MustUseIndex4ExplainFor(res, "idx_b(b)"), res.Rows()) + + tk.MustExec("drop global binding for insert into t1 select * from t2 where t2.b = 1 and t2.c > 1") + tk.MustExec("create global binding for insert into t1 select * from t2 where t2.b = 1 and t2.c > 1 using insert into t1 select /*+ use_index(t2,idx_c) */ * from t2 where t2.b = 1 and t2.c > 1") + + tk.MustExec("execute stmt5;") + require.Equal(t, "t2:idx_c", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + tkProcess = tk.Session().ShowProcess() + ps = []*util.ProcessInfo{tkProcess} + tk.Session().SetSessionManager(&mockSessionManager1{PS: ps}) + res = tk.MustQuery("explain for connection " + strconv.FormatUint(tkProcess.ID, 10)) + require.True(t, tk.MustUseIndex4ExplainFor(res, "idx_c(c)"), res.Rows()) + + tk.MustExec("prepare stmt6 from 'replace into t1 select * from t2 where t2.b = 2 and t2.c > 2';") + tk.MustExec("execute stmt6;") + require.Equal(t, "t2:idx_b", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + tkProcess = tk.Session().ShowProcess() + ps = []*util.ProcessInfo{tkProcess} + tk.Session().SetSessionManager(&mockSessionManager1{PS: ps}) + res = tk.MustQuery("explain for connection " + strconv.FormatUint(tkProcess.ID, 10)) + require.True(t, tk.MustUseIndex4ExplainFor(res, "idx_b(b)"), res.Rows()) + tk.MustExec("execute stmt6;") + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + + tk.MustExec("create global binding for replace into t1 select * from t2 where t2.b = 1 and t2.c > 1 using replace into t1 select /*+ use_index(t2,idx_c) */ * from t2 where t2.b = 1 and t2.c > 1") + + tk.MustExec("execute stmt6;") + require.Equal(t, "t2:idx_c", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + tkProcess = tk.Session().ShowProcess() + ps = []*util.ProcessInfo{tkProcess} + tk.Session().SetSessionManager(&mockSessionManager1{PS: ps}) + res = tk.MustQuery("explain for connection " + strconv.FormatUint(tkProcess.ID, 10)) + require.True(t, tk.MustUseIndex4ExplainFor(res, "idx_c(c)"), res.Rows()) + + // TestExplain + tk.MustExec("drop table if exists t1") + tk.MustExec("drop table if exists t2") + tk.MustExec("create table t1(id int)") + tk.MustExec("create table t2(id int)") + + tk.MustExec("prepare stmt1 from 'SELECT * from t1,t2 where t1.id = t2.id';") + tk.MustExec("execute stmt1;") + tkProcess = tk.Session().ShowProcess() + ps = []*util.ProcessInfo{tkProcess} + tk.Session().SetSessionManager(&mockSessionManager1{PS: ps}) + res = tk.MustQuery("explain for connection " + strconv.FormatUint(tkProcess.ID, 10)) + require.True(t, tk.HasPlan4ExplainFor(res, "HashJoin")) + tk.MustExec("execute stmt1;") + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + + tk.MustExec("prepare stmt2 from 'SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id';") + tk.MustExec("execute stmt2;") + tkProcess = tk.Session().ShowProcess() + ps = []*util.ProcessInfo{tkProcess} + tk.Session().SetSessionManager(&mockSessionManager1{PS: ps}) + res = tk.MustQuery("explain for connection " + strconv.FormatUint(tkProcess.ID, 10)) + require.True(t, tk.HasPlan4ExplainFor(res, "MergeJoin")) + tk.MustExec("execute stmt2;") + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + + tk.MustExec("create global binding for SELECT * from t1,t2 where t1.id = t2.id using SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id") + + tk.MustExec("execute stmt1;") + tkProcess = tk.Session().ShowProcess() + ps = []*util.ProcessInfo{tkProcess} + tk.Session().SetSessionManager(&mockSessionManager1{PS: ps}) + res = tk.MustQuery("explain for connection " + strconv.FormatUint(tkProcess.ID, 10)) + require.True(t, tk.HasPlan4ExplainFor(res, "MergeJoin")) + + tk.MustExec("drop global binding for SELECT * from t1,t2 where t1.id = t2.id") + + tk.MustExec("create index index_id on t1(id)") + tk.MustExec("prepare stmt1 from 'SELECT * from t1 use index(index_id)';") + tk.MustExec("execute stmt1;") + tkProcess = tk.Session().ShowProcess() + ps = []*util.ProcessInfo{tkProcess} + tk.Session().SetSessionManager(&mockSessionManager1{PS: ps}) + res = tk.MustQuery("explain for connection " + strconv.FormatUint(tkProcess.ID, 10)) + require.True(t, tk.HasPlan4ExplainFor(res, "IndexReader")) + tk.MustExec("execute stmt1;") + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + + tk.MustExec("create global binding for SELECT * from t1 using SELECT * from t1 ignore index(index_id)") + tk.MustExec("execute stmt1;") + tkProcess = tk.Session().ShowProcess() + ps = []*util.ProcessInfo{tkProcess} + tk.Session().SetSessionManager(&mockSessionManager1{PS: ps}) + res = tk.MustQuery("explain for connection " + strconv.FormatUint(tkProcess.ID, 10)) + require.False(t, tk.HasPlan4ExplainFor(res, "IndexReader")) + tk.MustExec("execute stmt1;") + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + + // Add test for SetOprStmt + tk.MustExec("prepare stmt1 from 'SELECT * from t1 union SELECT * from t1';") + tk.MustExec("execute stmt1;") + tkProcess = tk.Session().ShowProcess() + ps = []*util.ProcessInfo{tkProcess} + tk.Session().SetSessionManager(&mockSessionManager1{PS: ps}) + res = tk.MustQuery("explain for connection " + strconv.FormatUint(tkProcess.ID, 10)) + require.False(t, tk.HasPlan4ExplainFor(res, "IndexReader")) + tk.MustExec("execute stmt1;") + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + + tk.MustExec("prepare stmt2 from 'SELECT * from t1 use index(index_id) union SELECT * from t1';") + tk.MustExec("execute stmt2;") + tkProcess = tk.Session().ShowProcess() + ps = []*util.ProcessInfo{tkProcess} + tk.Session().SetSessionManager(&mockSessionManager1{PS: ps}) + res = tk.MustQuery("explain for connection " + strconv.FormatUint(tkProcess.ID, 10)) + require.True(t, tk.HasPlan4ExplainFor(res, "IndexReader")) + tk.MustExec("execute stmt2;") + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + + tk.MustExec("create global binding for SELECT * from t1 union SELECT * from t1 using SELECT * from t1 use index(index_id) union SELECT * from t1") + + tk.MustExec("execute stmt1;") + tkProcess = tk.Session().ShowProcess() + ps = []*util.ProcessInfo{tkProcess} + tk.Session().SetSessionManager(&mockSessionManager1{PS: ps}) + res = tk.MustQuery("explain for connection " + strconv.FormatUint(tkProcess.ID, 10)) + require.True(t, tk.HasPlan4ExplainFor(res, "IndexReader")) + + tk.MustExec("drop global binding for SELECT * from t1 union SELECT * from t1") + + // TestBindingSymbolList + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int, INDEX ia (a), INDEX ib (b));") + tk.MustExec("insert into t value(1, 1);") + tk.MustExec("prepare stmt1 from 'select a, b from t where a = 3 limit 1, 100';") + tk.MustExec("execute stmt1;") + require.Equal(t, "t:ia", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + tkProcess = tk.Session().ShowProcess() + ps = []*util.ProcessInfo{tkProcess} + tk.Session().SetSessionManager(&mockSessionManager1{PS: ps}) + res = tk.MustQuery("explain for connection " + strconv.FormatUint(tkProcess.ID, 10)) + require.True(t, tk.MustUseIndex4ExplainFor(res, "ia(a)"), res.Rows()) + tk.MustExec("execute stmt1;") + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + + tk.MustExec(`create global binding for select a, b from t where a = 1 limit 0, 1 using select a, b from t use index (ib) where a = 1 limit 0, 1`) + + // after binding + tk.MustExec("execute stmt1;") + require.Equal(t, "t:ib", tk.Session().GetSessionVars().StmtCtx.IndexNames[0]) + tkProcess = tk.Session().ShowProcess() + ps = []*util.ProcessInfo{tkProcess} + tk.Session().SetSessionManager(&mockSessionManager1{PS: ps}) + res = tk.MustQuery("explain for connection " + strconv.FormatUint(tkProcess.ID, 10)) + require.True(t, tk.MustUseIndex4ExplainFor(res, "ib(b)"), res.Rows()) +} + func TestExplain(t *testing.T) { store, clean := testkit.CreateMockStore(t) defer clean() diff --git a/executor/explainfor_test.go b/executor/explainfor_test.go index 078ed7e6bb548..5befb576c132a 100644 --- a/executor/explainfor_test.go +++ b/executor/explainfor_test.go @@ -1124,16 +1124,20 @@ func (s *testPrepareSerialSuite) TestSPM4PlanCache(c *C) { tk.MustQuery("select @@last_plan_from_binding").Check(testkit.Rows("1")) tk.MustQuery("execute stmt;").Check(testkit.Rows()) - tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + // The bindSQL has changed, the previous cache is invalid. + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("0")) tk.MustQuery("execute stmt;").Check(testkit.Rows()) tkProcess = tk.Se.ShowProcess() ps = []*util.ProcessInfo{tkProcess} tk.Se.SetSessionManager(&mockSessionManager1{PS: ps}) res = tk.MustQuery("explain for connection " + strconv.FormatUint(tkProcess.ID, 10)) - // The binding does not take effect for caches that have been cached. - c.Assert(res.Rows()[0][0], Matches, ".*TableReader.*") - c.Assert(res.Rows()[1][0], Matches, ".*TableFullScan.*") - tk.MustQuery("select @@last_plan_from_binding").Check(testkit.Rows("0")) + // We can use the new binding. + c.Assert(res.Rows()[0][0], Matches, ".*IndexReader.*") + c.Assert(res.Rows()[1][0], Matches, ".*IndexFullScan.*") + tk.MustQuery("execute stmt;").Check(testkit.Rows()) + tk.MustQuery("select @@last_plan_from_cache").Check(testkit.Rows("1")) + tk.MustQuery("execute stmt;").Check(testkit.Rows()) + tk.MustQuery("select @@last_plan_from_binding").Check(testkit.Rows("1")) tk.MustExec("delete from mysql.bind_info where default_db='test';") tk.MustExec("admin reload bindings;") diff --git a/executor/prepared.go b/executor/prepared.go index 3013aba0de9cd..82a030e76b6c1 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -314,8 +314,9 @@ func (e *DeallocateExec) Next(ctx context.Context, req *chunk.Chunk) error { prepared := preparedObj.PreparedAst delete(vars.PreparedStmtNameToID, e.Name) if plannercore.PreparedPlanCacheEnabled() { + bindSQL := planner.GetBindSQL4PlanCache(e.ctx, prepared.Stmt) e.ctx.PreparedPlanCache().Delete(plannercore.NewPSTMTPlanCacheKey( - vars, id, prepared.SchemaVersion, + vars, id, prepared.SchemaVersion, bindSQL, )) } vars.RemovePreparedStmt(id) diff --git a/planner/core/cache.go b/planner/core/cache.go index ea6d0c32e3b39..a386c4a5a3649 100644 --- a/planner/core/cache.go +++ b/planner/core/cache.go @@ -74,6 +74,7 @@ type pstmtPlanCacheKey struct { timezoneOffset int isolationReadEngines map[kv.StoreType]struct{} selectLimit uint64 + bindSQL string hash []byte } @@ -104,6 +105,7 @@ func (key *pstmtPlanCacheKey) Hash() []byte { key.hash = append(key.hash, kv.TiFlash.Name()...) } key.hash = codec.EncodeInt(key.hash, int64(key.selectLimit)) + key.hash = append(key.hash, hack.Slice(key.bindSQL)...) } return key.hash } @@ -125,7 +127,7 @@ func SetPstmtIDSchemaVersion(key kvcache.Key, pstmtID uint32, schemaVersion int6 } // NewPSTMTPlanCacheKey creates a new pstmtPlanCacheKey object. -func NewPSTMTPlanCacheKey(sessionVars *variable.SessionVars, pstmtID uint32, schemaVersion int64) kvcache.Key { +func NewPSTMTPlanCacheKey(sessionVars *variable.SessionVars, pstmtID uint32, schemaVersion int64, bindSQL string) kvcache.Key { timezoneOffset := 0 if sessionVars.TimeZone != nil { _, timezoneOffset = time.Now().In(sessionVars.TimeZone).Zone() @@ -139,6 +141,7 @@ func NewPSTMTPlanCacheKey(sessionVars *variable.SessionVars, pstmtID uint32, sch timezoneOffset: timezoneOffset, isolationReadEngines: make(map[kv.StoreType]struct{}), selectLimit: sessionVars.SelectLimit, + bindSQL: bindSQL, } for k, v := range sessionVars.IsolationReadEngines { key.isolationReadEngines[k] = v diff --git a/planner/core/cache_test.go b/planner/core/cache_test.go index ff0ab53fa558b..074d1e4cf2828 100644 --- a/planner/core/cache_test.go +++ b/planner/core/cache_test.go @@ -28,6 +28,6 @@ func TestCacheKey(t *testing.T) { ctx.GetSessionVars().SQLMode = mysql.ModeNone ctx.GetSessionVars().TimeZone = time.UTC ctx.GetSessionVars().ConnectionID = 0 - key := NewPSTMTPlanCacheKey(ctx.GetSessionVars(), 1, 1) + key := NewPSTMTPlanCacheKey(ctx.GetSessionVars(), 1, 1, "") require.Equal(t, []byte{0x74, 0x65, 0x73, 0x74, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x74, 0x69, 0x64, 0x62, 0x74, 0x69, 0x6b, 0x76, 0x74, 0x69, 0x66, 0x6c, 0x61, 0x73, 0x68, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, key.Hash()) } diff --git a/planner/core/common_plans.go b/planner/core/common_plans.go index 9af659ef6c972..d3e56600b2c25 100644 --- a/planner/core/common_plans.go +++ b/planner/core/common_plans.go @@ -401,8 +401,10 @@ func (e *Execute) getPhysicalPlan(ctx context.Context, sctx sessionctx.Context, } stmtCtx.UseCache = prepared.UseCache + var bindSQL string if prepared.UseCache { - cacheKey = NewPSTMTPlanCacheKey(sctx.GetSessionVars(), e.ExecID, prepared.SchemaVersion) + bindSQL = GetBindSQL4PlanCache(sctx, prepared.Stmt) + cacheKey = NewPSTMTPlanCacheKey(sctx.GetSessionVars(), e.ExecID, prepared.SchemaVersion, bindSQL) } tps := make([]*types.FieldType, len(e.UsingVars)) for i, param := range e.UsingVars { @@ -468,6 +470,14 @@ func (e *Execute) getPhysicalPlan(ctx context.Context, sctx sessionctx.Context, if err != nil { return err } + if len(bindSQL) > 0 { + // When the `len(bindSQL) > 0`, it means we use the binding. + // So we need to record this. + err = sessVars.SetSystemVar(variable.TiDBFoundInBinding, variable.BoolToOnOff(true)) + if err != nil { + return err + } + } if metrics.ResettablePlanCacheCounterFortTest { metrics.PlanCacheCounter.WithLabelValues("prepare").Inc() } else { @@ -500,8 +510,11 @@ REBUILD: // rebuild key to exclude kv.TiFlash when stmt is not read only if _, isolationReadContainTiFlash := sessVars.IsolationReadEngines[kv.TiFlash]; isolationReadContainTiFlash && !IsReadOnly(stmt, sessVars) { delete(sessVars.IsolationReadEngines, kv.TiFlash) - cacheKey = NewPSTMTPlanCacheKey(sctx.GetSessionVars(), e.ExecID, prepared.SchemaVersion) + cacheKey = NewPSTMTPlanCacheKey(sessVars, e.ExecID, prepared.SchemaVersion, sessVars.StmtCtx.BindSQL) sessVars.IsolationReadEngines[kv.TiFlash] = struct{}{} + } else { + // We need to reconstruct the plan cache key based on the bindSQL. + cacheKey = NewPSTMTPlanCacheKey(sessVars, e.ExecID, prepared.SchemaVersion, sessVars.StmtCtx.BindSQL) } cached := NewPSTMTPlanCacheValue(p, names, stmtCtx.TblInfo2UnionScan, tps) preparedStmt.NormalizedPlan, preparedStmt.PlanDigest = NormalizePlan(p) diff --git a/planner/core/optimizer.go b/planner/core/optimizer.go index 5103ea4fdf38d..89b156e632cea 100644 --- a/planner/core/optimizer.go +++ b/planner/core/optimizer.go @@ -44,6 +44,9 @@ import ( // OptimizeAstNode optimizes the query to a physical plan directly. var OptimizeAstNode func(ctx context.Context, sctx sessionctx.Context, node ast.Node, is infoschema.InfoSchema) (Plan, types.NameSlice, error) +// GetBindSQL4PlanCache get the bindSQL for the ast.StmtNode +var GetBindSQL4PlanCache func(sctx sessionctx.Context, stmtNode ast.StmtNode) (bindSQL string) + // AllowCartesianProduct means whether tidb allows cartesian join without equal conditions. var AllowCartesianProduct = atomic.NewBool(true) diff --git a/planner/optimize.go b/planner/optimize.go index 363c3e6f5374a..b16fc09a238f0 100644 --- a/planner/optimize.go +++ b/planner/optimize.go @@ -99,6 +99,28 @@ func GetExecuteForUpdateReadIS(node ast.Node, sctx sessionctx.Context) infoschem return nil } +// GetBindSQL4PlanCache used to get the bindSQL for plan cache to build the plan cache key. +func GetBindSQL4PlanCache(sctx sessionctx.Context, stmtNode ast.StmtNode) (bindSQL string) { + bindRecord, _, match := matchSQLBinding(sctx, stmtNode) + if match { + bindSQL = bindRecord.Bindings[0].BindSQL + } + return bindSQL +} + +func matchSQLBinding(sctx sessionctx.Context, stmtNode ast.StmtNode) (bindRecord *bindinfo.BindRecord, scope string, matched bool) { + useBinding := sctx.GetSessionVars().UsePlanBaselines + if !useBinding || stmtNode == nil { + return nil, "", false + } + var err error + bindRecord, scope, err = getBindRecord(sctx, stmtNode) + if err != nil || bindRecord == nil || len(bindRecord.Bindings) == 0 { + return nil, "", false + } + return bindRecord, scope, true +} + // Optimize does optimization and creates a Plan. // The node must be prepared first. func Optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is infoschema.InfoSchema) (plannercore.Plan, types.NameSlice, error) { @@ -149,16 +171,9 @@ func Optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in if !ok { useBinding = false } - var ( - bindRecord *bindinfo.BindRecord - scope string - err error - ) - if useBinding { - bindRecord, scope, err = getBindRecord(sctx, stmtNode) - if err != nil || bindRecord == nil || len(bindRecord.Bindings) == 0 { - useBinding = false - } + bindRecord, scope, match := matchSQLBinding(sctx, stmtNode) + if !match { + useBinding = false } if ok { // add the extra Limit after matching the bind record @@ -166,14 +181,15 @@ func Optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in node = stmtNode } - var names types.NameSlice - var bestPlan, bestPlanFromBind plannercore.Plan + var ( + names types.NameSlice + bestPlan, bestPlanFromBind plannercore.Plan + chosenBinding bindinfo.Binding + err error + ) if useBinding { minCost := math.MaxFloat64 - var ( - bindStmtHints stmtctx.StmtHints - chosenBinding bindinfo.Binding - ) + var bindStmtHints stmtctx.StmtHints originHints := hint.CollectHint(stmtNode) // bindRecord must be not nil when coming here, try to find the best binding. for _, binding := range bindRecord.Bindings { @@ -206,7 +222,7 @@ func Optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in for _, warn := range warns { sessVars.StmtCtx.AppendWarning(warn) } - if err := setFoundInBinding(sctx, true); err != nil { + if err := setFoundInBinding(sctx, true, chosenBinding.BindSQL); err != nil { logutil.BgLogger().Warn("set tidb_found_in_binding failed", zap.Error(err)) } if sessVars.StmtCtx.InVerboseExplain { @@ -694,13 +710,15 @@ func handleStmtHints(hints []*ast.TableOptimizerHint) (stmtHints stmtctx.StmtHin return } -func setFoundInBinding(sctx sessionctx.Context, opt bool) error { +func setFoundInBinding(sctx sessionctx.Context, opt bool, bindSQL string) error { vars := sctx.GetSessionVars() + vars.StmtCtx.BindSQL = bindSQL err := vars.SetSystemVar(variable.TiDBFoundInBinding, variable.BoolToOnOff(opt)) return err } func init() { plannercore.OptimizeAstNode = Optimize + plannercore.GetBindSQL4PlanCache = GetBindSQL4PlanCache plannercore.IsReadOnly = IsReadOnly } diff --git a/server/driver_tidb.go b/server/driver_tidb.go index 6dae49084eeee..9a13eea632962 100644 --- a/server/driver_tidb.go +++ b/server/driver_tidb.go @@ -25,6 +25,7 @@ import ( "github.com/pingcap/tidb/parser/charset" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" + "github.com/pingcap/tidb/planner" "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx/stmtctx" @@ -164,8 +165,10 @@ func (ts *TiDBStatement) Close() error { if !ok { return errors.Errorf("invalid CachedPrepareStmt type") } + preparedAst := preparedObj.PreparedAst + bindSQL := planner.GetBindSQL4PlanCache(ts.ctx, preparedAst.Stmt) ts.ctx.PreparedPlanCache().Delete(core.NewPSTMTPlanCacheKey( - ts.ctx.GetSessionVars(), ts.id, preparedObj.PreparedAst.SchemaVersion)) + ts.ctx.GetSessionVars(), ts.id, preparedObj.PreparedAst.SchemaVersion, bindSQL)) } ts.ctx.GetSessionVars().RemovePreparedStmt(ts.id) } diff --git a/session/session.go b/session/session.go index bfc5288a7ff4f..465de576b37c7 100644 --- a/session/session.go +++ b/session/session.go @@ -309,7 +309,8 @@ func (s *session) cleanRetryInfo() { preparedObj, ok := preparedPointer.(*plannercore.CachedPrepareStmt) if ok { preparedAst = preparedObj.PreparedAst - cacheKey = plannercore.NewPSTMTPlanCacheKey(s.sessionVars, firstStmtID, preparedAst.SchemaVersion) + bindSQL := planner.GetBindSQL4PlanCache(s, preparedAst.Stmt) + cacheKey = plannercore.NewPSTMTPlanCacheKey(s.sessionVars, firstStmtID, preparedAst.SchemaVersion, bindSQL) } } } diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index 3125ae419641e..e41eb4766b47b 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -153,6 +153,9 @@ type StatementContext struct { normalized string digest *parser.Digest } + // BindSQL used to construct the key for plan cache. It records the binding used by the stmt. + // If the binding is not used by the stmt, the value is empty + BindSQL string // planNormalized use for cache the normalized plan, avoid duplicate builds. planNormalized string planDigest *parser.Digest diff --git a/testkit/testkit.go b/testkit/testkit.go index 50af5e0178a1b..c99791efe369a 100644 --- a/testkit/testkit.go +++ b/testkit/testkit.go @@ -128,6 +128,16 @@ func (tk *TestKit) HasPlan(sql string, plan string, args ...interface{}) bool { return false } +// HasPlan4ExplainFor checks if the result execution plan contains specific plan. +func (tk *TestKit) HasPlan4ExplainFor(result *Result, plan string) bool { + for i := range result.rows { + if strings.Contains(result.rows[i][0], plan) { + return true + } + } + return false +} + // Exec executes a sql statement using the prepared stmt API func (tk *TestKit) Exec(sql string, args ...interface{}) (sqlexec.RecordSet, error) { ctx := context.Background() @@ -228,6 +238,20 @@ func (tk *TestKit) MustUseIndex(sql string, index string, args ...interface{}) b return false } +// MustUseIndex4ExplainFor checks if the result execution plan contains specific index(es). +func (tk *TestKit) MustUseIndex4ExplainFor(result *Result, index string) bool { + for i := range result.rows { + // It depends on whether we enable to collect the execution info. + if strings.Contains(result.rows[i][3], "index:"+index) { + return true + } + if strings.Contains(result.rows[i][4], "index:"+index) { + return true + } + } + return false +} + // CheckExecResult checks the affected rows and the insert id after executing MustExec. func (tk *TestKit) CheckExecResult(affectedRows, insertID int64) { tk.require.Equal(int64(tk.Session().AffectedRows()), affectedRows)