From f98fe483616d1a59ca343ffacc4c031fd54dcb2a Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Fri, 21 May 2021 14:11:39 +0200 Subject: [PATCH] Support row pattern measures over window: grammmar and AST --- .../sql/analyzer/ExpressionAnalyzer.java | 7 + .../io/trino/sql/analyzer/TestAnalyzer.java | 14 ++ .../antlr4/io/trino/sql/parser/SqlBase.g4 | 1 + .../io/trino/sql/ExpressionFormatter.java | 7 + .../java/io/trino/sql/parser/AstBuilder.java | 7 + .../java/io/trino/sql/tree/AstVisitor.java | 5 + .../sql/tree/DefaultTraversalVisitor.java | 9 + .../io/trino/sql/tree/ExpressionRewriter.java | 5 + .../sql/tree/ExpressionTreeRewriter.java | 187 ++++++++++-------- .../io/trino/sql/tree/WindowOperation.java | 111 +++++++++++ .../io/trino/sql/parser/TestSqlParser.java | 37 ++++ 11 files changed, 310 insertions(+), 80 deletions(-) create mode 100644 core/trino-parser/src/main/java/io/trino/sql/tree/WindowOperation.java diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java index 6c8aeb51c843b..87c47c1702332 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java @@ -122,6 +122,7 @@ import io.trino.sql.tree.TryExpression; import io.trino.sql.tree.WhenClause; import io.trino.sql.tree.WindowFrame; +import io.trino.sql.tree.WindowOperation; import io.trino.type.FunctionType; import io.trino.type.TypeCoercion; import io.trino.type.UnknownType; @@ -1365,6 +1366,12 @@ private void analyzeFrameRangeOffset(Expression offsetValue, FrameBound.Type bou frameBoundCalculations.put(NodeRef.of(offsetValue), function); } + @Override + protected Type visitWindowOperation(WindowOperation node, StackableAstVisitorContext context) + { + throw semanticException(NOT_SUPPORTED, node, "Row pattern measures over window not yet supported"); + } + public List getCallArgumentTypes(List arguments, StackableAstVisitorContext context) { ImmutableList.Builder argumentTypesBuilder = ImmutableList.builder(); diff --git a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java index 8b5ae17d6da94..eb5efd6685945 100644 --- a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java +++ b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java @@ -1384,6 +1384,20 @@ public void testWindowFrameWithPatternRecognition() .hasMessage("line 1:200: Pattern recognition in window not yet supported"); } + @Test + public void testMeasureOverWindow() + { + // in-line window specification + assertFails("SELECT last_z OVER () FROM (VALUES 1) t(z) ") + .hasErrorCode(NOT_SUPPORTED) + .hasMessage("line 1:8: Row pattern measures over window not yet supported"); + + // named window reference + assertFails("SELECT last_z OVER w FROM (VALUES 1) t(z) WINDOW w AS ()") + .hasErrorCode(NOT_SUPPORTED) + .hasMessage("line 1:8: Row pattern measures over window not yet supported"); + } + @Test public void testDistinctInWindowFunctionParameter() { diff --git a/core/trino-parser/src/main/antlr4/io/trino/sql/parser/SqlBase.g4 b/core/trino-parser/src/main/antlr4/io/trino/sql/parser/SqlBase.g4 index 6cccc1400ccb4..ca2fd0905811e 100644 --- a/core/trino-parser/src/main/antlr4/io/trino/sql/parser/SqlBase.g4 +++ b/core/trino-parser/src/main/antlr4/io/trino/sql/parser/SqlBase.g4 @@ -410,6 +410,7 @@ primaryExpression | qualifiedName '(' ASTERISK ')' filter? over? #functionCall | processingMode? qualifiedName '(' (setQuantifier? expression (',' expression)*)? (ORDER BY sortItem (',' sortItem)*)? ')' filter? (nullTreatment? over)? #functionCall + | identifier over #measure | identifier '->' expression #lambda | '(' (identifier (',' identifier)*)? ')' '->' expression #lambda | '(' query ')' #subqueryExpression diff --git a/core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java b/core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java index f44ceab60cae9..6f5852b0387f1 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java +++ b/core/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java @@ -94,6 +94,7 @@ import io.trino.sql.tree.WhenClause; import io.trino.sql.tree.Window; import io.trino.sql.tree.WindowFrame; +import io.trino.sql.tree.WindowOperation; import io.trino.sql.tree.WindowReference; import io.trino.sql.tree.WindowSpecification; @@ -425,6 +426,12 @@ protected String visitFunctionCall(FunctionCall node, Void context) return builder.toString(); } + @Override + protected String visitWindowOperation(WindowOperation node, Void context) + { + return process(node.getName(), context) + " OVER " + formatWindow(node.getWindow()); + } + @Override protected String visitLambdaExpression(LambdaExpression node, Void context) { diff --git a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java index 02bc83e8d729a..58f7bfe364206 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java +++ b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java @@ -222,6 +222,7 @@ import io.trino.sql.tree.Window; import io.trino.sql.tree.WindowDefinition; import io.trino.sql.tree.WindowFrame; +import io.trino.sql.tree.WindowOperation; import io.trino.sql.tree.WindowReference; import io.trino.sql.tree.WindowSpecification; import io.trino.sql.tree.With; @@ -2096,6 +2097,12 @@ else if (processingMode.FINAL() != null) { visit(context.expression(), Expression.class)); } + @Override + public Node visitMeasure(SqlBaseParser.MeasureContext context) + { + return new WindowOperation(getLocation(context), (Identifier) visit(context.identifier()), (Window) visit(context.over())); + } + @Override public Node visitLambda(SqlBaseParser.LambdaContext context) { diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java b/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java index e01dde330c2d2..9c75ace98bbd3 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/AstVisitor.java @@ -297,6 +297,11 @@ protected R visitProcessingMode(ProcessingMode node, C context) return visitNode(node, context); } + protected R visitWindowOperation(WindowOperation node, C context) + { + return visitExpression(node, context); + } + protected R visitLambdaExpression(LambdaExpression node, C context) { return visitExpression(node, context); diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/DefaultTraversalVisitor.java b/core/trino-parser/src/main/java/io/trino/sql/tree/DefaultTraversalVisitor.java index bcf17bbe71fde..72b11e34d49ff 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/DefaultTraversalVisitor.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/DefaultTraversalVisitor.java @@ -209,6 +209,15 @@ protected Void visitFunctionCall(FunctionCall node, C context) return null; } + @Override + protected Void visitWindowOperation(WindowOperation node, C context) + { + process(node.getName(), context); + process((Node) node.getWindow(), context); + + return null; + } + @Override protected Void visitGroupingOperation(GroupingOperation node, C context) { diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/ExpressionRewriter.java b/core/trino-parser/src/main/java/io/trino/sql/tree/ExpressionRewriter.java index 7e18d1a17b1e8..e22135458a202 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/ExpressionRewriter.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/ExpressionRewriter.java @@ -105,6 +105,11 @@ public Expression rewriteFunctionCall(FunctionCall node, C context, ExpressionTr return rewriteExpression(node, context, treeRewriter); } + public Expression rewriteWindowOperation(WindowOperation node, C context, ExpressionTreeRewriter treeRewriter) + { + return rewriteExpression(node, context, treeRewriter); + } + public Expression rewriteLambdaExpression(LambdaExpression node, C context, ExpressionTreeRewriter treeRewriter) { return rewriteExpression(node, context, treeRewriter); diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/ExpressionTreeRewriter.java b/core/trino-parser/src/main/java/io/trino/sql/tree/ExpressionTreeRewriter.java index 7b1d61db8cace..2e07659be2679 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/tree/ExpressionTreeRewriter.java +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/ExpressionTreeRewriter.java @@ -484,95 +484,22 @@ public Expression visitFunctionCall(FunctionCall node, Context context) filter = Optional.of(newFilterExpression); } - Optional rewrittenWindow = node.getWindow(); - if (node.getWindow().isPresent()) { - Window window = node.getWindow().get(); - - if (window instanceof WindowReference) { - WindowReference windowReference = (WindowReference) window; - Identifier rewrittenName = rewrite(windowReference.getName(), context.get()); - if (windowReference.getName() != rewrittenName) { - rewrittenWindow = Optional.of(new WindowReference(rewrittenName)); - } - } - else if (window instanceof WindowSpecification) { - WindowSpecification windowSpecification = (WindowSpecification) window; - Optional existingWindowName = windowSpecification.getExistingWindowName().map(name -> rewrite(name, context.get())); - - List partitionBy = rewrite(windowSpecification.getPartitionBy(), context); - - Optional orderBy = Optional.empty(); - if (windowSpecification.getOrderBy().isPresent()) { - orderBy = Optional.of(rewriteOrderBy(windowSpecification.getOrderBy().get(), context)); - } - - Optional rewrittenFrame = windowSpecification.getFrame(); - if (rewrittenFrame.isPresent()) { - WindowFrame frame = rewrittenFrame.get(); - - FrameBound start = frame.getStart(); - if (start.getValue().isPresent()) { - Expression value = rewrite(start.getValue().get(), context.get()); - if (value != start.getValue().get()) { - start = new FrameBound(start.getType(), value); - } - } - - Optional rewrittenEnd = frame.getEnd(); - if (rewrittenEnd.isPresent()) { - Optional value = rewrittenEnd.get().getValue(); - if (value.isPresent()) { - Expression rewrittenValue = rewrite(value.get(), context.get()); - if (rewrittenValue != value.get()) { - rewrittenEnd = Optional.of(new FrameBound(rewrittenEnd.get().getType(), rewrittenValue)); - } - } - } - - // Frame properties for row pattern matching are not rewritten. They are planned as parts of - // PatternRecognitionNode, and shouldn't be accessed past the Planner phase. - // There are nested expressions in Measures and VariableDefinitions. They are not rewritten by default. - // Rewriting them requires special handling of DereferenceExpression, aware of pattern labels. - if (!frame.getMeasures().isEmpty() || - frame.getAfterMatchSkipTo().isPresent() || - frame.getPatternSearchMode().isPresent() || - frame.getPattern().isPresent() || - !frame.getSubsets().isEmpty() || - !frame.getVariableDefinitions().isEmpty()) { - throw new UnsupportedOperationException("Cannot rewrite pattern recognition clauses in window"); - } - - if ((frame.getStart() != start) || !sameElements(frame.getEnd(), rewrittenEnd)) { - rewrittenFrame = Optional.of(new WindowFrame( - frame.getType(), - start, - rewrittenEnd, - frame.getMeasures(), - frame.getAfterMatchSkipTo(), - frame.getPatternSearchMode(), - frame.getPattern(), - frame.getSubsets(), - frame.getVariableDefinitions())); - } - } - - if (!sameElements(windowSpecification.getExistingWindowName(), existingWindowName) || - !sameElements(windowSpecification.getPartitionBy(), partitionBy) || - !sameElements(windowSpecification.getOrderBy(), orderBy) || - !sameElements(windowSpecification.getFrame(), rewrittenFrame)) { - rewrittenWindow = Optional.of(new WindowSpecification(existingWindowName, partitionBy, orderBy, rewrittenFrame)); - } + Optional window = node.getWindow(); + if (window.isPresent()) { + Window rewrittenWindow = rewriteWindow(window.get(), context); + if (rewrittenWindow != window.get()) { + window = Optional.of(rewrittenWindow); } } List arguments = rewrite(node.getArguments(), context); - if (!sameElements(node.getArguments(), arguments) || !sameElements(rewrittenWindow, node.getWindow()) + if (!sameElements(node.getArguments(), arguments) || !sameElements(window, node.getWindow()) || !sameElements(filter, node.getFilter())) { return new FunctionCall( node.getLocation(), node.getName(), - rewrittenWindow, + window, filter, node.getOrderBy().map(orderBy -> rewriteOrderBy(orderBy, context)), node.isDistinct(), @@ -609,6 +536,106 @@ private List rewriteSortItems(List sortItems, Context con return rewrittenSortItems.build(); } + private Window rewriteWindow(Window window, Context context) + { + if (window instanceof WindowReference) { + WindowReference windowReference = (WindowReference) window; + Identifier rewrittenName = rewrite(windowReference.getName(), context.get()); + if (windowReference.getName() != rewrittenName) { + return new WindowReference(rewrittenName); + } + return window; + } + + WindowSpecification windowSpecification = (WindowSpecification) window; + Optional existingWindowName = windowSpecification.getExistingWindowName().map(name -> rewrite(name, context.get())); + + List partitionBy = rewrite(windowSpecification.getPartitionBy(), context); + + Optional orderBy = Optional.empty(); + if (windowSpecification.getOrderBy().isPresent()) { + orderBy = Optional.of(rewriteOrderBy(windowSpecification.getOrderBy().get(), context)); + } + + Optional rewrittenFrame = windowSpecification.getFrame(); + if (rewrittenFrame.isPresent()) { + WindowFrame frame = rewrittenFrame.get(); + + FrameBound start = frame.getStart(); + if (start.getValue().isPresent()) { + Expression value = rewrite(start.getValue().get(), context.get()); + if (value != start.getValue().get()) { + start = new FrameBound(start.getType(), value); + } + } + + Optional rewrittenEnd = frame.getEnd(); + if (rewrittenEnd.isPresent()) { + Optional value = rewrittenEnd.get().getValue(); + if (value.isPresent()) { + Expression rewrittenValue = rewrite(value.get(), context.get()); + if (rewrittenValue != value.get()) { + rewrittenEnd = Optional.of(new FrameBound(rewrittenEnd.get().getType(), rewrittenValue)); + } + } + } + + // Frame properties for row pattern matching are not rewritten. They are planned as parts of + // PatternRecognitionNode, and shouldn't be accessed past the Planner phase. + // There are nested expressions in Measures and VariableDefinitions. They are not rewritten by default. + // Rewriting them requires special handling of DereferenceExpression, aware of pattern labels. + if (!frame.getMeasures().isEmpty() || + frame.getAfterMatchSkipTo().isPresent() || + frame.getPatternSearchMode().isPresent() || + frame.getPattern().isPresent() || + !frame.getSubsets().isEmpty() || + !frame.getVariableDefinitions().isEmpty()) { + throw new UnsupportedOperationException("cannot rewrite pattern recognition clauses in window"); + } + + if ((frame.getStart() != start) || !sameElements(frame.getEnd(), rewrittenEnd)) { + rewrittenFrame = Optional.of(new WindowFrame( + frame.getType(), + start, + rewrittenEnd, + frame.getMeasures(), + frame.getAfterMatchSkipTo(), + frame.getPatternSearchMode(), + frame.getPattern(), + frame.getSubsets(), + frame.getVariableDefinitions())); + } + } + + if (!sameElements(windowSpecification.getExistingWindowName(), existingWindowName) || + !sameElements(windowSpecification.getPartitionBy(), partitionBy) || + !sameElements(windowSpecification.getOrderBy(), orderBy) || + !sameElements(windowSpecification.getFrame(), rewrittenFrame)) { + return new WindowSpecification(existingWindowName, partitionBy, orderBy, rewrittenFrame); + } + return window; + } + + @Override + protected Expression visitWindowOperation(WindowOperation node, Context context) + { + if (!context.isDefaultRewrite()) { + Expression result = rewriter.rewriteWindowOperation(node, context.get(), ExpressionTreeRewriter.this); + if (result != null) { + return result; + } + } + + Identifier name = rewrite(node.getName(), context.get()); + Window window = rewriteWindow(node.getWindow(), context); + + if (name != node.getName() || window != node.getWindow()) { + return new WindowOperation(name, window); + } + + return node; + } + @Override protected Expression visitLambdaExpression(LambdaExpression node, Context context) { diff --git a/core/trino-parser/src/main/java/io/trino/sql/tree/WindowOperation.java b/core/trino-parser/src/main/java/io/trino/sql/tree/WindowOperation.java new file mode 100644 index 0000000000000..3ee1abc8822c2 --- /dev/null +++ b/core/trino-parser/src/main/java/io/trino/sql/tree/WindowOperation.java @@ -0,0 +1,111 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.tree; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +/** + * Represents a call over a window: + *
+ *     classifier OVER (...)
+ * 
+ * There are two types of window calls supported in Trino: + * - function calls + * - row pattern measures + * This class captures row pattern measures only. A function call over a window + * is represented as `FunctionCall` having a `Window` member. + * // TODO refactor `FunctionCall` so that it does not contain `Window`, and instead represent a windowed function call as `WindowOperation` + */ +public class WindowOperation + extends Expression +{ + private final Identifier name; + private final Window window; + + public WindowOperation(Identifier name, Window window) + { + this(Optional.empty(), name, window); + } + + public WindowOperation(NodeLocation location, Identifier name, Window window) + { + this(Optional.of(location), name, window); + } + + private WindowOperation(Optional location, Identifier name, Window window) + { + super(location); + requireNonNull(name, "name is null"); + requireNonNull(window, "window is null"); + checkArgument(window instanceof WindowReference || window instanceof WindowSpecification, "unexpected window: " + window.getClass().getSimpleName()); + + this.name = name; + this.window = window; + } + + public Identifier getName() + { + return name; + } + + public Window getWindow() + { + return window; + } + + @Override + public R accept(AstVisitor visitor, C context) + { + return visitor.visitWindowOperation(this, context); + } + + @Override + public List getChildren() + { + return ImmutableList.of(name, (Node) window); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if ((obj == null) || (getClass() != obj.getClass())) { + return false; + } + WindowOperation o = (WindowOperation) obj; + return Objects.equals(name, o.name) && + Objects.equals(window, o.window); + } + + @Override + public int hashCode() + { + return Objects.hash(name, window); + } + + @Override + public boolean shallowEquals(Node other) + { + return sameClass(this, other); + } +} diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java index 3c6a1ace68a6b..3735cbeef8ea8 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java @@ -182,6 +182,7 @@ import io.trino.sql.tree.WhenClause; import io.trino.sql.tree.WindowDefinition; import io.trino.sql.tree.WindowFrame; +import io.trino.sql.tree.WindowOperation; import io.trino.sql.tree.WindowReference; import io.trino.sql.tree.WindowSpecification; import io.trino.sql.tree.With; @@ -3303,6 +3304,42 @@ public void testWindowFrameWithPatternRecognition() ImmutableList.of())); } + @Test + public void testMeasureOverWindow() + { + assertThat(expression("last_z OVER (" + + " MEASURES z AS last_z " + + " ROWS CURRENT ROW " + + " PATTERN (A) " + + " DEFINE a AS true " + + " )")) + .isEqualTo(new WindowOperation( + location(1, 1), + new Identifier(location(1, 1), "last_z", false), + new WindowSpecification( + location(1, 41), + Optional.empty(), + ImmutableList.of(), + Optional.empty(), + Optional.of(new WindowFrame( + location(1, 41), + ROWS, + new FrameBound(location(1, 94), CURRENT_ROW), + Optional.empty(), + ImmutableList.of(new MeasureDefinition( + location(1, 50), + new Identifier(location(1, 50), "z", false), + new Identifier(location(1, 55), "last_z", false))), + Optional.empty(), + Optional.empty(), + Optional.of(new PatternVariable(location(1, 142), new Identifier(location(1, 142), "A", false))), + ImmutableList.of(), + ImmutableList.of(new VariableDefinition( + location(1, 179), + new Identifier(location(1, 179), "a", false), + new BooleanLiteral(location(1, 184), "true")))))))); + } + @Test public void testUpdate() {