Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Support NULLS FIRST/LAST ordering for window functions #929

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Use sort option in window function AST node
  • Loading branch information
dai-chen committed Dec 12, 2020
commit 5a71b406fb4df10ae3c04711fb5bec50984cbdd9
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@

package com.amazon.opendistroforelasticsearch.sql.analysis;

import static com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC;
import static com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort.SortOption.DEFAULT_DESC;

import com.amazon.opendistroforelasticsearch.sql.ast.AbstractNodeVisitor;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Alias;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.UnresolvedExpression;
Expand Down Expand Up @@ -94,13 +91,9 @@ private List<Pair<SortOption, Expression>> analyzeSortList(WindowFunction node,
return node.getSortList()
.stream()
.map(pair -> ImmutablePair
.of(getSortOption(pair.getLeft()),
.of(pair.getLeft(),
expressionAnalyzer.analyze(pair.getRight(), context)))
.collect(Collectors.toList());
}

private SortOption getSortOption(String option) {
return "ASC".equalsIgnoreCase(option) ? DEFAULT_ASC : DEFAULT_DESC;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import com.amazon.opendistroforelasticsearch.sql.ast.tree.RelationSubquery;
import com.amazon.opendistroforelasticsearch.sql.ast.tree.Rename;
import com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort;
import com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort.SortOption;
import com.amazon.opendistroforelasticsearch.sql.ast.tree.UnresolvedPlan;
import com.amazon.opendistroforelasticsearch.sql.ast.tree.Values;
import java.util.Arrays;
Expand Down Expand Up @@ -226,7 +227,7 @@ public When when(UnresolvedExpression condition, UnresolvedExpression result) {

public UnresolvedExpression window(Function function,
List<UnresolvedExpression> partitionByList,
List<Pair<String, UnresolvedExpression>> sortList) {
List<Pair<SortOption, UnresolvedExpression>> sortList) {
return new WindowFunction(function, partitionByList, sortList);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,26 @@

import com.amazon.opendistroforelasticsearch.sql.ast.AbstractNodeVisitor;
import com.amazon.opendistroforelasticsearch.sql.ast.Node;
import com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort.SortOption;
import java.util.Collections;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import org.apache.commons.lang3.tuple.Pair;

@AllArgsConstructor
@EqualsAndHashCode(callSuper = false)
@Getter
@RequiredArgsConstructor
@ToString
public class WindowFunction extends UnresolvedExpression {

private final Function function;
private List<UnresolvedExpression> partitionByList;
private List<Pair<String, UnresolvedExpression>> sortList;
private List<Pair<SortOption, UnresolvedExpression>> sortList;

@Override
public List<? extends Node> getChild() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ public void window_function() {
AstDSL.function("row_number"),
Collections.singletonList(AstDSL.qualifiedName("string_value")),
Collections.singletonList(
ImmutablePair.of("ASC", AstDSL.qualifiedName("integer_value")))))));
ImmutablePair.of(DEFAULT_ASC, AstDSL.qualifiedName("integer_value")))))));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ void should_wrap_child_with_window_and_sort_operator_if_project_item_windowed()
AstDSL.function("row_number"),
ImmutableList.of(AstDSL.qualifiedName("string_value")),
ImmutableList.of(
ImmutablePair.of("DESC", AstDSL.qualifiedName("integer_value"))))),
ImmutablePair.of(DEFAULT_DESC, AstDSL.qualifiedName("integer_value"))))),
analysisContext));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.MathExpressionAtomContext;
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.NotExpressionContext;
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.NullLiteralContext;
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.OrderByElementContext;
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.OverClauseContext;
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.QualifiedNameContext;
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.RankingWindowFunctionContext;
Expand All @@ -47,6 +46,7 @@
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.TimeLiteralContext;
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.TimestampLiteralContext;
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.WindowFunctionContext;
import static com.amazon.opendistroforelasticsearch.sql.sql.parser.ParserUtils.createSortOption;

import com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.AggregateFunction;
Expand All @@ -62,6 +62,7 @@
import com.amazon.opendistroforelasticsearch.sql.ast.expression.UnresolvedExpression;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.When;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.WindowFunction;
import com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort.SortOption;
import com.amazon.opendistroforelasticsearch.sql.common.utils.StringUtils;
import com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser;
import com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.AndExpressionContext;
Expand Down Expand Up @@ -170,12 +171,13 @@ public UnresolvedExpression visitWindowFunction(WindowFunctionContext ctx) {
.collect(Collectors.toList());
}

List<Pair<String, UnresolvedExpression>> sortList = Collections.emptyList();
List<Pair<SortOption, UnresolvedExpression>> sortList = Collections.emptyList();
if (overClause.orderByClause() != null) {
sortList = overClause.orderByClause()
.orderByElement()
.stream()
.map(item -> ImmutablePair.of(getOrder(item), visit(item.expression())))
.map(item -> ImmutablePair.of(
createSortOption(item), visit(item.expression())))
.collect(Collectors.toList());
}
return new WindowFunction((Function) visit(ctx.function), partitionByList, sortList);
Expand Down Expand Up @@ -324,8 +326,4 @@ private QualifiedName visitIdentifiers(List<IdentContext> identifiers) {
);
}

private String getOrder(OrderByElementContext item) {
return (item.order == null) ? "ASC" : item.order.getText();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@

package com.amazon.opendistroforelasticsearch.sql.sql.parser;

import static com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort.NullOrder;
import static com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort.SortOption;
import static com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort.SortOrder;
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.OrderByElementContext;

import lombok.experimental.UtilityClass;
import org.antlr.v4.runtime.ParserRuleContext;
import org.antlr.v4.runtime.Token;
import org.antlr.v4.runtime.tree.TerminalNode;

/**
* Parser Utils Class.
Expand All @@ -35,4 +41,37 @@ public static String getTextInQuery(ParserRuleContext ctx, String queryString) {
Token stop = ctx.getStop();
return queryString.substring(start.getStartIndex(), stop.getStopIndex() + 1);
}

/**
* Create sort option from syntax tree node.
*/
public static SortOption createSortOption(OrderByElementContext orderBy) {
return new SortOption(
createSortOrder(orderBy.order),
createNullOrder(orderBy.FIRST(), orderBy.LAST()));
}

/**
* Create sort order for sort option use from ASC/DESC token.
*/
public static SortOrder createSortOrder(Token ctx) {
if (ctx == null) {
return null;
}
return SortOrder.valueOf(ctx.getText().toUpperCase());
}

/**
* Create null order for sort option use from FIRST/LAST token.
*/
public static NullOrder createNullOrder(TerminalNode first, TerminalNode last) {
if (first != null) {
return NullOrder.NULL_FIRST;
} else if (last != null) {
return NullOrder.NULL_LAST;
} else {
return null;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,15 @@
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.SelectElementContext;
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.SubqueryAsRelationContext;
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.WindowFunctionContext;
import static com.amazon.opendistroforelasticsearch.sql.sql.parser.ParserUtils.createSortOption;
import static com.amazon.opendistroforelasticsearch.sql.sql.parser.ParserUtils.getTextInQuery;

import com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.DataType;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Literal;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.QualifiedName;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.UnresolvedExpression;
import com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort.NullOrder;
import com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort.SortOption;
import com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort.SortOrder;
import com.amazon.opendistroforelasticsearch.sql.common.utils.StringUtils;
import com.amazon.opendistroforelasticsearch.sql.exception.SemanticCheckException;
import com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.AggregateFunctionCallContext;
Expand All @@ -48,9 +47,7 @@
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import org.antlr.v4.runtime.Token;
import org.antlr.v4.runtime.tree.ParseTree;
import org.antlr.v4.runtime.tree.TerminalNode;

/**
* Query specification domain that collects basic info for a simple query.
Expand Down Expand Up @@ -222,10 +219,7 @@ public Void visitGroupByElement(GroupByElementContext ctx) {
@Override
public Void visitOrderByElement(OrderByElementContext ctx) {
orderByItems.add(visitAstExpression(ctx.expression()));
orderByOptions.add(
new SortOption(
visitSortOrder(ctx.order),
visitNullOrderClause(ctx.FIRST(), ctx.LAST())));
orderByOptions.add(createSortOption(ctx));
return super.visitOrderByElement(ctx);
}

Expand All @@ -239,23 +233,6 @@ private boolean isDistinct(SelectSpecContext ctx) {
return (ctx != null) && (ctx.DISTINCT() != null);
}

private SortOrder visitSortOrder(Token ctx) {
if (ctx == null) {
return null;
}
return SortOrder.valueOf(ctx.getText().toUpperCase());
}

private NullOrder visitNullOrderClause(TerminalNode first, TerminalNode last) {
if (first != null) {
return NullOrder.NULL_FIRST;
} else if (last != null) {
return NullOrder.NULL_LAST;
} else {
return null;
}
}

private UnresolvedExpression visitAstExpression(ParseTree tree) {
return expressionBuilder.visit(tree);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,14 @@
import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.timestampLiteral;
import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.when;
import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.window;
import static com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort.NullOrder.NULL_LAST;
import static com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort.SortOrder.ASC;
import static com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort.SortOrder.DESC;
import static org.junit.jupiter.api.Assertions.assertEquals;

import com.amazon.opendistroforelasticsearch.sql.ast.Node;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.DataType;
import com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort.SortOption;
import com.amazon.opendistroforelasticsearch.sql.common.antlr.CaseInsensitiveCharStream;
import com.amazon.opendistroforelasticsearch.sql.common.antlr.SyntaxAnalysisErrorListener;
import com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLLexer;
Expand Down Expand Up @@ -254,7 +258,7 @@ public void canBuildWindowFunction() {
window(
function("RANK"),
ImmutableList.of(qualifiedName("state")),
ImmutableList.of(ImmutablePair.of("ASC", qualifiedName("age")))),
ImmutableList.of(ImmutablePair.of(new SortOption(null, null), qualifiedName("age")))),
buildExprAst("RANK() OVER (PARTITION BY state ORDER BY age)"));
}

Expand All @@ -264,10 +268,21 @@ public void canBuildWindowFunctionWithoutPartitionBy() {
window(
function("DENSE_RANK"),
ImmutableList.of(),
ImmutableList.of(ImmutablePair.of("DESC", qualifiedName("age")))),
ImmutableList.of(ImmutablePair.of(new SortOption(DESC, null), qualifiedName("age")))),
buildExprAst("DENSE_RANK() OVER (ORDER BY age DESC)"));
}

@Test
public void canBuildWindowFunctionWithNullOrderSpecified() {
assertEquals(
window(
function("DENSE_RANK"),
ImmutableList.of(),
ImmutableList.of(ImmutablePair.of(
new SortOption(ASC, NULL_LAST), qualifiedName("age")))),
buildExprAst("DENSE_RANK() OVER (ORDER BY age ASC NULLS LAST)"));
}

@Test
public void canBuildWindowFunctionWithoutOrderBy() {
assertEquals(
Expand Down