-
Notifications
You must be signed in to change notification settings - Fork 3.9k
/
Copy pathhash_aggregator.go
537 lines (458 loc) · 18.2 KB
/
hash_aggregator.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
// Copyright 2019 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
package colexec
import (
"context"
"github.com/cockroachdb/cockroach/pkg/col/coldata"
"github.com/cockroachdb/cockroach/pkg/col/coltypes"
"github.com/cockroachdb/cockroach/pkg/col/coltypes/typeconv"
"github.com/cockroachdb/cockroach/pkg/sql/colexecbase"
"github.com/cockroachdb/cockroach/pkg/sql/colexecbase/colexecerror"
"github.com/cockroachdb/cockroach/pkg/sql/colmem"
"github.com/cockroachdb/cockroach/pkg/sql/execinfrapb"
"github.com/cockroachdb/cockroach/pkg/sql/types"
"github.com/cockroachdb/errors"
)
// hashAggregatorState represents the state of the hash aggregator operator.
type hashAggregatorState int
const (
// hashAggregatorBuffering is the state in which the hashAggregator is
// buffering up its inputs.
hashAggregatorBuffering hashAggregatorState = iota
// hashAggregatorAggregating is the state in which the hashAggregator is
// performing aggregation on its buffered inputs. After aggregation is done,
// the input buffer used in hashAggregatorBuffering phase is reset and ready
// to be reused.
hashAggregatorAggregating
// hashAggregatorOutputting is the state in which the hashAggregator is
// writing its aggregation results to output buffer after it has exhausted all
// inputs and finished aggregating.
hashAggregatorOutputting
// hashAggregatorDone is the state in which the hashAggregator has finished
// writing to the output buffer.
hashAggregatorDone
)
// hashAggregator is an operator that performs aggregation based on specified
// grouping columns. This operator performs aggregation in online fashion. It
// buffers the input up to batchTupleLimit. Then the aggregator hashes each
// tuple and groups the tuples with same hash code into same group. Then
// aggregation function is lazily created for each group. The tuples in that
// group will be then passed into the aggregation function. After all input is
// exhausted, the operator begins to write the result into an output buffer. The
// output row ordering of this operator is arbitrary.
type hashAggregator struct {
OneInputNode
allocator *colmem.Allocator
aggCols [][]uint32
aggTypes [][]types.T
aggFuncs []execinfrapb.AggregatorSpec_Func
inputTypes []types.T
inputPhysTypes []coltypes.T
outputTypes []types.T
// aggFuncMap stores the mapping from hash code to a vector of aggregation
// functions. Each aggregation function is stored along with keys that
// corresponds to the group the aggregation function operates on. This is to
// handle hash collisions.
aggFuncMap hashAggFuncMap
// batchTupleLimit limits the number of tuples the aggregator will buffer
// before it starts to perform aggregation.
batchTupleLimit int
// state stores the current state of hashAggregator.
state hashAggregatorState
scratch struct {
*appendOnlyBufferedBatch
// sels stores the intermediate selection vector for each hash code. It
// is maintained in such a way that when for a particular hashCode
// there are no tuples in the batch, the corresponding int slice is of
// length 0. Also, onlineAgg() method will reset all modified slices to
// have zero length once it is done processing all tuples in the batch,
// this allows us to not reset the slices for all possible hash codes.
//
// Instead of having a map from hashCode to []int (which could result
// in having many int slices), we are using a constant number of such
// slices and have a "map" from hashCode to a "slot" in sels that does
// the "translation." The key insight here is that we will have at most
// batchTupleLimit (plus - possibly - constant excess) different
// hashCodes at once.
sels [][]int
// hashCodeForSelsSlot stores the hashCode that corresponds to a slot
// in sels slice. For example, if we have tuples with the following
// hashCodes = {0, 2, 0, 0, 1, 2, 1}, then we will have:
// hashCodeForSelsSlot = {0, 2, 1}
// sels[0] = {0, 2, 3}
// sels[1] = {1, 5}
// sels[2] = {4, 6}
hashCodeForSelsSlot []uint64
// group is a boolean vector where "true" represent the beginning of a group
// in the column. It is shared among all aggregation functions. Since
// hashAggregator manually manages mapping between input groups and their
// corresponding aggregation functions, group is set to all false to prevent
// premature materialization of aggregation result in the aggregation
// function. However, aggregation function expects at least one group in its
// input batches, (that is, at least one "true" in the group vector
// corresponding to the selection vector). Therefore, before the first
// invocation of .Compute() method, the element in group vector which
// corresponds to the first value of the selection vector is set to true so
// that aggregation function will initialize properly. Then after .Compute()
// finishes, it is set back to false so the same group vector can be reused
// by other aggregation functions.
group []bool
}
// keyMapping stores the key values for each aggregation group. It is a
// bufferedBatch because in the worst case where all keys in the grouping
// columns are distinct, we need to store every single key in the input.
keyMapping *appendOnlyBufferedBatch
output struct {
coldata.Batch
// pendingOutput indicates if there is more data that needs to be returned.
pendingOutput bool
// resumeHashCode is the hash code that hashAggregator should start reading
// from on the next iteration of Next().
resumeHashCode uint64
// resumeIdx is the index of the vector corresponding to the resumeHashCode
// that hashAggregator should start reading from on the next iteration of Next().
resumeIdx int
}
testingKnobs struct {
// numOfHashBuckets is the number of hash buckets that each tuple will be
// assigned to. When it is 0, hash aggregator will not enforce maximum
// number of hash buckets. It is used to test hash collision.
numOfHashBuckets uint64
}
// groupCols stores the indices of the grouping columns.
groupCols []uint32
// groupCols stores the types of the grouping columns.
groupTypes []types.T
// hashBuffer stores hash values for each tuple in the buffered batch.
hashBuffer []uint64
alloc hashAggFuncsAlloc
cancelChecker CancelChecker
decimalScratch decimalOverloadScratch
}
var _ colexecbase.Operator = &hashAggregator{}
// NewHashAggregator creates a hash aggregator on the given grouping columns.
// The input specifications to this function are the same as that of the
// NewOrderedAggregator function.
func NewHashAggregator(
allocator *colmem.Allocator,
input colexecbase.Operator,
typs []types.T,
aggFns []execinfrapb.AggregatorSpec_Func,
groupCols []uint32,
aggCols [][]uint32,
) (colexecbase.Operator, error) {
aggTyps := extractAggTypes(aggCols, typs)
outputTypes, err := makeAggregateFuncsOutputTypes(aggTyps, aggFns)
if err != nil {
return nil, errors.AssertionFailedf(
"this error should have been checked in isAggregateSupported\n%+v", err,
)
}
groupTypes := make([]types.T, len(groupCols))
for i, colIdx := range groupCols {
groupTypes[i] = typs[colIdx]
}
// We picked value this as the result of our benchmark.
tupleLimit := coldata.BatchSize() * 2
inputPhysTypes, err := typeconv.FromColumnTypes(typs)
return &hashAggregator{
OneInputNode: NewOneInputNode(input),
allocator: allocator,
aggCols: aggCols,
aggFuncs: aggFns,
aggTypes: aggTyps,
aggFuncMap: make(hashAggFuncMap),
batchTupleLimit: tupleLimit,
state: hashAggregatorBuffering,
inputTypes: typs,
inputPhysTypes: inputPhysTypes,
outputTypes: outputTypes,
groupCols: groupCols,
groupTypes: groupTypes,
}, err
}
func (op *hashAggregator) Init() {
op.input.Init()
op.output.Batch = op.allocator.NewMemBatch(op.outputTypes)
// We allocate additional coldata.BatchSize for scratch buffer and hashBuffer
// to accommodate the case where sometimes number of buffered tuples exceeds
// op.batchTupleLimit. This is because we perform checks after appending the
// input tuples to the scratch buffer.
maxBufferedTuples := op.batchTupleLimit + coldata.BatchSize()
op.scratch.appendOnlyBufferedBatch = newAppendOnlyBufferedBatch(
op.allocator, op.inputTypes, maxBufferedTuples,
)
op.scratch.sels = make([][]int, maxBufferedTuples)
op.scratch.hashCodeForSelsSlot = make([]uint64, maxBufferedTuples)
op.scratch.group = make([]bool, maxBufferedTuples)
// Eventually, op.keyMapping will contain as many tuples as there are
// groups in the input, but we don't know that number upfront, so we
// allocate it with some reasonably sized constant capacity.
op.keyMapping = newAppendOnlyBufferedBatch(
op.allocator, op.groupTypes, op.batchTupleLimit,
)
op.hashBuffer = make([]uint64, maxBufferedTuples)
}
func (op *hashAggregator) Next(ctx context.Context) coldata.Batch {
for {
switch op.state {
case hashAggregatorBuffering:
op.scratch.ResetInternalBatch()
op.scratch.SetLength(0)
// Buffering up input batches.
if done := op.bufferBatch(ctx); done {
op.state = hashAggregatorOutputting
continue
}
op.buildSelectionForEachHashCode(ctx)
op.state = hashAggregatorAggregating
case hashAggregatorAggregating:
op.onlineAgg()
op.state = hashAggregatorBuffering
case hashAggregatorOutputting:
curOutputIdx := 0
op.output.ResetInternalBatch()
// If there is pending output, we try to finish outputting the aggregation
// result in the same bucket. If we cannot finish, we update resumeIdx and
// return the current batch.
if op.output.pendingOutput {
remainingAggFuncs := op.aggFuncMap[op.output.resumeHashCode][op.output.resumeIdx:]
for groupIdx, aggFunc := range remainingAggFuncs {
if curOutputIdx < coldata.BatchSize() {
for fnIdx, fn := range aggFunc.fns {
fn.SetOutputIndex(curOutputIdx)
// Passing a zero batch into an aggregation function causing it to
// flush the agg result to the output batch at curOutputIdx.
fn.Compute(coldata.ZeroBatch, op.aggCols[fnIdx])
}
} else {
op.output.resumeIdx = op.output.resumeIdx + groupIdx
op.output.SetLength(curOutputIdx)
return op.output
}
curOutputIdx++
}
delete(op.aggFuncMap, op.output.resumeHashCode)
}
op.output.pendingOutput = false
for aggHashCode, aggFuncs := range op.aggFuncMap {
for groupIdx, aggFunc := range aggFuncs {
if curOutputIdx < coldata.BatchSize() {
for fnIdx, fn := range aggFunc.fns {
fn.SetOutputIndex(curOutputIdx)
fn.Compute(coldata.ZeroBatch, op.aggCols[fnIdx])
}
} else {
// If current batch is filled, we record where we left off
// and then return the current batch.
op.output.resumeIdx = groupIdx
op.output.resumeHashCode = aggHashCode
op.output.pendingOutput = true
op.output.SetLength(curOutputIdx)
return op.output
}
curOutputIdx++
}
delete(op.aggFuncMap, aggHashCode)
}
op.state = hashAggregatorDone
op.output.SetLength(curOutputIdx)
return op.output
case hashAggregatorDone:
return coldata.ZeroBatch
default:
colexecerror.InternalError("hash aggregator in unhandled state")
// This code is unreachable, but the compiler cannot infer that.
return nil
}
}
}
// bufferBatch buffers up batches from input sources until number of tuples
// reaches batchTupleLimit. It returns true when the hash aggregator has
// consumed all batches from input.
func (op *hashAggregator) bufferBatch(ctx context.Context) bool {
for op.scratch.Length() < op.batchTupleLimit {
b := op.input.Next(ctx)
batchSize := b.Length()
if batchSize == 0 {
break
}
op.allocator.PerformOperation(op.scratch.ColVecs(), func() {
op.scratch.append(b, 0 /* startIdx */, batchSize)
})
}
return op.scratch.Length() == 0
}
func (op *hashAggregator) buildSelectionForEachHashCode(ctx context.Context) {
nKeys := op.scratch.Length()
hashBuffer := op.hashBuffer[:nKeys]
initHash(hashBuffer, nKeys, defaultInitHashValue)
for _, colIdx := range op.groupCols {
rehash(ctx,
hashBuffer,
&op.inputTypes[colIdx],
op.scratch.ColVec(int(colIdx)),
nKeys,
nil, /* sel */
op.cancelChecker,
op.decimalScratch)
}
if op.testingKnobs.numOfHashBuckets != 0 {
finalizeHash(hashBuffer, nKeys, op.testingKnobs.numOfHashBuckets)
}
// Note: we don't need to reset any of the slices in op.scratch.sels since
// they all are of zero length here (see the comment for op.scratch.sels
// for context).
// We can use selIdx to index into op.scratch since op.scratch never has a
// a selection vector.
op.scratch.hashCodeForSelsSlot = op.scratch.hashCodeForSelsSlot[:0]
for selIdx, hashCode := range hashBuffer {
selsSlot := -1
for slot, hash := range op.scratch.hashCodeForSelsSlot {
if hash == hashCode {
// We have already seen a tuple with the same hashCode
// previously, so we will append into the same sels slot.
selsSlot = slot
break
}
}
if selsSlot < 0 {
// This is the first tuple in hashBuffer with this hashCode, so we
// will add this tuple to the next available sels slot.
selsSlot = len(op.scratch.hashCodeForSelsSlot)
op.scratch.hashCodeForSelsSlot = append(op.scratch.hashCodeForSelsSlot, hashCode)
}
op.scratch.sels[selsSlot] = append(op.scratch.sels[selsSlot], selIdx)
}
}
// onlineAgg probes aggFuncMap using the built sels map and lazily creates
// aggFunctions for each group if it doesn't not exist. Then it calls Compute()
// on each aggregation function to perform aggregation.
func (op *hashAggregator) onlineAgg() {
for selsSlot, hashCode := range op.scratch.hashCodeForSelsSlot {
remaining := op.scratch.sels[selsSlot]
var anyMatched bool
// Stage 1: Probe aggregate functions for each hash code and perform
// aggregation.
if aggFuncs, ok := op.aggFuncMap[hashCode]; ok {
for _, aggFunc := range aggFuncs {
// We write the selection vector of matched tuples directly into the
// selection vector of op.scratch and selection vector of unmatched
// tuples into 'remaining'.'remaining' will reuse the underlying memory
// allocated for 'sel' to avoid extra allocation and copying.
anyMatched, remaining = aggFunc.match(
remaining, op.scratch, op.groupCols, op.groupTypes, op.keyMapping,
op.scratch.group[:len(remaining)], false, /* firstDefiniteMatch */
)
if anyMatched {
aggFunc.compute(op.scratch, op.aggCols)
}
}
} else {
// No aggregate functions exist for this hashCode, create one. Since we
// don't expect a lot of hash collisions we only allocate small amount of
// memory here.
op.aggFuncMap[hashCode] = make([]*hashAggFuncs, 0, 1)
}
// Stage 2: Build aggregate function that doesn't exist, then perform
// aggregation on the newly created aggregate function.
for len(remaining) > 0 {
// Record the selection vector index of the beginning of the group.
groupStartIdx := remaining[0]
// Build new agg functions.
keyIdx := op.keyMapping.Length()
aggFunc := op.alloc.newHashAggFuncs()
aggFunc.keyIdx = keyIdx
// Store the key of the current aggregating group into keyMapping.
op.allocator.PerformOperation(op.keyMapping.ColVecs(), func() {
for keyIdx, colIdx := range op.groupCols {
// TODO(azhng): Try to preallocate enough memory so instead of
// .Append() we can use execgen.SET to improve the
// performance.
op.keyMapping.ColVec(keyIdx).Append(coldata.SliceArgs{
Src: op.scratch.ColVec(int(colIdx)),
ColType: op.inputPhysTypes[colIdx],
DestIdx: aggFunc.keyIdx,
SrcStartIdx: groupStartIdx,
SrcEndIdx: groupStartIdx + 1,
})
}
op.keyMapping.SetLength(keyIdx + 1)
})
aggFunc.fns, _ = makeAggregateFuncs(op.allocator, op.aggTypes, op.aggFuncs)
op.aggFuncMap[hashCode] = append(op.aggFuncMap[hashCode], aggFunc)
// Select rest of the tuples that matches the current key. We don't need
// to check if there is any match since 'remaining[0]' will always be
// matched.
_, remaining = aggFunc.match(
remaining, op.scratch, op.groupCols, op.groupTypes, op.keyMapping,
op.scratch.group[:len(remaining)], true, /* firstDefiniteMatch */
)
// Hack required to get aggregation function working. See '.scratch.group'
// field comment in hashAggregator for more details.
op.scratch.group[groupStartIdx] = true
aggFunc.init(op.scratch.group, op.output.Batch)
aggFunc.compute(op.scratch, op.aggCols)
op.scratch.group[groupStartIdx] = false
}
// We have processed all tuples with this hashCode, so we should reset
// the length of the corresponding slice.
op.scratch.sels[selsSlot] = op.scratch.sels[selsSlot][:0]
}
}
// reset resets the hashAggregator for another run. Primarily used for
// benchmarks.
func (op *hashAggregator) reset(ctx context.Context) {
if r, ok := op.input.(resetter); ok {
r.reset(ctx)
}
op.aggFuncMap = hashAggFuncMap{}
op.state = hashAggregatorBuffering
op.output.ResetInternalBatch()
op.output.SetLength(0)
op.output.pendingOutput = false
op.scratch.ResetInternalBatch()
op.scratch.SetLength(0)
op.keyMapping.ResetInternalBatch()
op.keyMapping.SetLength(0)
}
// hashAggFuncs stores the aggregation functions for the corresponding
// aggregating group.
type hashAggFuncs struct {
// keyIdx is the index of key of the current aggregating group, which is
// stored in the hashAggregator keyMapping batch.
keyIdx int
fns []aggregateFunc
}
type hashAggFuncMap map[uint64][]*hashAggFuncs
func (v *hashAggFuncs) init(group []bool, b coldata.Batch) {
for fnIdx, fn := range v.fns {
fn.Init(group, b.ColVec(fnIdx))
}
}
func (v *hashAggFuncs) compute(b coldata.Batch, aggCols [][]uint32) {
for fnIdx, fn := range v.fns {
fn.Compute(b, aggCols[fnIdx])
}
}
const hashAggFuncsAllocSize = 16
// hashAggFuncsAlloc is a utility struct that batches allocations of
// hashAggFuncs.
type hashAggFuncsAlloc struct {
buf []hashAggFuncs
}
func (a *hashAggFuncsAlloc) newHashAggFuncs() *hashAggFuncs {
if len(a.buf) == 0 {
a.buf = make([]hashAggFuncs, hashAggFuncsAllocSize)
}
ret := &a.buf[0]
a.buf = a.buf[1:]
return ret
}