diff --git a/executor/executor.go b/executor/executor.go index b072ce1c2fc02..8109ffcc16a15 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -28,6 +28,7 @@ import ( "github.com/cznic/mathutil" "github.com/opentracing/opentracing-go" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/auth" "github.com/pingcap/parser/model" @@ -1426,6 +1427,12 @@ type UnionExec struct { results []*chunk.Chunk wg sync.WaitGroup initialized bool + mu struct { + *sync.Mutex + maxOpenedChildID int + } + + childInFlightForTest int32 } // unionWorkerResult stores the result for a union worker. @@ -1445,12 +1452,11 @@ func (e *UnionExec) waitAllFinished() { // Open implements the Executor Open interface. func (e *UnionExec) Open(ctx context.Context) error { - if err := e.baseExecutor.Open(ctx); err != nil { - return err - } e.stopFetchData.Store(false) e.initialized = false e.finished = make(chan struct{}) + e.mu.Mutex = &sync.Mutex{} + e.mu.maxOpenedChildID = -1 return nil } @@ -1496,6 +1502,19 @@ func (e *UnionExec) resultPuller(ctx context.Context, workerID int) { e.wg.Done() }() for childID := range e.childIDChan { + e.mu.Lock() + if childID > e.mu.maxOpenedChildID { + e.mu.maxOpenedChildID = childID + } + e.mu.Unlock() + if err := e.children[childID].Open(ctx); err != nil { + result.err = err + e.stopFetchData.Store(true) + e.resultPool <- result + } + failpoint.Inject("issue21441", func() { + atomic.AddInt32(&e.childInFlightForTest, 1) + }) for { if e.stopFetchData.Load().(bool) { return @@ -1510,12 +1529,20 @@ func (e *UnionExec) resultPuller(ctx context.Context, workerID int) { e.resourcePools[workerID] <- result.chk break } + failpoint.Inject("issue21441", func() { + if int(atomic.LoadInt32(&e.childInFlightForTest)) > e.concurrency { + panic("the count of child in flight is larger than e.concurrency unexpectedly") + } + }) e.resultPool <- result if result.err != nil { e.stopFetchData.Store(true) return } } + failpoint.Inject("issue21441", func() { + atomic.AddInt32(&e.childInFlightForTest, -1) + }) } } @@ -1554,7 +1581,15 @@ func (e *UnionExec) Close() error { for range e.childIDChan { } } - return e.baseExecutor.Close() + // We do not need to acquire the e.mu.Lock since all the resultPuller can be + // promised to exit when reaching here (e.childIDChan been closed). + var firstErr error + for i := 0; i <= e.mu.maxOpenedChildID; i++ { + if err := e.children[i].Close(); err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr } // ResetContextOfStmt resets the StmtContext and session variables. diff --git a/executor/executor_test.go b/executor/executor_test.go index 39067c4064115..0d21d69fd877f 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -6415,6 +6415,36 @@ func (s *testSuite) TestOOMActionPriority(c *C) { c.Assert(action.GetPriority(), Equals, int64(memory.DefLogPriority)) } +func (s *testSuite) TestIssue21441(c *C) { + failpoint.Enable("github.com/pingcap/tidb/executor/issue21441", `return`) + defer failpoint.Disable("github.com/pingcap/tidb/executor/issue21441") + + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int)") + tk.MustExec(`insert into t values(1),(2),(3)`) + tk.Se.GetSessionVars().InitChunkSize = 1 + tk.Se.GetSessionVars().MaxChunkSize = 1 + sql := ` +select a from t union all +select a from t union all +select a from t union all +select a from t union all +select a from t union all +select a from t union all +select a from t union all +select a from t` + tk.MustQuery(sql).Sort().Check(testkit.Rows( + "1", "1", "1", "1", "1", "1", "1", "1", + "2", "2", "2", "2", "2", "2", "2", "2", + "3", "3", "3", "3", "3", "3", "3", "3", + )) + + tk.MustQuery("select a from (" + sql + ") t order by a limit 4").Check(testkit.Rows("1", "1", "1", "1")) + tk.MustQuery("select a from (" + sql + ") t order by a limit 7, 4").Check(testkit.Rows("1", "2", "2", "2")) +} + func (s *testSuite) Test17780(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/store/tikv/snapshot.go b/store/tikv/snapshot.go index 553981e4b6e7a..05d0a3afbc196 100644 --- a/store/tikv/snapshot.go +++ b/store/tikv/snapshot.go @@ -470,11 +470,15 @@ func (s *tikvSnapshot) IterReverse(k kv.Key) (kv.Iterator, error) { func (s *tikvSnapshot) SetOption(opt kv.Option, val interface{}) { switch opt { case kv.ReplicaRead: + s.mu.Lock() s.replicaRead = val.(kv.ReplicaReadType) + s.mu.Unlock() case kv.Priority: s.priority = kvPriorityToCommandPri(val.(int)) case kv.TaskID: + s.mu.Lock() s.taskID = val.(uint64) + s.mu.Unlock() case kv.CollectRuntimeStats: s.mu.Lock() s.mu.stats = val.(*SnapshotRuntimeStats)