diff --git a/planner/core/BUILD.bazel b/planner/core/BUILD.bazel index 16d9524a476bf..747f35d70cbd7 100644 --- a/planner/core/BUILD.bazel +++ b/planner/core/BUILD.bazel @@ -42,6 +42,7 @@ go_library( "point_get_plan.go", "preprocess.go", "property_cols_prune.go", + "recheck_cte.go", "resolve_indices.go", "rule_aggregation_elimination.go", "rule_aggregation_push_down.go", diff --git a/planner/core/issuetest/planner_issue_test.go b/planner/core/issuetest/planner_issue_test.go index 1b5acda276873..898bf3f282208 100644 --- a/planner/core/issuetest/planner_issue_test.go +++ b/planner/core/issuetest/planner_issue_test.go @@ -90,3 +90,65 @@ func TestIssue46083(t *testing.T) { tk.MustExec("CREATE TEMPORARY TABLE v0(v1 int)") tk.MustExec("INSERT INTO v0 WITH ta2 AS (TABLE v0) TABLE ta2 FOR UPDATE OF ta2;") } + +func TestIssue47781(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t, t1, t2") + tk.MustExec("create table t (id int,name varchar(10))") + tk.MustExec("insert into t values(1,'tt')") + tk.MustExec("create table t1(id int,name varchar(10),name1 varchar(10),name2 varchar(10))") + tk.MustExec("insert into t1 values(1,'tt','ttt','tttt'),(2,'dd','ddd','dddd')") + tk.MustExec("create table t2(id int,name varchar(10),name1 varchar(10),name2 varchar(10),`date1` date)") + tk.MustExec("insert into t2 values(1,'tt','ttt','tttt','2099-12-31'),(2,'dd','ddd','dddd','2099-12-31')") + tk.MustQuery(`WITH bzzs AS ( + SELECT + count(1) AS bzn + FROM + t c +), +tmp1 AS ( + SELECT + t1.* + FROM + t1 + LEFT JOIN bzzs ON 1 = 1 + WHERE + name IN ('tt') + AND bzn <> 1 +), +tmp2 AS ( + SELECT + tmp1.*, + date('2099-12-31') AS endate + FROM + tmp1 +), +tmp3 AS ( + SELECT + * + FROM + tmp2 + WHERE + endate > CURRENT_DATE + UNION ALL + SELECT + '1' AS id, + 'ss' AS name, + 'sss' AS name1, + 'ssss' AS name2, + date('2099-12-31') AS endate + FROM + bzzs t1 + WHERE + bzn = 1 +) +SELECT + c2.id, + c3.id +FROM + t2 db + LEFT JOIN tmp3 c2 ON c2.id = '1' + LEFT JOIN tmp3 c3 ON c3.id = '1';`).Check(testkit.Rows("1 1", "1 1")) +} diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index ec358a250b6e5..81e9252db685b 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -4456,7 +4456,7 @@ func (b *PlanBuilder) tryBuildCTE(ctx context.Context, tn *ast.TableName, asName LimitEnd: limitEnd, pushDownPredicates: make([]expression.Expression, 0), ColumnMap: make(map[string]*expression.Column)} } var p LogicalPlan - lp := LogicalCTE{cteAsName: tn.Name, cteName: tn.Name, cte: cte.cteClass, seedStat: cte.seedStat, isOuterMostCTE: !b.buildingCTE}.Init(b.ctx, b.getSelectOffset()) + lp := LogicalCTE{cteAsName: tn.Name, cteName: tn.Name, cte: cte.cteClass, seedStat: cte.seedStat}.Init(b.ctx, b.getSelectOffset()) prevSchema := cte.seedLP.Schema().Clone() lp.SetSchema(getResultCTESchema(cte.seedLP.Schema(), b.ctx.GetSessionVars())) diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index eadb48c281aa5..a561ce702228d 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -1964,6 +1964,7 @@ type CTEClass struct { // pushDownPredicates may be push-downed by different references. pushDownPredicates []expression.Expression ColumnMap map[string]*expression.Column + isOuterMostCTE bool } const emptyCTEClassSize = int64(unsafe.Sizeof(CTEClass{})) @@ -1995,11 +1996,10 @@ func (cc *CTEClass) MemoryUsage() (sum int64) { type LogicalCTE struct { logicalSchemaProducer - cte *CTEClass - cteAsName model.CIStr - cteName model.CIStr - seedStat *property.StatsInfo - isOuterMostCTE bool + cte *CTEClass + cteAsName model.CIStr + cteName model.CIStr + seedStat *property.StatsInfo } // LogicalCTETable is for CTE table diff --git a/planner/core/optimizer.go b/planner/core/optimizer.go index 1e877ed114f24..4b30741c26200 100644 --- a/planner/core/optimizer.go +++ b/planner/core/optimizer.go @@ -148,6 +148,9 @@ func BuildLogicalPlanForTest(ctx context.Context, sctx sessionctx.Context, node if err != nil { return nil, nil, err } + if logic, ok := p.(LogicalPlan); ok { + RecheckCTE(logic) + } return p, p.OutputNames(), err } diff --git a/planner/core/recheck_cte.go b/planner/core/recheck_cte.go new file mode 100644 index 0000000000000..d09c3eb63bec5 --- /dev/null +++ b/planner/core/recheck_cte.go @@ -0,0 +1,53 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import "github.com/pingcap/tidb/planner/funcdep" + +// RecheckCTE fills the IsOuterMostCTE field for CTEs. +// It's a temp solution to before we fully use the Sequence to optimize the CTEs. +// This func checks whether the CTE is referenced only by the main query or not. +func RecheckCTE(p LogicalPlan) { + visited := funcdep.NewFastIntSet() + findCTEs(p, &visited, true) +} + +func findCTEs( + p LogicalPlan, + visited *funcdep.FastIntSet, + isRootTree bool, +) { + if cteReader, ok := p.(*LogicalCTE); ok { + cte := cteReader.cte + if !isRootTree { + // Set it to false since it's referenced by other CTEs. + cte.isOuterMostCTE = false + } + if visited.Has(cte.IDForStorage) { + return + } + visited.Insert(cte.IDForStorage) + // Set it when we meet it first time. + cte.isOuterMostCTE = isRootTree + findCTEs(cte.seedPartLogicalPlan, visited, false) + if cte.recursivePartLogicalPlan != nil { + findCTEs(cte.recursivePartLogicalPlan, visited, false) + } + return + } + for _, child := range p.Children() { + findCTEs(child, visited, isRootTree) + } +} diff --git a/planner/core/rule_predicate_push_down.go b/planner/core/rule_predicate_push_down.go index a0f9a7dd367df..01cd085d5bd0d 100644 --- a/planner/core/rule_predicate_push_down.go +++ b/planner/core/rule_predicate_push_down.go @@ -990,7 +990,7 @@ func (p *LogicalCTE) PredicatePushDown(predicates []expression.Expression, _ *lo // Doesn't support recursive CTE yet. return predicates, p.self } - if !p.isOuterMostCTE { + if !p.cte.isOuterMostCTE { return predicates, p.self } pushedPredicates := make([]expression.Expression, len(predicates)) diff --git a/planner/optimize.go b/planner/optimize.go index d5ee997057180..02620a48a4aec 100644 --- a/planner/optimize.go +++ b/planner/optimize.go @@ -407,6 +407,8 @@ func optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in return p, names, 0, nil } + core.RecheckCTE(logic) + // Handle the logical plan statement, use cascades planner if enabled. if sctx.GetSessionVars().GetEnableCascadesPlanner() { finalPlan, cost, err := cascades.DefaultOptimizer.FindBestPlan(sctx, logic)