diff --git a/pkg/sql/distsqlrun/windower.go b/pkg/sql/distsqlrun/windower.go index 7a0572be5691..18bb1b6b46f2 100644 --- a/pkg/sql/distsqlrun/windower.go +++ b/pkg/sql/distsqlrun/windower.go @@ -132,8 +132,10 @@ type windower struct { windowFns []*windowFunc populated bool - bucketToPartitionIdx map[string]int - rowsInBucketEmitted map[string]int + buckets []string + bucketToPartitionIdx []int + bucketIter int + rowsInBucketEmitted int windowValues [][][]tree.Datum outputRow sqlbase.EncDatumRow } @@ -161,8 +163,6 @@ func newWindower( w.partitionBy = spec.PartitionBy w.windowFns = make([]*windowFunc, 0, len(windowFns)) w.outputTypes = make([]sqlbase.ColumnType, 0, len(w.inputTypes)) - w.bucketToPartitionIdx = make(map[string]int) - w.rowsInBucketEmitted = make(map[string]int) // inputColIdx is the index of the column that should be processed next. inputColIdx := 0 @@ -368,11 +368,14 @@ func (w *windower) computeWindowFunctions(ctx context.Context, evalCtx *tree.Eva w.windowValues = make([][][]tree.Datum, len(w.windowFns)) partitions := make([]indexedRows, len(w.encodedPartitions)) + w.buckets = make([]string, 0, len(w.encodedPartitions)) + w.bucketToPartitionIdx = make([]int, 0, len(w.encodedPartitions)) partitionIdx := 0 for bucket, encodedPartition := range w.encodedPartitions { // We want to fix some order of iteration over encoded partitions // to be consistent. - w.bucketToPartitionIdx[bucket] = partitionIdx + w.buckets = append(w.buckets, bucket) + w.bucketToPartitionIdx = append(w.bucketToPartitionIdx, partitionIdx) rows := make([]indexedRow, 0, len(encodedPartition)) for idx := 0; idx < len(encodedPartition); idx++ { rows = append(rows, indexedRow{idx: idx, row: encodedPartition[idx]}) @@ -510,19 +513,14 @@ func (w *windower) computeWindowFunctions(ctx context.Context, evalCtx *tree.Eva // populateNextOutputRow combines results of computing window functions with // non-argument columns of the input row to produce an output row. func (w *windower) populateNextOutputRow() bool { - for bucket, encodedPartition := range w.encodedPartitions { - if w.rowsInBucketEmitted[bucket] == len(encodedPartition) { - // All output rows corresponding to partition 'bucket' have been fully - // emitted already, so we skip it. - continue - } - + if w.bucketIter < len(w.encodedPartitions) { // We reuse the same EncDatumRow since caller of Next() should've copied it. w.outputRow = w.outputRow[:0] - // rowIdx is the index of the next row to be emitted from partition 'bucket'. - rowIdx := w.rowsInBucketEmitted[bucket] - inputRow := encodedPartition[rowIdx] - partitionIdx := w.bucketToPartitionIdx[bucket] + // rowIdx is the index of the next row to be emitted from partition with + // hash w.buckets[w.bucketIter]. + rowIdx := w.rowsInBucketEmitted + inputRow := w.encodedPartitions[w.buckets[w.bucketIter]][rowIdx] + partitionIdx := w.bucketToPartitionIdx[w.bucketIter] inputColIdx := 0 for windowFnIdx, windowFn := range w.windowFns { // We simply pass through columns in [inputColIdx, windowFn.argIdxStart). @@ -535,8 +533,15 @@ func (w *windower) populateNextOutputRow() bool { } // We simply pass through all columns after all arguments to window functions. w.outputRow = append(w.outputRow, inputRow[inputColIdx:]...) - w.rowsInBucketEmitted[bucket] = rowIdx + 1 + w.rowsInBucketEmitted++ + if w.rowsInBucketEmitted == len(w.encodedPartitions[w.buckets[w.bucketIter]]) { + // We have emitted all rows from the current bucket, so we advance the + // iterator. + w.bucketIter++ + w.rowsInBucketEmitted = 0 + } return true + } return false }