Skip to content

Commit

Permalink
executor: improve parallel hash aggregation (#47428)
Browse files Browse the repository at this point in the history
close #47427
  • Loading branch information
xzhangxian1008 authored Nov 3, 2023
1 parent d99c5a5 commit a0d2409
Show file tree
Hide file tree
Showing 7 changed files with 272 additions and 116 deletions.
2 changes: 1 addition & 1 deletion pkg/executor/aggregate/agg_hash_base_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (w *baseHashAggWorker) getPartialResult(_ *stmtctx.StatementContext, groupK

func (w *baseHashAggWorker) getPartialResultSliceLenConsiderByteAlign() int {
length := len(w.aggFuncs)
if len(w.aggFuncs) == 1 {
if length == 1 {
return 1
}
return length + length&1
Expand Down
39 changes: 27 additions & 12 deletions pkg/executor/aggregate/agg_hash_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ type HashAggExec struct {

finishCh chan struct{}
finalOutputCh chan *AfFinalResult
partialOutputChs []chan *HashAggIntermData
partialOutputChs []chan *AggPartialResultMapper
inputCh chan *HashAggInput
partialInputChs []chan *chunk.Chunk
partialWorkers []HashAggPartialWorker
Expand Down Expand Up @@ -264,9 +264,9 @@ func (e *HashAggExec) initForParallelExec(_ sessionctx.Context) {
for i := range e.partialInputChs {
e.partialInputChs[i] = make(chan *chunk.Chunk, 1)
}
e.partialOutputChs = make([]chan *HashAggIntermData, finalConcurrency)
e.partialOutputChs = make([]chan *AggPartialResultMapper, finalConcurrency)
for i := range e.partialOutputChs {
e.partialOutputChs[i] = make(chan *HashAggIntermData, partialConcurrency)
e.partialOutputChs[i] = make(chan *AggPartialResultMapper, partialConcurrency)
}

e.partialWorkers = make([]HashAggPartialWorker, partialConcurrency)
Expand All @@ -275,17 +275,30 @@ func (e *HashAggExec) initForParallelExec(_ sessionctx.Context) {

// Init partial workers.
for i := 0; i < partialConcurrency; i++ {
partialResultsMap := make([]AggPartialResultMapper, finalConcurrency)
for i := 0; i < finalConcurrency; i++ {
partialResultsMap[i] = make(AggPartialResultMapper)
}

w := HashAggPartialWorker{
baseHashAggWorker: newBaseHashAggWorker(e.Ctx(), e.finishCh, e.PartialAggFuncs, e.MaxChunkSize(), e.memTracker),
inputCh: e.partialInputChs[i],
outputChs: e.partialOutputChs,
giveBackCh: e.inputCh,
globalOutputCh: e.finalOutputCh,
partialResultsMap: make(AggPartialResultMapper),
groupByItems: e.GroupByItems,
chk: exec.TryNewCacheChunk(e.Children(0)),
groupKey: make([][]byte, 0, 8),
baseHashAggWorker: newBaseHashAggWorker(e.Ctx(), e.finishCh, e.PartialAggFuncs, e.MaxChunkSize(), e.memTracker),
inputCh: e.partialInputChs[i],
outputChs: e.partialOutputChs,
giveBackCh: e.inputCh,
BInMaps: make([]int, finalConcurrency),
partialResultsBuffer: make([][]aggfuncs.PartialResult, 0, 2048),
globalOutputCh: e.finalOutputCh,
partialResultsMap: partialResultsMap,
groupByItems: e.GroupByItems,
chk: exec.TryNewCacheChunk(e.Children(0)),
groupKey: make([][]byte, 0, 8),
}

w.partialResultNumInRow = w.getPartialResultSliceLenConsiderByteAlign()
for i := 0; i < finalConcurrency; i++ {
w.BInMaps[i] = 0
}

// There is a bucket in the empty partialResultsMap.
failpoint.Inject("ConsumeRandomPanic", nil)
e.memTracker.Consume(hack.DefBucketMemoryUsageForMapStrToSlice * (1 << w.BInMap))
Expand All @@ -309,6 +322,8 @@ func (e *HashAggExec) initForParallelExec(_ sessionctx.Context) {
w := HashAggFinalWorker{
baseHashAggWorker: newBaseHashAggWorker(e.Ctx(), e.finishCh, e.FinalAggFuncs, e.MaxChunkSize(), e.memTracker),
partialResultMap: make(AggPartialResultMapper),
BInMap: 0,
isFirstInput: true,
groupSet: groupSet,
inputCh: e.partialOutputChs[i],
outputCh: e.finalOutputCh,
Expand Down
101 changes: 53 additions & 48 deletions pkg/executor/aggregate/agg_hash_final_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ import (
"time"

"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/pkg/executor/aggfuncs"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/hack"
"github.com/pingcap/tidb/pkg/util/logutil"
"github.com/pingcap/tidb/pkg/util/set"
"go.uber.org/zap"
Expand All @@ -43,14 +43,16 @@ type HashAggFinalWorker struct {
rowBuffer []types.Datum
mutableRow chunk.MutRow
partialResultMap AggPartialResultMapper
BInMap int
isFirstInput bool
groupSet set.StringSetWithMemoryUsage
inputCh chan *HashAggIntermData
inputCh chan *AggPartialResultMapper
outputCh chan *AfFinalResult
finalResultHolderCh chan *chunk.Chunk
groupKeys [][]byte
}

func (w *HashAggFinalWorker) getPartialInput() (input *HashAggIntermData, ok bool) {
func (w *HashAggFinalWorker) getPartialInput() (input *AggPartialResultMapper, ok bool) {
select {
case <-w.finishCh:
return nil, false
Expand All @@ -62,55 +64,60 @@ func (w *HashAggFinalWorker) getPartialInput() (input *HashAggIntermData, ok boo
return
}

func (w *HashAggFinalWorker) initBInMap() {
w.BInMap = 0
mapLen := len(w.partialResultMap)
for mapLen > (1<<w.BInMap)*hack.LoadFactorNum/hack.LoadFactorDen {
w.BInMap++
}
}

func (w *HashAggFinalWorker) consumeIntermData(sctx sessionctx.Context) (err error) {
var (
input *HashAggIntermData
ok bool
intermDataBuffer [][]aggfuncs.PartialResult
groupKeys []string
sc = sctx.GetSessionVars().StmtCtx
)
for {
waitStart := time.Now()
input, ok = w.getPartialInput()
input, ok := w.getPartialInput()
if w.stats != nil {
w.stats.WaitTime += int64(time.Since(waitStart))
}
if !ok {
return nil
}
execStart := time.Now()
if intermDataBuffer == nil {
intermDataBuffer = make([][]aggfuncs.PartialResult, 0, w.maxChunkSize)

// As the w.partialResultMap is empty when we get the first input.
// So it's better to directly assign the input to w.partialResultMap
if w.isFirstInput {
w.isFirstInput = false
w.partialResultMap = *input
w.initBInMap()
continue
}
// Consume input in batches, size of every batch is less than w.maxChunkSize.
for reachEnd := false; !reachEnd; {
intermDataBuffer, groupKeys, reachEnd = input.getPartialResultBatch(sc, intermDataBuffer[:0], w.aggFuncs, w.maxChunkSize)
groupKeysLen := len(groupKeys)
memSize := getGroupKeyMemUsage(w.groupKeys)
w.groupKeys = w.groupKeys[:0]
for i := 0; i < groupKeysLen; i++ {
w.groupKeys = append(w.groupKeys, []byte(groupKeys[i]))
}
failpoint.Inject("ConsumeRandomPanic", nil)
w.memTracker.Consume(getGroupKeyMemUsage(w.groupKeys) - memSize)
finalPartialResults := w.getPartialResult(sc, w.groupKeys, w.partialResultMap)
allMemDelta := int64(0)
for i, groupKey := range groupKeys {
if !w.groupSet.Exist(groupKey) {
allMemDelta += w.groupSet.Insert(groupKey)

failpoint.Inject("ConsumeRandomPanic", nil)

execStart := time.Now()
allMemDelta := int64(0)
for key, value := range *input {
dstVal, ok := w.partialResultMap[key]
if !ok {
// Map will expand when count > bucketNum * loadFactor. The memory usage will double.
if len(w.partialResultMap)+1 > (1<<w.BInMap)*hack.LoadFactorNum/hack.LoadFactorDen {
w.memTracker.Consume(hack.DefBucketMemoryUsageForMapStrToSlice * (1 << w.BInMap))
w.BInMap++
}
prs := intermDataBuffer[i]
for j, af := range w.aggFuncs {
memDelta, err := af.MergePartialResult(sctx, prs[j], finalPartialResults[i][j])
if err != nil {
return err
}
allMemDelta += memDelta
w.partialResultMap[key] = value
continue
}

for j, af := range w.aggFuncs {
memDelta, err := af.MergePartialResult(sctx, value[j], dstVal[j])
if err != nil {
return err
}
allMemDelta += memDelta
}
w.memTracker.Consume(allMemDelta)
}
w.memTracker.Consume(allMemDelta)

if w.stats != nil {
w.stats.ExecTime += int64(time.Since(execStart))
w.stats.TaskNum++
Expand All @@ -127,24 +134,21 @@ func (w *HashAggFinalWorker) loadFinalResult(sctx sessionctx.Context) {
if finished {
return
}
execStart := time.Now()
memSize := getGroupKeyMemUsage(w.groupKeys)
w.groupKeys = w.groupKeys[:0]
for groupKey := range w.groupSet.StringSet {
w.groupKeys = append(w.groupKeys, []byte(groupKey))
}

failpoint.Inject("ConsumeRandomPanic", nil)
w.memTracker.Consume(getGroupKeyMemUsage(w.groupKeys) - memSize)
partialResults := w.getPartialResult(sctx.GetSessionVars().StmtCtx, w.groupKeys, w.partialResultMap)
for i := 0; i < len(w.groupSet.StringSet); i++ {

execStart := time.Now()
for _, results := range w.partialResultMap {
for j, af := range w.aggFuncs {
if err := af.AppendFinalResult2Chunk(sctx, partialResults[i][j], result); err != nil {
if err := af.AppendFinalResult2Chunk(sctx, results[j], result); err != nil {
logutil.BgLogger().Error("HashAggFinalWorker failed to append final result to Chunk", zap.Error(err))
}
}

if len(w.aggFuncs) == 0 {
result.SetNumVirtualRows(result.NumRows() + 1)
}

if result.IsFull() {
w.outputCh <- &AfFinalResult{chk: result, giveBackCh: w.finalResultHolderCh}
result, finished = w.receiveFinalResultHolder()
Expand All @@ -153,6 +157,7 @@ func (w *HashAggFinalWorker) loadFinalResult(sctx sessionctx.Context) {
}
}
}

w.outputCh <- &AfFinalResult{chk: result, giveBackCh: w.finalResultHolderCh}
if w.stats != nil {
w.stats.ExecTime += int64(time.Since(execStart))
Expand Down
Loading

0 comments on commit a0d2409

Please sign in to comment.