diff --git a/executor/index_merge_reader.go b/executor/index_merge_reader.go index 60828bd514ac4..cd27a5049b8b5 100644 --- a/executor/index_merge_reader.go +++ b/executor/index_merge_reader.go @@ -133,6 +133,8 @@ func (e *IndexMergeReaderExecutor) Open(ctx context.Context) (err error) { } e.finished = make(chan struct{}) e.resultCh = make(chan *lookupTableTask, atomic.LoadInt32(&LookupTableTaskChannelSize)) + e.memTracker = memory.NewTracker(e.id, -1) + e.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker) return nil } @@ -520,7 +522,7 @@ func (e *IndexMergeReaderExecutor) startIndexMergeTableScanWorker(ctx context.Co finished: e.finished, indexMergeExec: e, tblPlans: e.tblPlans, - memTracker: memory.NewTracker(memory.LabelForSimpleTask, -1), + memTracker: e.memTracker, } ctx1, cancel := context.WithCancel(ctx) go func() { diff --git a/executor/index_merge_reader_test.go b/executor/index_merge_reader_test.go index 51620b861a94a..47fe9e148b531 100644 --- a/executor/index_merge_reader_test.go +++ b/executor/index_merge_reader_test.go @@ -17,6 +17,8 @@ package executor_test import ( "fmt" "math/rand" + "regexp" + "strconv" "strings" . "github.com/pingcap/check" @@ -170,3 +172,38 @@ func (s *testSuite1) TestPartitionTableRandomIndexMerge(c *C) { tk.MustQuery("select /*+ USE_INDEX_MERGE(tpk, a, b) */ * from tpk where " + cond).Sort().Check(result) } } + +func (test *testSerialSuite2) TestIndexMergeReaderMemTracker(c *C) { + tk := testkit.NewTestKit(c, test.store) + tk.MustExec("use test;") + tk.MustExec("create table t1(c1 int, c2 int, c3 int, key(c1), key(c2), key(c3));") + + insertStr := "insert into t1 values(0, 0, 0)" + rowNum := 1000 + for i := 0; i < rowNum; i++ { + insertStr += fmt.Sprintf(" ,(%d, %d, %d)", i, i, i) + } + insertStr += ";" + memTracker := tk.Se.GetSessionVars().StmtCtx.MemTracker + + tk.MustExec(insertStr) + + oriMaxUsage := memTracker.MaxConsumed() + + // We select all rows in t1, so the mem usage is more clear. + tk.MustQuery("select /*+ use_index_merge(t1) */ * from t1 where c1 > 1 or c2 > 1") + + newMaxUsage := memTracker.MaxConsumed() + c.Assert(newMaxUsage, Greater, oriMaxUsage) + + res := tk.MustQuery("explain analyze select /*+ use_index_merge(t1) */ * from t1 where c1 > 1 or c2 > 1") + c.Assert(len(res.Rows()), Equals, 4) + // Parse "xxx KB" and check it's greater than 0. + memStr := res.Rows()[0][7].(string) + re, err := regexp.Compile("[0-9]+ KB") + c.Assert(err, IsNil) + c.Assert(re.MatchString(memStr), IsTrue) + bytes, err := strconv.ParseFloat(memStr[:len(memStr)-3], 32) + c.Assert(err, IsNil) + c.Assert(bytes, Greater, 0.0) +}