Skip to content

Commit

Permalink
*: fix wrong result when to concurrency merge global stats (#48852) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ti-chi-bot authored Feb 19, 2024
1 parent b19d950 commit 24213f2
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 73 deletions.
1 change: 1 addition & 0 deletions statistics/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ go_library(
"analyze_jobs.go",
"builder.go",
"cmsketch.go",
"cmsketch_util.go",
"column.go",
"estimate.go",
"feedback.go",
Expand Down
7 changes: 4 additions & 3 deletions statistics/cmsketch.go
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ func NewTopN(n int) *TopN {
// 3. `[]*Histogram` are the partition-level histograms which just delete some values when we merge the global-level topN.
func MergePartTopN2GlobalTopN(loc *time.Location, version int, topNs []*TopN, n uint32, hists []*Histogram,
isIndex bool, kiiled *uint32) (*TopN, []TopNMeta, []*Histogram, error) {
if checkEmptyTopNs(topNs) {
if CheckEmptyTopNs(topNs) {
return nil, nil, hists, nil
}
partNum := len(topNs)
Expand Down Expand Up @@ -835,7 +835,7 @@ func MergePartTopN2GlobalTopN(loc *time.Location, version int, topNs []*TopN, n
// The output parameters are the newly generated TopN structure and the remaining numbers.
// Notice: The n can be 0. So n has no default value, we must explicitly specify this value.
func MergeTopN(topNs []*TopN, n uint32) (*TopN, []TopNMeta) {
if checkEmptyTopNs(topNs) {
if CheckEmptyTopNs(topNs) {
return nil, nil
}
// Different TopN structures may hold the same value, we have to merge them.
Expand All @@ -860,7 +860,8 @@ func MergeTopN(topNs []*TopN, n uint32) (*TopN, []TopNMeta) {
return getMergedTopNFromSortedSlice(sorted, n)
}

func checkEmptyTopNs(topNs []*TopN) bool {
// CheckEmptyTopNs checks whether all TopNs are empty.
func CheckEmptyTopNs(topNs []*TopN) bool {
count := uint64(0)
for _, topN := range topNs {
count += topN.TotalCount()
Expand Down
74 changes: 74 additions & 0 deletions statistics/cmsketch_util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package statistics

import (
"time"

"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/codec"
"github.com/pingcap/tidb/util/hack"
)

// DatumMapCache is used to store the mapping from the string type to datum type.
// The datum is used to find the value in the histogram.
type DatumMapCache struct {
datumMap map[hack.MutableString]types.Datum
}

// NewDatumMapCache creates a new DatumMapCache.
func NewDatumMapCache() *DatumMapCache {
return &DatumMapCache{
datumMap: make(map[hack.MutableString]types.Datum),
}
}

// Get gets the datum from the cache.
func (d *DatumMapCache) Get(key hack.MutableString) (val types.Datum, ok bool) {
val, ok = d.datumMap[key]
return
}

// Put puts the datum into the cache.
func (d *DatumMapCache) Put(val TopNMeta, encodedVal hack.MutableString,
tp byte, isIndex bool, loc *time.Location) (dat types.Datum, err error) {
dat, err = topNMetaToDatum(val, tp, isIndex, loc)
if err != nil {
return dat, err
}
d.datumMap[encodedVal] = dat
return dat, nil
}

func topNMetaToDatum(val TopNMeta,
tp byte, isIndex bool, loc *time.Location) (dat types.Datum, err error) {
if isIndex {
dat.SetBytes(val.Encoded)
} else {
var err error
if types.IsTypeTime(tp) {
// Handle date time values specially since they are encoded to int and we'll get int values if using DecodeOne.
_, dat, err = codec.DecodeAsDateTime(val.Encoded, tp, loc)
} else if types.IsTypeFloat(tp) {
_, dat, err = codec.DecodeAsFloat32(val.Encoded, tp)
} else {
_, dat, err = codec.DecodeOne(val.Encoded)
}
if err != nil {
return dat, err
}
}
return dat, err
}
1 change: 1 addition & 0 deletions statistics/handle/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ go_library(
"//util/chunk",
"//util/codec",
"//util/collate",
"//util/hack",
"//util/logutil",
"//util/mathutil",
"//util/memory",
Expand Down
23 changes: 13 additions & 10 deletions statistics/handle/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import (
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tidb/util/logutil"
"github.com/pingcap/tidb/util/mathutil"
"github.com/pingcap/tidb/util/memory"
Expand Down Expand Up @@ -856,6 +857,9 @@ func (h *Handle) mergePartitionStats2GlobalStats(sc sessionctx.Context,
func mergeGlobalStatsTopN(sc sessionctx.Context, wrapper *statistics.StatsWrapper,
timeZone *time.Location, version int, n uint32, isIndex bool) (*statistics.TopN,
[]statistics.TopNMeta, []*statistics.Histogram, error) {
if statistics.CheckEmptyTopNs(wrapper.AllTopN) {
return nil, nil, wrapper.AllHg, nil
}
mergeConcurrency := sc.GetSessionVars().AnalyzePartitionMergeConcurrency
killed := &sc.GetSessionVars().Killed
// use original method if concurrency equals 1 or for version1
Expand Down Expand Up @@ -895,10 +899,10 @@ func MergeGlobalStatsTopNByConcurrency(mergeConcurrency, mergeBatchSize int, wra
taskNum := len(tasks)
taskCh := make(chan *statistics.TopnStatsMergeTask, taskNum)
respCh := make(chan *statistics.TopnStatsMergeResponse, taskNum)
worker := statistics.NewTopnStatsMergeWorker(taskCh, respCh, wrapper, killed)
for i := 0; i < mergeConcurrency; i++ {
worker := statistics.NewTopnStatsMergeWorker(taskCh, respCh, wrapper, killed)
wg.Run(func() {
worker.Run(timeZone, isIndex, n, version)
worker.Run(timeZone, isIndex, version)
})
}
for _, task := range tasks {
Expand All @@ -924,17 +928,16 @@ func MergeGlobalStatsTopNByConcurrency(mergeConcurrency, mergeBatchSize int, wra
}

// fetch the response from each worker and merge them into global topn stats
sorted := make([]statistics.TopNMeta, 0, mergeConcurrency)
leftTopn := make([]statistics.TopNMeta, 0)
for _, resp := range resps {
if resp.TopN != nil {
sorted = append(sorted, resp.TopN.TopN...)
}
leftTopn = append(leftTopn, resp.PopedTopn...)
counter := worker.Result()
numTop := len(counter)
sorted := make([]statistics.TopNMeta, 0, numTop)
for value, cnt := range counter {
data := hack.Slice(string(value))
sorted = append(sorted, statistics.TopNMeta{Encoded: data, Count: uint64(cnt)})
}

globalTopN, popedTopn := statistics.GetMergedTopNFromSortedSlice(sorted, n)
return globalTopN, statistics.SortTopnMeta(append(leftTopn, popedTopn...)), wrapper.AllHg, nil
return globalTopN, popedTopn, wrapper.AllHg, nil
}

func (h *Handle) getTableByPhysicalID(is infoschema.InfoSchema, physicalID int64) (table.Table, bool) {
Expand Down
92 changes: 32 additions & 60 deletions statistics/merge_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ import (
"time"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/codec"
"github.com/pingcap/tidb/util/hack"
)

Expand All @@ -45,8 +43,11 @@ type topnStatsMergeWorker struct {
respCh chan<- *TopnStatsMergeResponse
// the stats in the wrapper should only be read during the worker
statsWrapper *StatsWrapper
// Different TopN structures may hold the same value, we have to merge them.
counter map[hack.MutableString]float64
// shardMutex is used to protect `statsWrapper.AllHg`
shardMutex []sync.Mutex
mu sync.Mutex
}

// NewTopnStatsMergeWorker returns topn merge worker
Expand All @@ -56,8 +57,9 @@ func NewTopnStatsMergeWorker(
wrapper *StatsWrapper,
killed *uint32) *topnStatsMergeWorker {
worker := &topnStatsMergeWorker{
taskCh: taskCh,
respCh: respCh,
taskCh: taskCh,
respCh: respCh,
counter: make(map[hack.MutableString]float64),
}
worker.statsWrapper = wrapper
worker.shardMutex = make([]sync.Mutex, len(wrapper.AllHg))
Expand All @@ -81,33 +83,23 @@ func NewTopnStatsMergeTask(start, end int) *TopnStatsMergeTask {

// TopnStatsMergeResponse indicates topn merge worker response
type TopnStatsMergeResponse struct {
Err error
TopN *TopN
PopedTopn []TopNMeta
Err error
}

// Run runs topn merge like statistics.MergePartTopN2GlobalTopN
func (worker *topnStatsMergeWorker) Run(timeZone *time.Location, isIndex bool,
n uint32,
version int) {
func (worker *topnStatsMergeWorker) Run(timeZone *time.Location, isIndex bool, version int) {
for task := range worker.taskCh {
start := task.start
end := task.end
checkTopNs := worker.statsWrapper.AllTopN[start:end]
allTopNs := worker.statsWrapper.AllTopN
allHists := worker.statsWrapper.AllHg
resp := &TopnStatsMergeResponse{}
if checkEmptyTopNs(checkTopNs) {
worker.respCh <- resp
return
}
partNum := len(allTopNs)
// Different TopN structures may hold the same value, we have to merge them.
counter := make(map[hack.MutableString]float64)

// datumMap is used to store the mapping from the string type to datum type.
// The datum is used to find the value in the histogram.
datumMap := make(map[hack.MutableString]types.Datum)

datumMap := NewDatumMapCache()
for i, topN := range checkTopNs {
if atomic.LoadUint32(worker.killed) == 1 {
resp.Err = errors.Trace(ErrQueryInterrupted)
Expand All @@ -119,12 +111,15 @@ func (worker *topnStatsMergeWorker) Run(timeZone *time.Location, isIndex bool,
}
for _, val := range topN.TopN {
encodedVal := hack.String(val.Encoded)
_, exists := counter[encodedVal]
counter[encodedVal] += float64(val.Count)
worker.mu.Lock()
_, exists := worker.counter[encodedVal]
worker.counter[encodedVal] += float64(val.Count)
if exists {
worker.mu.Unlock()
// We have already calculated the encodedVal from the histogram, so just continue to next topN value.
continue
}
worker.mu.Unlock()
// We need to check whether the value corresponding to encodedVal is contained in other partition-level stats.
// 1. Check the topN first.
// 2. If the topN doesn't contain the value corresponding to encodedVal. We should check the histogram.
Expand All @@ -138,59 +133,36 @@ func (worker *topnStatsMergeWorker) Run(timeZone *time.Location, isIndex bool,
continue
}
// Get the encodedVal from the hists[j]
datum, exists := datumMap[encodedVal]
datum, exists := datumMap.Get(encodedVal)
if !exists {
// If the datumMap does not have the encodedVal datum,
// we should generate the datum based on the encoded value.
// This part is copied from the function MergePartitionHist2GlobalHist.
var d types.Datum
if isIndex {
d.SetBytes(val.Encoded)
} else {
var err error
if types.IsTypeTime(allHists[0].Tp.GetType()) {
// handle datetime values specially since they are encoded to int and we'll get int values if using DecodeOne.
_, d, err = codec.DecodeAsDateTime(val.Encoded, allHists[0].Tp.GetType(), timeZone)
} else if types.IsTypeFloat(allHists[0].Tp.GetType()) {
_, d, err = codec.DecodeAsFloat32(val.Encoded, allHists[0].Tp.GetType())
} else {
_, d, err = codec.DecodeOne(val.Encoded)
}
if err != nil {
resp.Err = err
worker.respCh <- resp
return
}
d, err := datumMap.Put(val, encodedVal, allHists[0].Tp.GetType(), isIndex, timeZone)
if err != nil {
resp.Err = err
worker.respCh <- resp
return
}
datumMap[encodedVal] = d
datum = d
}
worker.shardMutex[j].Lock()
// Get the row count which the value is equal to the encodedVal from histogram.
count, _ := allHists[j].equalRowCount(datum, isIndex)
if count != 0 {
counter[encodedVal] += count
// Remove the value corresponding to encodedVal from the histogram.
worker.shardMutex[j].Lock()
worker.statsWrapper.AllHg[j].BinarySearchRemoveVal(TopNMeta{Encoded: datum.GetBytes(), Count: uint64(count)})
worker.shardMutex[j].Unlock()
}
worker.shardMutex[j].Unlock()
if count != 0 {
worker.mu.Lock()
worker.counter[encodedVal] += count
worker.mu.Unlock()
}
}
}
}

numTop := len(counter)
if numTop == 0 {
worker.respCh <- resp
continue
}
sorted := make([]TopNMeta, 0, numTop)
for value, cnt := range counter {
data := hack.Slice(string(value))
sorted = append(sorted, TopNMeta{Encoded: data, Count: uint64(cnt)})
}
globalTopN, leftTopN := getMergedTopNFromSortedSlice(sorted, n)
resp.TopN = globalTopN
resp.PopedTopn = leftTopN
worker.respCh <- resp
}
}

func (worker *topnStatsMergeWorker) Result() map[hack.MutableString]float64 {
return worker.counter
}

0 comments on commit 24213f2

Please sign in to comment.