diff --git a/executor/cte.go b/executor/cte.go index 7e98064b1d8bd..49848e4e75ce1 100644 --- a/executor/cte.go +++ b/executor/cte.go @@ -228,6 +228,12 @@ func (e *CTEExec) Close() (err error) { } func (e *CTEExec) computeSeedPart(ctx context.Context) (err error) { + defer func() { + if r := recover(); r != nil && err == nil { + err = errors.Errorf("%v", r) + } + }() + failpoint.Inject("testCTESeedPanic", nil) e.curIter = 0 e.iterInTbl.SetIter(e.curIter) chks := make([]*chunk.Chunk, 0, 10) @@ -237,13 +243,13 @@ func (e *CTEExec) computeSeedPart(ctx context.Context) (err error) { } chk := newFirstChunk(e.seedExec) if err = Next(ctx, e.seedExec, chk); err != nil { - return err + return } if chk.NumRows() == 0 { break } if chk, err = e.tryDedupAndAdd(chk, e.iterInTbl, e.hashTbl); err != nil { - return err + return } chks = append(chks, chk) } @@ -251,18 +257,24 @@ func (e *CTEExec) computeSeedPart(ctx context.Context) (err error) { // Just adding is ok. for _, chk := range chks { if err = e.resTbl.Add(chk); err != nil { - return err + return } } e.curIter++ e.iterInTbl.SetIter(e.curIter) - return nil + return } func (e *CTEExec) computeRecursivePart(ctx context.Context) (err error) { + defer func() { + if r := recover(); r != nil && err == nil { + err = errors.Errorf("%v", r) + } + }() + failpoint.Inject("testCTERecursivePanic", nil) if e.recursiveExec == nil || e.iterInTbl.NumChunks() == 0 { - return nil + return } if e.curIter > e.ctx.GetSessionVars().CTEMaxRecursionDepth { @@ -270,17 +282,17 @@ func (e *CTEExec) computeRecursivePart(ctx context.Context) (err error) { } if e.limitDone(e.resTbl) { - return nil + return } for { chk := newFirstChunk(e.recursiveExec) if err = Next(ctx, e.recursiveExec, chk); err != nil { - return err + return } if chk.NumRows() == 0 { if err = e.setupTblsForNewIteration(); err != nil { - return err + return } if e.limitDone(e.resTbl) { break @@ -297,18 +309,18 @@ func (e *CTEExec) computeRecursivePart(ctx context.Context) (err error) { // Make sure iterInTbl is setup before Close/Open, // because some executors will read iterInTbl in Open() (like IndexLookupJoin). if err = e.recursiveExec.Close(); err != nil { - return err + return } if err = e.recursiveExec.Open(ctx); err != nil { - return err + return } } else { if err = e.iterOutTbl.Add(chk); err != nil { - return err + return } } } - return nil + return } // Get next chunk from resTbl for limit. diff --git a/executor/cte_test.go b/executor/cte_test.go index 5f68f140fed5e..d7dd2c6f9ae3a 100644 --- a/executor/cte_test.go +++ b/executor/cte_test.go @@ -460,3 +460,25 @@ func TestCTEsInView(t *testing.T) { tk.MustExec("use test1;") tk.MustQuery("select * from test.v;").Check(testkit.Rows("1")) } + +func TestCTEPanic(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test;") + tk.MustExec("create table t1(c1 int)") + tk.MustExec("insert into t1 values(1), (2), (3)") + + fpPathPrefix := "github.com/pingcap/tidb/executor/" + fp := "testCTESeedPanic" + require.NoError(t, failpoint.Enable(fpPathPrefix+fp, fmt.Sprintf(`panic("%s")`, fp))) + err := tk.QueryToErr("with recursive cte1 as (select c1 from t1 union all select c1 + 1 from cte1 where c1 < 5) select t_alias_1.c1 from cte1 as t_alias_1 inner join cte1 as t_alias_2 on t_alias_1.c1 = t_alias_2.c1 order by c1") + require.Contains(t, err.Error(), fp) + require.NoError(t, failpoint.Disable(fpPathPrefix+fp)) + + fp = "testCTERecursivePanic" + require.NoError(t, failpoint.Enable(fpPathPrefix+fp, fmt.Sprintf(`panic("%s")`, fp))) + err = tk.QueryToErr("with recursive cte1 as (select c1 from t1 union all select c1 + 1 from cte1 where c1 < 5) select t_alias_1.c1 from cte1 as t_alias_1 inner join cte1 as t_alias_2 on t_alias_1.c1 = t_alias_2.c1 order by c1") + require.Contains(t, err.Error(), fp) + require.NoError(t, failpoint.Disable(fpPathPrefix+fp)) +} diff --git a/util/cteutil/storage.go b/util/cteutil/storage.go index 5fdaf4424db8c..1e6109eeb5715 100644 --- a/util/cteutil/storage.go +++ b/util/cteutil/storage.go @@ -131,13 +131,14 @@ func (s *StorageRC) DerefAndClose() (err error) { if s.refCnt < 0 { return errors.New("Storage ref count is less than zero") } else if s.refCnt == 0 { - // TODO: unreg memtracker + s.refCnt = -1 + s.done = false + s.err = nil + s.iter = 0 if err = s.rc.Close(); err != nil { return err } - if err = s.resetAll(); err != nil { - return err - } + s.rc = nil } return nil } @@ -157,7 +158,7 @@ func (s *StorageRC) SwapData(other Storage) (err error) { // Reopen impls Storage Reopen interface. func (s *StorageRC) Reopen() (err error) { - if err = s.rc.Reset(); err != nil { + if err = s.rc.Close(); err != nil { return err } s.iter = 0 @@ -267,18 +268,6 @@ func (s *StorageRC) ActionSpillForTest() *chunk.SpillDiskAction { return s.rc.ActionSpillForTest() } -func (s *StorageRC) resetAll() error { - s.refCnt = -1 - s.done = false - s.err = nil - s.iter = 0 - if err := s.rc.Reset(); err != nil { - return err - } - s.rc = nil - return nil -} - func (s *StorageRC) valid() bool { return s.refCnt > 0 && s.rc != nil }