Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor: refactor hashjoin part6 #39531

Merged
merged 11 commits into from
Dec 1, 2022
5 changes: 3 additions & 2 deletions executor/benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,7 @@ func prepare4HashJoin(testCase *hashJoinTestCase, innerExec, outerExec Executor)
e := &HashJoinExec{
baseExecutor: newBaseExecutor(testCase.ctx, joinSchema, 5, innerExec, outerExec),
hashJoinCtx: &hashJoinCtx{
sessCtx: testCase.ctx,
joinType: testCase.joinType, // 0 for InnerJoin, 1 for LeftOutersJoin, 2 for RightOuterJoin
isOuterJoin: false,
useOuterToBuild: testCase.useOuterToBuild,
Expand All @@ -936,13 +937,13 @@ func prepare4HashJoin(testCase *hashJoinTestCase, innerExec, outerExec Executor)
for i := uint(0); i < e.concurrency; i++ {
e.probeWorkers[i] = &probeWorker{
workerID: i,
sessCtx: e.ctx,
hashJoinCtx: e.hashJoinCtx,
joiner: newJoiner(testCase.ctx, e.joinType, true, defaultValues,
nil, lhsTypes, rhsTypes, childrenUsedSchema, false),
probeKeyColIdx: probeKeysColIdx,
}
}
e.buildWorker.hashJoinCtx = e.hashJoinCtx
memLimit := int64(-1)
if testCase.disk {
memLimit = 1
Expand Down Expand Up @@ -1200,7 +1201,7 @@ func benchmarkBuildHashTable(b *testing.B, casTest *hashJoinTestCase, dataSource
close(innerResultCh)

b.StartTimer()
if err := exec.buildHashTableForList(innerResultCh); err != nil {
if err := exec.buildWorker.buildHashTableForList(innerResultCh); err != nil {
b.Fatal(err)
}

Expand Down
5 changes: 3 additions & 2 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1417,12 +1417,14 @@ func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) Executo
probeWorkers: make([]*probeWorker, v.Concurrency),
buildWorker: &buildWorker{},
hashJoinCtx: &hashJoinCtx{
sessCtx: b.ctx,
isOuterJoin: v.JoinType.IsOuterJoin(),
useOuterToBuild: v.UseOuterToBuild,
joinType: v.JoinType,
concurrency: v.Concurrency,
},
}
e.hashJoinCtx.allocPool = e.AllocPool
defaultValues := v.DefaultValues
lhsTypes, rhsTypes := retTypes(leftExec), retTypes(rightExec)
if v.InnerChildIdx == 1 {
Expand Down Expand Up @@ -1494,13 +1496,12 @@ func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) Executo
e.probeWorkers[i] = &probeWorker{
hashJoinCtx: e.hashJoinCtx,
workerID: i,
sessCtx: e.ctx,
joiner: newJoiner(b.ctx, v.JoinType, v.InnerChildIdx == 0, defaultValues, v.OtherConditions, lhsTypes, rhsTypes, childrenUsedSchema, isNAJoin),
probeKeyColIdx: probeKeyColIdx,
probeNAKeyColIdx: probeNAKeColIdx,
}
}
e.buildWorker.buildKeyColIdx, e.buildWorker.buildNAKeyColIdx, e.buildWorker.buildSideExec = buildKeyColIdx, buildNAKeyColIdx, buildSideExec
e.buildWorker.buildKeyColIdx, e.buildWorker.buildNAKeyColIdx, e.buildWorker.buildSideExec, e.buildWorker.hashJoinCtx = buildKeyColIdx, buildNAKeyColIdx, buildSideExec, e.hashJoinCtx
e.hashJoinCtx.isNullAware = isNAJoin
executorCountHashJoinExec.Inc()

Expand Down
95 changes: 49 additions & 46 deletions executor/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ var (
)

type hashJoinCtx struct {
sessCtx sessionctx.Context
allocPool chunk.Allocator
// concurrency is the number of partition, build and join workers.
concurrency uint
joinResultCh chan *hashjoinWorkerResult
Expand All @@ -65,6 +67,8 @@ type hashJoinCtx struct {
buildTypes []*types.FieldType
outerFilter expression.CNFExprs
isNullAware bool
memTracker *memory.Tracker // track memory usage.
diskTracker *disk.Tracker // track disk usage.
}

// probeSideTupleFetcher reads tuples from probeSideExec and send them to probeWorkers.
Expand All @@ -79,7 +83,6 @@ type probeSideTupleFetcher struct {

type probeWorker struct {
hashJoinCtx *hashJoinCtx
sessCtx sessionctx.Context
workerID uint

probeKeyColIdx []int
Expand All @@ -102,6 +105,7 @@ type probeWorker struct {
}

type buildWorker struct {
hashJoinCtx *hashJoinCtx
buildSideExec Executor
buildKeyColIdx []int
buildNAKeyColIdx []int
Expand All @@ -116,11 +120,8 @@ type HashJoinExec struct {
probeWorkers []*probeWorker
buildWorker *buildWorker

worker util.WaitGroupWrapper
waiter util.WaitGroupWrapper

memTracker *memory.Tracker // track memory usage.
diskTracker *disk.Tracker // track disk usage.
workerWg util.WaitGroupWrapper
waiterWg util.WaitGroupWrapper

prepared bool
}
Expand Down Expand Up @@ -169,7 +170,7 @@ func (e *HashJoinExec) Close() error {
}
e.probeSideTupleFetcher.probeChkResourceCh = nil
terror.Call(e.rowContainer.Close)
e.waiter.Wait()
e.waiterWg.Wait()
}
e.outerMatchedStatus = e.outerMatchedStatus[:0]
for _, w := range e.probeWorkers {
Expand Down Expand Up @@ -198,14 +199,14 @@ func (e *HashJoinExec) Open(ctx context.Context) error {
return err
}
e.prepared = false
e.memTracker = memory.NewTracker(e.id, -1)
e.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker)
e.hashJoinCtx.memTracker = memory.NewTracker(e.id, -1)
e.hashJoinCtx.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker)

e.diskTracker = disk.NewTracker(e.id, -1)
e.diskTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.DiskTracker)

e.worker = util.WaitGroupWrapper{}
e.waiter = util.WaitGroupWrapper{}
e.workerWg = util.WaitGroupWrapper{}
e.waiterWg = util.WaitGroupWrapper{}
e.closeCh = make(chan struct{})
e.finished.Store(false)

Expand Down Expand Up @@ -295,7 +296,7 @@ func (fetcher *probeSideTupleFetcher) wait4BuildSide() (emptyBuild bool, err err

// fetchBuildSideRows fetches all rows from build side executor, and append them
// to e.buildSideResult.
func (e *HashJoinExec) fetchBuildSideRows(ctx context.Context, chkCh chan<- *chunk.Chunk, errCh chan<- error, doneCh <-chan struct{}) {
func (w *buildWorker) fetchBuildSideRows(ctx context.Context, chkCh chan<- *chunk.Chunk, errCh chan<- error, doneCh <-chan struct{}) {
defer close(chkCh)
var err error
failpoint.Inject("issue30289", func(val failpoint.Value) {
Expand All @@ -305,12 +306,13 @@ func (e *HashJoinExec) fetchBuildSideRows(ctx context.Context, chkCh chan<- *chu
return
}
})
sessVars := w.hashJoinCtx.sessCtx.GetSessionVars()
for {
if e.finished.Load() {
if w.hashJoinCtx.finished.Load() {
return
}
chk := e.ctx.GetSessionVars().GetNewChunkWithCapacity(e.buildWorker.buildSideExec.base().retFieldTypes, e.ctx.GetSessionVars().MaxChunkSize, e.ctx.GetSessionVars().MaxChunkSize, e.AllocPool)
err = Next(ctx, e.buildWorker.buildSideExec, chk)
chk := sessVars.GetNewChunkWithCapacity(w.buildSideExec.base().retFieldTypes, sessVars.MaxChunkSize, sessVars.MaxChunkSize, w.hashJoinCtx.allocPool)
err = Next(ctx, w.buildSideExec, chk)
if err != nil {
errCh <- errors.Trace(err)
return
Expand All @@ -323,7 +325,7 @@ func (e *HashJoinExec) fetchBuildSideRows(ctx context.Context, chkCh chan<- *chu
select {
case <-doneCh:
return
case <-e.closeCh:
case <-w.hashJoinCtx.closeCh:
return
case chkCh <- chk:
}
Expand Down Expand Up @@ -366,19 +368,19 @@ func (e *HashJoinExec) initializeForProbe() {

func (e *HashJoinExec) fetchAndProbeHashTable(ctx context.Context) {
e.initializeForProbe()
e.worker.RunWithRecover(func() {
e.workerWg.RunWithRecover(func() {
defer trace.StartRegion(ctx, "HashJoinProbeSideFetcher").End()
e.probeSideTupleFetcher.fetchProbeSideChunks(ctx, e.maxChunkSize)
}, e.probeSideTupleFetcher.handleProbeSideFetcherPanic)

for i := uint(0); i < e.concurrency; i++ {
workerID := i
e.worker.RunWithRecover(func() {
e.workerWg.RunWithRecover(func() {
defer trace.StartRegion(ctx, "HashJoinWorker").End()
e.probeWorkers[workerID].runJoinWorker()
}, e.probeWorkers[workerID].handleProbeWorkerPanic)
}
e.waiter.RunWithRecover(e.waitJoinWorkersAndCloseResultChan, nil)
e.waiterWg.RunWithRecover(e.waitJoinWorkersAndCloseResultChan, nil)
}

func (fetcher *probeSideTupleFetcher) handleProbeSideFetcherPanic(r interface{}) {
Expand Down Expand Up @@ -439,14 +441,14 @@ func (w *probeWorker) handleUnmatchedRowsFromHashTable() {
}

func (e *HashJoinExec) waitJoinWorkersAndCloseResultChan() {
e.worker.Wait()
e.workerWg.Wait()
if e.useOuterToBuild {
// Concurrently handling unmatched rows from the hash table at the tail
for i := uint(0); i < e.concurrency; i++ {
var workerID = i
e.worker.RunWithRecover(func() { e.probeWorkers[workerID].handleUnmatchedRowsFromHashTable() }, e.handleJoinWorkerPanic)
e.workerWg.RunWithRecover(func() { e.probeWorkers[workerID].handleUnmatchedRowsFromHashTable() }, e.handleJoinWorkerPanic)
}
e.worker.Wait()
e.workerWg.Wait()
}
close(e.joinResultCh)
}
Expand Down Expand Up @@ -954,7 +956,7 @@ func (w *probeWorker) getNewJoinResult() (bool, *hashjoinWorkerResult) {
func (w *probeWorker) join2Chunk(probeSideChk *chunk.Chunk, hCtx *hashContext, joinResult *hashjoinWorkerResult,
selected []bool) (ok bool, _ *hashjoinWorkerResult) {
var err error
selected, err = expression.VectorizedFilter(w.sessCtx, w.hashJoinCtx.outerFilter, chunk.NewIterator4Chunk(probeSideChk), selected)
selected, err = expression.VectorizedFilter(w.hashJoinCtx.sessCtx, w.hashJoinCtx.outerFilter, chunk.NewIterator4Chunk(probeSideChk), selected)
if err != nil {
joinResult.err = err
return false, joinResult
Expand Down Expand Up @@ -994,7 +996,7 @@ func (w *probeWorker) join2Chunk(probeSideChk *chunk.Chunk, hCtx *hashContext, j
}

for i := range selected {
killed := atomic.LoadUint32(&w.sessCtx.GetSessionVars().Killed) == 1
killed := atomic.LoadUint32(&w.hashJoinCtx.sessCtx.GetSessionVars().Killed) == 1
failpoint.Inject("killedInJoin2Chunk", func(val failpoint.Value) {
if val.(bool) {
killed = true
Expand Down Expand Up @@ -1060,7 +1062,7 @@ func (w *probeWorker) join2ChunkForOuterHashJoin(probeSideChk *chunk.Chunk, hCtx
}
}
for i := 0; i < probeSideChk.NumRows(); i++ {
killed := atomic.LoadUint32(&w.sessCtx.GetSessionVars().Killed) == 1
killed := atomic.LoadUint32(&w.hashJoinCtx.sessCtx.GetSessionVars().Killed) == 1
failpoint.Inject("killedInJoin2ChunkForOuterHashJoin", func(val failpoint.Value) {
if val.(bool) {
killed = true
Expand Down Expand Up @@ -1110,7 +1112,7 @@ func (e *HashJoinExec) Next(ctx context.Context, req *chunk.Chunk) (err error) {
for i := uint(0); i < e.concurrency; i++ {
e.probeWorkers[i].rowIters = chunk.NewIterator4Slice([]chunk.Row{}).(*chunk.Iterator4Slice)
}
e.worker.RunWithRecover(func() {
e.workerWg.RunWithRecover(func() {
defer trace.StartRegion(ctx, "HashJoinHashTableBuilder").End()
e.fetchAndBuildHashTable(ctx)
}, e.handleFetchAndBuildHashTablePanic)
Expand Down Expand Up @@ -1153,10 +1155,10 @@ func (e *HashJoinExec) fetchAndBuildHashTable(ctx context.Context) {
buildSideResultCh := make(chan *chunk.Chunk, 1)
doneCh := make(chan struct{})
fetchBuildSideRowsOk := make(chan error, 1)
e.worker.RunWithRecover(
e.workerWg.RunWithRecover(
func() {
defer trace.StartRegion(ctx, "HashJoinBuildSideFetcher").End()
e.fetchBuildSideRows(ctx, buildSideResultCh, fetchBuildSideRowsOk, doneCh)
e.buildWorker.fetchBuildSideRows(ctx, buildSideResultCh, fetchBuildSideRowsOk, doneCh)
},
func(r interface{}) {
if r != nil {
Expand All @@ -1167,7 +1169,7 @@ func (e *HashJoinExec) fetchAndBuildHashTable(ctx context.Context) {
)

// TODO: Parallel build hash table. Currently not support because `unsafeHashTable` is not thread-safe.
err := e.buildHashTableForList(buildSideResultCh)
err := e.buildWorker.buildHashTableForList(buildSideResultCh)
if err != nil {
e.buildFinished <- errors.Trace(err)
close(doneCh)
Expand All @@ -1185,41 +1187,42 @@ func (e *HashJoinExec) fetchAndBuildHashTable(ctx context.Context) {
}

// buildHashTableForList builds hash table from `list`.
func (e *HashJoinExec) buildHashTableForList(buildSideResultCh <-chan *chunk.Chunk) error {
func (w *buildWorker) buildHashTableForList(buildSideResultCh <-chan *chunk.Chunk) error {
var err error
var selected []bool
e.rowContainer.GetMemTracker().AttachTo(e.memTracker)
e.rowContainer.GetMemTracker().SetLabel(memory.LabelForBuildSideResult)
e.rowContainer.GetDiskTracker().AttachTo(e.diskTracker)
e.rowContainer.GetDiskTracker().SetLabel(memory.LabelForBuildSideResult)
rowContainer := w.hashJoinCtx.rowContainer
rowContainer.GetMemTracker().AttachTo(w.hashJoinCtx.memTracker)
rowContainer.GetMemTracker().SetLabel(memory.LabelForBuildSideResult)
rowContainer.GetDiskTracker().AttachTo(w.hashJoinCtx.diskTracker)
rowContainer.GetDiskTracker().SetLabel(memory.LabelForBuildSideResult)
if variable.EnableTmpStorageOnOOM.Load() {
actionSpill := e.rowContainer.ActionSpill()
actionSpill := rowContainer.ActionSpill()
failpoint.Inject("testRowContainerSpill", func(val failpoint.Value) {
if val.(bool) {
actionSpill = e.rowContainer.rowContainer.ActionSpillForTest()
actionSpill = rowContainer.rowContainer.ActionSpillForTest()
defer actionSpill.(*chunk.SpillDiskAction).WaitForTest()
}
})
e.ctx.GetSessionVars().MemTracker.FallbackOldAndSetNewAction(actionSpill)
w.hashJoinCtx.sessCtx.GetSessionVars().MemTracker.FallbackOldAndSetNewAction(actionSpill)
}
for chk := range buildSideResultCh {
if e.finished.Load() {
if w.hashJoinCtx.finished.Load() {
return nil
}
if !e.useOuterToBuild {
err = e.rowContainer.PutChunk(chk, e.isNullEQ)
if !w.hashJoinCtx.useOuterToBuild {
err = rowContainer.PutChunk(chk, w.hashJoinCtx.isNullEQ)
} else {
var bitMap = bitmap.NewConcurrentBitmap(chk.NumRows())
e.outerMatchedStatus = append(e.outerMatchedStatus, bitMap)
e.memTracker.Consume(bitMap.BytesConsumed())
if len(e.outerFilter) == 0 {
err = e.rowContainer.PutChunk(chk, e.isNullEQ)
w.hashJoinCtx.outerMatchedStatus = append(w.hashJoinCtx.outerMatchedStatus, bitMap)
w.hashJoinCtx.memTracker.Consume(bitMap.BytesConsumed())
if len(w.hashJoinCtx.outerFilter) == 0 {
err = w.hashJoinCtx.rowContainer.PutChunk(chk, w.hashJoinCtx.isNullEQ)
} else {
selected, err = expression.VectorizedFilter(e.ctx, e.outerFilter, chunk.NewIterator4Chunk(chk), selected)
selected, err = expression.VectorizedFilter(w.hashJoinCtx.sessCtx, w.hashJoinCtx.outerFilter, chunk.NewIterator4Chunk(chk), selected)
if err != nil {
return err
}
err = e.rowContainer.PutChunkSelected(chk, selected, e.isNullEQ)
err = rowContainer.PutChunkSelected(chk, selected, w.hashJoinCtx.isNullEQ)
}
}
failpoint.Inject("ConsumeRandomPanic", nil)
Expand Down