Skip to content

Commit

Permalink
executor: refactor hashjoin part6 (#39531)
Browse files Browse the repository at this point in the history
ref #39061
  • Loading branch information
XuHuaiyu authored Dec 1, 2022
1 parent 3f3e102 commit 7dedfab
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 50 deletions.
5 changes: 3 additions & 2 deletions executor/benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,7 @@ func prepare4HashJoin(testCase *hashJoinTestCase, innerExec, outerExec Executor)
e := &HashJoinExec{
baseExecutor: newBaseExecutor(testCase.ctx, joinSchema, 5, innerExec, outerExec),
hashJoinCtx: &hashJoinCtx{
sessCtx: testCase.ctx,
joinType: testCase.joinType, // 0 for InnerJoin, 1 for LeftOutersJoin, 2 for RightOuterJoin
isOuterJoin: false,
useOuterToBuild: testCase.useOuterToBuild,
Expand All @@ -936,13 +937,13 @@ func prepare4HashJoin(testCase *hashJoinTestCase, innerExec, outerExec Executor)
for i := uint(0); i < e.concurrency; i++ {
e.probeWorkers[i] = &probeWorker{
workerID: i,
sessCtx: e.ctx,
hashJoinCtx: e.hashJoinCtx,
joiner: newJoiner(testCase.ctx, e.joinType, true, defaultValues,
nil, lhsTypes, rhsTypes, childrenUsedSchema, false),
probeKeyColIdx: probeKeysColIdx,
}
}
e.buildWorker.hashJoinCtx = e.hashJoinCtx
memLimit := int64(-1)
if testCase.disk {
memLimit = 1
Expand Down Expand Up @@ -1200,7 +1201,7 @@ func benchmarkBuildHashTable(b *testing.B, casTest *hashJoinTestCase, dataSource
close(innerResultCh)

b.StartTimer()
if err := exec.buildHashTableForList(innerResultCh); err != nil {
if err := exec.buildWorker.buildHashTableForList(innerResultCh); err != nil {
b.Fatal(err)
}

Expand Down
5 changes: 3 additions & 2 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1417,12 +1417,14 @@ func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) Executo
probeWorkers: make([]*probeWorker, v.Concurrency),
buildWorker: &buildWorker{},
hashJoinCtx: &hashJoinCtx{
sessCtx: b.ctx,
isOuterJoin: v.JoinType.IsOuterJoin(),
useOuterToBuild: v.UseOuterToBuild,
joinType: v.JoinType,
concurrency: v.Concurrency,
},
}
e.hashJoinCtx.allocPool = e.AllocPool
defaultValues := v.DefaultValues
lhsTypes, rhsTypes := retTypes(leftExec), retTypes(rightExec)
if v.InnerChildIdx == 1 {
Expand Down Expand Up @@ -1494,13 +1496,12 @@ func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) Executo
e.probeWorkers[i] = &probeWorker{
hashJoinCtx: e.hashJoinCtx,
workerID: i,
sessCtx: e.ctx,
joiner: newJoiner(b.ctx, v.JoinType, v.InnerChildIdx == 0, defaultValues, v.OtherConditions, lhsTypes, rhsTypes, childrenUsedSchema, isNAJoin),
probeKeyColIdx: probeKeyColIdx,
probeNAKeyColIdx: probeNAKeColIdx,
}
}
e.buildWorker.buildKeyColIdx, e.buildWorker.buildNAKeyColIdx, e.buildWorker.buildSideExec = buildKeyColIdx, buildNAKeyColIdx, buildSideExec
e.buildWorker.buildKeyColIdx, e.buildWorker.buildNAKeyColIdx, e.buildWorker.buildSideExec, e.buildWorker.hashJoinCtx = buildKeyColIdx, buildNAKeyColIdx, buildSideExec, e.hashJoinCtx
e.hashJoinCtx.isNullAware = isNAJoin
executorCountHashJoinExec.Inc()

Expand Down
95 changes: 49 additions & 46 deletions executor/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ var (
)

type hashJoinCtx struct {
sessCtx sessionctx.Context
allocPool chunk.Allocator
// concurrency is the number of partition, build and join workers.
concurrency uint
joinResultCh chan *hashjoinWorkerResult
Expand All @@ -65,6 +67,8 @@ type hashJoinCtx struct {
buildTypes []*types.FieldType
outerFilter expression.CNFExprs
isNullAware bool
memTracker *memory.Tracker // track memory usage.
diskTracker *disk.Tracker // track disk usage.
}

// probeSideTupleFetcher reads tuples from probeSideExec and send them to probeWorkers.
Expand All @@ -79,7 +83,6 @@ type probeSideTupleFetcher struct {

type probeWorker struct {
hashJoinCtx *hashJoinCtx
sessCtx sessionctx.Context
workerID uint

probeKeyColIdx []int
Expand All @@ -102,6 +105,7 @@ type probeWorker struct {
}

type buildWorker struct {
hashJoinCtx *hashJoinCtx
buildSideExec Executor
buildKeyColIdx []int
buildNAKeyColIdx []int
Expand All @@ -116,11 +120,8 @@ type HashJoinExec struct {
probeWorkers []*probeWorker
buildWorker *buildWorker

worker util.WaitGroupWrapper
waiter util.WaitGroupWrapper

memTracker *memory.Tracker // track memory usage.
diskTracker *disk.Tracker // track disk usage.
workerWg util.WaitGroupWrapper
waiterWg util.WaitGroupWrapper

prepared bool
}
Expand Down Expand Up @@ -169,7 +170,7 @@ func (e *HashJoinExec) Close() error {
}
e.probeSideTupleFetcher.probeChkResourceCh = nil
terror.Call(e.rowContainer.Close)
e.waiter.Wait()
e.waiterWg.Wait()
}
e.outerMatchedStatus = e.outerMatchedStatus[:0]
for _, w := range e.probeWorkers {
Expand Down Expand Up @@ -198,14 +199,14 @@ func (e *HashJoinExec) Open(ctx context.Context) error {
return err
}
e.prepared = false
e.memTracker = memory.NewTracker(e.id, -1)
e.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker)
e.hashJoinCtx.memTracker = memory.NewTracker(e.id, -1)
e.hashJoinCtx.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker)

e.diskTracker = disk.NewTracker(e.id, -1)
e.diskTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.DiskTracker)

e.worker = util.WaitGroupWrapper{}
e.waiter = util.WaitGroupWrapper{}
e.workerWg = util.WaitGroupWrapper{}
e.waiterWg = util.WaitGroupWrapper{}
e.closeCh = make(chan struct{})
e.finished.Store(false)

Expand Down Expand Up @@ -295,7 +296,7 @@ func (fetcher *probeSideTupleFetcher) wait4BuildSide() (emptyBuild bool, err err

// fetchBuildSideRows fetches all rows from build side executor, and append them
// to e.buildSideResult.
func (e *HashJoinExec) fetchBuildSideRows(ctx context.Context, chkCh chan<- *chunk.Chunk, errCh chan<- error, doneCh <-chan struct{}) {
func (w *buildWorker) fetchBuildSideRows(ctx context.Context, chkCh chan<- *chunk.Chunk, errCh chan<- error, doneCh <-chan struct{}) {
defer close(chkCh)
var err error
failpoint.Inject("issue30289", func(val failpoint.Value) {
Expand All @@ -305,12 +306,13 @@ func (e *HashJoinExec) fetchBuildSideRows(ctx context.Context, chkCh chan<- *chu
return
}
})
sessVars := w.hashJoinCtx.sessCtx.GetSessionVars()
for {
if e.finished.Load() {
if w.hashJoinCtx.finished.Load() {
return
}
chk := e.ctx.GetSessionVars().GetNewChunkWithCapacity(e.buildWorker.buildSideExec.base().retFieldTypes, e.ctx.GetSessionVars().MaxChunkSize, e.ctx.GetSessionVars().MaxChunkSize, e.AllocPool)
err = Next(ctx, e.buildWorker.buildSideExec, chk)
chk := sessVars.GetNewChunkWithCapacity(w.buildSideExec.base().retFieldTypes, sessVars.MaxChunkSize, sessVars.MaxChunkSize, w.hashJoinCtx.allocPool)
err = Next(ctx, w.buildSideExec, chk)
if err != nil {
errCh <- errors.Trace(err)
return
Expand All @@ -323,7 +325,7 @@ func (e *HashJoinExec) fetchBuildSideRows(ctx context.Context, chkCh chan<- *chu
select {
case <-doneCh:
return
case <-e.closeCh:
case <-w.hashJoinCtx.closeCh:
return
case chkCh <- chk:
}
Expand Down Expand Up @@ -366,19 +368,19 @@ func (e *HashJoinExec) initializeForProbe() {

func (e *HashJoinExec) fetchAndProbeHashTable(ctx context.Context) {
e.initializeForProbe()
e.worker.RunWithRecover(func() {
e.workerWg.RunWithRecover(func() {
defer trace.StartRegion(ctx, "HashJoinProbeSideFetcher").End()
e.probeSideTupleFetcher.fetchProbeSideChunks(ctx, e.maxChunkSize)
}, e.probeSideTupleFetcher.handleProbeSideFetcherPanic)

for i := uint(0); i < e.concurrency; i++ {
workerID := i
e.worker.RunWithRecover(func() {
e.workerWg.RunWithRecover(func() {
defer trace.StartRegion(ctx, "HashJoinWorker").End()
e.probeWorkers[workerID].runJoinWorker()
}, e.probeWorkers[workerID].handleProbeWorkerPanic)
}
e.waiter.RunWithRecover(e.waitJoinWorkersAndCloseResultChan, nil)
e.waiterWg.RunWithRecover(e.waitJoinWorkersAndCloseResultChan, nil)
}

func (fetcher *probeSideTupleFetcher) handleProbeSideFetcherPanic(r interface{}) {
Expand Down Expand Up @@ -439,14 +441,14 @@ func (w *probeWorker) handleUnmatchedRowsFromHashTable() {
}

func (e *HashJoinExec) waitJoinWorkersAndCloseResultChan() {
e.worker.Wait()
e.workerWg.Wait()
if e.useOuterToBuild {
// Concurrently handling unmatched rows from the hash table at the tail
for i := uint(0); i < e.concurrency; i++ {
var workerID = i
e.worker.RunWithRecover(func() { e.probeWorkers[workerID].handleUnmatchedRowsFromHashTable() }, e.handleJoinWorkerPanic)
e.workerWg.RunWithRecover(func() { e.probeWorkers[workerID].handleUnmatchedRowsFromHashTable() }, e.handleJoinWorkerPanic)
}
e.worker.Wait()
e.workerWg.Wait()
}
close(e.joinResultCh)
}
Expand Down Expand Up @@ -954,7 +956,7 @@ func (w *probeWorker) getNewJoinResult() (bool, *hashjoinWorkerResult) {
func (w *probeWorker) join2Chunk(probeSideChk *chunk.Chunk, hCtx *hashContext, joinResult *hashjoinWorkerResult,
selected []bool) (ok bool, _ *hashjoinWorkerResult) {
var err error
selected, err = expression.VectorizedFilter(w.sessCtx, w.hashJoinCtx.outerFilter, chunk.NewIterator4Chunk(probeSideChk), selected)
selected, err = expression.VectorizedFilter(w.hashJoinCtx.sessCtx, w.hashJoinCtx.outerFilter, chunk.NewIterator4Chunk(probeSideChk), selected)
if err != nil {
joinResult.err = err
return false, joinResult
Expand Down Expand Up @@ -994,7 +996,7 @@ func (w *probeWorker) join2Chunk(probeSideChk *chunk.Chunk, hCtx *hashContext, j
}

for i := range selected {
killed := atomic.LoadUint32(&w.sessCtx.GetSessionVars().Killed) == 1
killed := atomic.LoadUint32(&w.hashJoinCtx.sessCtx.GetSessionVars().Killed) == 1
failpoint.Inject("killedInJoin2Chunk", func(val failpoint.Value) {
if val.(bool) {
killed = true
Expand Down Expand Up @@ -1060,7 +1062,7 @@ func (w *probeWorker) join2ChunkForOuterHashJoin(probeSideChk *chunk.Chunk, hCtx
}
}
for i := 0; i < probeSideChk.NumRows(); i++ {
killed := atomic.LoadUint32(&w.sessCtx.GetSessionVars().Killed) == 1
killed := atomic.LoadUint32(&w.hashJoinCtx.sessCtx.GetSessionVars().Killed) == 1
failpoint.Inject("killedInJoin2ChunkForOuterHashJoin", func(val failpoint.Value) {
if val.(bool) {
killed = true
Expand Down Expand Up @@ -1110,7 +1112,7 @@ func (e *HashJoinExec) Next(ctx context.Context, req *chunk.Chunk) (err error) {
for i := uint(0); i < e.concurrency; i++ {
e.probeWorkers[i].rowIters = chunk.NewIterator4Slice([]chunk.Row{}).(*chunk.Iterator4Slice)
}
e.worker.RunWithRecover(func() {
e.workerWg.RunWithRecover(func() {
defer trace.StartRegion(ctx, "HashJoinHashTableBuilder").End()
e.fetchAndBuildHashTable(ctx)
}, e.handleFetchAndBuildHashTablePanic)
Expand Down Expand Up @@ -1153,10 +1155,10 @@ func (e *HashJoinExec) fetchAndBuildHashTable(ctx context.Context) {
buildSideResultCh := make(chan *chunk.Chunk, 1)
doneCh := make(chan struct{})
fetchBuildSideRowsOk := make(chan error, 1)
e.worker.RunWithRecover(
e.workerWg.RunWithRecover(
func() {
defer trace.StartRegion(ctx, "HashJoinBuildSideFetcher").End()
e.fetchBuildSideRows(ctx, buildSideResultCh, fetchBuildSideRowsOk, doneCh)
e.buildWorker.fetchBuildSideRows(ctx, buildSideResultCh, fetchBuildSideRowsOk, doneCh)
},
func(r interface{}) {
if r != nil {
Expand All @@ -1167,7 +1169,7 @@ func (e *HashJoinExec) fetchAndBuildHashTable(ctx context.Context) {
)

// TODO: Parallel build hash table. Currently not support because `unsafeHashTable` is not thread-safe.
err := e.buildHashTableForList(buildSideResultCh)
err := e.buildWorker.buildHashTableForList(buildSideResultCh)
if err != nil {
e.buildFinished <- errors.Trace(err)
close(doneCh)
Expand All @@ -1185,41 +1187,42 @@ func (e *HashJoinExec) fetchAndBuildHashTable(ctx context.Context) {
}

// buildHashTableForList builds hash table from `list`.
func (e *HashJoinExec) buildHashTableForList(buildSideResultCh <-chan *chunk.Chunk) error {
func (w *buildWorker) buildHashTableForList(buildSideResultCh <-chan *chunk.Chunk) error {
var err error
var selected []bool
e.rowContainer.GetMemTracker().AttachTo(e.memTracker)
e.rowContainer.GetMemTracker().SetLabel(memory.LabelForBuildSideResult)
e.rowContainer.GetDiskTracker().AttachTo(e.diskTracker)
e.rowContainer.GetDiskTracker().SetLabel(memory.LabelForBuildSideResult)
rowContainer := w.hashJoinCtx.rowContainer
rowContainer.GetMemTracker().AttachTo(w.hashJoinCtx.memTracker)
rowContainer.GetMemTracker().SetLabel(memory.LabelForBuildSideResult)
rowContainer.GetDiskTracker().AttachTo(w.hashJoinCtx.diskTracker)
rowContainer.GetDiskTracker().SetLabel(memory.LabelForBuildSideResult)
if variable.EnableTmpStorageOnOOM.Load() {
actionSpill := e.rowContainer.ActionSpill()
actionSpill := rowContainer.ActionSpill()
failpoint.Inject("testRowContainerSpill", func(val failpoint.Value) {
if val.(bool) {
actionSpill = e.rowContainer.rowContainer.ActionSpillForTest()
actionSpill = rowContainer.rowContainer.ActionSpillForTest()
defer actionSpill.(*chunk.SpillDiskAction).WaitForTest()
}
})
e.ctx.GetSessionVars().MemTracker.FallbackOldAndSetNewAction(actionSpill)
w.hashJoinCtx.sessCtx.GetSessionVars().MemTracker.FallbackOldAndSetNewAction(actionSpill)
}
for chk := range buildSideResultCh {
if e.finished.Load() {
if w.hashJoinCtx.finished.Load() {
return nil
}
if !e.useOuterToBuild {
err = e.rowContainer.PutChunk(chk, e.isNullEQ)
if !w.hashJoinCtx.useOuterToBuild {
err = rowContainer.PutChunk(chk, w.hashJoinCtx.isNullEQ)
} else {
var bitMap = bitmap.NewConcurrentBitmap(chk.NumRows())
e.outerMatchedStatus = append(e.outerMatchedStatus, bitMap)
e.memTracker.Consume(bitMap.BytesConsumed())
if len(e.outerFilter) == 0 {
err = e.rowContainer.PutChunk(chk, e.isNullEQ)
w.hashJoinCtx.outerMatchedStatus = append(w.hashJoinCtx.outerMatchedStatus, bitMap)
w.hashJoinCtx.memTracker.Consume(bitMap.BytesConsumed())
if len(w.hashJoinCtx.outerFilter) == 0 {
err = w.hashJoinCtx.rowContainer.PutChunk(chk, w.hashJoinCtx.isNullEQ)
} else {
selected, err = expression.VectorizedFilter(e.ctx, e.outerFilter, chunk.NewIterator4Chunk(chk), selected)
selected, err = expression.VectorizedFilter(w.hashJoinCtx.sessCtx, w.hashJoinCtx.outerFilter, chunk.NewIterator4Chunk(chk), selected)
if err != nil {
return err
}
err = e.rowContainer.PutChunkSelected(chk, selected, e.isNullEQ)
err = rowContainer.PutChunkSelected(chk, selected, w.hashJoinCtx.isNullEQ)
}
}
failpoint.Inject("ConsumeRandomPanic", nil)
Expand Down

0 comments on commit 7dedfab

Please sign in to comment.