From 77ce2b7558e833619ac62ce79ec064e4377d2026 Mon Sep 17 00:00:00 2001 From: yuzefovich Date: Tue, 26 Jun 2018 15:29:43 -0400 Subject: [PATCH] sql: Add efficient min, max, sum, avg when used as window functions. Adds linear-time implementations of min, max, sum, and avg (using sliding window approach) instead of naive quadratic version. Addresses: #26464. Bonus: min and max are an order of magnitude faster than PG (when window frame doesn't include the whole partition). Release note (performance improvement): min, max, sum, avg now take linear time when used for aggregation as window functions. --- pkg/sql/logictest/testdata/logic_test/window | 46 ++- pkg/sql/sem/builtins/aggregate_builtins.go | 35 +- pkg/sql/sem/builtins/window_builtins.go | 33 +- pkg/sql/sem/builtins/window_frame_builtins.go | 324 ++++++++++++++++++ .../builtins/window_frame_builtins_test.go | 242 +++++++++++++ pkg/sql/sem/tree/window_funcs.go | 12 + pkg/sql/window.go | 13 +- 7 files changed, 679 insertions(+), 26 deletions(-) create mode 100644 pkg/sql/sem/builtins/window_frame_builtins.go create mode 100644 pkg/sql/sem/builtins/window_frame_builtins_test.go diff --git a/pkg/sql/logictest/testdata/logic_test/window b/pkg/sql/logictest/testdata/logic_test/window index 7c73fa1cef67..9e96e5913a95 100644 --- a/pkg/sql/logictest/testdata/logic_test/window +++ b/pkg/sql/logictest/testdata/logic_test/window @@ -1708,7 +1708,6 @@ Tablet iPad 700.00 NULL Tablet Kindle Fire 150.00 NULL Tablet Samsung 200.00 NULL - query TRRR SELECT product_name, price, min(price) OVER (PARTITION BY group_name ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS min_over_three, max(price) OVER (PARTITION BY group_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_over_partition FROM products ORDER BY group_id; ---- @@ -1723,3 +1722,48 @@ Dell 800.00 700.00 1200.00 iPad 700.00 150.00 700.00 Kindle Fire 150.00 150.00 700.00 Samsung 200.00 150.00 700.00 + +query TTRT +SELECT group_name, product_name, price, min(price) OVER (PARTITION BY group_name ROWS CURRENT ROW) AS min_over_single_row FROM products ORDER BY group_id; +---- +Smartphone Microsoft Lumia 200.00 200.00 +Smartphone HTC One 400.00 400.00 +Smartphone Nexus 500.00 500.00 +Smartphone iPhone 900.00 900.00 +Laptop HP Elite 1200.00 1200.00 +Laptop Lenovo Thinkpad 700.00 700.00 +Laptop Sony VAIO 700.00 700.00 +Laptop Dell 800.00 800.00 +Tablet iPad 700.00 700.00 +Tablet Kindle Fire 150.00 150.00 +Tablet Samsung 200.00 200.00 + +query TTRR +SELECT group_name, product_name, price, avg(price) OVER (PARTITION BY group_name ROWS BETWEEN 1 FOLLOWING AND UNBOUNDED FOLLOWING) AS running_avg FROM products ORDER BY group_id; +---- +Smartphone Microsoft Lumia 200.00 600.00 +Smartphone HTC One 400.00 700.00 +Smartphone Nexus 500.00 900.00 +Smartphone iPhone 900.00 NULL +Laptop HP Elite 1200.00 733.33333333333333333 +Laptop Lenovo Thinkpad 700.00 750.00 +Laptop Sony VAIO 700.00 800.00 +Laptop Dell 800.00 NULL +Tablet iPad 700.00 175.00 +Tablet Kindle Fire 150.00 200.00 +Tablet Samsung 200.00 NULL + +query TRRRRR +SELECT product_name, price, min(price) OVER (PARTITION BY group_name ROWS UNBOUNDED PRECEDING), max(price) OVER (PARTITION BY group_name ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING), sum(price) OVER (PARTITION BY group_name ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING), avg(price) OVER (PARTITION BY group_name ROWS CURRENT ROW) FROM products ORDER BY group_id; +---- +Microsoft Lumia 200.00 200.00 400.00 2000.00 200.00 +HTC One 400.00 200.00 500.00 2000.00 400.00 +Nexus 500.00 200.00 900.00 1800.00 500.00 +iPhone 900.00 200.00 900.00 1400.00 900.00 +HP Elite 1200.00 1200.00 1200.00 3400.00 1200.00 +Lenovo Thinkpad 700.00 700.00 1200.00 3400.00 700.00 +Sony VAIO 700.00 700.00 1200.00 2200.00 700.00 +Dell 800.00 700.00 1200.00 1500.00 800.00 +iPad 700.00 700.00 700.00 1050.00 700.00 +Kindle Fire 150.00 150.00 700.00 1050.00 150.00 +Samsung 200.00 150.00 700.00 350.00 200.00 diff --git a/pkg/sql/sem/builtins/aggregate_builtins.go b/pkg/sql/sem/builtins/aggregate_builtins.go index da914c3f483a..2be923089e78 100644 --- a/pkg/sql/sem/builtins/aggregate_builtins.go +++ b/pkg/sql/sem/builtins/aggregate_builtins.go @@ -156,7 +156,9 @@ var aggregates = map[string]builtinDefinition{ ReturnType: tree.FixedReturnType(types.Int), AggregateFunc: newCountRowsAggregate, WindowFunc: func(params []types.T, evalCtx *tree.EvalContext) tree.WindowFunc { - return newFramableAggregateWindow(newCountRowsAggregate(params, evalCtx)) + return newFramableAggregateWindow(newCountRowsAggregate(params, evalCtx), func(evalCtx *tree.EvalContext) tree.AggregateFunc { + return newCountRowsAggregate(params, evalCtx) + }) }, Info: "Calculates the number of rows.", }, @@ -318,7 +320,36 @@ func makeAggOverloadWithReturnType( ReturnType: retType, AggregateFunc: f, WindowFunc: func(params []types.T, evalCtx *tree.EvalContext) tree.WindowFunc { - return newFramableAggregateWindow(f(params, evalCtx)) + aggWindowFunc := f(params, evalCtx) + switch w := aggWindowFunc.(type) { + case *MinAggregate: + min := &slidingWindowFunc{} + min.sw = makeSlidingWindow(evalCtx, func(evalCtx *tree.EvalContext, a, b tree.Datum) int { + return -a.Compare(evalCtx, b) + }) + return min + case *MaxAggregate: + max := &slidingWindowFunc{} + max.sw = makeSlidingWindow(evalCtx, func(evalCtx *tree.EvalContext, a, b tree.Datum) int { + return a.Compare(evalCtx, b) + }) + return max + case *intSumAggregate: + return &slidingWindowSumFunc{agg: aggWindowFunc} + case *decimalSumAggregate: + return &slidingWindowSumFunc{agg: aggWindowFunc} + case *floatSumAggregate: + return &slidingWindowSumFunc{agg: aggWindowFunc} + case *intervalSumAggregate: + return &slidingWindowSumFunc{agg: aggWindowFunc} + case *avgAggregate: + // w.agg is a sum aggregate. + return &avgWindowFunc{sum: slidingWindowSumFunc{agg: w.agg}} + } + + return newFramableAggregateWindow(aggWindowFunc, func(evalCtx *tree.EvalContext) tree.AggregateFunc { + return f(params, evalCtx) + }) }, Info: info, } diff --git a/pkg/sql/sem/builtins/window_builtins.go b/pkg/sql/sem/builtins/window_builtins.go index 76cf70408af1..1488bf3b9ed6 100644 --- a/pkg/sql/sem/builtins/window_builtins.go +++ b/pkg/sql/sem/builtins/window_builtins.go @@ -178,15 +178,26 @@ func (w *aggregateWindowFunc) Close(ctx context.Context, evalCtx *tree.EvalConte w.agg.Close(ctx) } +// ShouldReset sets shouldReset to true if w is framableAggregateWindowFunc. +func ShouldReset(w tree.WindowFunc) { + if f, ok := w.(*framableAggregateWindowFunc); ok { + f.shouldReset = true + } +} + // framableAggregateWindowFunc is a wrapper around aggregateWindowFunc that allows // to reset the aggregate by creating a new instance via a provided constructor. +// shouldReset indicates whether the resetting behavior is desired. type framableAggregateWindowFunc struct { agg *aggregateWindowFunc aggConstructor func(*tree.EvalContext) tree.AggregateFunc + shouldReset bool } -func newFramableAggregateWindow(agg tree.AggregateFunc) tree.WindowFunc { - return &framableAggregateWindowFunc{agg: &aggregateWindowFunc{agg: agg}} +func newFramableAggregateWindow( + agg tree.AggregateFunc, aggConstructor func(*tree.EvalContext) tree.AggregateFunc, +) tree.WindowFunc { + return &framableAggregateWindowFunc{agg: &aggregateWindowFunc{agg: agg}, aggConstructor: aggConstructor} } func (w *framableAggregateWindowFunc) Compute( @@ -195,12 +206,12 @@ func (w *framableAggregateWindowFunc) Compute( if !wfr.FirstInPeerGroup() { return w.agg.peerRes, nil } - if w.aggConstructor == nil { - // No constructor is given, so we use default approach. + if !w.shouldReset { + // We should not reset, so we will use the same aggregateWindowFunc. return w.agg.Compute(ctx, evalCtx, wfr) } - // When aggConstructor is provided, we want to dispose of the old aggregate function + // We should reset the aggregate, so we dispose of the old aggregate function // and construct a new one for the computation. w.agg.Close(ctx, evalCtx) *w.agg = aggregateWindowFunc{w.aggConstructor(evalCtx), tree.DNull} @@ -231,18 +242,6 @@ func (w *framableAggregateWindowFunc) Close(ctx context.Context, evalCtx *tree.E w.agg.Close(ctx, evalCtx) } -// AddAggregateConstructorToFramableAggregate adds provided constructor to framableAggregateWindowFunc -// so that aggregates can be 'reset' when computing values over a window frame. -func AddAggregateConstructorToFramableAggregate( - windowFunc tree.WindowFunc, aggConstructor func(*tree.EvalContext) tree.AggregateFunc, -) { - // We only want to add aggConstructor to framableAggregateWindowFunc's since - // all non-aggregates builtins specific to window functions support framing "natively". - if framableAgg, ok := windowFunc.(*framableAggregateWindowFunc); ok { - framableAgg.aggConstructor = aggConstructor - } -} - // rowNumberWindow computes the number of the current row within its partition, // counting from 1. type rowNumberWindow struct{} diff --git a/pkg/sql/sem/builtins/window_frame_builtins.go b/pkg/sql/sem/builtins/window_frame_builtins.go new file mode 100644 index 000000000000..ed30b5c375e8 --- /dev/null +++ b/pkg/sql/sem/builtins/window_frame_builtins.go @@ -0,0 +1,324 @@ +// Copyright 2018 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. + +package builtins + +import ( + "context" + "fmt" + "strings" + + "github.com/cockroachdb/apd" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroach/pkg/util/duration" +) + +// indexedValue combines a value from the row with the index of that row. +type indexedValue struct { + value tree.Datum + idx int +} + +// RingBufferInitialSize defines the initial size of the ring buffer. +const RingBufferInitialSize = 8 + +// ringBuffer is a deque of indexedValue's maintained over a ring buffer. +type ringBuffer struct { + values []*indexedValue + head int // the index of the front of the deque. + tail int // the index of the first position right after the end of the deque. + + nonEmpty bool // indicates whether the deque is empty, necessary to distinguish + // between an empty deque and a deque that uses all of its capacity. +} + +// len returns number of indexedValue's in the deque. +func (r *ringBuffer) len() int { + if !r.nonEmpty { + return 0 + } + if r.head < r.tail { + return r.tail - r.head + } else if r.head == r.tail { + return cap(r.values) + } else { + return cap(r.values) + r.tail - r.head + } +} + +// add adds value to the end of the deque +// and doubles it's underlying slice if necessary. +func (r *ringBuffer) add(value *indexedValue) { + if cap(r.values) == 0 { + r.values = make([]*indexedValue, RingBufferInitialSize) + r.values[0] = value + r.tail = 1 + } else { + if r.len() == cap(r.values) { + newValues := make([]*indexedValue, 2*cap(r.values)) + if r.head < r.tail { + copy(newValues[:r.len()], r.values[r.head:r.tail]) + } else { + copy(newValues[:cap(r.values)-r.head], r.values[r.head:]) + copy(newValues[cap(r.values)-r.head:r.len()], r.values[:r.tail]) + } + r.head = 0 + r.tail = cap(r.values) + r.values = newValues + } + r.values[r.tail] = value + r.tail = (r.tail + 1) % cap(r.values) + } + r.nonEmpty = true +} + +// get returns indexedValue at position pos in the deque (zero-based). +func (r *ringBuffer) get(pos int) *indexedValue { + if !r.nonEmpty || pos < 0 || pos >= r.len() { + panic("unexpected behavior: index out of bounds") + } + return r.values[(pos+r.head)%cap(r.values)] +} + +// removeHead removes a single element from the front of the deque. +func (r *ringBuffer) removeHead() { + if r.len() == 0 { + panic("removing head from empty ring buffer") + } + r.values[r.head] = nil + r.head = (r.head + 1) % cap(r.values) + if r.head == r.tail { + r.nonEmpty = false + } +} + +// removeTail removes a single element from the end of the deque. +func (r *ringBuffer) removeTail() { + if r.len() == 0 { + panic("removing tail from empty ring buffer") + } + lastPos := (cap(r.values) + r.tail - 1) % cap(r.values) + r.values[lastPos] = nil + r.tail = lastPos + if r.tail == r.head { + r.nonEmpty = false + } +} + +// slidingWindow maintains a deque of values along with corresponding indices +// based on cmp function: +// for Min behavior, cmp = -a.Compare(b) +// for Max behavior, cmp = a.Compare(b) +// +// It assumes that the frame bounds will never go back, i.e. non-decreasing sequences +// of frame start and frame end indices. +type slidingWindow struct { + values ringBuffer + evalCtx *tree.EvalContext + cmp func(*tree.EvalContext, tree.Datum, tree.Datum) int +} + +func makeSlidingWindow( + evalCtx *tree.EvalContext, cmp func(*tree.EvalContext, tree.Datum, tree.Datum) int, +) *slidingWindow { + return &slidingWindow{ + values: ringBuffer{}, + evalCtx: evalCtx, + cmp: cmp, + } +} + +// add first removes all values that are "smaller or equal" (depending on cmp) +// from the end of the deque and then appends 'iv' to the end. This way, the deque +// always contains unique values sorted in descending order of their "priority" +// (when we encounter duplicates, we always keep the one with the largest idx). +func (sw *slidingWindow) add(iv *indexedValue) { + for i := sw.values.len() - 1; i >= 0; i-- { + if sw.cmp(sw.evalCtx, sw.values.get(i).value, iv.value) > 0 { + break + } + sw.values.removeTail() + } + sw.values.add(iv) +} + +// removeAllBefore removes all values from the beginning of the deque that have indices +// smaller than given 'idx'. +// This operation corresponds to shifting the start of the frame up to 'idx'. +func (sw *slidingWindow) removeAllBefore(idx int) { + for i := 0; i < sw.values.len() && i < idx; i++ { + if sw.values.get(i).idx >= idx { + break + } + sw.values.removeHead() + } +} + +func (sw *slidingWindow) string() string { + var builder strings.Builder + for i := 0; i < sw.values.len(); i++ { + builder.WriteString(fmt.Sprintf("(%v, %v)\t", sw.values.get(i).value, sw.values.get(i).idx)) + } + return builder.String() +} + +type slidingWindowFunc struct { + sw *slidingWindow + prevEnd int +} + +// Compute implements WindowFunc interface. +func (w *slidingWindowFunc) Compute( + _ context.Context, _ *tree.EvalContext, wfr *tree.WindowFrameRun, +) (tree.Datum, error) { + start, end := wfr.FrameStartIdx(), wfr.FrameEndIdx() + + // We need to discard all values that are no longer in the frame. + w.sw.removeAllBefore(start) + + // We need to add all values that just entered the frame and have not been added yet. + for idx := max(w.prevEnd, start); idx < end; idx++ { + w.sw.add(&indexedValue{wfr.ArgsByRowIdx(idx)[0], idx}) + } + w.prevEnd = end + + if w.sw.values.len() == 0 { + // Spec: the frame is empty, so we return NULL. + return tree.DNull, nil + } + + // The datum with "highest priority" within the frame is at the very front of the deque. + return w.sw.values.get(0).value, nil +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} + +// Close implements WindowFunc interface. +func (w *slidingWindowFunc) Close(context.Context, *tree.EvalContext) { + w.sw = nil +} + +// slidingWindowSumFunc applies sliding window approach to summation over a frame. +// It assumes that the frame bounds will never go back, i.e. non-decreasing sequences +// of frame start and frame end indices. +type slidingWindowSumFunc struct { + agg tree.AggregateFunc // one of the four SumAggregates + prevStart, prevEnd int +} + +// removeAllBefore subtracts the values from all the rows that are no longer in the frame. +func (w *slidingWindowSumFunc) removeAllBefore( + ctx context.Context, wfr *tree.WindowFrameRun, +) error { + for idx := w.prevStart; idx < wfr.FrameStartIdx() && idx < w.prevEnd; idx++ { + value := wfr.ArgsByRowIdx(idx)[0] + switch v := value.(type) { + case *tree.DInt: + return w.agg.Add(ctx, tree.NewDInt(-*v)) + case *tree.DDecimal: + d := tree.DDecimal{} + d.Neg(&v.Decimal) + return w.agg.Add(ctx, &d) + case *tree.DFloat: + return w.agg.Add(ctx, tree.NewDFloat(-*v)) + case *tree.DInterval: + return w.agg.Add(ctx, &tree.DInterval{Duration: duration.Duration{}.Sub(v.Duration)}) + default: + return pgerror.NewErrorf(pgerror.CodeInternalError, "unexpected value %v", v) + } + } + return nil +} + +// Compute implements WindowFunc interface. +func (w *slidingWindowSumFunc) Compute( + ctx context.Context, evalCtx *tree.EvalContext, wfr *tree.WindowFrameRun, +) (tree.Datum, error) { + start, end := wfr.FrameStartIdx(), wfr.FrameEndIdx() + + // We need to discard all values that are no longer in the frame. + err := w.removeAllBefore(ctx, wfr) + if err != nil { + return tree.DNull, err + } + + // We need to sum all values that just entered the frame and have not been added yet. + for idx := max(w.prevEnd, start); idx < end; idx++ { + err = w.agg.Add(ctx, wfr.ArgsByRowIdx(idx)[0]) + if err != nil { + return tree.DNull, err + } + } + + w.prevStart = start + w.prevEnd = end + return w.agg.Result() +} + +// Close implements WindowFunc interface. +func (w *slidingWindowSumFunc) Close(ctx context.Context, _ *tree.EvalContext) { + w.agg.Close(ctx) +} + +// avgWindowFunc uses slidingWindowSumFunc to compute average over a frame. +type avgWindowFunc struct { + sum slidingWindowSumFunc +} + +// Compute implements WindowFunc interface. +func (w *avgWindowFunc) Compute( + ctx context.Context, evalCtx *tree.EvalContext, wfr *tree.WindowFrameRun, +) (tree.Datum, error) { + if wfr.FrameSize() == 0 { + // Spec: the frame is empty, so we return NULL. + return tree.DNull, nil + } + + var sum tree.Datum + var err error + sum, err = w.sum.Compute(ctx, evalCtx, wfr) + if err != nil { + return nil, err + } + + switch t := sum.(type) { + case *tree.DFloat: + return tree.NewDFloat(*t / tree.DFloat(wfr.FrameSize())), nil + case *tree.DDecimal: + var avg tree.DDecimal + count := apd.New(int64(wfr.FrameSize()), 0) + _, err := tree.DecimalCtx.Quo(&avg.Decimal, &t.Decimal, count) + return &avg, err + case *tree.DInt: + dd := tree.DDecimal{} + dd.SetCoefficient(int64(*t)) + var avg tree.DDecimal + count := apd.New(int64(wfr.FrameSize()), 0) + _, err := tree.DecimalCtx.Quo(&avg.Decimal, &dd.Decimal, count) + return &avg, err + default: + return nil, pgerror.NewErrorf(pgerror.CodeInternalError, "unexpected SUM result type: %s", t) + } +} + +// Close implements WindowFunc interface. +func (w *avgWindowFunc) Close(ctx context.Context, evalCtx *tree.EvalContext) { + w.sum.Close(ctx, evalCtx) +} diff --git a/pkg/sql/sem/builtins/window_frame_builtins_test.go b/pkg/sql/sem/builtins/window_frame_builtins_test.go new file mode 100644 index 000000000000..94c5872067a8 --- /dev/null +++ b/pkg/sql/sem/builtins/window_frame_builtins_test.go @@ -0,0 +1,242 @@ +// Copyright 2018 The Cockroach Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. See the License for the specific language governing +// permissions and limitations under the License. + +package builtins + +import ( + "bytes" + "context" + "fmt" + "math/rand" + "testing" + + "github.com/cockroachdb/cockroach/pkg/settings/cluster" + "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" +) + +const maxCount = 1000 +const maxInt = 1000000 +const maxOffset = 100 + +func testSlidingWindow(t *testing.T, count int) { + evalCtx := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings()) + defer evalCtx.Stop(context.Background()) + wfr := makeTestWindowFrameRun(count) + wfr.Frame = &tree.WindowFrame{ + Mode: tree.ROWS, + Bounds: tree.WindowFrameBounds{ + StartBound: &tree.WindowFrameBound{BoundType: tree.ValuePreceding}, + EndBound: &tree.WindowFrameBound{BoundType: tree.ValueFollowing}, + }, + } + testMin(t, evalCtx, wfr) + testMax(t, evalCtx, wfr) + testSumAndAvg(t, evalCtx, wfr) +} + +func testMin(t *testing.T, evalCtx *tree.EvalContext, wfr *tree.WindowFrameRun) { + for offset := 0; offset < maxOffset; offset += int(rand.Int31n(maxOffset / 10)) { + wfr.StartBoundOffset = offset + wfr.EndBoundOffset = offset + min := &slidingWindowFunc{} + min.sw = makeSlidingWindow(evalCtx, func(evalCtx *tree.EvalContext, a, b tree.Datum) int { + return -a.Compare(evalCtx, b) + }) + for wfr.RowIdx = 0; wfr.RowIdx < wfr.PartitionSize(); wfr.RowIdx++ { + res, err := min.Compute(evalCtx.Ctx(), evalCtx, wfr) + if err != nil { + t.Errorf("Unexpected error received when getting min from sliding window: %+v", err) + } + minResult, _ := tree.AsDInt(res) + naiveMin := tree.DInt(maxInt) + for idx := wfr.RowIdx - offset; idx <= wfr.RowIdx+offset; idx++ { + if idx < 0 || idx >= wfr.PartitionSize() { + continue + } + el, _ := tree.AsDInt(wfr.Rows[idx].Row[0]) + if el < naiveMin { + naiveMin = el + } + } + if minResult != naiveMin { + t.Errorf("Min sliding window returned wrong result: expected %+v, found %+v", naiveMin, minResult) + t.Errorf("partitionSize: %+v idx: %+v offset: %+v", wfr.PartitionSize(), wfr.RowIdx, offset) + t.Errorf(min.sw.string()) + t.Errorf(partitionToString(wfr.Rows)) + panic("") + } + } + } +} + +func testMax(t *testing.T, evalCtx *tree.EvalContext, wfr *tree.WindowFrameRun) { + for offset := 0; offset < maxOffset; offset += int(rand.Int31n(maxOffset / 10)) { + wfr.StartBoundOffset = offset + wfr.EndBoundOffset = offset + max := &slidingWindowFunc{} + max.sw = makeSlidingWindow(evalCtx, func(evalCtx *tree.EvalContext, a, b tree.Datum) int { + return a.Compare(evalCtx, b) + }) + for wfr.RowIdx = 0; wfr.RowIdx < wfr.PartitionSize(); wfr.RowIdx++ { + res, err := max.Compute(evalCtx.Ctx(), evalCtx, wfr) + if err != nil { + t.Errorf("Unexpected error received when getting max from sliding window: %+v", err) + } + maxResult, _ := tree.AsDInt(res) + naiveMax := tree.DInt(-maxInt) + for idx := wfr.RowIdx - offset; idx <= wfr.RowIdx+offset; idx++ { + if idx < 0 || idx >= wfr.PartitionSize() { + continue + } + el, _ := tree.AsDInt(wfr.Rows[idx].Row[0]) + if el > naiveMax { + naiveMax = el + } + } + if maxResult != naiveMax { + t.Errorf("Max sliding window returned wrong result: expected %+v, found %+v", naiveMax, maxResult) + t.Errorf("partitionSize: %+v idx: %+v offset: %+v", wfr.PartitionSize(), wfr.RowIdx, offset) + t.Errorf(max.sw.string()) + t.Errorf(partitionToString(wfr.Rows)) + panic("") + } + } + } +} + +func testSumAndAvg(t *testing.T, evalCtx *tree.EvalContext, wfr *tree.WindowFrameRun) { + for offset := 0; offset < maxOffset; offset += int(rand.Int31n(maxOffset / 10)) { + wfr.StartBoundOffset = offset + wfr.EndBoundOffset = offset + sum := &slidingWindowSumFunc{agg: &intSumAggregate{}} + avg := &avgWindowFunc{sum: slidingWindowSumFunc{agg: &intSumAggregate{}}} + for wfr.RowIdx = 0; wfr.RowIdx < wfr.PartitionSize(); wfr.RowIdx++ { + res, err := sum.Compute(evalCtx.Ctx(), evalCtx, wfr) + if err != nil { + t.Errorf("Unexpected error received when getting sum from sliding window: %+v", err) + } + sumResult := tree.DDecimal{Decimal: res.(*tree.DDecimal).Decimal} + res, err = avg.Compute(evalCtx.Ctx(), evalCtx, wfr) + if err != nil { + t.Errorf("Unexpected error received when getting avg from sliding window: %+v", err) + } + avgResult := tree.DDecimal{Decimal: res.(*tree.DDecimal).Decimal} + naiveSum := int64(0) + for idx := wfr.RowIdx - offset; idx <= wfr.RowIdx+offset; idx++ { + if idx < 0 || idx >= wfr.PartitionSize() { + continue + } + el, _ := tree.AsDInt(wfr.Rows[idx].Row[0]) + naiveSum += int64(el) + } + s, err := sumResult.Int64() + if err != nil { + t.Errorf("Unexpected error received when converting sum from DDecimal to int64: %+v", err) + } + if s != naiveSum { + t.Errorf("Sum sliding window returned wrong result: expected %+v, found %+v", naiveSum, s) + t.Errorf("partitionSize: %+v idx: %+v offset: %+v", wfr.PartitionSize(), wfr.RowIdx, offset) + t.Errorf(partitionToString(wfr.Rows)) + panic("") + } + a, err := avgResult.Float64() + if err != nil { + t.Errorf("Unexpected error received when converting avg from DDecimal to float64: %+v", err) + } + if a != float64(naiveSum)/float64(wfr.FrameSize()) { + t.Errorf("Sum sliding window returned wrong result: expected %+v, found %+v", float64(naiveSum)/float64(wfr.FrameSize()), a) + t.Errorf("partitionSize: %+v idx: %+v offset: %+v", wfr.PartitionSize(), wfr.RowIdx, offset) + t.Errorf(partitionToString(wfr.Rows)) + panic("") + } + } + } +} + +func makeTestWindowFrameRun(count int) *tree.WindowFrameRun { + return &tree.WindowFrameRun{ + Rows: makeTestPartition(count), + ArgIdxStart: 0, + ArgCount: 1, + } +} + +func makeTestPartition(count int) []tree.IndexedRow { + partition := make([]tree.IndexedRow, count) + for idx := 0; idx < count; idx++ { + partition[idx] = tree.IndexedRow{Idx: idx, Row: tree.Datums{tree.NewDInt(tree.DInt(rand.Int31n(maxInt)))}} + } + return partition +} + +func partitionToString(partition []tree.IndexedRow) string { + var buf bytes.Buffer + buf.WriteString("\n=====Partition=====\n") + for idx := 0; idx < len(partition); idx++ { + buf.WriteString(fmt.Sprintf("%v\n", partition[idx])) + } + buf.WriteString("====================\n") + return buf.String() +} + +func testRingBuffer(t *testing.T, count int) { + evalCtx := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings()) + defer evalCtx.Stop(context.Background()) + partition := makeTestPartition(count) + ring := ringBuffer{} + naiveBuffer := make([]*indexedValue, 0, count) + for idx, row := range partition { + if ring.len() != len(naiveBuffer) { + t.Errorf("Ring ring returned incorrect len: expected %v, found %v", len(naiveBuffer), ring.len()) + panic("") + } + + op := rand.Float64() + if op < 0.5 { + iv := &indexedValue{idx: idx, value: row.Row[0]} + ring.add(iv) + naiveBuffer = append(naiveBuffer, iv) + } else if op < 0.75 { + if len(naiveBuffer) > 0 { + ring.removeHead() + naiveBuffer = naiveBuffer[1:] + } + } else { + if len(naiveBuffer) > 0 { + ring.removeTail() + naiveBuffer = naiveBuffer[:len(naiveBuffer)-1] + } + } + + for pos, iv := range naiveBuffer { + res := ring.get(pos) + if res.idx != iv.idx || res.value.Compare(evalCtx, iv.value) != 0 { + t.Errorf("Ring buffer returned incorrect value: expected %+v, found %+v", iv, res) + panic("") + } + } + } +} + +func TestSlidingWindow(t *testing.T) { + for count := 1; count <= maxCount; count += int(rand.Int31n(maxCount / 10)) { + testSlidingWindow(t, count) + } +} + +func TestRingBuffer(t *testing.T) { + for count := 1; count <= maxCount; count++ { + testRingBuffer(t, count) + } +} diff --git a/pkg/sql/sem/tree/window_funcs.go b/pkg/sql/sem/tree/window_funcs.go index e9ec8eba1f03..c1cac5292504 100644 --- a/pkg/sql/sem/tree/window_funcs.go +++ b/pkg/sql/sem/tree/window_funcs.go @@ -90,6 +90,18 @@ func (wfr WindowFrameRun) FrameStartIdx() int { } } +// IsDefaultFrame returns whether a frame equivalent to the default frame +// is being used (default is RANGE UNBOUNDED PRECEDING). +func (wfr WindowFrameRun) IsDefaultFrame() bool { + if wfr.Frame == nil { + return true + } + if wfr.Frame.Bounds.StartBound.BoundType == UnboundedPreceding { + return wfr.Frame.Bounds.EndBound == nil || wfr.Frame.Bounds.EndBound.BoundType == CurrentRow + } + return false +} + // FrameEndIdx returns the index of the first row after the frame. func (wfr WindowFrameRun) FrameEndIdx() int { if wfr.Frame == nil { diff --git a/pkg/sql/window.go b/pkg/sql/window.go index 27d303c23b28..70768e61e721 100644 --- a/pkg/sql/window.go +++ b/pkg/sql/window.go @@ -732,10 +732,6 @@ func (n *windowNode) computeWindows(ctx context.Context, evalCtx *tree.EvalConte builtin := windowFn.expr.GetWindowConstructor()(evalCtx) defer builtin.Close(ctx, evalCtx) - // In order to calculate aggregates over a particular window frame, - // we need a way to 'reset' the aggregate, so this constructor will be used for that. - aggConstructor := windowFn.expr.GetAggregateConstructor() - var peerGrouper peerGroupChecker if windowFn.columnOrdering != nil { // If an ORDER BY clause is provided, order the partition and use the @@ -769,8 +765,13 @@ func (n *windowNode) computeWindows(ctx context.Context, evalCtx *tree.EvalConte frameRun.ArgCount = windowFn.argCount frameRun.RowIdx = 0 - if frameRun.Frame != nil { - builtins.AddAggregateConstructorToFramableAggregate(builtin, aggConstructor) + if !frameRun.IsDefaultFrame() { + // We have a custom frame not equivalent to default one, so if we have + // an aggregate function, we want to reset it for each row. + // Not resetting is an optimization since we're not computing + // the result over the whole frame but only as a result of the current + // row and previous results of aggregation. + builtins.ShouldReset(builtin) } for frameRun.RowIdx < len(partition) {