diff --git a/pkg/planner/core/BUILD.bazel b/pkg/planner/core/BUILD.bazel index c429b032d5002..b1d214f11a43b 100644 --- a/pkg/planner/core/BUILD.bazel +++ b/pkg/planner/core/BUILD.bazel @@ -22,7 +22,6 @@ go_library( "indexmerge_path.go", "indexmerge_unfinished_path.go", "initialize.go", - "logical_cte.go", "logical_datasource.go", "logical_index_scan.go", "logical_initialize.go", @@ -287,6 +286,7 @@ go_test( "//pkg/planner/core/base", "//pkg/planner/core/operator/logicalop", "//pkg/planner/core/operator/physicalop", + "//pkg/planner/core/rule", "//pkg/planner/property", "//pkg/planner/util", "//pkg/planner/util/coretestsdk", diff --git a/pkg/planner/core/collect_column_stats_usage.go b/pkg/planner/core/collect_column_stats_usage.go index 3045e6b6a88e2..0bdf6be5a66ce 100644 --- a/pkg/planner/core/collect_column_stats_usage.go +++ b/pkg/planner/core/collect_column_stats_usage.go @@ -293,7 +293,7 @@ func (c *columnStatsUsageCollector) collectFromPlan(lp base.LogicalPlan) { c.collectPredicateColumnsForUnionAll(x) case *logicalop.LogicalPartitionUnionAll: c.collectPredicateColumnsForUnionAll(&x.LogicalUnionAll) - case *LogicalCTE: + case *logicalop.LogicalCTE: // Visit SeedPartLogicalPlan and RecursivePartLogicalPlan first. c.collectFromPlan(x.Cte.SeedPartLogicalPlan) if x.Cte.RecursivePartLogicalPlan != nil { diff --git a/pkg/planner/core/collect_column_stats_usage_test.go b/pkg/planner/core/collect_column_stats_usage_test.go index fd26b49e11878..7f027f8a51af3 100644 --- a/pkg/planner/core/collect_column_stats_usage_test.go +++ b/pkg/planner/core/collect_column_stats_usage_test.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/tidb/pkg/infoschema" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/core/rule" "github.com/pingcap/tidb/pkg/util/hint" "github.com/stretchr/testify/require" ) @@ -396,7 +397,7 @@ func TestCollectHistNeededColumns(t *testing.T) { flags := builder.GetOptFlag() // JoinReOrder may need columns stats so collecting hist-needed columns must happen before JoinReOrder. // Hence, we disable JoinReOrder and PruneColumnsAgain here. - flags &= ^(flagJoinReOrder | flagPrunColumnsAgain) + flags &= ^(rule.FlagJoinReOrder | rule.FlagPruneColumnsAgain) lp, err = logicalOptimize(ctx, flags, lp) require.NoError(t, err, comment) checkColumnStatsUsageForStatsLoad(t, s.is, lp, tt.res, comment) diff --git a/pkg/planner/core/core_init.go b/pkg/planner/core/core_init.go index aa64607250bef..b191d90f3a85c 100644 --- a/pkg/planner/core/core_init.go +++ b/pkg/planner/core/core_init.go @@ -34,11 +34,13 @@ func init() { utilfuncp.GetTaskPlanCost = getTaskPlanCost utilfuncp.CanPushToCopImpl = canPushToCopImpl utilfuncp.PushDownTopNForBaseLogicalPlan = pushDownTopNForBaseLogicalPlan + utilfuncp.FindBestTask4LogicalCTE = findBestTask4LogicalCTE utilfuncp.FindBestTask4LogicalShow = findBestTask4LogicalShow utilfuncp.FindBestTask4LogicalCTETable = findBestTask4LogicalCTETable utilfuncp.FindBestTask4LogicalMemTable = findBestTask4LogicalMemTable utilfuncp.FindBestTask4LogicalTableDual = findBestTask4LogicalTableDual utilfuncp.FindBestTask4LogicalShowDDLJobs = findBestTask4LogicalShowDDLJobs + utilfuncp.ExhaustPhysicalPlans4LogicalCTE = exhaustPhysicalPlans4LogicalCTE utilfuncp.ExhaustPhysicalPlans4LogicalSort = exhaustPhysicalPlans4LogicalSort utilfuncp.ExhaustPhysicalPlans4LogicalTopN = exhaustPhysicalPlans4LogicalTopN utilfuncp.ExhaustPhysicalPlans4LogicalLock = exhaustPhysicalPlans4LogicalLock @@ -63,6 +65,7 @@ func init() { utilfuncp.PushDownTopNForBaseLogicalPlan = pushDownTopNForBaseLogicalPlan utilfuncp.AttachPlan2Task = attachPlan2Task utilfuncp.WindowIsTopN = windowIsTopN + utilfuncp.DoOptimize = doOptimize // For mv index init. cardinality.GetTblInfoForUsedStatsByPhysicalID = getTblInfoForUsedStatsByPhysicalID diff --git a/pkg/planner/core/exhaust_physical_plans.go b/pkg/planner/core/exhaust_physical_plans.go index 9c7b9d9233bb3..b4655e9ef9d46 100644 --- a/pkg/planner/core/exhaust_physical_plans.go +++ b/pkg/planner/core/exhaust_physical_plans.go @@ -2480,7 +2480,7 @@ func canPushToCopImpl(lp base.LogicalPlan, storeTp kv.StoreType, considerDual bo return false case *logicalop.LogicalSequence: return storeTp == kv.TiFlash - case *LogicalCTE: + case *logicalop.LogicalCTE: if storeTp != kv.TiFlash { return false } @@ -3073,7 +3073,8 @@ func exhaustPhysicalPlans4LogicalMaxOneRow(lp base.LogicalPlan, prop *property.P return []base.PhysicalPlan{mor}, true, nil } -func exhaustPhysicalPlans4LogicalCTE(p *LogicalCTE, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { +func exhaustPhysicalPlans4LogicalCTE(lp base.LogicalPlan, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { + p := lp.(*logicalop.LogicalCTE) pcte := PhysicalCTE{CTE: p.Cte}.Init(p.SCtx(), p.StatsInfo()) if prop.IsFlashProp() { pcte.storageSender = PhysicalExchangeSender{ diff --git a/pkg/planner/core/expression_rewriter.go b/pkg/planner/core/expression_rewriter.go index 4e5aa75339799..7c47717e0a773 100644 --- a/pkg/planner/core/expression_rewriter.go +++ b/pkg/planner/core/expression_rewriter.go @@ -35,6 +35,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/opcode" "github.com/pingcap/tidb/pkg/planner/core/base" "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" + "github.com/pingcap/tidb/pkg/planner/core/rule" "github.com/pingcap/tidb/pkg/planner/util/coreusage" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/table" @@ -1227,9 +1228,9 @@ func (er *expressionRewriter) handleInSubquery(ctx context.Context, planCtx *exp // and don't need to append a scalar value, we can rewrite it to inner join. if planCtx.builder.ctx.GetSessionVars().GetAllowInSubqToJoinAndAgg() && !v.Not && !asScalar && len(corCols) == 0 && collFlag { // We need to try to eliminate the agg and the projection produced by this operation. - planCtx.builder.optFlag |= flagEliminateAgg - planCtx.builder.optFlag |= flagEliminateProjection - planCtx.builder.optFlag |= flagJoinReOrder + planCtx.builder.optFlag |= rule.FlagEliminateAgg + planCtx.builder.optFlag |= rule.FlagEliminateProjection + planCtx.builder.optFlag |= rule.FlagJoinReOrder // Build distinct for the inner query. agg, err := planCtx.builder.buildDistinct(np, np.Schema().Len()) if err != nil { @@ -1375,7 +1376,7 @@ func (er *expressionRewriter) handleScalarSubquery(ctx context.Context, planCtx } func hasCTEConsumerInSubPlan(p base.LogicalPlan) bool { - if _, ok := p.(*LogicalCTE); ok { + if _, ok := p.(*logicalop.LogicalCTE); ok { return true } for _, child := range p.Children() { diff --git a/pkg/planner/core/find_best_task.go b/pkg/planner/core/find_best_task.go index 77f63158d879b..48334328d9623 100644 --- a/pkg/planner/core/find_best_task.go +++ b/pkg/planner/core/find_best_task.go @@ -2900,7 +2900,8 @@ func getOriginalPhysicalIndexScan(ds *DataSource, prop *property.PhysicalPropert return is } -func findBestTask4LogicalCTE(p *LogicalCTE, prop *property.PhysicalProperty, counter *base.PlanCounterTp, pop *optimizetrace.PhysicalOptimizeOp) (t base.Task, cntPlan int64, err error) { +func findBestTask4LogicalCTE(lp base.LogicalPlan, prop *property.PhysicalProperty, counter *base.PlanCounterTp, pop *optimizetrace.PhysicalOptimizeOp) (t base.Task, cntPlan int64, err error) { + p := lp.(*logicalop.LogicalCTE) if p.ChildLen() > 0 { return p.BaseLogicalPlan.FindBestTask(prop, counter, pop) } diff --git a/pkg/planner/core/logical_plan_builder.go b/pkg/planner/core/logical_plan_builder.go index 7304c3c1a0006..5e9ebd6c90729 100644 --- a/pkg/planner/core/logical_plan_builder.go +++ b/pkg/planner/core/logical_plan_builder.go @@ -45,6 +45,7 @@ import ( "github.com/pingcap/tidb/pkg/planner/core/base" core_metrics "github.com/pingcap/tidb/pkg/planner/core/metrics" "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" + "github.com/pingcap/tidb/pkg/planner/core/rule" "github.com/pingcap/tidb/pkg/planner/property" "github.com/pingcap/tidb/pkg/planner/util" "github.com/pingcap/tidb/pkg/planner/util/coreusage" @@ -132,7 +133,7 @@ func (a *aggOrderByResolver) Leave(inNode ast.Node) (ast.Node, bool) { func (b *PlanBuilder) buildExpand(p base.LogicalPlan, gbyItems []expression.Expression) (base.LogicalPlan, []expression.Expression, error) { ectx := p.SCtx().GetExprCtx().GetEvalCtx() - b.optFlag |= flagResolveExpand + b.optFlag |= rule.FlagResolveExpand // Rollup syntax require expand OP to do the data expansion, different data replica supply the different grouping layout. distinctGbyExprs, gbyExprsRefPos := expression.DeduplicateGbyExpression(gbyItems) @@ -243,19 +244,19 @@ func (b *PlanBuilder) buildExpand(p base.LogicalPlan, gbyItems []expression.Expr func (b *PlanBuilder) buildAggregation(ctx context.Context, p base.LogicalPlan, aggFuncList []*ast.AggregateFuncExpr, gbyItems []expression.Expression, correlatedAggMap map[*ast.AggregateFuncExpr]int) (base.LogicalPlan, map[int]int, error) { - b.optFlag |= flagBuildKeyInfo - b.optFlag |= flagPushDownAgg + b.optFlag |= rule.FlagBuildKeyInfo + b.optFlag |= rule.FlagPushDownAgg // We may apply aggregation eliminate optimization. // So we add the flagMaxMinEliminate to try to convert max/min to topn and flagPushDownTopN to handle the newly added topn operator. - b.optFlag |= flagMaxMinEliminate - b.optFlag |= flagPushDownTopN + b.optFlag |= rule.FlagMaxMinEliminate + b.optFlag |= rule.FlagPushDownTopN // when we eliminate the max and min we may add `is not null` filter. - b.optFlag |= flagPredicatePushDown - b.optFlag |= flagEliminateAgg - b.optFlag |= flagEliminateProjection + b.optFlag |= rule.FlagPredicatePushDown + b.optFlag |= rule.FlagEliminateAgg + b.optFlag |= rule.FlagEliminateProjection if b.ctx.GetSessionVars().EnableSkewDistinctAgg { - b.optFlag |= flagSkewDistinctAgg + b.optFlag |= rule.FlagSkewDistinctAgg } // flag it if cte contain aggregation if b.buildingCTE { @@ -439,7 +440,7 @@ func (b *PlanBuilder) buildResultSetNode(ctx context.Context, node ast.ResultSet case *ast.SelectStmt: ci := b.prepareCTECheckForSubQuery() defer resetCTECheckForSubQuery(ci) - b.optFlag = b.optFlag | flagConstantPropagation + b.optFlag = b.optFlag | rule.FlagConstantPropagation p, err = b.buildSelect(ctx, v) case *ast.SetOprStmt: ci := b.prepareCTECheckForSubQuery() @@ -553,11 +554,11 @@ func (b *PlanBuilder) buildJoin(ctx context.Context, joinNode *ast.Join) (base.L return b.buildResultSetNode(ctx, joinNode.Left, false) } - b.optFlag = b.optFlag | flagPredicatePushDown + b.optFlag = b.optFlag | rule.FlagPredicatePushDown // Add join reorder flag regardless of inner join or outer join. - b.optFlag = b.optFlag | flagJoinReOrder - b.optFlag |= flagPredicateSimplification - b.optFlag |= flagConvertOuterToInnerJoin + b.optFlag = b.optFlag | rule.FlagJoinReOrder + b.optFlag |= rule.FlagPredicateSimplification + b.optFlag |= rule.FlagConvertOuterToInnerJoin leftPlan, err := b.buildResultSetNode(ctx, joinNode.Left, false) if err != nil { @@ -589,12 +590,12 @@ func (b *PlanBuilder) buildJoin(ctx context.Context, joinNode *ast.Join) (base.L switch joinNode.Tp { case ast.LeftJoin: // left outer join need to be checked elimination - b.optFlag = b.optFlag | flagEliminateOuterJoin + b.optFlag = b.optFlag | rule.FlagEliminateOuterJoin joinPlan.JoinType = logicalop.LeftOuterJoin util.ResetNotNullFlag(joinPlan.Schema(), leftPlan.Schema().Len(), joinPlan.Schema().Len()) case ast.RightJoin: // right outer join need to be checked elimination - b.optFlag = b.optFlag | flagEliminateOuterJoin + b.optFlag = b.optFlag | rule.FlagEliminateOuterJoin joinPlan.JoinType = logicalop.RightOuterJoin util.ResetNotNullFlag(joinPlan.Schema(), 0, leftPlan.Schema().Len()) default: @@ -907,9 +908,9 @@ func (b *PlanBuilder) coalesceCommonColumns(p *logicalop.LogicalJoin, leftPlan, } func (b *PlanBuilder) buildSelection(ctx context.Context, p base.LogicalPlan, where ast.ExprNode, aggMapper map[*ast.AggregateFuncExpr]int) (base.LogicalPlan, error) { - b.optFlag |= flagPredicatePushDown - b.optFlag |= flagDeriveTopNFromWindow - b.optFlag |= flagPredicateSimplification + b.optFlag |= rule.FlagPredicatePushDown + b.optFlag |= rule.FlagDeriveTopNFromWindow + b.optFlag |= rule.FlagPredicateSimplification if b.curClause != havingClause { b.curClause = whereClause } @@ -1296,7 +1297,7 @@ func (b *PlanBuilder) buildProjection(ctx context.Context, p base.LogicalPlan, f if err != nil { return nil, nil, 0, err } - b.optFlag |= flagEliminateProjection + b.optFlag |= rule.FlagEliminateProjection b.curClause = fieldList proj := logicalop.LogicalProjection{Exprs: make([]expression.Expression, 0, len(fields))}.Init(b.ctx, b.getSelectOffset()) schema := expression.NewSchema(make([]*expression.Column, 0, len(fields))...) @@ -1490,8 +1491,8 @@ func (b *PlanBuilder) buildProjection(ctx context.Context, p base.LogicalPlan, f } func (b *PlanBuilder) buildDistinct(child base.LogicalPlan, length int) (*logicalop.LogicalAggregation, error) { - b.optFlag = b.optFlag | flagBuildKeyInfo - b.optFlag = b.optFlag | flagPushDownAgg + b.optFlag = b.optFlag | rule.FlagBuildKeyInfo + b.optFlag = b.optFlag | rule.FlagPushDownAgg plan4Agg := logicalop.LogicalAggregation{ AggFuncs: make([]*aggregation.AggFuncDesc, 0, child.Schema().Len()), GroupByItems: expression.Column2Exprs(child.Schema().Clone().Columns[:length]), @@ -1613,7 +1614,7 @@ func (b *PlanBuilder) buildProjection4Union(_ context.Context, u *logicalop.Logi exprs[i] = srcCol } } - b.optFlag |= flagEliminateProjection + b.optFlag |= rule.FlagEliminateProjection proj := logicalop.LogicalProjection{Exprs: exprs}.Init(b.ctx, b.getSelectOffset()) proj.SetSchema(u.Schema().Clone()) // reset the schema type to make the "not null" flag right. @@ -1725,7 +1726,7 @@ func (b *PlanBuilder) buildSemiJoinForSetOperator( if err != nil { return nil, err } - b.optFlag |= flagConvertOuterToInnerJoin + b.optFlag |= rule.FlagConvertOuterToInnerJoin joinPlan := logicalop.LogicalJoin{JoinType: joinType}.Init(b.ctx, b.getSelectOffset()) joinPlan.SetChildren(leftPlan, rightPlan) @@ -2086,7 +2087,7 @@ func extractLimitCountOffset(ctx expression.BuildContext, limit *ast.Limit) (cou } func (b *PlanBuilder) buildLimit(src base.LogicalPlan, limit *ast.Limit) (base.LogicalPlan, error) { - b.optFlag = b.optFlag | flagPushDownTopN + b.optFlag = b.optFlag | rule.FlagPushDownTopN var ( offset, count uint64 err error @@ -3966,7 +3967,7 @@ func (b *PlanBuilder) tryToBuildSequence(ctes []*cteInfo, p base.LogicalPlan) ba } lctes := make([]base.LogicalPlan, 0, len(ctes)+1) for _, cte := range ctes { - lcte := LogicalCTE{ + lcte := logicalop.LogicalCTE{ Cte: cte.cteClass, CteAsName: cte.def.Name, CteName: cte.def.Name, @@ -3976,7 +3977,7 @@ func (b *PlanBuilder) tryToBuildSequence(ctes []*cteInfo, p base.LogicalPlan) ba lcte.SetSchema(getResultCTESchema(cte.seedLP.Schema(), b.ctx.GetSessionVars())) lctes = append(lctes, lcte) } - b.optFlag |= flagPushDownSequence + b.optFlag |= rule.FlagPushDownSequence seq := logicalop.LogicalSequence{}.Init(b.ctx, b.getSelectOffset()) seq.SetChildren(append(lctes, p)...) seq.SetOutputNames(p.OutputNames().Shallow()) @@ -4205,7 +4206,7 @@ func (b *PlanBuilder) tryBuildCTE(ctx context.Context, tn *ast.TableName, asName } if cte.cteClass == nil { - cte.cteClass = &CTEClass{ + cte.cteClass = &logicalop.CTEClass{ IsDistinct: cte.isDistinct, SeedPartLogicalPlan: cte.seedLP, RecursivePartLogicalPlan: cte.recurLP, @@ -4219,7 +4220,7 @@ func (b *PlanBuilder) tryBuildCTE(ctx context.Context, tn *ast.TableName, asName } } var p base.LogicalPlan - lp := LogicalCTE{CteAsName: tn.Name, CteName: tn.Name, Cte: cte.cteClass, SeedStat: cte.seedStat}.Init(b.ctx, b.getSelectOffset()) + lp := logicalop.LogicalCTE{CteAsName: tn.Name, CteName: tn.Name, Cte: cte.cteClass, SeedStat: cte.seedStat}.Init(b.ctx, b.getSelectOffset()) prevSchema := cte.seedLP.Schema().Clone() lp.SetSchema(getResultCTESchema(cte.seedLP.Schema(), b.ctx.GetSessionVars())) @@ -4321,7 +4322,7 @@ func (b *PlanBuilder) buildDataSourceFromCTEMerge(ctx context.Context, cte *ast. } func (b *PlanBuilder) buildDataSource(ctx context.Context, tn *ast.TableName, asName *model.CIStr) (base.LogicalPlan, error) { - b.optFlag |= flagPredicateSimplification + b.optFlag |= rule.FlagPredicateSimplification dbName := tn.Schema sessionVars := b.ctx.GetSessionVars() @@ -4379,10 +4380,10 @@ func (b *PlanBuilder) buildDataSource(ctx context.Context, tn *ast.TableName, as // If `UseDynamicPruneMode` already been false, then we don't need to check whether execute `flagPartitionProcessor` // otherwise we need to check global stats initialized for each partition table if !b.ctx.GetSessionVars().IsDynamicPartitionPruneEnabled() { - b.optFlag = b.optFlag | flagPartitionProcessor + b.optFlag = b.optFlag | rule.FlagPartitionProcessor } else { if !b.ctx.GetSessionVars().StmtCtx.UseDynamicPruneMode { - b.optFlag = b.optFlag | flagPartitionProcessor + b.optFlag = b.optFlag | rule.FlagPartitionProcessor } else { h := domain.GetDomain(b.ctx).StatsHandle() tblStats := h.GetTableStats(tableInfo) @@ -4404,7 +4405,7 @@ func (b *PlanBuilder) buildDataSource(ctx context.Context, tn *ast.TableName, as }) if usePartitionProcessor { - b.optFlag = b.optFlag | flagPartitionProcessor + b.optFlag = b.optFlag | rule.FlagPartitionProcessor b.ctx.GetSessionVars().StmtCtx.UseDynamicPruneMode = false if isDynamicEnabled { b.ctx.GetSessionVars().StmtCtx.AppendWarning( @@ -4431,7 +4432,7 @@ func (b *PlanBuilder) buildDataSource(ctx context.Context, tn *ast.TableName, as return nil, plannererrors.ErrPartitionClauseOnNonpartitioned } - possiblePaths, err := getPossibleAccessPaths(b.ctx, b.TableHints(), tn.IndexHints, tbl, dbName, tblName, b.isForUpdateRead, b.optFlag&flagPartitionProcessor > 0) + possiblePaths, err := getPossibleAccessPaths(b.ctx, b.TableHints(), tn.IndexHints, tbl, dbName, tblName, b.isForUpdateRead, b.optFlag&rule.FlagPartitionProcessor > 0) if err != nil { return nil, err } @@ -4499,7 +4500,7 @@ func (b *PlanBuilder) buildDataSource(ctx context.Context, tn *ast.TableName, as for _, indexCol := range index.Columns { colInfo := tbl.Cols()[indexCol.Offset] if colInfo.IsGenerated() && !colInfo.GeneratedStored { - b.optFlag |= flagGcSubstitute + b.optFlag |= rule.FlagGcSubstitute break } } @@ -4683,7 +4684,7 @@ func (b *PlanBuilder) buildDataSource(ctx context.Context, tn *ast.TableName, as if dirty || tableInfo.TempTableType == model.TempTableLocal || tableInfo.TableCacheStatusType == model.TableCacheStatusEnable { us := logicalop.LogicalUnionScan{HandleCols: handleCols}.Init(b.ctx, b.getSelectOffset()) us.SetChildren(ds) - if tableInfo.Partition != nil && b.optFlag&flagPartitionProcessor == 0 { + if tableInfo.Partition != nil && b.optFlag&rule.FlagPartitionProcessor == 0 { // Adding ExtraPhysTblIDCol for UnionScan (transaction buffer handling) // Not using old static prune mode // Single TableReader for all partitions, needs the PhysTblID from storage @@ -5075,7 +5076,7 @@ func (b *PlanBuilder) buildProjUponView(_ context.Context, dbName model.CIStr, t // buildApplyWithJoinType builds apply plan with outerPlan and innerPlan, which apply join with particular join type for // every row from outerPlan and the whole innerPlan. func (b *PlanBuilder) buildApplyWithJoinType(outerPlan, innerPlan base.LogicalPlan, tp logicalop.JoinType, markNoDecorrelate bool) base.LogicalPlan { - b.optFlag = b.optFlag | flagPredicatePushDown | flagBuildKeyInfo | flagDecorrelate | flagConvertOuterToInnerJoin + b.optFlag = b.optFlag | rule.FlagPredicatePushDown | rule.FlagBuildKeyInfo | rule.FlagDecorrelate | rule.FlagConvertOuterToInnerJoin ap := logicalop.LogicalApply{LogicalJoin: logicalop.LogicalJoin{JoinType: tp}, NoDecorrelate: markNoDecorrelate}.Init(b.ctx, b.getSelectOffset()) ap.SetChildren(outerPlan, innerPlan) ap.SetOutputNames(make([]*types.FieldName, outerPlan.Schema().Len()+innerPlan.Schema().Len())) @@ -5084,7 +5085,7 @@ func (b *PlanBuilder) buildApplyWithJoinType(outerPlan, innerPlan base.LogicalPl setIsInApplyForCTE(innerPlan, ap.Schema()) // Note that, tp can only be LeftOuterJoin or InnerJoin, so we don't consider other outer joins. if tp == logicalop.LeftOuterJoin { - b.optFlag = b.optFlag | flagEliminateOuterJoin + b.optFlag = b.optFlag | rule.FlagEliminateOuterJoin util.ResetNotNullFlag(ap.Schema(), outerPlan.Schema().Len(), ap.Schema().Len()) } for i := outerPlan.Schema().Len(); i < ap.Schema().Len(); i++ { @@ -5097,7 +5098,7 @@ func (b *PlanBuilder) buildApplyWithJoinType(outerPlan, innerPlan base.LogicalPl // buildSemiApply builds apply plan with outerPlan and innerPlan, which apply semi-join for every row from outerPlan and the whole innerPlan. func (b *PlanBuilder) buildSemiApply(outerPlan, innerPlan base.LogicalPlan, condition []expression.Expression, asScalar, not, considerRewrite, markNoDecorrelate bool) (base.LogicalPlan, error) { - b.optFlag = b.optFlag | flagPredicatePushDown | flagBuildKeyInfo | flagDecorrelate + b.optFlag = b.optFlag | rule.FlagPredicatePushDown | rule.FlagBuildKeyInfo | rule.FlagDecorrelate join, err := b.buildSemiJoin(outerPlan, innerPlan, condition, asScalar, not, considerRewrite) if err != nil { @@ -5116,7 +5117,7 @@ func (b *PlanBuilder) buildSemiApply(outerPlan, innerPlan base.LogicalPlan, cond // It's better to handle this in CTEExec.Close(), but cte storage is closed when SQL is finished. func setIsInApplyForCTE(p base.LogicalPlan, apSchema *expression.Schema) { switch x := p.(type) { - case *LogicalCTE: + case *logicalop.LogicalCTE: if len(coreusage.ExtractCorColumnsBySchema4LogicalPlan(p, apSchema)) > 0 { x.Cte.IsInApply = true } @@ -5139,7 +5140,7 @@ func (b *PlanBuilder) buildMaxOneRow(p base.LogicalPlan) base.LogicalPlan { } func (b *PlanBuilder) buildSemiJoin(outerPlan, innerPlan base.LogicalPlan, onCondition []expression.Expression, asScalar, not, forceRewrite bool) (*logicalop.LogicalJoin, error) { - b.optFlag |= flagConvertOuterToInnerJoin + b.optFlag |= rule.FlagConvertOuterToInnerJoin joinPlan := logicalop.LogicalJoin{}.Init(b.ctx, b.getSelectOffset()) for i, expr := range onCondition { onCondition[i] = expr.Decorrelate(outerPlan.Schema()) @@ -5173,7 +5174,7 @@ func (b *PlanBuilder) buildSemiJoin(outerPlan, innerPlan base.LogicalPlan, onCon joinPlan.SetPreferredJoinTypeAndOrder(b.TableHints()) if forceRewrite { joinPlan.PreferJoinType |= h.PreferRewriteSemiJoin - b.optFlag |= flagSemiJoinRewrite + b.optFlag |= rule.FlagSemiJoinRewrite } return joinPlan, nil } @@ -5367,7 +5368,7 @@ func (b *PlanBuilder) buildUpdate(ctx context.Context, update *ast.UpdateStmt) ( updt.names = p.OutputNames() // We cannot apply projection elimination when building the subplan, because // columns in orderedList cannot be resolved. (^flagEliminateProjection should also be applied in postOptimize) - updt.SelectPlan, _, err = DoOptimize(ctx, b.ctx, b.optFlag&^flagEliminateProjection, p) + updt.SelectPlan, _, err = DoOptimize(ctx, b.ctx, b.optFlag&^rule.FlagEliminateProjection, p) if err != nil { return nil, err } @@ -5924,7 +5925,7 @@ func getWindowName(name string) string { // buildProjectionForWindow builds the projection for expressions in the window specification that is not an column, // so after the projection, window functions only needs to deal with columns. func (b *PlanBuilder) buildProjectionForWindow(ctx context.Context, p base.LogicalPlan, spec *ast.WindowSpec, args []ast.ExprNode, aggMap map[*ast.AggregateFuncExpr]int) (base.LogicalPlan, []property.SortItem, []property.SortItem, []expression.Expression, error) { - b.optFlag |= flagEliminateProjection + b.optFlag |= rule.FlagEliminateProjection var partitionItems, orderItems []*ast.ByItem if spec.PartitionBy != nil { @@ -5983,7 +5984,7 @@ func (b *PlanBuilder) buildProjectionForWindow(ctx context.Context, p base.Logic } func (b *PlanBuilder) buildArgs4WindowFunc(ctx context.Context, p base.LogicalPlan, args []ast.ExprNode, aggMap map[*ast.AggregateFuncExpr]int) ([]expression.Expression, error) { - b.optFlag |= flagEliminateProjection + b.optFlag |= rule.FlagEliminateProjection newArgList := make([]expression.Expression, 0, len(args)) // use below index for created a new col definition @@ -7137,7 +7138,7 @@ func (b *PlanBuilder) buildWith(ctx context.Context, w *ast.WithClause) ([]*cteI b.allocIDForCTEStorage++ saveFlag := b.optFlag // Init the flag to flagPrunColumns, otherwise it's missing. - b.optFlag = flagPrunColumns + b.optFlag = rule.FlagPruneColumns if b.ctx.GetSessionVars().EnableForceInlineCTE() { b.outerCTEs[len(b.outerCTEs)-1].forceInlineByHintOrVar = true } @@ -7168,7 +7169,7 @@ func (b *PlanBuilder) buildProjection4CTEUnion(_ context.Context, seed base.Logi exprs[i] = col } } - b.optFlag |= flagEliminateProjection + b.optFlag |= rule.FlagEliminateProjection proj := logicalop.LogicalProjection{Exprs: exprs}.Init(b.ctx, b.getSelectOffset()) proj.SetSchema(resSchema) proj.SetChildren(recur) diff --git a/pkg/planner/core/logical_plan_trace_test.go b/pkg/planner/core/logical_plan_trace_test.go index 531ed47319466..36774d60e7219 100644 --- a/pkg/planner/core/logical_plan_trace_test.go +++ b/pkg/planner/core/logical_plan_trace_test.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/tidb/pkg/config" "github.com/pingcap/tidb/pkg/domain" "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/core/rule" "github.com/pingcap/tidb/pkg/util/hint" "github.com/stretchr/testify/require" ) @@ -35,7 +36,7 @@ func TestSingleRuleTraceStep(t *testing.T) { }{ { sql: "select count(1) from t join (select count(1) from t where false) as tmp;", - flags: []uint64{flagPrunColumns}, + flags: []uint64{rule.FlagPruneColumns}, assertRuleName: "column_prune", assertRuleSteps: []assertTraceStep{ { @@ -66,7 +67,7 @@ func TestSingleRuleTraceStep(t *testing.T) { }, { sql: "select a from t where b > 5;", - flags: []uint64{flagPrunColumns}, + flags: []uint64{rule.FlagPruneColumns}, assertRuleName: "column_prune", assertRuleSteps: []assertTraceStep{ { @@ -77,7 +78,7 @@ func TestSingleRuleTraceStep(t *testing.T) { }, { sql: "select * from t as t1 where t1.a < (select sum(t2.a) from t as t2 where t2.b = t1.b);", - flags: []uint64{flagDecorrelate, flagBuildKeyInfo, flagPrunColumns}, + flags: []uint64{rule.FlagDecorrelate, rule.FlagBuildKeyInfo, rule.FlagPruneColumns}, assertRuleName: "decorrelate", assertRuleSteps: []assertTraceStep{ { @@ -104,7 +105,7 @@ func TestSingleRuleTraceStep(t *testing.T) { }, { sql: "select * from t as t1 join t as t2 on t1.a = t2.a where t1.a < 1;", - flags: []uint64{flagPredicatePushDown, flagBuildKeyInfo, flagPrunColumns}, + flags: []uint64{rule.FlagPredicatePushDown, rule.FlagBuildKeyInfo, rule.FlagPruneColumns}, assertRuleName: "predicate_push_down", assertRuleSteps: []assertTraceStep{ { @@ -127,7 +128,7 @@ func TestSingleRuleTraceStep(t *testing.T) { }, { sql: "select * from t where a < 1;", - flags: []uint64{flagPredicatePushDown, flagBuildKeyInfo, flagPrunColumns}, + flags: []uint64{rule.FlagPredicatePushDown, rule.FlagBuildKeyInfo, rule.FlagPruneColumns}, assertRuleName: "predicate_push_down", assertRuleSteps: []assertTraceStep{ { @@ -142,7 +143,7 @@ func TestSingleRuleTraceStep(t *testing.T) { }, { sql: "select * from t as t1 left join t as t2 on t1.a = t2.a order by t1.a limit 10;", - flags: []uint64{flagPrunColumns, flagBuildKeyInfo, flagPushDownTopN}, + flags: []uint64{rule.FlagPruneColumns, rule.FlagBuildKeyInfo, rule.FlagPushDownTopN}, assertRuleName: "topn_push_down", assertRuleSteps: []assertTraceStep{ { @@ -169,7 +170,7 @@ func TestSingleRuleTraceStep(t *testing.T) { }, { sql: "select * from t order by a limit 10", - flags: []uint64{flagPrunColumns, flagBuildKeyInfo, flagPushDownTopN}, + flags: []uint64{rule.FlagPruneColumns, rule.FlagBuildKeyInfo, rule.FlagPushDownTopN}, assertRuleName: "topn_push_down", assertRuleSteps: []assertTraceStep{ { @@ -188,7 +189,7 @@ func TestSingleRuleTraceStep(t *testing.T) { }, { sql: "select * from pt3 where ptn > 3;", - flags: []uint64{flagPartitionProcessor, flagPredicatePushDown, flagBuildKeyInfo, flagPrunColumns}, + flags: []uint64{rule.FlagPartitionProcessor, rule.FlagPredicatePushDown, rule.FlagBuildKeyInfo, rule.FlagPruneColumns}, assertRuleName: "partition_processor", assertRuleSteps: []assertTraceStep{ { @@ -199,7 +200,7 @@ func TestSingleRuleTraceStep(t *testing.T) { }, { sql: "select * from pt3 where ptn = 1;", - flags: []uint64{flagPartitionProcessor, flagPredicatePushDown, flagBuildKeyInfo, flagPrunColumns}, + flags: []uint64{rule.FlagPartitionProcessor, rule.FlagPredicatePushDown, rule.FlagBuildKeyInfo, rule.FlagPruneColumns}, assertRuleName: "partition_processor", assertRuleSteps: []assertTraceStep{ { @@ -210,7 +211,7 @@ func TestSingleRuleTraceStep(t *testing.T) { }, { sql: "select * from pt2 where ptn in (1,2,3);", - flags: []uint64{flagPartitionProcessor, flagPredicatePushDown, flagBuildKeyInfo, flagPrunColumns}, + flags: []uint64{rule.FlagPartitionProcessor, rule.FlagPredicatePushDown, rule.FlagBuildKeyInfo, rule.FlagPruneColumns}, assertRuleName: "partition_processor", assertRuleSteps: []assertTraceStep{ { @@ -221,7 +222,7 @@ func TestSingleRuleTraceStep(t *testing.T) { }, { sql: "select * from pt2 where ptn = 1;", - flags: []uint64{flagPartitionProcessor, flagPredicatePushDown, flagBuildKeyInfo, flagPrunColumns}, + flags: []uint64{rule.FlagPartitionProcessor, rule.FlagPredicatePushDown, rule.FlagBuildKeyInfo, rule.FlagPruneColumns}, assertRuleName: "partition_processor", assertRuleSteps: []assertTraceStep{ { @@ -232,7 +233,7 @@ func TestSingleRuleTraceStep(t *testing.T) { }, { sql: "select * from pt1 where ptn > 100;", - flags: []uint64{flagPartitionProcessor, flagPredicatePushDown, flagBuildKeyInfo, flagPrunColumns}, + flags: []uint64{rule.FlagPartitionProcessor, rule.FlagPredicatePushDown, rule.FlagBuildKeyInfo, rule.FlagPruneColumns}, assertRuleName: "partition_processor", assertRuleSteps: []assertTraceStep{ { @@ -243,7 +244,7 @@ func TestSingleRuleTraceStep(t *testing.T) { }, { sql: "select * from pt1 where ptn in (10,20);", - flags: []uint64{flagPartitionProcessor, flagPredicatePushDown, flagBuildKeyInfo, flagPrunColumns}, + flags: []uint64{rule.FlagPartitionProcessor, rule.FlagPredicatePushDown, rule.FlagBuildKeyInfo, rule.FlagPruneColumns}, assertRuleName: "partition_processor", assertRuleSteps: []assertTraceStep{ { @@ -254,7 +255,7 @@ func TestSingleRuleTraceStep(t *testing.T) { }, { sql: "select * from pt1 where ptn < 4;", - flags: []uint64{flagPartitionProcessor, flagPredicatePushDown, flagBuildKeyInfo, flagPrunColumns}, + flags: []uint64{rule.FlagPartitionProcessor, rule.FlagPredicatePushDown, rule.FlagBuildKeyInfo, rule.FlagPruneColumns}, assertRuleName: "partition_processor", assertRuleSteps: []assertTraceStep{ { @@ -265,7 +266,7 @@ func TestSingleRuleTraceStep(t *testing.T) { }, { sql: "select * from (t t1, t t2, t t3,t t4) union all select * from (t t5, t t6, t t7,t t8)", - flags: []uint64{flagBuildKeyInfo, flagPrunColumns, flagDecorrelate, flagPredicatePushDown, flagEliminateOuterJoin, flagJoinReOrder}, + flags: []uint64{rule.FlagBuildKeyInfo, rule.FlagPruneColumns, rule.FlagDecorrelate, rule.FlagPredicatePushDown, rule.FlagEliminateOuterJoin, rule.FlagJoinReOrder}, assertRuleName: "join_reorder", assertRuleSteps: []assertTraceStep{ { @@ -276,7 +277,7 @@ func TestSingleRuleTraceStep(t *testing.T) { }, { sql: "select * from t t1, t t2, t t3 where t1.a=t2.a and t3.a=t2.a and t1.a=t3.a", - flags: []uint64{flagBuildKeyInfo, flagPrunColumns, flagDecorrelate, flagPredicatePushDown, flagEliminateOuterJoin, flagJoinReOrder}, + flags: []uint64{rule.FlagBuildKeyInfo, rule.FlagPruneColumns, rule.FlagDecorrelate, rule.FlagPredicatePushDown, rule.FlagEliminateOuterJoin, rule.FlagJoinReOrder}, assertRuleName: "join_reorder", assertRuleSteps: []assertTraceStep{ { @@ -287,7 +288,7 @@ func TestSingleRuleTraceStep(t *testing.T) { }, { sql: "select min(distinct a) from t group by a", - flags: []uint64{flagBuildKeyInfo, flagEliminateAgg}, + flags: []uint64{rule.FlagBuildKeyInfo, rule.FlagEliminateAgg}, assertRuleName: "aggregation_eliminate", assertRuleSteps: []assertTraceStep{ { @@ -302,7 +303,7 @@ func TestSingleRuleTraceStep(t *testing.T) { }, { sql: "select 1+num from (select 1+a as num from t) t1;", - flags: []uint64{flagEliminateProjection}, + flags: []uint64{rule.FlagEliminateProjection}, assertRuleName: "projection_eliminate", assertRuleSteps: []assertTraceStep{ { @@ -313,7 +314,7 @@ func TestSingleRuleTraceStep(t *testing.T) { }, { sql: "select count(*) from t a , t b, t c", - flags: []uint64{flagBuildKeyInfo, flagPrunColumns, flagPushDownAgg}, + flags: []uint64{rule.FlagBuildKeyInfo, rule.FlagPruneColumns, rule.FlagPushDownAgg}, assertRuleName: "aggregation_push_down", assertRuleSteps: []assertTraceStep{ { @@ -324,7 +325,7 @@ func TestSingleRuleTraceStep(t *testing.T) { }, { sql: "select sum(c1) from (select c c1, d c2 from t a union all select a c1, b c2 from t b) x group by c2", - flags: []uint64{flagBuildKeyInfo, flagPrunColumns, flagPushDownAgg}, + flags: []uint64{rule.FlagBuildKeyInfo, rule.FlagPruneColumns, rule.FlagPushDownAgg}, assertRuleName: "aggregation_push_down", assertRuleSteps: []assertTraceStep{ { @@ -343,7 +344,7 @@ func TestSingleRuleTraceStep(t *testing.T) { }, { sql: "select max(a)-min(a) from t", - flags: []uint64{flagBuildKeyInfo, flagPrunColumns, flagMaxMinEliminate}, + flags: []uint64{rule.FlagBuildKeyInfo, rule.FlagPruneColumns, rule.FlagMaxMinEliminate}, assertRuleName: "max_min_eliminate", assertRuleSteps: []assertTraceStep{ { @@ -362,7 +363,7 @@ func TestSingleRuleTraceStep(t *testing.T) { }, { sql: "select max(e) from t", - flags: []uint64{flagBuildKeyInfo, flagPrunColumns, flagMaxMinEliminate}, + flags: []uint64{rule.FlagBuildKeyInfo, rule.FlagPruneColumns, rule.FlagMaxMinEliminate}, assertRuleName: "max_min_eliminate", assertRuleSteps: []assertTraceStep{ { @@ -373,7 +374,7 @@ func TestSingleRuleTraceStep(t *testing.T) { }, { sql: "select t1.b,t1.c from t as t1 left join t as t2 on t1.a = t2.a;", - flags: []uint64{flagBuildKeyInfo, flagEliminateOuterJoin}, + flags: []uint64{rule.FlagBuildKeyInfo, rule.FlagEliminateOuterJoin}, assertRuleName: "outer_join_eliminate", assertRuleSteps: []assertTraceStep{ { @@ -384,7 +385,7 @@ func TestSingleRuleTraceStep(t *testing.T) { }, { sql: "select count(distinct t1.a, t1.b) from t t1 left join t t2 on t1.b = t2.b", - flags: []uint64{flagPrunColumns, flagBuildKeyInfo, flagEliminateOuterJoin}, + flags: []uint64{rule.FlagPruneColumns, rule.FlagBuildKeyInfo, rule.FlagEliminateOuterJoin}, assertRuleName: "outer_join_eliminate", assertRuleSteps: []assertTraceStep{ { diff --git a/pkg/planner/core/logical_plans.go b/pkg/planner/core/logical_plans.go index 1ce76663b5398..6c303de79b9a1 100644 --- a/pkg/planner/core/logical_plans.go +++ b/pkg/planner/core/logical_plans.go @@ -42,7 +42,7 @@ var ( _ base.LogicalPlan = &logicalop.LogicalMemTable{} _ base.LogicalPlan = &logicalop.LogicalShow{} _ base.LogicalPlan = &logicalop.LogicalShowDDLJobs{} - _ base.LogicalPlan = &LogicalCTE{} + _ base.LogicalPlan = &logicalop.LogicalCTE{} _ base.LogicalPlan = &logicalop.LogicalCTETable{} _ base.LogicalPlan = &logicalop.LogicalSequence{} ) diff --git a/pkg/planner/core/logical_plans_test.go b/pkg/planner/core/logical_plans_test.go index 89f8e812dcb78..3cb0af0b486a9 100644 --- a/pkg/planner/core/logical_plans_test.go +++ b/pkg/planner/core/logical_plans_test.go @@ -34,6 +34,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/terror" "github.com/pingcap/tidb/pkg/planner/core/base" "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" + "github.com/pingcap/tidb/pkg/planner/core/rule" "github.com/pingcap/tidb/pkg/planner/property" "github.com/pingcap/tidb/pkg/planner/util" "github.com/pingcap/tidb/pkg/sessionctx" @@ -140,7 +141,7 @@ func TestPredicatePushDown(t *testing.T) { require.NoError(t, err, comment) p, err := BuildLogicalPlanForTest(ctx, s.sctx, stmt, s.is) require.NoError(t, err) - p, err = logicalOptimize(context.TODO(), flagConvertOuterToInnerJoin|flagPredicatePushDown|flagDecorrelate|flagPrunColumns|flagPrunColumnsAgain, p.(base.LogicalPlan)) + p, err = logicalOptimize(context.TODO(), rule.FlagConvertOuterToInnerJoin|rule.FlagPredicatePushDown|rule.FlagDecorrelate|rule.FlagPruneColumns|rule.FlagPruneColumnsAgain, p.(base.LogicalPlan)) require.NoError(t, err) testdata.OnRecord(func() { output[ith] = ToString(p) @@ -160,7 +161,7 @@ func TestImplicitCastNotNullFlag(t *testing.T) { require.NoError(t, err, comment) p, err := BuildLogicalPlanForTest(ctx, s.sctx, stmt, s.is) require.NoError(t, err) - p, err = logicalOptimize(context.TODO(), flagPredicatePushDown|flagJoinReOrder|flagPrunColumns|flagEliminateProjection, p.(base.LogicalPlan)) + p, err = logicalOptimize(context.TODO(), rule.FlagPredicatePushDown|rule.FlagJoinReOrder|rule.FlagPruneColumns|rule.FlagEliminateProjection, p.(base.LogicalPlan)) require.NoError(t, err) // AggFuncs[0] is count; AggFuncs[1] is bit_and, args[0] is return type of the implicit cast castNotNullFlag := (p.(*logicalop.LogicalProjection).Children()[0].(*logicalop.LogicalSelection).Children()[0].(*logicalop.LogicalAggregation).AggFuncs[1].Args[0].GetType(s.ctx.GetExprCtx().GetEvalCtx()).GetFlag()) & mysql.NotNullFlag @@ -178,7 +179,7 @@ func TestEliminateProjectionUnderUnion(t *testing.T) { require.NoError(t, err, comment) p, err := BuildLogicalPlanForTest(ctx, s.sctx, stmt, s.is) require.NoError(t, err) - p, err = logicalOptimize(context.TODO(), flagPredicatePushDown|flagJoinReOrder|flagPrunColumns|flagEliminateProjection, p.(base.LogicalPlan)) + p, err = logicalOptimize(context.TODO(), rule.FlagPredicatePushDown|rule.FlagJoinReOrder|rule.FlagPruneColumns|rule.FlagEliminateProjection, p.(base.LogicalPlan)) require.NoError(t, err) // after folding constants, the null flag should keep the same with the old one's (i.e., the schema's). schemaNullFlag := p.(*logicalop.LogicalProjection).Children()[0].(*logicalop.LogicalJoin).Children()[1].Children()[1].(*logicalop.LogicalProjection).Schema().Columns[0].RetType.GetFlag() & mysql.NotNullFlag @@ -206,7 +207,7 @@ func TestJoinPredicatePushDown(t *testing.T) { require.NoError(t, err, comment) p, err := BuildLogicalPlanForTest(ctx, s.sctx, stmt, s.is) require.NoError(t, err, comment) - p, err = logicalOptimize(context.TODO(), flagPredicatePushDown|flagDecorrelate|flagPrunColumns|flagPrunColumnsAgain, p.(base.LogicalPlan)) + p, err = logicalOptimize(context.TODO(), rule.FlagPredicatePushDown|rule.FlagDecorrelate|rule.FlagPruneColumns|rule.FlagPruneColumnsAgain, p.(base.LogicalPlan)) require.NoError(t, err, comment) proj, ok := p.(*logicalop.LogicalProjection) require.True(t, ok, comment) @@ -247,7 +248,7 @@ func TestOuterWherePredicatePushDown(t *testing.T) { require.NoError(t, err, comment) p, err := BuildLogicalPlanForTest(ctx, s.sctx, stmt, s.is) require.NoError(t, err, comment) - p, err = logicalOptimize(context.TODO(), flagPredicatePushDown|flagDecorrelate|flagPrunColumns|flagPrunColumnsAgain, p.(base.LogicalPlan)) + p, err = logicalOptimize(context.TODO(), rule.FlagPredicatePushDown|rule.FlagDecorrelate|rule.FlagPruneColumns|rule.FlagPruneColumnsAgain, p.(base.LogicalPlan)) require.NoError(t, err, comment) proj, ok := p.(*logicalop.LogicalProjection) require.True(t, ok, comment) @@ -293,7 +294,7 @@ func TestSimplifyOuterJoin(t *testing.T) { require.NoError(t, err, comment) p, err := BuildLogicalPlanForTest(ctx, s.sctx, stmt, s.is) require.NoError(t, err, comment) - p, err = logicalOptimize(context.TODO(), flagPredicatePushDown|flagPrunColumns|flagPrunColumnsAgain|flagConvertOuterToInnerJoin, p.(base.LogicalPlan)) + p, err = logicalOptimize(context.TODO(), rule.FlagPredicatePushDown|rule.FlagPruneColumns|rule.FlagPruneColumnsAgain|rule.FlagConvertOuterToInnerJoin, p.(base.LogicalPlan)) require.NoError(t, err, comment) planString := ToString(p) testdata.OnRecord(func() { @@ -334,7 +335,7 @@ func TestAntiSemiJoinConstFalse(t *testing.T) { require.NoError(t, err, comment) p, err := BuildLogicalPlanForTest(ctx, s.sctx, stmt, s.is) require.NoError(t, err, comment) - p, err = logicalOptimize(context.TODO(), flagDecorrelate|flagPredicatePushDown|flagPrunColumns|flagPrunColumnsAgain, p.(base.LogicalPlan)) + p, err = logicalOptimize(context.TODO(), rule.FlagDecorrelate|rule.FlagPredicatePushDown|rule.FlagPruneColumns|rule.FlagPruneColumnsAgain, p.(base.LogicalPlan)) require.NoError(t, err, comment) require.Equal(t, ca.best, ToString(p), comment) join, _ := p.(base.LogicalPlan).Children()[0].(*logicalop.LogicalJoin) @@ -363,7 +364,7 @@ func TestDeriveNotNullConds(t *testing.T) { require.NoError(t, err, comment) p, err := BuildLogicalPlanForTest(ctx, s.sctx, stmt, s.is) require.NoError(t, err, comment) - p, err = logicalOptimize(context.TODO(), flagPredicatePushDown|flagPrunColumns|flagPrunColumnsAgain|flagDecorrelate, p.(base.LogicalPlan)) + p, err = logicalOptimize(context.TODO(), rule.FlagPredicatePushDown|rule.FlagPruneColumns|rule.FlagPruneColumnsAgain|rule.FlagDecorrelate, p.(base.LogicalPlan)) require.NoError(t, err, comment) testdata.OnRecord(func() { output[i].Plan = ToString(p) @@ -475,7 +476,7 @@ func TestDupRandJoinCondsPushDown(t *testing.T) { require.NoError(t, err, comment) p, err := BuildLogicalPlanForTest(context.Background(), s.sctx, stmt, s.is) require.NoError(t, err, comment) - p, err = logicalOptimize(context.TODO(), flagPredicatePushDown, p.(base.LogicalPlan)) + p, err = logicalOptimize(context.TODO(), rule.FlagPredicatePushDown, p.(base.LogicalPlan)) require.NoError(t, err, comment) proj, ok := p.(*logicalop.LogicalProjection) require.True(t, ok, comment) @@ -545,7 +546,7 @@ func TestTablePartition(t *testing.T) { }) p, err := BuildLogicalPlanForTest(ctx, s.sctx, stmt, isChoices[ca.IsIdx]) require.NoError(t, err) - p, err = logicalOptimize(context.TODO(), flagDecorrelate|flagPrunColumns|flagPrunColumnsAgain|flagPredicatePushDown|flagPartitionProcessor, p.(base.LogicalPlan)) + p, err = logicalOptimize(context.TODO(), rule.FlagDecorrelate|rule.FlagPruneColumns|rule.FlagPruneColumnsAgain|rule.FlagPredicatePushDown|rule.FlagPartitionProcessor, p.(base.LogicalPlan)) require.NoError(t, err) planString := ToString(p) testdata.OnRecord(func() { @@ -572,7 +573,7 @@ func TestSubquery(t *testing.T) { p, err := BuildLogicalPlanForTest(ctx, s.sctx, stmt, s.is) require.NoError(t, err) if lp, ok := p.(base.LogicalPlan); ok { - p, err = logicalOptimize(context.TODO(), flagBuildKeyInfo|flagDecorrelate|flagPrunColumns|flagPrunColumnsAgain|flagSemiJoinRewrite, lp) + p, err = logicalOptimize(context.TODO(), rule.FlagBuildKeyInfo|rule.FlagDecorrelate|rule.FlagPruneColumns|rule.FlagPruneColumnsAgain|rule.FlagSemiJoinRewrite, lp) require.NoError(t, err) } testdata.OnRecord(func() { @@ -601,7 +602,7 @@ func TestPlanBuilder(t *testing.T) { p, err := BuildLogicalPlanForTest(ctx, s.sctx, stmt, s.is) require.NoError(t, err) if lp, ok := p.(base.LogicalPlan); ok { - p, err = logicalOptimize(context.TODO(), flagPrunColumns|flagPrunColumnsAgain, lp) + p, err = logicalOptimize(context.TODO(), rule.FlagPruneColumns|rule.FlagPruneColumnsAgain, lp) require.NoError(t, err) } testdata.OnRecord(func() { @@ -625,7 +626,7 @@ func TestJoinReOrder(t *testing.T) { p, err := BuildLogicalPlanForTest(ctx, s.sctx, stmt, s.is) require.NoError(t, err) - p, err = logicalOptimize(context.TODO(), flagPredicatePushDown|flagJoinReOrder, p.(base.LogicalPlan)) + p, err = logicalOptimize(context.TODO(), rule.FlagPredicatePushDown|rule.FlagJoinReOrder, p.(base.LogicalPlan)) require.NoError(t, err) planString := ToString(p) testdata.OnRecord(func() { @@ -654,7 +655,7 @@ func TestEagerAggregation(t *testing.T) { p, err := BuildLogicalPlanForTest(ctx, s.sctx, stmt, s.is) require.NoError(t, err) - p, err = logicalOptimize(context.TODO(), flagBuildKeyInfo|flagPredicatePushDown|flagPrunColumns|flagPrunColumnsAgain|flagPushDownAgg, p.(base.LogicalPlan)) + p, err = logicalOptimize(context.TODO(), rule.FlagBuildKeyInfo|rule.FlagPredicatePushDown|rule.FlagPruneColumns|rule.FlagPruneColumnsAgain|rule.FlagPushDownAgg, p.(base.LogicalPlan)) require.NoError(t, err) testdata.OnRecord(func() { output[ith] = ToString(p) @@ -680,7 +681,7 @@ func TestColumnPruning(t *testing.T) { p, err := BuildLogicalPlanForTest(ctx, s.sctx, stmt, s.is) require.NoError(t, err) - lp, err := logicalOptimize(ctx, flagPredicatePushDown|flagPrunColumns|flagPrunColumnsAgain, p.(base.LogicalPlan)) + lp, err := logicalOptimize(ctx, rule.FlagPredicatePushDown|rule.FlagPruneColumns|rule.FlagPruneColumnsAgain, p.(base.LogicalPlan)) require.NoError(t, err) testdata.OnRecord(func() { output[i] = make(map[int][]string) @@ -709,7 +710,7 @@ func TestSortByItemsPruning(t *testing.T) { p, err := BuildLogicalPlanForTest(ctx, s.sctx, stmt, s.is) require.NoError(t, err) - lp, err := logicalOptimize(ctx, flagEliminateProjection|flagPredicatePushDown|flagPrunColumns|flagPrunColumnsAgain, p.(base.LogicalPlan)) + lp, err := logicalOptimize(ctx, rule.FlagEliminateProjection|rule.FlagPredicatePushDown|rule.FlagPruneColumns|rule.FlagPruneColumnsAgain, p.(base.LogicalPlan)) require.NoError(t, err) checkOrderByItems(lp, t, &output[i], comment) } @@ -739,7 +740,7 @@ func TestProjectionEliminator(t *testing.T) { p, err := BuildLogicalPlanForTest(ctx, s.sctx, stmt, s.is) require.NoError(t, err) - p, err = logicalOptimize(context.TODO(), flagBuildKeyInfo|flagPrunColumns|flagPrunColumnsAgain|flagEliminateProjection, p.(base.LogicalPlan)) + p, err = logicalOptimize(context.TODO(), rule.FlagBuildKeyInfo|rule.FlagPruneColumns|rule.FlagPruneColumnsAgain|rule.FlagEliminateProjection, p.(base.LogicalPlan)) require.NoError(t, err) require.Equal(t, tt.best, ToString(p), fmt.Sprintf("for %s %d", tt.sql, ith)) } @@ -753,7 +754,7 @@ func TestCS3389(t *testing.T) { require.NoError(t, err) p, err := BuildLogicalPlanForTest(ctx, s.sctx, stmt, s.is) require.NoError(t, err) - p, err = logicalOptimize(context.TODO(), flagBuildKeyInfo|flagPrunColumns|flagPrunColumnsAgain|flagEliminateProjection|flagJoinReOrder, p.(base.LogicalPlan)) + p, err = logicalOptimize(context.TODO(), rule.FlagBuildKeyInfo|rule.FlagPruneColumns|rule.FlagPruneColumnsAgain|rule.FlagEliminateProjection|rule.FlagJoinReOrder, p.(base.LogicalPlan)) require.NoError(t, err) // Assert that all Projection is not empty and there is no Projection between Aggregation and Join. @@ -1062,7 +1063,7 @@ func TestUniqueKeyInfo(t *testing.T) { p, err := BuildLogicalPlanForTest(ctx, s.sctx, stmt, s.is) require.NoError(t, err) - lp, err := logicalOptimize(context.TODO(), flagPredicatePushDown|flagPrunColumns|flagBuildKeyInfo, p.(base.LogicalPlan)) + lp, err := logicalOptimize(context.TODO(), rule.FlagPredicatePushDown|rule.FlagPruneColumns|rule.FlagBuildKeyInfo, p.(base.LogicalPlan)) require.NoError(t, err) testdata.OnRecord(func() { output[ith] = make(map[int][][]string) @@ -1086,7 +1087,7 @@ func TestAggPrune(t *testing.T) { p, err := BuildLogicalPlanForTest(ctx, s.sctx, stmt, s.is) require.NoError(t, err) - p, err = logicalOptimize(context.TODO(), flagPredicatePushDown|flagPrunColumns|flagPrunColumnsAgain|flagBuildKeyInfo|flagEliminateAgg|flagEliminateProjection, p.(base.LogicalPlan)) + p, err = logicalOptimize(context.TODO(), rule.FlagPredicatePushDown|rule.FlagPruneColumns|rule.FlagPruneColumnsAgain|rule.FlagBuildKeyInfo|rule.FlagEliminateAgg|rule.FlagEliminateProjection, p.(base.LogicalPlan)) require.NoError(t, err) planString := ToString(p) testdata.OnRecord(func() { @@ -2191,7 +2192,7 @@ func TestResolvingCorrelatedAggregate(t *testing.T) { require.NoError(t, err, comment) p, err := BuildLogicalPlanForTest(ctx, s.sctx, stmt, s.is) require.NoError(t, err, comment) - p, err = logicalOptimize(context.TODO(), flagBuildKeyInfo|flagEliminateProjection|flagPrunColumns|flagPrunColumnsAgain, p.(base.LogicalPlan)) + p, err = logicalOptimize(context.TODO(), rule.FlagBuildKeyInfo|rule.FlagEliminateProjection|rule.FlagPruneColumns|rule.FlagPruneColumnsAgain, p.(base.LogicalPlan)) require.NoError(t, err, comment) require.Equal(t, tt.best, ToString(p), comment) } @@ -2348,7 +2349,7 @@ func TestRollupExpand(t *testing.T) { require.Equal(t, builder.currentBlockExpand.DistinctSize, 3) require.Equal(t, len(builder.currentBlockExpand.DistinctGroupByCol), 2) - _, err = logicalOptimize(context.TODO(), flagPredicatePushDown|flagJoinReOrder|flagPrunColumns|flagEliminateProjection|flagResolveExpand, p.(base.LogicalPlan)) + _, err = logicalOptimize(context.TODO(), rule.FlagPredicatePushDown|rule.FlagJoinReOrder|rule.FlagPruneColumns|rule.FlagEliminateProjection|rule.FlagResolveExpand, p.(base.LogicalPlan)) require.NoError(t, err) expand := builder.currentBlockExpand diff --git a/pkg/planner/core/operator/logicalop/BUILD.bazel b/pkg/planner/core/operator/logicalop/BUILD.bazel index c58bffccc93d7..4cb48ff7463ab 100644 --- a/pkg/planner/core/operator/logicalop/BUILD.bazel +++ b/pkg/planner/core/operator/logicalop/BUILD.bazel @@ -6,6 +6,7 @@ go_library( "base_logical_plan.go", "logical_aggregation.go", "logical_apply.go", + "logical_cte.go", "logical_cte_table.go", "logical_expand.go", "logical_join.go", diff --git a/pkg/planner/core/logical_cte.go b/pkg/planner/core/operator/logicalop/logical_cte.go similarity index 92% rename from pkg/planner/core/logical_cte.go rename to pkg/planner/core/operator/logicalop/logical_cte.go index 8bb9aceaa54d6..c387f2be2bfc5 100644 --- a/pkg/planner/core/logical_cte.go +++ b/pkg/planner/core/operator/logicalop/logical_cte.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package core +package logicalop import ( "context" @@ -22,18 +22,18 @@ import ( "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/planner/cardinality" "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" ruleutil "github.com/pingcap/tidb/pkg/planner/core/rule/util" "github.com/pingcap/tidb/pkg/planner/property" "github.com/pingcap/tidb/pkg/planner/util/coreusage" "github.com/pingcap/tidb/pkg/planner/util/optimizetrace" + "github.com/pingcap/tidb/pkg/planner/util/utilfuncp" "github.com/pingcap/tidb/pkg/util/plancodec" "github.com/pingcap/tidb/pkg/util/size" ) // LogicalCTE is for CTE. type LogicalCTE struct { - logicalop.LogicalSchemaProducer + LogicalSchemaProducer Cte *CTEClass CteAsName model.CIStr @@ -45,7 +45,7 @@ type LogicalCTE struct { // Init only assigns type and context. func (p LogicalCTE) Init(ctx base.PlanContext, offset int) *LogicalCTE { - p.BaseLogicalPlan = logicalop.NewBaseLogicalPlan(ctx, plancodec.TypeCTE, &p, offset) + p.BaseLogicalPlan = NewBaseLogicalPlan(ctx, plancodec.TypeCTE, &p, offset) return &p } @@ -147,16 +147,16 @@ func (p *LogicalCTE) PruneColumns(_ []*expression.Column, _ *optimizetrace.Logic // FindBestTask implements the base.LogicalPlan.<3rd> interface. func (p *LogicalCTE) FindBestTask(prop *property.PhysicalProperty, counter *base.PlanCounterTp, pop *optimizetrace.PhysicalOptimizeOp) (t base.Task, cntPlan int64, err error) { - return findBestTask4LogicalCTE(p, prop, counter, pop) + return utilfuncp.FindBestTask4LogicalCTE(p, prop, counter, pop) } // BuildKeyInfo inherits the BaseLogicalPlan.<4th> implementation. // PushDownTopN implements the base.LogicalPlan.<5th> interface. func (p *LogicalCTE) PushDownTopN(topNLogicalPlan base.LogicalPlan, opt *optimizetrace.LogicalOptimizeOp) base.LogicalPlan { - var topN *logicalop.LogicalTopN + var topN *LogicalTopN if topNLogicalPlan != nil { - topN = topNLogicalPlan.(*logicalop.LogicalTopN) + topN = topNLogicalPlan.(*LogicalTopN) } if topN != nil { return topN.AttachChild(p, opt) @@ -185,12 +185,12 @@ func (p *LogicalCTE) DeriveStats(_ []*property.StatsInfo, selfSchema *expression // Build push-downed predicates. if len(p.Cte.PushDownPredicates) > 0 { newCond := expression.ComposeDNFCondition(p.SCtx().GetExprCtx(), p.Cte.PushDownPredicates...) - newSel := logicalop.LogicalSelection{Conditions: []expression.Expression{newCond}}.Init(p.SCtx(), p.Cte.SeedPartLogicalPlan.QueryBlockOffset()) + newSel := LogicalSelection{Conditions: []expression.Expression{newCond}}.Init(p.SCtx(), p.Cte.SeedPartLogicalPlan.QueryBlockOffset()) newSel.SetChildren(p.Cte.SeedPartLogicalPlan) p.Cte.SeedPartLogicalPlan = newSel - p.Cte.OptFlag |= flagPredicatePushDown + p.Cte.OptFlag = ruleutil.SetPredicatePushDownFlag(p.Cte.OptFlag) } - p.Cte.SeedPartLogicalPlan, p.Cte.SeedPartPhysicalPlan, _, err = doOptimize(context.TODO(), p.SCtx(), p.Cte.OptFlag, p.Cte.SeedPartLogicalPlan) + p.Cte.SeedPartLogicalPlan, p.Cte.SeedPartPhysicalPlan, _, err = utilfuncp.DoOptimize(context.TODO(), p.SCtx(), p.Cte.OptFlag, p.Cte.SeedPartLogicalPlan) if err != nil { return nil, err } @@ -210,7 +210,7 @@ func (p *LogicalCTE) DeriveStats(_ []*property.StatsInfo, selfSchema *expression } if p.Cte.RecursivePartLogicalPlan != nil { if p.Cte.RecursivePartPhysicalPlan == nil { - p.Cte.RecursivePartPhysicalPlan, _, err = DoOptimize(context.TODO(), p.SCtx(), p.Cte.OptFlag, p.Cte.RecursivePartLogicalPlan) + _, p.Cte.RecursivePartPhysicalPlan, _, err = utilfuncp.DoOptimize(context.TODO(), p.SCtx(), p.Cte.OptFlag, p.Cte.RecursivePartLogicalPlan) if err != nil { return nil, err } @@ -234,7 +234,7 @@ func (p *LogicalCTE) DeriveStats(_ []*property.StatsInfo, selfSchema *expression // ExhaustPhysicalPlans implements the base.LogicalPlan.<14th> interface. func (p *LogicalCTE) ExhaustPhysicalPlans(prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { - return exhaustPhysicalPlans4LogicalCTE(p, prop) + return utilfuncp.ExhaustPhysicalPlans4LogicalCTE(p, prop) } // ExtractCorrelatedCols implements the base.LogicalPlan.<15th> interface. diff --git a/pkg/planner/core/optimizer.go b/pkg/planner/core/optimizer.go index ebdf33444df6b..27d8545a03b8c 100644 --- a/pkg/planner/core/optimizer.go +++ b/pkg/planner/core/optimizer.go @@ -69,36 +69,6 @@ var AllowCartesianProduct = atomic.NewBool(true) // IsReadOnly check whether the ast.Node is a read only statement. var IsReadOnly func(node ast.Node, vars *variable.SessionVars) bool -// Note: The order of flags is same as the order of optRule in the list. -// Do not mess up the order. -const ( - flagGcSubstitute uint64 = 1 << iota - flagPrunColumns - flagStabilizeResults - flagBuildKeyInfo - flagDecorrelate - flagSemiJoinRewrite - flagEliminateAgg - flagSkewDistinctAgg - flagEliminateProjection - flagMaxMinEliminate - flagConstantPropagation - flagConvertOuterToInnerJoin - flagPredicatePushDown - flagEliminateOuterJoin - flagPartitionProcessor - flagCollectPredicateColumnsPoint - flagPushDownAgg - flagDeriveTopNFromWindow - flagPredicateSimplification - flagPushDownTopN - flagSyncWaitStatsLoadPoint - flagJoinReOrder - flagPrunColumnsAgain - flagPushDownSequence - flagResolveExpand -) - var optRuleList = []base.LogicalOptRule{ &GcSubstituter{}, &ColumnPruner{}, @@ -311,21 +281,21 @@ func doOptimize( } func adjustOptimizationFlags(flag uint64, logic base.LogicalPlan) uint64 { - // If there is something after flagPrunColumns, do flagPrunColumnsAgain. - if flag&flagPrunColumns > 0 && flag-flagPrunColumns > flagPrunColumns { - flag |= flagPrunColumnsAgain + // If there is something after flagPrunColumns, do FlagPruneColumnsAgain. + if flag&rule.FlagPruneColumns > 0 && flag-rule.FlagPruneColumns > rule.FlagPruneColumns { + flag |= rule.FlagPruneColumnsAgain } if checkStableResultMode(logic.SCtx()) { - flag |= flagStabilizeResults + flag |= rule.FlagStabilizeResults } if logic.SCtx().GetSessionVars().StmtCtx.StraightJoinOrder { // When we use the straight Join Order hint, we should disable the join reorder optimization. - flag &= ^flagJoinReOrder + flag &= ^rule.FlagJoinReOrder } - flag |= flagCollectPredicateColumnsPoint - flag |= flagSyncWaitStatsLoadPoint + flag |= rule.FlagCollectPredicateColumnsPoint + flag |= rule.FlagSyncWaitStatsLoadPoint if !logic.SCtx().GetSessionVars().StmtCtx.UseDynamicPruneMode { - flag |= flagPartitionProcessor // apply partition pruning under static mode + flag |= rule.FlagPartitionProcessor // apply partition pruning under static mode } return flag } diff --git a/pkg/planner/core/physical_plans.go b/pkg/planner/core/physical_plans.go index f988fd11c19f3..cbf7d5bb8dc88 100644 --- a/pkg/planner/core/physical_plans.go +++ b/pkg/planner/core/physical_plans.go @@ -2679,7 +2679,7 @@ type PhysicalCTE struct { SeedPlan base.PhysicalPlan RecurPlan base.PhysicalPlan - CTE *CTEClass + CTE *logicalop.CTEClass cteAsName model.CIStr cteName model.CIStr diff --git a/pkg/planner/core/plan_cache_utils.go b/pkg/planner/core/plan_cache_utils.go index 13b8565b09c82..1048821f6d64a 100644 --- a/pkg/planner/core/plan_cache_utils.go +++ b/pkg/planner/core/plan_cache_utils.go @@ -35,6 +35,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/core/rule" "github.com/pingcap/tidb/pkg/planner/util" "github.com/pingcap/tidb/pkg/planner/util/fixcontrol" "github.com/pingcap/tidb/pkg/sessionctx" @@ -174,7 +175,7 @@ func GeneratePlanCacheStmtWithAST(ctx context.Context, sctx sessionctx.Context, return nil, nil, 0, err } - if cacheable && destBuilder.optFlag&flagPartitionProcessor > 0 { + if cacheable && destBuilder.optFlag&rule.FlagPartitionProcessor > 0 { // dynamic prune mode is not used, could be that global statistics not yet available! cacheable = false reason = "static partition prune mode used" diff --git a/pkg/planner/core/planbuilder.go b/pkg/planner/core/planbuilder.go index 304a4f34cd6cf..254babce79def 100644 --- a/pkg/planner/core/planbuilder.go +++ b/pkg/planner/core/planbuilder.go @@ -42,6 +42,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/terror" "github.com/pingcap/tidb/pkg/planner/core/base" "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" + "github.com/pingcap/tidb/pkg/planner/core/rule" "github.com/pingcap/tidb/pkg/planner/property" "github.com/pingcap/tidb/pkg/planner/util" "github.com/pingcap/tidb/pkg/privilege" @@ -168,7 +169,7 @@ type cteInfo struct { // seedStat is shared between logicalCTE and logicalCTETable. seedStat *property.StatsInfo // The LogicalCTEs that reference the same table should share the same CteClass. - cteClass *CTEClass + cteClass *logicalop.CTEClass // isInline will determine whether it can be inlined when **CTE is used** isInline bool @@ -490,7 +491,7 @@ func (b *PlanBuilder) ResetForReuse() *PlanBuilder { // Build builds the ast node to a Plan. func (b *PlanBuilder) Build(ctx context.Context, node ast.Node) (base.Plan, error) { - b.optFlag |= flagPrunColumns + b.optFlag |= rule.FlagPruneColumns switch x := node.(type) { case *ast.AdminStmt: return b.buildAdmin(ctx, x) @@ -3413,7 +3414,7 @@ func (b *PlanBuilder) buildShow(ctx context.Context, show *ast.ShowStmt) (base.P } } if np != p { - b.optFlag |= flagEliminateProjection + b.optFlag |= rule.FlagEliminateProjection fieldsLen := len(p.Schema().Columns) proj := logicalop.LogicalProjection{Exprs: make([]expression.Expression, 0, fieldsLen)}.Init(b.ctx, 0) schema := expression.NewSchema(make([]*expression.Column, 0, fieldsLen)...) diff --git a/pkg/planner/core/recheck_cte.go b/pkg/planner/core/recheck_cte.go index 6fce7b61faa35..096b21a9307e8 100644 --- a/pkg/planner/core/recheck_cte.go +++ b/pkg/planner/core/recheck_cte.go @@ -16,6 +16,7 @@ package core import ( "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" "github.com/pingcap/tidb/pkg/util/intset" ) @@ -32,7 +33,7 @@ func findCTEs( visited *intset.FastIntSet, isRootTree bool, ) { - if cteReader, ok := p.(*LogicalCTE); ok { + if cteReader, ok := p.(*logicalop.LogicalCTE); ok { cte := cteReader.Cte if !isRootTree { // Set it to false since it's referenced by other CTEs. diff --git a/pkg/planner/core/rule/BUILD.bazel b/pkg/planner/core/rule/BUILD.bazel index e37b3eaaf3852..ba59f45745e39 100644 --- a/pkg/planner/core/rule/BUILD.bazel +++ b/pkg/planner/core/rule/BUILD.bazel @@ -3,6 +3,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library") go_library( name = "rule", srcs = [ + "logical_rules.go", "rule_build_key_info.go", "rule_constant_propagation.go", "rule_init.go", diff --git a/pkg/planner/core/rule/logical_rules.go b/pkg/planner/core/rule/logical_rules.go new file mode 100644 index 0000000000000..0d84115b4a87d --- /dev/null +++ b/pkg/planner/core/rule/logical_rules.go @@ -0,0 +1,50 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rule + +// Note: The order of flags is same as the order of optRule in the list. +// Do not mess up the order. +const ( + FlagGcSubstitute uint64 = 1 << iota + FlagPruneColumns + FlagStabilizeResults + FlagBuildKeyInfo + FlagDecorrelate + FlagSemiJoinRewrite + FlagEliminateAgg + FlagSkewDistinctAgg + FlagEliminateProjection + FlagMaxMinEliminate + FlagConstantPropagation + FlagConvertOuterToInnerJoin + FlagPredicatePushDown + FlagEliminateOuterJoin + FlagPartitionProcessor + FlagCollectPredicateColumnsPoint + FlagPushDownAgg + FlagDeriveTopNFromWindow + FlagPredicateSimplification + FlagPushDownTopN + FlagSyncWaitStatsLoadPoint + FlagJoinReOrder + FlagPruneColumnsAgain + FlagPushDownSequence + FlagResolveExpand +) + +func setPredicatePushDownFlag(u uint64) uint64 { + u |= FlagPredicatePushDown + return u +} diff --git a/pkg/planner/core/rule/rule_init.go b/pkg/planner/core/rule/rule_init.go index d10d7f754c429..e49f3fadb9b81 100644 --- a/pkg/planner/core/rule/rule_init.go +++ b/pkg/planner/core/rule/rule_init.go @@ -23,4 +23,5 @@ import "github.com/pingcap/tidb/pkg/planner/core/rule/util" func init() { util.BuildKeyInfoPortal = buildKeyInfo + util.SetPredicatePushDownFlag = setPredicatePushDownFlag } diff --git a/pkg/planner/core/rule/util/misc.go b/pkg/planner/core/rule/util/misc.go index ca5f987da91c5..b5c0ac4340480 100644 --- a/pkg/planner/core/rule/util/misc.go +++ b/pkg/planner/core/rule/util/misc.go @@ -41,3 +41,6 @@ func ResolveColumnAndReplace(origin *expression.Column, replace map[string]*expr origin.RetType, origin.InOperand = retType, inOperand } } + +// SetPredicatePushDownFlag is a hook for other packages to set rule flag. +var SetPredicatePushDownFlag func(uint64) uint64 diff --git a/pkg/planner/core/rule_decorrelate.go b/pkg/planner/core/rule_decorrelate.go index 962632c362d30..21348041f1ca6 100644 --- a/pkg/planner/core/rule_decorrelate.go +++ b/pkg/planner/core/rule_decorrelate.go @@ -373,7 +373,7 @@ func (s *DecorrelateSolver) Optimize(ctx context.Context, p base.LogicalPlan, op } NoOptimize: // CTE's logical optimization is independent. - if _, ok := p.(*LogicalCTE); ok { + if _, ok := p.(*logicalop.LogicalCTE); ok { return p, planChanged, nil } newChildren := make([]base.LogicalPlan, 0, len(p.Children())) diff --git a/pkg/planner/core/rule_eliminate_projection.go b/pkg/planner/core/rule_eliminate_projection.go index 426b46c4bbfc2..5f4107c09943d 100644 --- a/pkg/planner/core/rule_eliminate_projection.go +++ b/pkg/planner/core/rule_eliminate_projection.go @@ -160,7 +160,7 @@ func (pe *ProjectionEliminator) Optimize(_ context.Context, lp base.LogicalPlan, // eliminate eliminates the redundant projection in a logical plan. func (pe *ProjectionEliminator) eliminate(p base.LogicalPlan, replace map[string]*expression.Column, canEliminate bool, opt *optimizetrace.LogicalOptimizeOp) base.LogicalPlan { // LogicalCTE's logical optimization is independent. - if _, ok := p.(*LogicalCTE); ok { + if _, ok := p.(*logicalop.LogicalCTE); ok { return p } proj, isProj := p.(*logicalop.LogicalProjection) diff --git a/pkg/planner/core/rule_generate_column_substitute.go b/pkg/planner/core/rule_generate_column_substitute.go index 2f593849d0266..7a37bcc26b968 100644 --- a/pkg/planner/core/rule_generate_column_substitute.go +++ b/pkg/planner/core/rule_generate_column_substitute.go @@ -57,7 +57,7 @@ func (gc *GcSubstituter) Optimize(ctx context.Context, lp base.LogicalPlan, opt // For the sake of simplicity, we don't collect the stored generate column because we can't get their expressions directly. // TODO: support stored generate column. func collectGenerateColumn(lp base.LogicalPlan, exprToColumn ExprColumnMap) { - if _, ok := lp.(*LogicalCTE); ok { + if _, ok := lp.(*logicalop.LogicalCTE); ok { return } for _, child := range lp.Children() { diff --git a/pkg/planner/core/rule_join_elimination.go b/pkg/planner/core/rule_join_elimination.go index 2fa4e41cceb64..c85615bbd46a3 100644 --- a/pkg/planner/core/rule_join_elimination.go +++ b/pkg/planner/core/rule_join_elimination.go @@ -198,7 +198,7 @@ func GetDupAgnosticAggCols( func (o *OuterJoinEliminator) doOptimize(p base.LogicalPlan, aggCols []*expression.Column, parentCols []*expression.Column, opt *optimizetrace.LogicalOptimizeOp) (base.LogicalPlan, error) { // CTE's logical optimization is independent. - if _, ok := p.(*LogicalCTE); ok { + if _, ok := p.(*logicalop.LogicalCTE); ok { return p, nil } var err error diff --git a/pkg/planner/core/rule_join_reorder.go b/pkg/planner/core/rule_join_reorder.go index 93896039eb145..1bd970609dff3 100644 --- a/pkg/planner/core/rule_join_reorder.go +++ b/pkg/planner/core/rule_join_reorder.go @@ -240,7 +240,7 @@ func (s *JoinReOrderSolver) Optimize(_ context.Context, p base.LogicalPlan, opt // optimizeRecursive recursively collects join groups and applies join reorder algorithm for each group. func (s *JoinReOrderSolver) optimizeRecursive(ctx base.PlanContext, p base.LogicalPlan, tracer *joinReorderTrace) (base.LogicalPlan, error) { - if _, ok := p.(*LogicalCTE); ok { + if _, ok := p.(*logicalop.LogicalCTE); ok { return p, nil } diff --git a/pkg/planner/core/rule_max_min_eliminate.go b/pkg/planner/core/rule_max_min_eliminate.go index c6cb256fcf0f0..571d900cef3d2 100644 --- a/pkg/planner/core/rule_max_min_eliminate.go +++ b/pkg/planner/core/rule_max_min_eliminate.go @@ -217,7 +217,7 @@ func (*MaxMinEliminator) eliminateSingleMaxMin(agg *logicalop.LogicalAggregation // eliminateMaxMin tries to convert max/min to Limit+Sort operators. func (a *MaxMinEliminator) eliminateMaxMin(p base.LogicalPlan, opt *optimizetrace.LogicalOptimizeOp) base.LogicalPlan { // CTE's logical optimization is indenpent. - if _, ok := p.(*LogicalCTE); ok { + if _, ok := p.(*logicalop.LogicalCTE); ok { return p } newChildren := make([]base.LogicalPlan, 0, len(p.Children())) diff --git a/pkg/planner/core/rule_partition_processor.go b/pkg/planner/core/rule_partition_processor.go index 8601b44e3d64d..646860a56cdc7 100644 --- a/pkg/planner/core/rule_partition_processor.go +++ b/pkg/planner/core/rule_partition_processor.go @@ -107,7 +107,7 @@ func (s *PartitionProcessor) rewriteDataSource(lp base.LogicalPlan, opt *optimiz // Only one partition, no union all. p.SetChildren(ds) return p, nil - case *LogicalCTE: + case *logicalop.LogicalCTE: return lp, nil default: children := lp.Children() diff --git a/pkg/planner/core/rule_push_down_sequence.go b/pkg/planner/core/rule_push_down_sequence.go index ef433dbe78c85..fe471023b8481 100644 --- a/pkg/planner/core/rule_push_down_sequence.go +++ b/pkg/planner/core/rule_push_down_sequence.go @@ -62,7 +62,7 @@ func (pdss *PushDownSequenceSolver) recursiveOptimize(pushedSequence *logicalop. pushedSequence = logicalop.LogicalSequence{}.Init(lp.SCtx(), lp.QueryBlockOffset()) pushedSequence.SetChildren(append(allCTEs, mainQuery)...) return pdss.recursiveOptimize(pushedSequence, mainQuery) - case *DataSource, *logicalop.LogicalAggregation, *LogicalCTE: + case *DataSource, *logicalop.LogicalAggregation, *logicalop.LogicalCTE: pushedSequence.SetChild(pushedSequence.ChildLen()-1, pdss.recursiveOptimize(nil, lp)) return pushedSequence default: diff --git a/pkg/planner/core/rule_semi_join_rewrite.go b/pkg/planner/core/rule_semi_join_rewrite.go index 8cafc180a16dc..9b9a99fa82605 100644 --- a/pkg/planner/core/rule_semi_join_rewrite.go +++ b/pkg/planner/core/rule_semi_join_rewrite.go @@ -52,7 +52,7 @@ func (*SemiJoinRewriter) Name() string { } func (smj *SemiJoinRewriter) recursivePlan(p base.LogicalPlan) (base.LogicalPlan, error) { - if _, ok := p.(*LogicalCTE); ok { + if _, ok := p.(*logicalop.LogicalCTE); ok { return p, nil } newChildren := make([]base.LogicalPlan, 0, len(p.Children())) diff --git a/pkg/planner/util/utilfuncp/func_pointer_misc.go b/pkg/planner/util/utilfuncp/func_pointer_misc.go index c653e13eb7ea8..a48277c5ece4d 100644 --- a/pkg/planner/util/utilfuncp/func_pointer_misc.go +++ b/pkg/planner/util/utilfuncp/func_pointer_misc.go @@ -15,6 +15,8 @@ package utilfuncp import ( + "context" + "github.com/pingcap/tidb/pkg/expression" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/planner/core/base" @@ -101,14 +103,18 @@ var FindBestTask4LogicalShow func(lp base.LogicalPlan, prop *property.PhysicalPr var FindBestTask4LogicalShowDDLJobs func(lp base.LogicalPlan, prop *property.PhysicalProperty, planCounter *base.PlanCounterTp, _ *optimizetrace.PhysicalOptimizeOp) (base.Task, int64, error) -// ExhaustPhysicalPlans4LogicalSequence will be called by LogicalSequence in logicalOp pkg. -var ExhaustPhysicalPlans4LogicalSequence func(lp base.LogicalPlan, prop *property.PhysicalProperty) ( - []base.PhysicalPlan, bool, error) +// FindBestTask4LogicalCTE will be called by LogicalCTE in logicalOp pkg. +var FindBestTask4LogicalCTE func(lp base.LogicalPlan, prop *property.PhysicalProperty, + counter *base.PlanCounterTp, pop *optimizetrace.PhysicalOptimizeOp) (t base.Task, cntPlan int64, err error) // FindBestTask4LogicalTableDual will be called by LogicalTableDual in logicalOp pkg. var FindBestTask4LogicalTableDual func(lp base.LogicalPlan, prop *property.PhysicalProperty, planCounter *base.PlanCounterTp, opt *optimizetrace.PhysicalOptimizeOp) (base.Task, int64, error) +// ExhaustPhysicalPlans4LogicalSequence will be called by LogicalSequence in logicalOp pkg. +var ExhaustPhysicalPlans4LogicalSequence func(lp base.LogicalPlan, prop *property.PhysicalProperty) ( + []base.PhysicalPlan, bool, error) + // ExhaustPhysicalPlans4LogicalSort will be called by LogicalSort in logicalOp pkg. var ExhaustPhysicalPlans4LogicalSort func(lp base.LogicalPlan, prop *property.PhysicalProperty) ( []base.PhysicalPlan, bool, error) @@ -165,6 +171,10 @@ var ExhaustPhysicalPlans4LogicalUnionAll func(lp base.LogicalPlan, prop *propert var ExhaustPhysicalPlans4LogicalExpand func(lp base.LogicalPlan, prop *property.PhysicalProperty) ( []base.PhysicalPlan, bool, error) +// ExhaustPhysicalPlans4LogicalCTE will be called by LogicalCTE in logicalOp pkg. +var ExhaustPhysicalPlans4LogicalCTE func(lp base.LogicalPlan, prop *property.PhysicalProperty) ( + []base.PhysicalPlan, bool, error) + // *************************************** physical op related ******************************************* // GetEstimatedProbeCntFromProbeParents will be called by BasePhysicalPlan in physicalOp pkg. @@ -181,3 +191,13 @@ var AttachPlan2Task func(p base.PhysicalPlan, t base.Task) base.Task // WindowIsTopN is used in DeriveTopNFromWindow rule. // todo: @arenatlx: remove it after logical_datasource is migrated to logicalop. var WindowIsTopN func(p base.LogicalPlan) (bool, uint64) + +// ****************************************** optimize portal ********************************************* + +// DoOptimize is to optimize a logical plan. +var DoOptimize func( + ctx context.Context, + sctx base.PlanContext, + flag uint64, + logic base.LogicalPlan, +) (base.LogicalPlan, base.PhysicalPlan, float64, error)