diff --git a/pkg/statistics/handle/globalstats/merge_worker.go b/pkg/statistics/handle/globalstats/merge_worker.go index e4d9b880cc1c5..ea00a6dbbefc7 100644 --- a/pkg/statistics/handle/globalstats/merge_worker.go +++ b/pkg/statistics/handle/globalstats/merge_worker.go @@ -44,8 +44,11 @@ type topnStatsMergeWorker struct { respCh chan<- *TopnStatsMergeResponse // the stats in the wrapper should only be read during the worker statsWrapper *StatsWrapper + // Different TopN structures may hold the same value, we have to merge them. + counter map[hack.MutableString]float64 // shardMutex is used to protect `statsWrapper.AllHg` shardMutex []sync.Mutex + mu sync.Mutex } // NewTopnStatsMergeWorker returns topn merge worker @@ -55,8 +58,9 @@ func NewTopnStatsMergeWorker( wrapper *StatsWrapper, killed *uint32) *topnStatsMergeWorker { worker := &topnStatsMergeWorker{ - taskCh: taskCh, - respCh: respCh, + taskCh: taskCh, + respCh: respCh, + counter: make(map[hack.MutableString]float64), } worker.statsWrapper = wrapper worker.shardMutex = make([]sync.Mutex, len(wrapper.AllHg)) @@ -80,15 +84,11 @@ func NewTopnStatsMergeTask(start, end int) *TopnStatsMergeTask { // TopnStatsMergeResponse indicates topn merge worker response type TopnStatsMergeResponse struct { - Err error - TopN *statistics.TopN - PopedTopn []statistics.TopNMeta + Err error } // Run runs topn merge like statistics.MergePartTopN2GlobalTopN -func (worker *topnStatsMergeWorker) Run(timeZone *time.Location, isIndex bool, - n uint32, - version int) { +func (worker *topnStatsMergeWorker) Run(timeZone *time.Location, isIndex bool, version int) { for task := range worker.taskCh { start := task.start end := task.end @@ -96,17 +96,12 @@ func (worker *topnStatsMergeWorker) Run(timeZone *time.Location, isIndex bool, allTopNs := worker.statsWrapper.AllTopN allHists := worker.statsWrapper.AllHg resp := &TopnStatsMergeResponse{} - if statistics.CheckEmptyTopNs(checkTopNs) { - worker.respCh <- resp - return - } + partNum := len(allTopNs) - // Different TopN structures may hold the same value, we have to merge them. - counter := make(map[hack.MutableString]float64) + // datumMap is used to store the mapping from the string type to datum type. // The datum is used to find the value in the histogram. datumMap := statistics.NewDatumMapCache() - for i, topN := range checkTopNs { i = i + start if atomic.LoadUint32(worker.killed) == 1 { @@ -119,12 +114,15 @@ func (worker *topnStatsMergeWorker) Run(timeZone *time.Location, isIndex bool, } for _, val := range topN.TopN { encodedVal := hack.String(val.Encoded) - _, exists := counter[encodedVal] - counter[encodedVal] += float64(val.Count) + worker.mu.Lock() + _, exists := worker.counter[encodedVal] + worker.counter[encodedVal] += float64(val.Count) if exists { + worker.mu.Unlock() // We have already calculated the encodedVal from the histogram, so just continue to next topN value. continue } + worker.mu.Unlock() // We need to check whether the value corresponding to encodedVal is contained in other partition-level stats. // 1. Check the topN first. // 2. If the topN doesn't contain the value corresponding to encodedVal. We should check the histogram. @@ -148,31 +146,26 @@ func (worker *topnStatsMergeWorker) Run(timeZone *time.Location, isIndex bool, } datum = d } + worker.shardMutex[j].Lock() // Get the row count which the value is equal to the encodedVal from histogram. count, _ := allHists[j].EqualRowCount(nil, datum, isIndex) if count != 0 { - counter[encodedVal] += count // Remove the value corresponding to encodedVal from the histogram. - worker.shardMutex[j].Lock() worker.statsWrapper.AllHg[j].BinarySearchRemoveVal(statistics.TopNMeta{Encoded: datum.GetBytes(), Count: uint64(count)}) - worker.shardMutex[j].Unlock() + } + worker.shardMutex[j].Unlock() + if count != 0 { + worker.mu.Lock() + worker.counter[encodedVal] += count + worker.mu.Unlock() } } } } - numTop := len(counter) - if numTop == 0 { - worker.respCh <- resp - continue - } - sorted := make([]statistics.TopNMeta, 0, numTop) - for value, cnt := range counter { - data := hack.Slice(string(value)) - sorted = append(sorted, statistics.TopNMeta{Encoded: data, Count: uint64(cnt)}) - } - globalTopN, leftTopN := statistics.GetMergedTopNFromSortedSlice(sorted, n) - resp.TopN = globalTopN - resp.PopedTopn = leftTopN worker.respCh <- resp } } + +func (worker *topnStatsMergeWorker) Result() map[hack.MutableString]float64 { + return worker.counter +} diff --git a/pkg/statistics/handle/globalstats/topn.go b/pkg/statistics/handle/globalstats/topn.go index 19782f76b2616..7edc838b99911 100644 --- a/pkg/statistics/handle/globalstats/topn.go +++ b/pkg/statistics/handle/globalstats/topn.go @@ -30,6 +30,9 @@ import ( func mergeGlobalStatsTopN(gp *gp.Pool, sc sessionctx.Context, wrapper *StatsWrapper, timeZone *time.Location, version int, n uint32, isIndex bool) (*statistics.TopN, []statistics.TopNMeta, []*statistics.Histogram, error) { + if statistics.CheckEmptyTopNs(wrapper.AllTopN) { + return nil, nil, wrapper.AllHg, nil + } mergeConcurrency := sc.GetSessionVars().AnalyzePartitionMergeConcurrency killed := &sc.GetSessionVars().Killed // use original method if concurrency equals 1 or for version1 @@ -69,12 +72,12 @@ func MergeGlobalStatsTopNByConcurrency(gp *gp.Pool, mergeConcurrency, mergeBatch taskNum := len(tasks) taskCh := make(chan *TopnStatsMergeTask, taskNum) respCh := make(chan *TopnStatsMergeResponse, taskNum) + worker := NewTopnStatsMergeWorker(taskCh, respCh, wrapper, killed) for i := 0; i < mergeConcurrency; i++ { - worker := NewTopnStatsMergeWorker(taskCh, respCh, wrapper, killed) wg.Add(1) gp.Go(func() { defer wg.Done() - worker.Run(timeZone, isIndex, n, version) + worker.Run(timeZone, isIndex, version) }) } for _, task := range tasks { @@ -83,8 +86,6 @@ func MergeGlobalStatsTopNByConcurrency(gp *gp.Pool, mergeConcurrency, mergeBatch close(taskCh) wg.Wait() close(respCh) - resps := make([]*TopnStatsMergeResponse, 0) - // handle Error hasErr := false errMsg := make([]string, 0) @@ -93,27 +94,21 @@ func MergeGlobalStatsTopNByConcurrency(gp *gp.Pool, mergeConcurrency, mergeBatch hasErr = true errMsg = append(errMsg, resp.Err.Error()) } - resps = append(resps, resp) } if hasErr { return nil, nil, nil, errors.New(strings.Join(errMsg, ",")) } // fetch the response from each worker and merge them into global topn stats - sorted := make([]statistics.TopNMeta, 0, mergeConcurrency) - leftTopn := make([]statistics.TopNMeta, 0) - for _, resp := range resps { - if resp.TopN != nil { - sorted = append(sorted, resp.TopN.TopN...) - } - leftTopn = append(leftTopn, resp.PopedTopn...) + counter := worker.Result() + numTop := len(counter) + sorted := make([]statistics.TopNMeta, 0, numTop) + for value, cnt := range counter { + data := hack.Slice(string(value)) + sorted = append(sorted, statistics.TopNMeta{Encoded: data, Count: uint64(cnt)}) } - globalTopN, popedTopn := statistics.GetMergedTopNFromSortedSlice(sorted, n) - - result := append(leftTopn, popedTopn...) - statistics.SortTopnMeta(result) - return globalTopN, result, wrapper.AllHg, nil + return globalTopN, popedTopn, wrapper.AllHg, nil } // MergePartTopN2GlobalTopN is used to merge the partition-level topN to global-level topN. @@ -124,13 +119,19 @@ func MergeGlobalStatsTopNByConcurrency(gp *gp.Pool, mergeConcurrency, mergeBatch // // The output parameters: // 1. `*TopN` is the final global-level topN. -// 2. `[]TopNMeta` is the left topN value from the partition-level TopNs, but is not placed to global-level TopN. We should put them back to histogram latter. -// 3. `[]*Histogram` are the partition-level histograms which just delete some values when we merge the global-level topN. -func MergePartTopN2GlobalTopN(loc *time.Location, version int, topNs []*statistics.TopN, n uint32, hists []*statistics.Histogram, - isIndex bool, killed *uint32) (*statistics.TopN, []statistics.TopNMeta, []*statistics.Histogram, error) { - if statistics.CheckEmptyTopNs(topNs) { - return nil, nil, hists, nil - } +// 2. `[]TopNMeta` is the left topN value from the partition-level TopNs, +// but is not placed to global-level TopN. We should put them back to histogram latter. +// 3. `[]*Histogram` are the partition-level histograms which +// just delete some values when we merge the global-level topN. +func MergePartTopN2GlobalTopN( + loc *time.Location, + version int, + topNs []*statistics.TopN, + n uint32, + hists []*statistics.Histogram, + isIndex bool, + killed *uint32, +) (*statistics.TopN, []statistics.TopNMeta, []*statistics.Histogram, error) { partNum := len(topNs) // Different TopN structures may hold the same value, we have to merge them. counter := make(map[hack.MutableString]float64)