diff --git a/pkg/sql/colexec/colexecjoin/mergejoiner.go b/pkg/sql/colexec/colexecjoin/mergejoiner.go index 7c92b9abac34..c881de2df59c 100644 --- a/pkg/sql/colexec/colexecjoin/mergejoiner.go +++ b/pkg/sql/colexec/colexecjoin/mergejoiner.go @@ -12,7 +12,6 @@ package colexecjoin import ( "context" - "math" "unsafe" "github.com/cockroachdb/cockroach/pkg/col/coldata" @@ -247,9 +246,8 @@ type mjBufferedGroupState struct { // rightFirstTuple is the first tuple of the right buffered group. It is set // only in case the right buffered group spans more than one input batch. rightFirstTuple []coldata.Vec - // rightScratchBatch is a scratch space for copying the tuples out of the - // right input batches before enqueueing them into the spilling queue. - rightScratchBatch coldata.Batch + // scratchSel is a scratch selection vector initialized only when needed. + scratchSel []int // helper is the building facility for the cross join of the buffered group. helper *crossJoinerBase @@ -708,29 +706,45 @@ func (o *mergeJoinBase) appendToRightBufferedGroup(sel []int, groupStartIdx int, return } - // We don't impose any memory limits on the scratch batch because we rely on - // the inputs to the merge joiner to produce reasonably sized batches. - const maxBatchMemSize = math.MaxInt64 - o.bufferedGroup.rightScratchBatch, _ = o.unlimitedAllocator.ResetMaybeReallocate( - sourceTypes, o.bufferedGroup.rightScratchBatch, groupLength, maxBatchMemSize, - ) - // TODO(yuzefovich): SpillingQueue.Enqueue deep-copies the batch too. Think - // through whether the copy here can be avoided altogether. - o.unlimitedAllocator.PerformOperation(o.bufferedGroup.rightScratchBatch.ColVecs(), func() { - for colIdx := range sourceTypes { - o.bufferedGroup.rightScratchBatch.ColVec(colIdx).Copy( - coldata.SliceArgs{ - Src: o.proberState.rBatch.ColVec(colIdx), - Sel: sel, - DestIdx: 0, - SrcStartIdx: groupStartIdx, - SrcEndIdx: groupStartIdx + groupLength, - }, + // Update the selection on the probing batch to only include tuples from the + // buffered group. + rBatch, rLength := o.proberState.rBatch, o.proberState.rLength + rSel := rBatch.Selection() + rBatchHasSel := rSel != nil + // No need to modify the batch if the whole batch is part of the buffered + // group. + needToModify := groupStartIdx != 0 || groupLength != rLength + if needToModify { + if rBatchHasSel { + // Since rBatch already has a selection vector which we'll be + // modifying, we need to copy the original. + o.bufferedGroup.scratchSel = colexecutils.EnsureSelectionVectorLength(o.bufferedGroup.scratchSel, rLength) + copy(o.bufferedGroup.scratchSel, rSel) + // Now we need to shift elements in range + // [groupStartIdx; groupStartIdx+groupLength) to the beginning of + // the selection vector and then update the length of the batch + // accordingly. + copy(rSel[:groupLength], rSel[groupStartIdx:groupStartIdx+groupLength]) + rBatch.SetLength(groupLength) + } else { + // Since rBatch doesn't have a selection vector, we will set the + // selection vector to include tuples in range + // [groupStartIdx; groupStartIdx+groupLength). + colexecutils.UpdateBatchState( + rBatch, groupLength, true, /* usesSel */ + colexecutils.DefaultSelectionVector[groupStartIdx:groupStartIdx+groupLength], ) } - o.bufferedGroup.rightScratchBatch.SetLength(groupLength) - }) - bufferedTuples.Enqueue(o.Ctx, o.bufferedGroup.rightScratchBatch) + } + + bufferedTuples.Enqueue(o.Ctx, rBatch) + + // If we had to modify the batch, then restore the original state now. + if needToModify { + colexecutils.UpdateBatchState( + rBatch, rLength, rBatchHasSel, o.bufferedGroup.scratchSel, + ) + } } // sourceFinished returns true if either of input sources has no more rows. diff --git a/pkg/sql/colexec/colexecutils/utils.go b/pkg/sql/colexec/colexecutils/utils.go index 22a7a2c08371..d1415d2f348c 100644 --- a/pkg/sql/colexec/colexecutils/utils.go +++ b/pkg/sql/colexec/colexecutils/utils.go @@ -310,3 +310,14 @@ func UpdateBatchState(batch coldata.Batch, length int, usesSel bool, sel []int) // in the selection vector to maintain invariants (like for flat bytes). batch.SetLength(length) } + +// DefaultSelectionVector contains all integers in [0, coldata.MaxBatchSize) +// range. +var DefaultSelectionVector []int + +func init() { + DefaultSelectionVector = make([]int, coldata.MaxBatchSize) + for i := range DefaultSelectionVector { + DefaultSelectionVector[i] = i + } +} diff --git a/pkg/sql/colexec/hash_aggregator.eg.go b/pkg/sql/colexec/hash_aggregator.eg.go index 1ad31c9d9346..e60eadae1d1f 100644 --- a/pkg/sql/colexec/hash_aggregator.eg.go +++ b/pkg/sql/colexec/hash_aggregator.eg.go @@ -75,17 +75,6 @@ const _ = "template_findSplit" // input tuples are processed before emitting any data. const _ = "template_getNext" -// defaultSelectionVector contains all integers in [0, coldata.MaxBatchSize) -// range. -var defaultSelectionVector []int - -func init() { - defaultSelectionVector = make([]int, coldata.MaxBatchSize) - for i := range defaultSelectionVector { - defaultSelectionVector[i] = i - } -} - func (op *hashAggregator) Next() coldata.Batch { if len(op.spec.OrderedGroupCols) > 0 { return getNext_true(op) @@ -486,7 +475,7 @@ func getNext_true(op *hashAggregator) coldata.Batch { } } if op.curOutputBucketIdx >= len(op.buckets) { - if op.bufferingState.pendingBatch.Length() > 0 { + if l := op.bufferingState.pendingBatch.Length(); l > 0 { // Clear the buckets. op.state = hashAggregatorBuffering op.resetBucketsAndTrackingState(op.Ctx) @@ -497,13 +486,15 @@ func getNext_true(op *hashAggregator) coldata.Batch { // in the buffering state, since it only contains tuples that still // need to be aggregated, so we do not need to reset to the original // batch state. - if op.inputTrackingState.tuples != nil && op.bufferingState.unprocessedIdx < op.bufferingState.pendingBatch.Length() { + if op.inputTrackingState.tuples != nil && op.bufferingState.unprocessedIdx < l { sel := op.bufferingState.pendingBatch.Selection() if sel != nil { - copy(sel, sel[op.bufferingState.unprocessedIdx:op.bufferingState.pendingBatch.Length()]) - op.bufferingState.pendingBatch.SetLength(op.bufferingState.pendingBatch.Length() - op.bufferingState.unprocessedIdx) + copy(sel, sel[op.bufferingState.unprocessedIdx:l]) + op.bufferingState.pendingBatch.SetLength(l - op.bufferingState.unprocessedIdx) } else { - colexecutils.UpdateBatchState(op.bufferingState.pendingBatch, op.bufferingState.pendingBatch.Length()-op.bufferingState.unprocessedIdx, true, defaultSelectionVector[op.bufferingState.unprocessedIdx:op.bufferingState.pendingBatch.Length()]) + colexecutils.UpdateBatchState( + op.bufferingState.pendingBatch, l-op.bufferingState.unprocessedIdx, true, + colexecutils.DefaultSelectionVector[op.bufferingState.unprocessedIdx:l]) } op.inputTrackingState.tuples.Enqueue(op.Ctx, op.bufferingState.pendingBatch) // We modified pendingBatch to only contain unprocessed diff --git a/pkg/sql/colexec/hash_aggregator_tmpl.go b/pkg/sql/colexec/hash_aggregator_tmpl.go index 804f9646c11a..039557e36d9b 100644 --- a/pkg/sql/colexec/hash_aggregator_tmpl.go +++ b/pkg/sql/colexec/hash_aggregator_tmpl.go @@ -362,7 +362,7 @@ func getNext(op *hashAggregator, partialOrder bool) coldata.Batch { } if op.curOutputBucketIdx >= len(op.buckets) { if partialOrder { - if op.bufferingState.pendingBatch.Length() > 0 { + if l := op.bufferingState.pendingBatch.Length(); l > 0 { // Clear the buckets. op.state = hashAggregatorBuffering op.resetBucketsAndTrackingState(op.Ctx) @@ -373,13 +373,15 @@ func getNext(op *hashAggregator, partialOrder bool) coldata.Batch { // in the buffering state, since it only contains tuples that still // need to be aggregated, so we do not need to reset to the original // batch state. - if op.inputTrackingState.tuples != nil && op.bufferingState.unprocessedIdx < op.bufferingState.pendingBatch.Length() { + if op.inputTrackingState.tuples != nil && op.bufferingState.unprocessedIdx < l { sel := op.bufferingState.pendingBatch.Selection() if sel != nil { - copy(sel, sel[op.bufferingState.unprocessedIdx:op.bufferingState.pendingBatch.Length()]) - op.bufferingState.pendingBatch.SetLength(op.bufferingState.pendingBatch.Length() - op.bufferingState.unprocessedIdx) + copy(sel, sel[op.bufferingState.unprocessedIdx:l]) + op.bufferingState.pendingBatch.SetLength(l - op.bufferingState.unprocessedIdx) } else { - colexecutils.UpdateBatchState(op.bufferingState.pendingBatch, op.bufferingState.pendingBatch.Length()-op.bufferingState.unprocessedIdx, true, defaultSelectionVector[op.bufferingState.unprocessedIdx:op.bufferingState.pendingBatch.Length()]) + colexecutils.UpdateBatchState( + op.bufferingState.pendingBatch, l-op.bufferingState.unprocessedIdx, true, + colexecutils.DefaultSelectionVector[op.bufferingState.unprocessedIdx:l]) } op.inputTrackingState.tuples.Enqueue(op.Ctx, op.bufferingState.pendingBatch) // We modified pendingBatch to only contain unprocessed @@ -407,17 +409,6 @@ func getNext(op *hashAggregator, partialOrder bool) coldata.Batch { } } -// defaultSelectionVector contains all integers in [0, coldata.MaxBatchSize) -// range. -var defaultSelectionVector []int - -func init() { - defaultSelectionVector = make([]int, coldata.MaxBatchSize) - for i := range defaultSelectionVector { - defaultSelectionVector[i] = i - } -} - func (op *hashAggregator) Next() coldata.Batch { if len(op.spec.OrderedGroupCols) > 0 { return getNext(op, true)