diff --git a/cmd/explaintest/r/explain_cte.result b/cmd/explaintest/r/explain_cte.result index 4092fe73fbd07..36eda33222fd6 100644 --- a/cmd/explaintest/r/explain_cte.result +++ b/cmd/explaintest/r/explain_cte.result @@ -7,19 +7,19 @@ insert into t2 values(1, 0), (2, 1); explain with cte(a) as (select 1) select * from cte; id estRows task access object operator info CTEFullScan_8 1.00 root CTE:cte data:CTE_0 -CTE_0 1.00 root None Recursive CTE +CTE_0 1.00 root Non-Recursive CTE └─Projection_6(Seed Part) 1.00 root 1->Column#1 └─TableDual_7 1.00 root rows:1 explain with cte(a) as (select c1 from t1) select * from cte; id estRows task access object operator info CTEFullScan_11 1.00 root CTE:cte data:CTE_0 -CTE_0 1.00 root None Recursive CTE +CTE_0 1.00 root Non-Recursive CTE └─TableReader_8(Seed Part) 10000.00 root data:TableFullScan_7 └─TableFullScan_7 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo explain with cte(a,b,c,d) as (select * from t1, t2) select * from cte; id estRows task access object operator info CTEFullScan_18 1.00 root CTE:cte data:CTE_0 -CTE_0 1.00 root None Recursive CTE +CTE_0 1.00 root Non-Recursive CTE └─HashJoin_10(Seed Part) 100000000.00 root CARTESIAN inner join ├─TableReader_17(Build) 10000.00 root data:TableFullScan_16 │ └─TableFullScan_16 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo @@ -46,7 +46,7 @@ CTE_0 1.00 root Recursive CTE explain with cte(a) as (with recursive cte1(a) as (select 1 union select a + 1 from cte1 where a < 10) select * from cte1) select * from cte; id estRows task access object operator info CTEFullScan_21 1.00 root CTE:cte data:CTE_0 -CTE_0 1.00 root None Recursive CTE +CTE_0 1.00 root Non-Recursive CTE └─CTEFullScan_20(Seed Part) 1.00 root CTE:cte1 data:CTE_1 CTE_1 1.00 root Recursive CTE ├─Projection_15(Seed Part) 1.00 root 1->Column#2 @@ -70,7 +70,7 @@ id estRows task access object operator info HashJoin_17 1.00 root CARTESIAN inner join ├─CTEFullScan_27(Build) 1.00 root CTE:t2 data:CTE_0 └─CTEFullScan_26(Probe) 1.00 root CTE:t1 data:CTE_0 -CTE_0 1.00 root None Recursive CTE +CTE_0 1.00 root Non-Recursive CTE └─CTEFullScan_25(Seed Part) 1.00 root CTE:cte1 data:CTE_1 CTE_1 1.00 root Recursive CTE ├─Projection_20(Seed Part) 1.00 root 1->Column#2 @@ -102,7 +102,7 @@ HashJoin_12 0.64 root CARTESIAN inner join │ └─CTEFullScan_22 1.00 root CTE:q1 data:CTE_0 └─Selection_14(Probe) 0.80 root eq(test.t1.c1, 1) └─CTEFullScan_20 1.00 root CTE:q data:CTE_0 -CTE_0 1.00 root None Recursive CTE +CTE_0 1.00 root Non-Recursive CTE └─TableReader_17(Seed Part) 10000.00 root data:TableFullScan_16 └─TableFullScan_16 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo explain with recursive cte(a,b) as (select 1, concat('a', 1) union select a+1, concat(b, 1) from cte where a < 5) select * from cte; @@ -127,3 +127,78 @@ CTE_0 1.00 root Recursive CTE └─Projection_27(Recursive Part) 0.80 root plus(Column#5, 1)->Column#7 └─Selection_28 0.80 root eq(Column#5, 0) └─CTETable_29 1.00 root Scan on CTE_0 +explain with recursive cte1(c1) as (select c1 from t1 union select c1 + 1 c1 from cte1 limit 1) select * from cte1; +id estRows task access object operator info +CTEFullScan_19 1.00 root CTE:cte1 data:CTE_0 +CTE_0 1.00 root Recursive CTE, limit(offset:0, count:1) +├─TableReader_14(Seed Part) 10000.00 root data:TableFullScan_13 +│ └─TableFullScan_13 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo +└─Projection_17(Recursive Part) 1.00 root cast(plus(test.t1.c1, 1), int(11))->test.t1.c1 + └─CTETable_18 1.00 root Scan on CTE_0 +explain with recursive cte1(c1) as (select c1 from t1 union select c1 + 1 c1 from cte1 limit 100 offset 100) select * from cte1; +id estRows task access object operator info +CTEFullScan_19 1.00 root CTE:cte1 data:CTE_0 +CTE_0 1.00 root Recursive CTE, limit(offset:100, count:100) +├─TableReader_14(Seed Part) 10000.00 root data:TableFullScan_13 +│ └─TableFullScan_13 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo +└─Projection_17(Recursive Part) 1.00 root cast(plus(test.t1.c1, 1), int(11))->test.t1.c1 + └─CTETable_18 1.00 root Scan on CTE_0 +explain with recursive cte1(c1) as (select c1 from t1 union select c1 + 1 c1 from cte1 limit 0 offset 0) select * from cte1; +id estRows task access object operator info +CTEFullScan_19 1.00 root CTE:cte1 data:CTE_0 +CTE_0 1.00 root Recursive CTE, limit(offset:0, count:0) +├─TableReader_14(Seed Part) 10000.00 root data:TableFullScan_13 +│ └─TableFullScan_13 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo +└─Projection_17(Recursive Part) 1.00 root cast(plus(test.t1.c1, 1), int(11))->test.t1.c1 + └─CTETable_18 1.00 root Scan on CTE_0 +explain with recursive cte1(c1) as (select c1 from t1 union select c1 + 1 c1 from cte1 limit 1) select * from cte1 dt1 join cte1 dt2 on dt1.c1 = dt2.c1; +id estRows task access object operator info +HashJoin_18 0.64 root inner join, equal:[eq(test.t1.c1, test.t1.c1)] +├─Selection_29(Build) 0.80 root not(isnull(test.t1.c1)) +│ └─CTEFullScan_30 1.00 root CTE:dt2 data:CTE_0 +└─Selection_20(Probe) 0.80 root not(isnull(test.t1.c1)) + └─CTEFullScan_28 1.00 root CTE:dt1 data:CTE_0 +CTE_0 1.00 root Recursive CTE, limit(offset:0, count:1) +├─TableReader_23(Seed Part) 10000.00 root data:TableFullScan_22 +│ └─TableFullScan_22 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo +└─Projection_26(Recursive Part) 1.00 root cast(plus(test.t1.c1, 1), int(11))->test.t1.c1 + └─CTETable_27 1.00 root Scan on CTE_0 +explain with recursive cte1(c1) as (select c1 from t1 union select c1 + 1 c1 from cte1 limit 0 offset 0) select * from cte1 dt1 join cte1 dt2 on dt1.c1 = dt2.c1; +id estRows task access object operator info +HashJoin_18 0.64 root inner join, equal:[eq(test.t1.c1, test.t1.c1)] +├─Selection_29(Build) 0.80 root not(isnull(test.t1.c1)) +│ └─CTEFullScan_30 1.00 root CTE:dt2 data:CTE_0 +└─Selection_20(Probe) 0.80 root not(isnull(test.t1.c1)) + └─CTEFullScan_28 1.00 root CTE:dt1 data:CTE_0 +CTE_0 1.00 root Recursive CTE, limit(offset:0, count:0) +├─TableReader_23(Seed Part) 10000.00 root data:TableFullScan_22 +│ └─TableFullScan_22 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo +└─Projection_26(Recursive Part) 1.00 root cast(plus(test.t1.c1, 1), int(11))->test.t1.c1 + └─CTETable_27 1.00 root Scan on CTE_0 +explain with recursive cte1(c1) as (select c1 from t1 union select c1 from t2 limit 1) select * from cte1; +id estRows task access object operator info +CTEFullScan_34 1.00 root CTE:cte1 data:CTE_0 +CTE_0 1.00 root Non-Recursive CTE +└─Limit_21(Seed Part) 1.00 root offset:0, count:1 + └─HashAgg_22 1.00 root group by:Column#11, funcs:firstrow(Column#11)->Column#11 + └─Union_23 20000.00 root + ├─TableReader_26 10000.00 root data:TableFullScan_25 + │ └─TableFullScan_25 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo + └─IndexReader_33 10000.00 root index:IndexFullScan_32 + └─IndexFullScan_32 10000.00 cop[tikv] table:t2, index:c1(c1) keep order:false, stats:pseudo +explain with recursive cte1(c1) as (select c1 from t1 union select c1 from t2 limit 100 offset 100) select * from cte1; +id estRows task access object operator info +CTEFullScan_34 1.00 root CTE:cte1 data:CTE_0 +CTE_0 1.00 root Non-Recursive CTE +└─Limit_21(Seed Part) 100.00 root offset:100, count:100 + └─HashAgg_22 200.00 root group by:Column#11, funcs:firstrow(Column#11)->Column#11 + └─Union_23 20000.00 root + ├─TableReader_26 10000.00 root data:TableFullScan_25 + │ └─TableFullScan_25 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo + └─IndexReader_33 10000.00 root index:IndexFullScan_32 + └─IndexFullScan_32 10000.00 cop[tikv] table:t2, index:c1(c1) keep order:false, stats:pseudo +explain with recursive cte1(c1) as (select c1 from t1 union select c1 from t2 limit 0 offset 0) select * from cte1; +id estRows task access object operator info +CTEFullScan_18 1.00 root CTE:cte1 data:CTE_0 +CTE_0 1.00 root Non-Recursive CTE +└─TableDual_17(Seed Part) 0.00 root rows:0 diff --git a/cmd/explaintest/t/explain_cte.test b/cmd/explaintest/t/explain_cte.test index 50032776f85dd..c657ad5c68898 100644 --- a/cmd/explaintest/t/explain_cte.test +++ b/cmd/explaintest/t/explain_cte.test @@ -29,3 +29,16 @@ explain with q(a,b) as (select * from t1) select /*+ merge(q) no_merge(q1) */ * # explain with cte(a,b) as (select * from t1) select (select 1 from cte limit 1) from cte; explain with recursive cte(a,b) as (select 1, concat('a', 1) union select a+1, concat(b, 1) from cte where a < 5) select * from cte; explain select * from t1 dt where exists(with recursive qn as (select c1*0+1 as b union all select b+1 from qn where b=0) select * from qn where b=1); + +# recursive limit +explain with recursive cte1(c1) as (select c1 from t1 union select c1 + 1 c1 from cte1 limit 1) select * from cte1; +explain with recursive cte1(c1) as (select c1 from t1 union select c1 + 1 c1 from cte1 limit 100 offset 100) select * from cte1; +explain with recursive cte1(c1) as (select c1 from t1 union select c1 + 1 c1 from cte1 limit 0 offset 0) select * from cte1; + +explain with recursive cte1(c1) as (select c1 from t1 union select c1 + 1 c1 from cte1 limit 1) select * from cte1 dt1 join cte1 dt2 on dt1.c1 = dt2.c1; +explain with recursive cte1(c1) as (select c1 from t1 union select c1 + 1 c1 from cte1 limit 0 offset 0) select * from cte1 dt1 join cte1 dt2 on dt1.c1 = dt2.c1; + +# non-recursive limit +explain with recursive cte1(c1) as (select c1 from t1 union select c1 from t2 limit 1) select * from cte1; +explain with recursive cte1(c1) as (select c1 from t1 union select c1 from t2 limit 100 offset 100) select * from cte1; +explain with recursive cte1(c1) as (select c1 from t1 union select c1 from t2 limit 0 offset 0) select * from cte1; diff --git a/executor/builder.go b/executor/builder.go index 6827e18da55ba..5e24ae8b880c0 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -4192,6 +4192,9 @@ func (b *executorBuilder) buildCTE(v *plannercore.PhysicalCTE) Executor { chkIdx: 0, isDistinct: v.CTE.IsDistinct, sel: sel, + hasLimit: v.CTE.HasLimit, + limitBeg: v.CTE.LimitBeg, + limitEnd: v.CTE.LimitEnd, } } diff --git a/executor/cte.go b/executor/cte.go index 055163dc17e4f..fb0aaccb3d512 100644 --- a/executor/cte.go +++ b/executor/cte.go @@ -79,6 +79,13 @@ type CTEExec struct { hCtx *hashContext sel []int + // Limit related info. + hasLimit bool + limitBeg uint64 + limitEnd uint64 + cursor uint64 + meetFirstBatch bool + memTracker *memory.Tracker diskTracker *disk.Tracker } @@ -131,8 +138,8 @@ func (e *CTEExec) Open(ctx context.Context) (err error) { func (e *CTEExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { req.Reset() e.resTbl.Lock() + defer e.resTbl.Unlock() if !e.resTbl.Done() { - defer e.resTbl.Unlock() resAction := setupCTEStorageTracker(e.resTbl, e.ctx, e.memTracker, e.diskTracker) iterInAction := setupCTEStorageTracker(e.iterInTbl, e.ctx, e.memTracker, e.diskTracker) iterOutAction := setupCTEStorageTracker(e.iterOutTbl, e.ctx, e.memTracker, e.diskTracker) @@ -160,10 +167,11 @@ func (e *CTEExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { return err } e.resTbl.SetDone() - } else { - e.resTbl.Unlock() } + if e.hasLimit { + return e.nextChunkLimit(req) + } if e.chkIdx < e.resTbl.NumChunks() { res, err := e.resTbl.GetChunk(e.chkIdx) if err != nil { @@ -205,6 +213,9 @@ func (e *CTEExec) computeSeedPart(ctx context.Context) (err error) { defer close(e.iterInTbl.GetBegCh()) chks := make([]*chunk.Chunk, 0, 10) for { + if e.limitDone(e.iterInTbl) { + break + } chk := newFirstChunk(e.seedExec) if err = Next(ctx, e.seedExec, chk); err != nil { return err @@ -239,6 +250,10 @@ func (e *CTEExec) computeRecursivePart(ctx context.Context) (err error) { return ErrCTEMaxRecursionDepth.GenWithStackByArgs(e.curIter) } + if e.limitDone(e.resTbl) { + return nil + } + for { chk := newFirstChunk(e.recursiveExec) if err = Next(ctx, e.recursiveExec, chk); err != nil { @@ -248,6 +263,9 @@ func (e *CTEExec) computeRecursivePart(ctx context.Context) (err error) { if err = e.setupTblsForNewIteration(); err != nil { return err } + if e.limitDone(e.resTbl) { + break + } if e.iterInTbl.NumChunks() == 0 { break } @@ -274,6 +292,51 @@ func (e *CTEExec) computeRecursivePart(ctx context.Context) (err error) { return nil } +// Get next chunk from resTbl for limit. +func (e *CTEExec) nextChunkLimit(req *chunk.Chunk) error { + if !e.meetFirstBatch { + for e.chkIdx < e.resTbl.NumChunks() { + res, err := e.resTbl.GetChunk(e.chkIdx) + if err != nil { + return err + } + e.chkIdx++ + numRows := uint64(res.NumRows()) + if newCursor := e.cursor + numRows; newCursor >= e.limitBeg { + e.meetFirstBatch = true + begInChk, endInChk := e.limitBeg-e.cursor, numRows + if newCursor > e.limitEnd { + endInChk = e.limitEnd - e.cursor + } + e.cursor += endInChk + if begInChk == endInChk { + break + } + tmpChk := res.CopyConstructSel() + req.Append(tmpChk, int(begInChk), int(endInChk)) + return nil + } + e.cursor += numRows + } + } + if e.chkIdx < e.resTbl.NumChunks() && e.cursor < e.limitEnd { + res, err := e.resTbl.GetChunk(e.chkIdx) + if err != nil { + return err + } + e.chkIdx++ + numRows := uint64(res.NumRows()) + if e.cursor+numRows > e.limitEnd { + numRows = e.limitEnd - e.cursor + req.Append(res.CopyConstructSel(), 0, int(numRows)+1) + } else { + req.SwapColumns(res.CopyConstructSel()) + } + e.cursor += numRows + } + return nil +} + func (e *CTEExec) setupTblsForNewIteration() (err error) { num := e.iterOutTbl.NumChunks() chks := make([]*chunk.Chunk, 0, num) @@ -322,6 +385,8 @@ func (e *CTEExec) reset() { e.curIter = 0 e.chkIdx = 0 e.hashTbl = nil + e.cursor = 0 + e.meetFirstBatch = false } func (e *CTEExec) reopenTbls() (err error) { @@ -332,6 +397,11 @@ func (e *CTEExec) reopenTbls() (err error) { return e.iterInTbl.Reopen() } +// Check if tbl meets the requirement of limit. +func (e *CTEExec) limitDone(tbl cteutil.Storage) bool { + return e.hasLimit && uint64(tbl.NumRows()) >= e.limitEnd +} + func setupCTEStorageTracker(tbl cteutil.Storage, ctx sessionctx.Context, parentMemTracker *memory.Tracker, parentDiskTracker *disk.Tracker) (actionSpill *chunk.SpillDiskAction) { memTracker := tbl.GetMemTracker() diff --git a/executor/cte_test.go b/executor/cte_test.go index d6e212484ad29..aa7c2804425c8 100644 --- a/executor/cte_test.go +++ b/executor/cte_test.go @@ -237,3 +237,140 @@ func (test *CTETestSuite) TestCTEMaxRecursionDepth(c *check.C) { rows = tk.MustQuery("with cte1(c1) as (select 1 union select 2) select * from cte1 order by c1;") rows.Check(testkit.Rows("1", "2")) } + +func (test *CTETestSuite) TestCTEWithLimit(c *check.C) { + tk := testkit.NewTestKit(c, test.store) + tk.MustExec("use test;") + + // Basic recursive tests. + rows := tk.MustQuery("with recursive cte1(c1) as (select 1 union select c1 + 1 from cte1 limit 5 offset 0) select * from cte1") + rows.Check(testkit.Rows("1", "2", "3", "4", "5")) + + rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select c1 + 1 from cte1 limit 5 offset 1) select * from cte1") + rows.Check(testkit.Rows("2", "3", "4", "5", "6")) + + rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select c1 + 1 from cte1 limit 5 offset 10) select * from cte1") + rows.Check(testkit.Rows("11", "12", "13", "14", "15")) + + rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select c1 + 1 from cte1 limit 5 offset 995) select * from cte1") + rows.Check(testkit.Rows("996", "997", "998", "999", "1000")) + + rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select c1 + 1 from cte1 limit 5 offset 6) select * from cte1;") + rows.Check(testkit.Rows("7", "8", "9", "10", "11")) + + // Test with cte_max_recursion_depth + tk.MustExec("set cte_max_recursion_depth=2;") + rows = tk.MustQuery("with recursive cte1(c1) as (select 0 union select c1 + 1 from cte1 limit 1 offset 2) select * from cte1;") + rows.Check(testkit.Rows("2")) + + err := tk.QueryToErr("with recursive cte1(c1) as (select 0 union select c1 + 1 from cte1 limit 1 offset 3) select * from cte1;") + c.Assert(err, check.NotNil) + c.Assert(err.Error(), check.Equals, "[executor:3636]Recursive query aborted after 3 iterations. Try increasing @@cte_max_recursion_depth to a larger value") + + tk.MustExec("set cte_max_recursion_depth=1000;") + rows = tk.MustQuery("with recursive cte1(c1) as (select 0 union select c1 + 1 from cte1 limit 5 offset 996) select * from cte1;") + rows.Check(testkit.Rows("996", "997", "998", "999", "1000")) + + err = tk.QueryToErr("with recursive cte1(c1) as (select 0 union select c1 + 1 from cte1 limit 5 offset 997) select * from cte1;") + c.Assert(err, check.NotNil) + c.Assert(err.Error(), check.Equals, "[executor:3636]Recursive query aborted after 1001 iterations. Try increasing @@cte_max_recursion_depth to a larger value") + + rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select c1 + 1 from cte1 limit 0 offset 1) select * from cte1") + rows.Check(testkit.Rows()) + + rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select c1 + 1 from cte1 limit 0 offset 10) select * from cte1") + rows.Check(testkit.Rows()) + + // Test join. + rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select c1 + 1 from cte1 limit 2 offset 1) select * from cte1 dt1 join cte1 dt2 order by dt1.c1, dt2.c1;") + rows.Check(testkit.Rows("2 2", "2 3", "3 2", "3 3")) + + rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select c1 + 1 from cte1 limit 2 offset 1) select * from cte1 dt1 join cte1 dt2 on dt1.c1 = dt2.c1 order by dt1.c1, dt1.c1;") + rows.Check(testkit.Rows("2 2", "3 3")) + + // Test subquery. + // Different with mysql, maybe it's mysql bug?(https://bugs.mysql.com/bug.php?id=103890&thanks=4) + rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select c1 + 1 from cte1 limit 2 offset 1) select c1 from cte1 where c1 in (select 2);") + rows.Check(testkit.Rows("2")) + + rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select c1 + 1 from cte1 limit 2 offset 1) select c1 from cte1 dt where c1 in (select c1 from cte1 where 1 = dt.c1 - 1);") + rows.Check(testkit.Rows("2")) + + // Test Apply. + rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select c1 + 1 from cte1 limit 2 offset 1) select c1 from cte1 where cte1.c1 = (select dt1.c1 from cte1 dt1 where dt1.c1 = cte1.c1);") + rows.Check(testkit.Rows("2", "3")) + + // Recursive tests with table. + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(c1 int);") + tk.MustExec("insert into t1 values(1), (2), (3);") + + // Error: ERROR 1221 (HY000): Incorrect usage of UNION and LIMIT. + // Limit can only be at the end of SQL stmt. + err = tk.ExecToErr("with recursive cte1(c1) as (select c1 from t1 limit 1 offset 1 union select c1 + 1 from cte1 limit 0 offset 1) select * from cte1") + c.Assert(err.Error(), check.Equals, "[planner:1221]Incorrect usage of UNION and LIMIT") + + // Basic non-recusive tests. + rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select 2 order by 1 limit 1 offset 1) select * from cte1") + rows.Check(testkit.Rows("2")) + + rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select 2 order by 1 limit 0 offset 1) select * from cte1") + rows.Check(testkit.Rows()) + + rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select 2 order by 1 limit 2 offset 0) select * from cte1") + rows.Check(testkit.Rows("1", "2")) + + // Test with table. + tk.MustExec("drop table if exists t1;") + insertStr := "insert into t1 values(0)" + for i := 1; i < 5000; i++ { + insertStr += fmt.Sprintf(", (%d)", i) + } + + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(c1 int);") + tk.MustExec(insertStr) + + rows = tk.MustQuery("with recursive cte1(c1) as (select c1 from t1 union select c1 + 1 c1 from cte1 limit 1) select * from cte1") + rows.Check(testkit.Rows("0")) + + rows = tk.MustQuery("with recursive cte1(c1) as (select c1 from t1 union select c1 + 1 c1 from cte1 limit 1 offset 100) select * from cte1") + rows.Check(testkit.Rows("100")) + + rows = tk.MustQuery("with recursive cte1(c1) as (select c1 from t1 union select c1 + 1 c1 from cte1 limit 5 offset 100) select * from cte1") + rows.Check(testkit.Rows("100", "101", "102", "103", "104")) + + // Basic non-recursive tests. + rows = tk.MustQuery("with cte1 as (select c1 from t1 limit 2 offset 1) select * from cte1") + rows.Check(testkit.Rows("1", "2")) + + rows = tk.MustQuery("with cte1 as (select c1 from t1 limit 2 offset 1) select * from cte1 dt1 join cte1 dt2 on dt1.c1 = dt2.c1") + rows.Check(testkit.Rows("1 1", "2 2")) + + rows = tk.MustQuery("with recursive cte1(c1) as (select c1 from t1 union select 2 limit 0 offset 1) select * from cte1") + rows.Check(testkit.Rows()) + + rows = tk.MustQuery("with recursive cte1(c1) as (select c1 from t1 union select 2 limit 0 offset 1) select * from cte1 dt1 join cte1 dt2 on dt1.c1 = dt2.c1") + rows.Check(testkit.Rows()) + + // rows = tk.MustQuery("with recursive cte1(c1) as (select c1 from t1 union select 2 limit 5 offset 100) select * from cte1") + // rows.Check(testkit.Rows("100", "101", "102", "103", "104")) + + rows = tk.MustQuery("with recursive cte1(c1) as (select c1 from t1 limit 3 offset 100) select * from cte1") + rows.Check(testkit.Rows("100", "101", "102")) + + rows = tk.MustQuery("with recursive cte1(c1) as (select c1 from t1 limit 3 offset 100) select * from cte1 dt1 join cte1 dt2 on dt1.c1 = dt2.c1") + rows.Check(testkit.Rows("100 100", "101 101", "102 102")) + + // Test limit 0. + tk.MustExec("set cte_max_recursion_depth = 0;") + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(c1 int);") + tk.MustExec("insert into t1 values(0);") + rows = tk.MustQuery("with recursive cte1 as (select 1/c1 c1 from t1 union select c1 + 1 c1 from cte1 where c1 < 2 limit 0) select * from cte1;") + rows.Check(testkit.Rows()) + // MySQL err: ERROR 1365 (22012): Division by 0. Because it gives error when computing 1/c1. + err = tk.QueryToErr("with recursive cte1 as (select 1/c1 c1 from t1 union select c1 + 1 c1 from cte1 where c1 < 2 limit 1) select * from cte1;") + c.Assert(err, check.NotNil) + c.Assert(err.Error(), check.Equals, "[executor:3636]Recursive query aborted after 1 iterations. Try increasing @@cte_max_recursion_depth to a larger value") +} diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 9a9a9351a92c1..e575bb79e7135 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -3674,8 +3674,28 @@ func (b *PlanBuilder) tryBuildCTE(ctx context.Context, tn *ast.TableName, asName } b.handleHelper.pushMap(nil) + + hasLimit := false + limitBeg := uint64(0) + limitEnd := uint64(0) + if cte.limitLP != nil { + hasLimit = true + switch x := cte.limitLP.(type) { + case *LogicalLimit: + limitBeg = x.Offset + limitEnd = x.Offset + x.Count + case *LogicalTableDual: + // Beg and End will both be 0. + default: + return nil, errors.Errorf("invalid type for limit plan: %v", cte.limitLP) + } + } + var p LogicalPlan - lp := LogicalCTE{cteAsName: tn.Name, cte: &CTEClass{IsDistinct: cte.isDistinct, seedPartLogicalPlan: cte.seedLP, recursivePartLogicalPlan: cte.recurLP, IDForStorage: cte.storageID, optFlag: cte.optFlag}}.Init(b.ctx, b.getSelectOffset()) + lp := LogicalCTE{cteAsName: tn.Name, cte: &CTEClass{IsDistinct: cte.isDistinct, seedPartLogicalPlan: cte.seedLP, + recursivePartLogicalPlan: cte.recurLP, IDForStorage: cte.storageID, + optFlag: cte.optFlag, HasLimit: hasLimit, LimitBeg: limitBeg, + LimitEnd: limitEnd}}.Init(b.ctx, b.getSelectOffset()) lp.SetSchema(getResultCTESchema(cte.seedLP.Schema(), b.ctx.GetSessionVars())) p = lp p.SetOutputNames(cte.seedLP.OutputNames()) @@ -5907,6 +5927,9 @@ func (b *PlanBuilder) buildRecursiveCTE(ctx context.Context, cte ast.ResultSetNo if x.OrderBy != nil { return ErrNotSupportedYet.GenWithStackByArgs("ORDER BY over UNION in recursive Common Table Expression") } + // Limit clause is for the whole CTE instead of only for the seed part. + oriLimit := x.Limit + x.Limit = nil // Check union type. if afterOpr != nil { @@ -5936,6 +5959,7 @@ func (b *PlanBuilder) buildRecursiveCTE(ctx context.Context, cte ast.ResultSetNo // Rebuild the plan. i-- b.buildingRecursivePartForCTE = true + x.Limit = oriLimit continue } if err != nil { @@ -5984,6 +6008,15 @@ func (b *PlanBuilder) buildRecursiveCTE(ctx context.Context, cte ast.ResultSetNo } // 4. Finally, we get the seed part plan and recursive part plan. cInfo.recurLP = recurPart + // Only need to handle limit if x is SetOprStmt. + if x.Limit != nil { + limit, err := b.buildLimit(cInfo.seedLP, x.Limit) + if err != nil { + return err + } + limit.SetChildren(limit.Children()[:0]...) + cInfo.limitLP = limit + } return nil default: p, err := b.buildResultSetNode(ctx, x) diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index 16489edc02cba..7186813a0fe04 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -1183,7 +1183,10 @@ type CTEClass struct { // storageID for this CTE. IDForStorage int // optFlag is the optFlag for the whole CTE. - optFlag uint64 + optFlag uint64 + HasLimit bool + LimitBeg uint64 + LimitEnd uint64 } // LogicalCTE is for CTE. diff --git a/planner/core/physical_plans.go b/planner/core/physical_plans.go index 74c70e1fce3c9..a80dd96a28259 100644 --- a/planner/core/physical_plans.go +++ b/planner/core/physical_plans.go @@ -1463,10 +1463,16 @@ type CTEDefinition PhysicalCTE // ExplainInfo overrides the ExplainInfo func (p *CTEDefinition) ExplainInfo() string { + var res string if p.RecurPlan != nil { - return "Recursive CTE" + res = "Recursive CTE" + } else { + res = "Non-Recursive CTE" + } + if p.CTE.HasLimit { + res += fmt.Sprintf(", limit(offset:%v, count:%v)", p.CTE.LimitBeg, p.CTE.LimitEnd-p.CTE.LimitBeg) } - return "None Recursive CTE" + return res } // ExplainID overrides the ExplainID. diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index b743204a1db91..b287c42c8b4f7 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -432,6 +432,7 @@ type cteInfo struct { // enterSubquery and recursiveRef are used to check "recursive table must be referenced only once, and not in any subquery". enterSubquery bool recursiveRef bool + limitLP LogicalPlan } // PlanBuilder builds Plan from an ast.Node. diff --git a/util/cteutil/storage.go b/util/cteutil/storage.go index 9d42b1a11c015..d2607892db62c 100644 --- a/util/cteutil/storage.go +++ b/util/cteutil/storage.go @@ -63,6 +63,9 @@ type Storage interface { // NumChunks return chunk number of the underlying storage. NumChunks() int + // NumRows return row number of the underlying storage. + NumRows() int + // Storage is not thread-safe. // By using Lock(), users can achieve the purpose of ensuring thread safety. Lock() @@ -200,6 +203,11 @@ func (s *StorageRC) NumChunks() int { return s.rc.NumChunks() } +// NumRows impls Storage NumRows interface. +func (s *StorageRC) NumRows() int { + return s.rc.NumRow() +} + // Lock impls Storage Lock interface. func (s *StorageRC) Lock() { s.mu.Lock()