diff --git a/presto-main/src/main/java/com/facebook/presto/operator/window/WindowPartition.java b/presto-main/src/main/java/com/facebook/presto/operator/window/WindowPartition.java index 68bfd695e75e..448faf9633f3 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/window/WindowPartition.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/window/WindowPartition.java @@ -160,7 +160,7 @@ private Range getFrameRange(FrameInfo frameInfo) return new Range(-1, -1); } - if (frameInfo.getType() == RANGE && emptyFrame(frameInfo, peerGroupStart, peerGroupEnd - 1)) { + if (frameInfo.getType() == RANGE && emptyFrame(frameInfo, peerGroupStart, endPosition - (peerGroupEnd - 1))) { return new Range(-1, -1); } @@ -174,6 +174,9 @@ private Range getFrameRange(FrameInfo frameInfo) else if (frameInfo.getType() == RANGE && frameInfo.getStartType() == PRECEDING) { frameStart = precedingStartRange(getStartValue(frameInfo)); } + else if (frameInfo.getType() == RANGE && frameInfo.getStartType() == FOLLOWING) { + frameStart = followingRange(rowPosition, endPosition, getStartValue(frameInfo)); + } else if (frameInfo.getStartType() == PRECEDING) { frameStart = preceding(rowPosition, getStartValue(frameInfo)); } @@ -194,6 +197,9 @@ else if (frameInfo.getType() == RANGE) { else if (frameInfo.getType() == RANGE && frameInfo.getEndType() == PRECEDING) { frameEnd = precedingEndRange(getEndValue(frameInfo)); } + else if (frameInfo.getType() == RANGE && frameInfo.getEndType() == FOLLOWING) { + frameEnd = followingRange(peerGroupEnd, endPosition + 1, getEndValue(frameInfo)) - 1; + } else if (frameInfo.getEndType() == PRECEDING) { frameEnd = preceding(rowPosition, getEndValue(frameInfo)); } @@ -228,6 +234,33 @@ private int precedingStartRange(long startValue) return peerGroupStartIndices.get(toIntExact(peerGroupStartIndex - startValue)); } + private int followingRange(int followingPeerGroupStart, int endPosition, long value) + { + if (value == 0) { + return followingPeerGroupStart; + } + // TODO: Optimize this to *not* look for peers often, probably have pageIndex keep the peer groups + int followingPeerGroupEnd = 0; + int currentValue = 0; + while (currentValue < value) { + boolean peerFound = false; + followingPeerGroupEnd = followingPeerGroupStart + 1; + while ((followingPeerGroupEnd < partitionEnd) && pagesIndex.positionEqualsPosition(peerGroupHashStrategy, followingPeerGroupStart, followingPeerGroupEnd)) { + followingPeerGroupEnd++; + peerFound = true; + } + if (followingPeerGroupEnd >= partitionEnd) { + return endPosition; + } + + if (!peerFound) { + currentValue++; + } + followingPeerGroupStart++; + } + return followingPeerGroupEnd; + } + private boolean emptyFrame(FrameInfo frameInfo, int rowPosition, int position) { FrameBound.Type startType = frameInfo.getStartType(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java index c6272c0e76f9..034a56098ad8 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java @@ -187,7 +187,6 @@ import static com.facebook.presto.sql.tree.FrameBound.Type.PRECEDING; import static com.facebook.presto.sql.tree.FrameBound.Type.UNBOUNDED_FOLLOWING; import static com.facebook.presto.sql.tree.FrameBound.Type.UNBOUNDED_PRECEDING; -import static com.facebook.presto.sql.tree.WindowFrame.Type.RANGE; import static com.facebook.presto.type.UnknownType.UNKNOWN; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -1260,9 +1259,6 @@ private void analyzeWindowFrame(WindowFrame frame) if ((startType == FOLLOWING) && (endType == CURRENT_ROW)) { throw new SemanticException(INVALID_WINDOW_FRAME, frame, "Window frame starting from FOLLOWING cannot end with CURRENT ROW"); } - if ((frame.getType() == RANGE) && ((startType == FOLLOWING) || (endType == FOLLOWING))) { - throw new SemanticException(INVALID_WINDOW_FRAME, frame, "Window frame RANGE FOLLOWING is only supported with UNBOUNDED"); - } } private void analyzeHaving(QuerySpecification node, Scope scope) diff --git a/presto-main/src/test/java/com/facebook/presto/operator/window/TestAggregateWindowFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/window/TestAggregateWindowFunction.java index 199d2efa3601..028a34b18eaf 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/window/TestAggregateWindowFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/window/TestAggregateWindowFunction.java @@ -641,4 +641,128 @@ public void testSumRangePrecedingBounded() .row(null, null, null) .build()); } + + @Test + public void testSumRangeFollowingBounded() + { + assertWindowQueryWithNulls("sum(orderkey) OVER (ORDER BY orderstatus " + + "RANGE BETWEEN current row AND 1 FOLLOWING)", + resultBuilder(TEST_SESSION, BIGINT, VARCHAR, BIGINT) + .row(3L, "F", 48L) + .row(5L, "F", 48L) + .row(6L, "F", 48L) + .row(null, "F", 48L) + .row(34L, "O", 42L) + .row(null, "O", 42L) + .row(1L, null, 8L) + .row(7L, null, 8L) + .row(null, null, 8L) + .row(null, null, 8L) + .build()); + + assertWindowQueryWithNulls("sum(orderkey) OVER (PARTITION BY orderstatus ORDER BY orderkey " + + "RANGE BETWEEN current row AND 1 FOLLOWING)", + resultBuilder(TEST_SESSION, BIGINT, VARCHAR, BIGINT) + .row(3L, "F", 8L) + .row(5L, "F", 11L) + .row(6L, "F", 6L) + .row(null, "F", null) + .row(34L, "O", 34L) + .row(null, "O", null) + .row(1L, null, 8L) + .row(7L, null, 7L) + .row(null, null, null) + .row(null, null, null) + .build()); + + assertWindowQueryWithNulls("sum(orderkey) OVER (PARTITION BY orderstatus ORDER BY orderkey " + + "RANGE BETWEEN current row AND 0 FOLLOWING)", + resultBuilder(TEST_SESSION, BIGINT, VARCHAR, BIGINT) + .row(3L, "F", 3L) + .row(5L, "F", 5L) + .row(6L, "F", 6L) + .row(null, "F", null) + .row(34L, "O", 34L) + .row(null, "O", null) + .row(1L, null, 1L) + .row(7L, null, 7L) + .row(null, null, null) + .row(null, null, null) + .build()); + + assertWindowQueryWithNulls("sum(orderkey) OVER (PARTITION BY orderstatus ORDER BY orderkey " + + "RANGE BETWEEN 1 FOLLOWING AND 2 FOLLOWING)", + resultBuilder(TEST_SESSION, INTEGER, VARCHAR, BIGINT) + .row(3L, "F", 11L) + .row(5L, "F", 6L) + .row(6L, "F", null) + .row(null, "F", null) + .row(34L, "O", null) + .row(null, "O", null) + .row(1L, null, 7L) + .row(7L, null, null) + .row(null, null, null) + .row(null, null, null) + .build()); + + assertWindowQueryWithNulls("sum(orderkey) OVER (PARTITION BY orderstatus ORDER BY orderkey " + + "RANGE BETWEEN 0 FOLLOWING AND UNBOUNDED FOLLOWING)", + resultBuilder(TEST_SESSION, INTEGER, VARCHAR, BIGINT) + .row(3L, "F", 14L) + .row(5L, "F", 11L) + .row(6L, "F", 6L) + .row(null, "F", null) + .row(34L, "O", 34L) + .row(null, "O", null) + .row(1L, null, 8L) + .row(7L, null, 7L) + .row(null, null, null) + .row(null, null, null) + .build()); + + assertWindowQueryWithNulls("sum(orderkey) OVER (PARTITION BY orderstatus ORDER BY orderkey " + + "RANGE BETWEEN 2 FOLLOWING AND UNBOUNDED FOLLOWING)", + resultBuilder(TEST_SESSION, INTEGER, VARCHAR, BIGINT) + .row(3L, "F", 6L) + .row(5L, "F", null) + .row(6L, "F", null) + .row(null, "F", null) + .row(34L, "O", null) + .row(null, "O", null) + .row(1L, null, null) + .row(7L, null, null) + .row(null, null, null) + .row(null, null, null) + .build()); + + assertWindowQueryWithNulls("sum(orderkey) OVER (PARTITION BY orderstatus ORDER BY orderkey " + + "RANGE BETWEEN 4 FOLLOWING AND UNBOUNDED FOLLOWING)", + resultBuilder(TEST_SESSION, INTEGER, VARCHAR, BIGINT) + .row(3L, "F", null) + .row(5L, "F", null) + .row(6L, "F", null) + .row(null, "F", null) + .row(34L, "O", null) + .row(null, "O", null) + .row(1L, null, null) + .row(7L, null, null) + .row(null, null, null) + .row(null, null, null) + .build()); + + assertWindowQueryWithNulls("sum(orderkey) OVER (PARTITION BY orderstatus ORDER BY orderkey " + + "RANGE BETWEEN 3 FOLLOWING AND 2 FOLLOWING)", + resultBuilder(TEST_SESSION, INTEGER, VARCHAR, BIGINT) + .row(3L, "F", null) + .row(5L, "F", null) + .row(6L, "F", null) + .row(null, "F", null) + .row(34L, "O", null) + .row(null, "O", null) + .row(1L, null, null) + .row(7L, null, null) + .row(null, null, null) + .row(null, null, null) + .build()); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/window/TestFirstValueFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/window/TestFirstValueFunction.java index 64dcd318301e..2c2ab5bb08f7 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/window/TestFirstValueFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/window/TestFirstValueFunction.java @@ -142,5 +142,19 @@ public void testFirstValueBounded() .row(null, null, 1L) .row(null, null, 1L) .build()); + assertWindowQueryWithNulls("first_value(orderkey) OVER (PARTITION BY orderstatus ORDER BY orderkey " + + "RANGE BETWEEN 2 PRECEDING AND 2 FOLLOWING)", + resultBuilder(TEST_SESSION, BIGINT, VARCHAR, BIGINT) + .row(3L, "F", 3L) + .row(5L, "F", 3L) + .row(6L, "F", 3L) + .row(null, "F", 5L) + .row(34L, "O", 34L) + .row(null, "O", 34L) + .row(1L, null, 1L) + .row(7L, null, 1L) + .row(null, null, 1L) + .row(null, null, 1L) + .build()); } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/window/TestLastValueFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/window/TestLastValueFunction.java index 5274d76f12a7..57cbcdc8b9dd 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/window/TestLastValueFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/window/TestLastValueFunction.java @@ -141,5 +141,19 @@ public void testLastValueBounded() .row(null, null, 7L) .row(null, null, 7L) .build()); + assertWindowQueryWithNulls("last_value(orderkey) OVER (PARTITION BY orderstatus ORDER BY orderkey " + + "RANGE BETWEEN 2 PRECEDING AND 1 FOLLOWING)", + resultBuilder(TEST_SESSION, BIGINT, VARCHAR, BIGINT) + .row(3L, "F", 5L) + .row(5L, "F", 6L) + .row(6L, "F", null) + .row(null, "F", null) + .row(34L, "O", null) + .row(null, "O", null) + .row(1L, null, 7L) + .row(7L, null, null) + .row(null, null, null) + .row(null, null, null) + .build()); } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/window/TestMultipleWindowSpecifications.java b/presto-main/src/test/java/com/facebook/presto/operator/window/TestMultipleWindowSpecifications.java index 289575cf36ec..fafbad350ada 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/window/TestMultipleWindowSpecifications.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/window/TestMultipleWindowSpecifications.java @@ -222,34 +222,34 @@ public void testDisjointWindowSpecifications() public void testMultipleWindowSpecificationsWithRange() { // Intersection previous to current row - assertWindowQueryWithNulls("count(orderkey) OVER (PARTITION BY orderstatus ORDER BY orderkey RANGE BETWEEN 3 PRECEDING AND 2 PRECEDING), " + + assertWindowQueryWithNulls("count(orderkey) OVER (PARTITION BY orderstatus ORDER BY orderkey RANGE BETWEEN 3 PRECEDING AND 1 FOLLOWING), " + "sum(orderkey) OVER (PARTITION BY orderstatus ORDER BY orderkey RANGE BETWEEN 2 PRECEDING AND CURRENT ROW)", resultBuilder(TEST_SESSION, BIGINT, VARCHAR, BIGINT, BIGINT) - .row(3L, "F", 0L, 3L) - .row(5L, "F", 0L, 8L) - .row(6L, "F", 1L, 14L) - .row(null, "F", 2L, 11L) - .row(34L, "O", 0L, 34L) - .row(null, "O", 0L, 34L) - .row(1L, null, 0L, 1L) - .row(7L, null, 0L, 8L) - .row(null, null, 1L, 8L) - .row(null, null, 1L, 8L) + .row(3L, "F", 2L, 3L) + .row(5L, "F", 3L, 8L) + .row(6L, "F", 3L, 14L) + .row(null, "F", 3L, 11L) + .row(34L, "O", 1L, 34L) + .row(null, "O", 1L, 34L) + .row(1L, null, 2L, 1L) + .row(7L, null, 2L, 8L) + .row(null, null, 2L, 8L) + .row(null, null, 2L, 8L) .build()); // Disjoint - assertWindowQueryWithNulls("count(orderkey) OVER (PARTITION BY orderstatus ORDER BY orderkey RANGE BETWEEN 3 PRECEDING AND 2 PRECEDING), " + + assertWindowQueryWithNulls("count(orderkey) OVER (PARTITION BY orderstatus ORDER BY orderkey RANGE BETWEEN 3 PRECEDING AND 2 FOLLOWING), " + "sum(orderkey) OVER (PARTITION BY orderstatus ORDER BY orderkey RANGE BETWEEN 1 PRECEDING AND CURRENT ROW)", resultBuilder(TEST_SESSION, BIGINT, VARCHAR, BIGINT, BIGINT) - .row(3L, "F", 0L, 3L) - .row(5L, "F", 0L, 8L) - .row(6L, "F", 1L, 11L) - .row(null, "F", 2L, 6L) - .row(34L, "O", 0L, 34L) - .row(null, "O", 0L, 34L) - .row(1L, null, 0L, 1L) - .row(7L, null, 0L, 8L) - .row(null, null, 1L, 7L) - .row(null, null, 1L, 7L) + .row(3L, "F", 3L, 3L) + .row(5L, "F", 3L, 8L) + .row(6L, "F", 3L, 11L) + .row(null, "F", 3L, 6L) + .row(34L, "O", 1L, 34L) + .row(null, "O", 1L, 34L) + .row(1L, null, 2L, 1L) + .row(7L, null, 2L, 8L) + .row(null, null, 2L, 7L) + .row(null, null, 2L, 7L) .build()); } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/window/TestNthValueFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/window/TestNthValueFunction.java index 81dc6e98553d..90df84df2b70 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/window/TestNthValueFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/window/TestNthValueFunction.java @@ -147,12 +147,12 @@ public void testNthValueBounded() .row(null, null, null) .build()); - assertWindowQueryWithNulls("nth_value(orderkey, 4) OVER (PARTITION BY orderstatus ORDER BY orderkey " + - "RANGE BETWEEN 2 PRECEDING AND 1 PRECEDING)", + assertWindowQueryWithNulls("nth_value(orderkey, 3) OVER (PARTITION BY orderstatus ORDER BY orderkey " + + "RANGE BETWEEN 2 PRECEDING AND 1 FOLLOWING)", resultBuilder(TEST_SESSION, BIGINT, VARCHAR, BIGINT) .row(3L, "F", null) - .row(5L, "F", null) - .row(6L, "F", null) + .row(5L, "F", 6L) + .row(6L, "F", 6L) .row(null, "F", null) .row(34L, "O", null) .row(null, "O", null) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java index 189e21c397d1..6b4445d3c72b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java @@ -645,8 +645,6 @@ public void testInvalidWindowFrame() assertFails(INVALID_WINDOW_FRAME, "SELECT rank() OVER (ROWS BETWEEN CURRENT ROW AND 5 PRECEDING)"); assertFails(INVALID_WINDOW_FRAME, "SELECT rank() OVER (ROWS BETWEEN 2 FOLLOWING AND 5 PRECEDING)"); assertFails(INVALID_WINDOW_FRAME, "SELECT rank() OVER (ROWS BETWEEN 2 FOLLOWING AND CURRENT ROW)"); - assertFails(INVALID_WINDOW_FRAME, "SELECT rank() OVER (RANGE BETWEEN CURRENT ROW AND 5 FOLLOWING)"); - assertFails(INVALID_WINDOW_FRAME, "SELECT rank() OVER (RANGE BETWEEN 2 PRECEDING AND 5 FOLLOWING)"); assertFails(TYPE_MISMATCH, "SELECT rank() OVER (ROWS 0.5 PRECEDING)"); assertFails(TYPE_MISMATCH, "SELECT rank() OVER (ROWS 'foo' PRECEDING)");