From c3b8757df8d619e4342639d9700aab0cddfa61de Mon Sep 17 00:00:00 2001 From: Ti Chi Robot Date: Mon, 16 Oct 2023 10:25:57 +0800 Subject: [PATCH] planner: store the hints of session variable (#45814) (#46047) close pingcap/tidb#45812 --- planner/core/plan_cache.go | 11 ++++- planner/core/plan_cache_utils.go | 6 ++- planner/core/point_get_plan.go | 2 + session/sessiontest/BUILD.bazel | 1 + session/sessiontest/session_test.go | 68 +++++++++++++++++++++++++++++ sessionctx/stmtctx/BUILD.bazel | 2 +- sessionctx/stmtctx/stmtctx.go | 38 +++++++++++++++- sessionctx/stmtctx/stmtctx_test.go | 23 ++++++++++ 8 files changed, 146 insertions(+), 5 deletions(-) diff --git a/planner/core/plan_cache.go b/planner/core/plan_cache.go index c58020a36964a..d3490e74907bb 100644 --- a/planner/core/plan_cache.go +++ b/planner/core/plan_cache.go @@ -247,6 +247,9 @@ func getCachedPointPlan(stmt *ast.Prepared, sessVars *variable.SessionVars, stmt } sessVars.FoundInPlanCache = true stmtCtx.PointExec = true + if pointGetPlan, ok := plan.(*PointGetPlan); ok && pointGetPlan != nil && pointGetPlan.stmtHints != nil { + sessVars.StmtCtx.StmtHints = *pointGetPlan.stmtHints + } return plan, names, true, nil } @@ -287,6 +290,7 @@ func getCachedPlan(sctx sessionctx.Context, isNonPrepared bool, cacheKey kvcache core_metrics.GetPlanCacheHitCounter(isNonPrepared).Inc() } stmtCtx.SetPlanDigest(stmt.NormalizedPlan, stmt.PlanDigest) + stmtCtx.StmtHints = *cachedVal.stmtHints return cachedVal.Plan, cachedVal.OutPutNames, true, nil } @@ -329,7 +333,7 @@ func generateNewPlan(ctx context.Context, sctx sessionctx.Context, isNonPrepared } sessVars.IsolationReadEngines[kv.TiFlash] = struct{}{} } - cached := NewPlanCacheValue(p, names, stmtCtx.TblInfo2UnionScan, matchOpts) + cached := NewPlanCacheValue(p, names, stmtCtx.TblInfo2UnionScan, matchOpts, &stmtCtx.StmtHints) stmt.NormalizedPlan, stmt.PlanDigest = NormalizePlan(p) stmtCtx.SetPlan(p) stmtCtx.SetPlanDigest(stmt.NormalizedPlan, stmt.PlanDigest) @@ -759,12 +763,15 @@ func tryCachePointPlan(_ context.Context, sctx sessionctx.Context, names types.NameSlice ) - if _, _ok := p.(*PointGetPlan); _ok { + if plan, _ok := p.(*PointGetPlan); _ok { ok, err = IsPointGetWithPKOrUniqueKeyByAutoCommit(sctx, p) names = p.OutputNames() if err != nil { return err } + if ok { + plan.stmtHints = sctx.GetSessionVars().StmtCtx.StmtHints.Clone() + } } if ok { diff --git a/planner/core/plan_cache_utils.go b/planner/core/plan_cache_utils.go index 0ffbbfa9033ce..17b068e20d1b1 100644 --- a/planner/core/plan_cache_utils.go +++ b/planner/core/plan_cache_utils.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/types" @@ -335,6 +336,8 @@ type PlanCacheValue struct { // matchOpts stores some fields help to choose a suitable plan matchOpts *utilpc.PlanCacheMatchOpts + // stmtHints stores the hints which set session variables, because the hints won't be processed using cached plan. + stmtHints *stmtctx.StmtHints } func (v *PlanCacheValue) varTypesUnchanged(txtVarTps []*types.FieldType) bool { @@ -385,7 +388,7 @@ func (v *PlanCacheValue) MemoryUsage() (sum int64) { // NewPlanCacheValue creates a SQLCacheValue. func NewPlanCacheValue(plan Plan, names []*types.FieldName, srcMap map[*model.TableInfo]bool, - matchOpts *utilpc.PlanCacheMatchOpts) *PlanCacheValue { + matchOpts *utilpc.PlanCacheMatchOpts, stmtHints *stmtctx.StmtHints) *PlanCacheValue { dstMap := make(map[*model.TableInfo]bool) for k, v := range srcMap { dstMap[k] = v @@ -399,6 +402,7 @@ func NewPlanCacheValue(plan Plan, names []*types.FieldName, srcMap map[*model.Ta OutPutNames: names, TblInfo2UnionScan: dstMap, matchOpts: matchOpts, + stmtHints: stmtHints.Clone(), } } diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index 43241b68fde7a..5f670f224a5cd 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -96,6 +96,8 @@ type PointGetPlan struct { // probeParents records the IndexJoins and Applys with this operator in their inner children. // Please see comments in PhysicalPlan for details. probeParents []PhysicalPlan + // stmtHints should restore in executing context. + stmtHints *stmtctx.StmtHints } func (p *PointGetPlan) getEstRowCountForDisplay() float64 { diff --git a/session/sessiontest/BUILD.bazel b/session/sessiontest/BUILD.bazel index c8348dd42c6e2..a7b0a0bc4402e 100644 --- a/session/sessiontest/BUILD.bazel +++ b/session/sessiontest/BUILD.bazel @@ -26,6 +26,7 @@ go_test( "//privilege/privileges", "//session", "//sessionctx", + "//sessionctx/stmtctx", "//sessionctx/variable", "//store/copr", "//store/mockstore", diff --git a/session/sessiontest/session_test.go b/session/sessiontest/session_test.go index 6d2d502e5fd9d..6124d4e4e269e 100644 --- a/session/sessiontest/session_test.go +++ b/session/sessiontest/session_test.go @@ -41,6 +41,7 @@ import ( "github.com/pingcap/tidb/privilege/privileges" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/store/copr" "github.com/pingcap/tidb/store/mockstore" @@ -3602,3 +3603,70 @@ func TestSQLModeOp(t *testing.T) { a = mysql.SetSQLMode(s, mysql.ModeAllowInvalidDates) require.Equal(t, mysql.ModeNoBackslashEscapes|mysql.ModeOnlyFullGroupBy|mysql.ModeAllowInvalidDates, a) } + +func TestPrepareExecuteWithSQLHints(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + se := tk.Session() + se.SetConnectionID(1) + tk.MustExec("use test") + tk.MustExec("create table t(a int primary key)") + + type hintCheck struct { + hint string + check func(*stmtctx.StmtHints) + } + + hintChecks := []hintCheck{ + { + hint: "MEMORY_QUOTA(1024 MB)", + check: func(stmtHint *stmtctx.StmtHints) { + require.True(t, stmtHint.HasMemQuotaHint) + require.Equal(t, int64(1024*1024*1024), stmtHint.MemQuotaQuery) + }, + }, + { + hint: "READ_CONSISTENT_REPLICA()", + check: func(stmtHint *stmtctx.StmtHints) { + require.True(t, stmtHint.HasReplicaReadHint) + require.Equal(t, byte(kv.ReplicaReadFollower), stmtHint.ReplicaRead) + }, + }, + { + hint: "MAX_EXECUTION_TIME(1000)", + check: func(stmtHint *stmtctx.StmtHints) { + require.True(t, stmtHint.HasMaxExecutionTime) + require.Equal(t, uint64(1000), stmtHint.MaxExecutionTime) + }, + }, + { + hint: "USE_TOJA(TRUE)", + check: func(stmtHint *stmtctx.StmtHints) { + require.True(t, stmtHint.HasAllowInSubqToJoinAndAggHint) + require.True(t, stmtHint.AllowInSubqToJoinAndAgg) + }, + }, + { + hint: "RESOURCE_GROUP(rg1)", + check: func(stmtHint *stmtctx.StmtHints) { + require.True(t, stmtHint.HasResourceGroup) + require.Equal(t, "rg1", stmtHint.ResourceGroup) + }, + }, + } + + for i, check := range hintChecks { + // common path + tk.MustExec(fmt.Sprintf("prepare stmt%d from 'select /*+ %s */ * from t'", i, check.hint)) + for j := 0; j < 10; j++ { + tk.MustQuery(fmt.Sprintf("execute stmt%d", i)) + check.check(&tk.Session().GetSessionVars().StmtCtx.StmtHints) + } + // fast path + tk.MustExec(fmt.Sprintf("prepare fast%d from 'select /*+ %s */ * from t where a = 1'", i, check.hint)) + for j := 0; j < 10; j++ { + tk.MustQuery(fmt.Sprintf("execute fast%d", i)) + check.check(&tk.Session().GetSessionVars().StmtCtx.StmtHints) + } + } +} diff --git a/sessionctx/stmtctx/BUILD.bazel b/sessionctx/stmtctx/BUILD.bazel index d02edd86138a1..4d7be5fc6faf7 100644 --- a/sessionctx/stmtctx/BUILD.bazel +++ b/sessionctx/stmtctx/BUILD.bazel @@ -37,7 +37,7 @@ go_test( ], embed = [":stmtctx"], flaky = True, - shard_count = 5, + shard_count = 6, deps = [ "//kv", "//sessionctx/variable", diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index c853354209f37..e2e02f35abd1e 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -415,7 +415,6 @@ type StatementContext struct { type StmtHints struct { // Hint Information MemQuotaQuery int64 - ApplyCacheCapacity int64 MaxExecutionTime uint64 ReplicaRead byte AllowInSubqToJoinAndAgg bool @@ -446,6 +445,43 @@ func (sh *StmtHints) TaskMapNeedBackUp() bool { return sh.ForceNthPlan != -1 } +// Clone the StmtHints struct and returns the pointer of the new one. +func (sh *StmtHints) Clone() *StmtHints { + var ( + vars map[string]string + tableHints []*ast.TableOptimizerHint + ) + if len(sh.SetVars) > 0 { + vars = make(map[string]string, len(sh.SetVars)) + for k, v := range sh.SetVars { + vars[k] = v + } + } + if len(sh.OriginalTableHints) > 0 { + tableHints = make([]*ast.TableOptimizerHint, len(sh.OriginalTableHints)) + copy(tableHints, sh.OriginalTableHints) + } + return &StmtHints{ + MemQuotaQuery: sh.MemQuotaQuery, + MaxExecutionTime: sh.MaxExecutionTime, + ReplicaRead: sh.ReplicaRead, + AllowInSubqToJoinAndAgg: sh.AllowInSubqToJoinAndAgg, + NoIndexMergeHint: sh.NoIndexMergeHint, + StraightJoinOrder: sh.StraightJoinOrder, + EnableCascadesPlanner: sh.EnableCascadesPlanner, + ForceNthPlan: sh.ForceNthPlan, + ResourceGroup: sh.ResourceGroup, + HasAllowInSubqToJoinAndAggHint: sh.HasAllowInSubqToJoinAndAggHint, + HasMemQuotaHint: sh.HasMemQuotaHint, + HasReplicaReadHint: sh.HasReplicaReadHint, + HasMaxExecutionTime: sh.HasMaxExecutionTime, + HasEnableCascadesPlannerHint: sh.HasEnableCascadesPlannerHint, + HasResourceGroup: sh.HasResourceGroup, + SetVars: vars, + OriginalTableHints: tableHints, + } +} + // StmtCacheKey represents the key type in the StmtCache. type StmtCacheKey int diff --git a/sessionctx/stmtctx/stmtctx_test.go b/sessionctx/stmtctx/stmtctx_test.go index 461168ee4b607..9a3951278befc 100644 --- a/sessionctx/stmtctx/stmtctx_test.go +++ b/sessionctx/stmtctx/stmtctx_test.go @@ -19,6 +19,7 @@ import ( "encoding/json" "fmt" "math/rand" + "reflect" "sort" "testing" "time" @@ -273,3 +274,25 @@ func TestApproxRuntimeInfo(t *testing.T) { require.Equal(t, d.TotBackoffTime[backoff], timeSum) } } + +func TestStmtHintsClone(t *testing.T) { + hints := stmtctx.StmtHints{} + value := reflect.ValueOf(&hints).Elem() + for i := 0; i < value.NumField(); i++ { + field := value.Field(i) + switch field.Kind() { + case reflect.Int, reflect.Int32, reflect.Int64: + field.SetInt(1) + case reflect.Uint, reflect.Uint32, reflect.Uint64: + field.SetUint(1) + case reflect.Uint8: // byte + field.SetUint(1) + case reflect.Bool: + field.SetBool(true) + case reflect.String: + field.SetString("test") + default: + } + } + require.Equal(t, hints, *hints.Clone()) +}