diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/Analyzer.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/Analyzer.java index 34d152fef3..81c21ce8c4 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/Analyzer.java @@ -290,6 +290,9 @@ public LogicalPlan visitEval(Eval node, AnalysisContext context) { @Override public LogicalPlan visitSort(Sort node, AnalysisContext context) { LogicalPlan child = node.getChild().get(0).accept(this, context); + ExpressionReferenceOptimizer optimizer = + new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child); + // the first options is {"count": "integer"} Integer count = (Integer) node.getOptions().get(0).getValue().getValue(); List> sortList = @@ -298,7 +301,8 @@ public LogicalPlan visitSort(Sort node, AnalysisContext context) { sortField -> { // the first options is {"asc": "true/false"} Boolean asc = (Boolean) sortField.getFieldArgs().get(0).getValue().getValue(); - Expression expression = expressionAnalyzer.analyze(sortField, context); + Expression expression = optimizer.optimize( + expressionAnalyzer.analyze(sortField.getField(), context), context); return ImmutablePair.of( asc ? SortOption.DEFAULT_ASC : SortOption.DEFAULT_DESC, expression); }) diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/dsl/AstDSL.java index 9ca9636d03..6c378f7004 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/dsl/AstDSL.java @@ -26,7 +26,6 @@ import com.amazon.opendistroforelasticsearch.sql.ast.expression.Function; import com.amazon.opendistroforelasticsearch.sql.ast.expression.In; import com.amazon.opendistroforelasticsearch.sql.ast.expression.Interval; -import com.amazon.opendistroforelasticsearch.sql.ast.expression.IntervalUnit; import com.amazon.opendistroforelasticsearch.sql.ast.expression.Let; import com.amazon.opendistroforelasticsearch.sql.ast.expression.Literal; import com.amazon.opendistroforelasticsearch.sql.ast.expression.Map; @@ -228,27 +227,27 @@ public static UnresolvedArgument unresolvedArg(String argName, UnresolvedExpress return new UnresolvedArgument(argName, argValue); } - public static UnresolvedExpression field(UnresolvedExpression field) { + public Field field(UnresolvedExpression field) { return new Field((QualifiedName) field); } - public static Field field(String field) { + public Field field(String field) { return new Field(field); } - public static UnresolvedExpression field(UnresolvedExpression field, Argument... fieldArgs) { - return new Field((QualifiedName) field, Arrays.asList(fieldArgs)); + public Field field(UnresolvedExpression field, Argument... fieldArgs) { + return new Field(field, Arrays.asList(fieldArgs)); } - public static Field field(String field, Argument... fieldArgs) { + public Field field(String field, Argument... fieldArgs) { return new Field(field, Arrays.asList(fieldArgs)); } - public static UnresolvedExpression field(UnresolvedExpression field, List fieldArgs) { - return new Field((QualifiedName) field, fieldArgs); + public Field field(UnresolvedExpression field, List fieldArgs) { + return new Field(field, fieldArgs); } - public static Field field(String field, List fieldArgs) { + public Field field(String field, List fieldArgs) { return new Field(field, fieldArgs); } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/expression/Field.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/expression/Field.java index c8152f280d..198b1c6d94 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/expression/Field.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/expression/Field.java @@ -29,7 +29,7 @@ @EqualsAndHashCode(callSuper = false) @AllArgsConstructor public class Field extends UnresolvedExpression { - private QualifiedName field; + private UnresolvedExpression field; private List fieldArgs = Collections.emptyList(); public Field(QualifiedName field) { diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/AnalyzerTest.java index e3c4e9c5a8..bf428728c0 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/AnalyzerTest.java @@ -40,6 +40,7 @@ import com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL; import com.amazon.opendistroforelasticsearch.sql.ast.tree.RareTopN.CommandType; +import com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort; import com.amazon.opendistroforelasticsearch.sql.exception.SemanticCheckException; import com.amazon.opendistroforelasticsearch.sql.expression.DSL; import com.amazon.opendistroforelasticsearch.sql.expression.config.ExpressionConfig; @@ -49,6 +50,7 @@ import com.google.common.collect.ImmutableMap; import java.util.Collections; import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -255,6 +257,42 @@ public void project_values() { ); } + @SuppressWarnings("unchecked") + @Test + public void sort_with_aggregator() { + assertAnalyzeEqual( + LogicalPlanDSL.project( + LogicalPlanDSL.sort( + LogicalPlanDSL.aggregation( + LogicalPlanDSL.relation("test"), + ImmutableList.of( + DSL.named( + "avg(integer_value)", + dsl.avg(DSL.ref("integer_value", INTEGER)))), + ImmutableList.of(DSL.named("string_value", DSL.ref("string_value", STRING)))), + 0, + // Aggregator in Sort AST node is replaced with reference by expression optimizer + Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("avg(integer_value)", DOUBLE))), + DSL.named("string_value", DSL.ref("string_value", STRING))), + AstDSL.project( + AstDSL.sort( + AstDSL.agg( + AstDSL.relation("test"), + ImmutableList.of( + AstDSL.alias( + "avg(integer_value)", + function("avg", qualifiedName("integer_value")))), + emptyList(), + ImmutableList.of(AstDSL.alias("string_value", qualifiedName("string_value"))), + emptyList() + ), + ImmutableList.of(argument("count", intLiteral(0))), + field( + function("avg", qualifiedName("integer_value")), + argument("asc", booleanLiteral(true)))), + AstDSL.alias("string_value", qualifiedName("string_value")))); + } + @SuppressWarnings("unchecked") @Test public void window_function() { diff --git a/docs/category.json b/docs/category.json index 33f98c7966..aa71477dd6 100644 --- a/docs/category.json +++ b/docs/category.json @@ -18,6 +18,7 @@ "user/dql/expressions.rst", "user/general/identifiers.rst", "user/general/values.rst", + "user/dql/basics.rst", "user/dql/functions.rst", "user/dql/window.rst", "user/beyond/partiql.rst", diff --git a/docs/user/dql/basics.rst b/docs/user/dql/basics.rst index 082abe9aee..443cce0284 100644 --- a/docs/user/dql/basics.rst +++ b/docs/user/dql/basics.rst @@ -929,6 +929,30 @@ Result set: | Quility| +--------+ +Example 3: Ordering by Aggregate Functions +------------------------------------------ + +Aggregate functions are allowed to be used in ``ORDER BY`` clause. You can reference it by same function call or its alias or ordinal in select list:: + + od> SELECT gender, MAX(age) FROM accounts GROUP BY gender ORDER BY MAX(age) DESC; + fetched rows / total rows = 2/2 + +----------+------------+ + | gender | MAX(age) | + |----------+------------| + | M | 36 | + | F | 28 | + +----------+------------+ + +Even if it's not present in ``SELECT`` clause, it can be also used as follows:: + + od> SELECT gender, MIN(age) FROM accounts GROUP BY gender ORDER BY MAX(age) DESC; + fetched rows / total rows = 2/2 + +----------+------------+ + | gender | MIN(age) | + |----------+------------| + | M | 32 | + | F | 28 | + +----------+------------+ LIMIT ===== diff --git a/integ-test/build.gradle b/integ-test/build.gradle index 5a46ccbcbc..93447cd49b 100644 --- a/integ-test/build.gradle +++ b/integ-test/build.gradle @@ -128,6 +128,9 @@ task integTestWithNewEngine(type: RestIntegTestTask) { // Skip this IT to avoid breaking tests due to inconsistency in JDBC schema exclude 'com/amazon/opendistroforelasticsearch/sql/legacy/AggregationExpressionIT.class' + + // Skip this IT because all assertions are against explain output + exclude 'com/amazon/opendistroforelasticsearch/sql/legacy/OrderIT.class' } } diff --git a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/correctness/runner/connection/JDBCConnection.java b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/correctness/runner/connection/JDBCConnection.java index f6e2c3ec89..299a0631fd 100644 --- a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/correctness/runner/connection/JDBCConnection.java +++ b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/correctness/runner/connection/JDBCConnection.java @@ -127,7 +127,8 @@ public void insert(String tableName, String[] columnNames, List batch) public DBResult select(String query) { try (Statement stmt = connection.createStatement()) { ResultSet resultSet = stmt.executeQuery(query); - DBResult result = new DBResult(databaseName); + DBResult result = isOrderByQuery(query) + ? DBResult.resultInOrder(databaseName) : DBResult.result(databaseName); populateMetaData(resultSet, result); populateData(resultSet, result); return result; @@ -200,6 +201,10 @@ private String mapToJDBCType(String esType) { } } + private boolean isOrderByQuery(String query) { + return query.trim().toUpperCase().contains("ORDER BY"); + } + /** * Setter for unit test mock */ diff --git a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/correctness/runner/resultset/DBResult.java b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/correctness/runner/resultset/DBResult.java index c8036e1a75..b50a6a2771 100644 --- a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/correctness/runner/resultset/DBResult.java +++ b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/correctness/runner/resultset/DBResult.java @@ -16,14 +16,13 @@ package com.amazon.opendistroforelasticsearch.sql.correctness.runner.resultset; import com.amazon.opendistroforelasticsearch.sql.legacy.utils.StringUtils; +import com.google.common.collect.HashMultiset; import com.google.common.collect.ImmutableSet; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.HashSet; import java.util.List; import java.util.Set; -import java.util.stream.Collectors; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; @@ -60,11 +59,19 @@ public class DBResult { private final Collection dataRows; /** - * By default treat both columns and data rows in order. This makes sense for typical query - * with specific column names in SELECT but without ORDER BY. + * In theory, a result set is a multi-set (bag) that allows duplicate and doesn't + * have order. */ - public DBResult(String databaseName) { - this(databaseName, new ArrayList<>(), new HashSet<>()); + public static DBResult result(String databaseName) { + return new DBResult(databaseName, new ArrayList<>(), HashMultiset.create()); + } + + /** + * But for queries with ORDER BY clause, we want to preserve the original order of data rows + * so we can check if the order is correct. + */ + public static DBResult resultInOrder(String databaseName) { + return new DBResult(databaseName, new ArrayList<>(), new ArrayList<>()); } public DBResult(String databaseName, Collection schema, Collection rows) { @@ -97,10 +104,13 @@ public String getDatabaseName() { } /** - * Flatten for simplifying json generated + * Flatten for simplifying json generated. */ public Collection> getDataRows() { - return dataRows.stream().map(Row::getValues).collect(Collectors.toSet()); + Collection> values = isDataRowOrdered() + ? new ArrayList<>() : HashMultiset.create(); + dataRows.stream().map(Row::getValues).forEach(values::add); + return values; } /** @@ -124,6 +134,9 @@ private String diffSchema(DBResult other) { } private String diffDataRows(DBResult other) { + if (isDataRowOrdered()) { + return diff("Data row", (List) dataRows, (List) other.dataRows); + } List thisRows = sort(dataRows); List otherRows = sort(other.dataRows); return diff("Data row", thisRows, otherRows); @@ -160,6 +173,14 @@ private static int findFirstDifference(List list1, List list2) { return -1; } + /** + * Is data row a list that represent original order of data set + * which doesn't/shouldn't sort again. + */ + private boolean isDataRowOrdered() { + return (dataRows instanceof List); + } + /** * Convert a collection to list and sort and return this new list. */ diff --git a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/correctness/tests/DBResultTest.java b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/correctness/tests/DBResultTest.java index ce52bcbf27..bbb3ae4e72 100644 --- a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/correctness/tests/DBResultTest.java +++ b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/correctness/tests/DBResultTest.java @@ -22,6 +22,7 @@ import com.amazon.opendistroforelasticsearch.sql.correctness.runner.resultset.DBResult; import com.amazon.opendistroforelasticsearch.sql.correctness.runner.resultset.Row; import com.amazon.opendistroforelasticsearch.sql.correctness.runner.resultset.Type; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import com.google.common.collect.Sets; import java.util.Arrays; @@ -49,6 +50,36 @@ public void dbResultWithDifferentColumnShouldNotEqual() { assertNotEquals(result1, result2); } + @Test + public void dbResultWithSameRowsInDifferentOrderShouldEqual() { + DBResult result1 = DBResult.result("DB 1"); + result1.addColumn("name", "VARCHAR"); + result1.addRow(new Row(ImmutableList.of("test-1"))); + result1.addRow(new Row(ImmutableList.of("test-2"))); + + DBResult result2 = DBResult.result("DB 2"); + result2.addColumn("name", "VARCHAR"); + result2.addRow(new Row(ImmutableList.of("test-2"))); + result2.addRow(new Row(ImmutableList.of("test-1"))); + + assertEquals(result1, result2); + } + + @Test + public void dbResultInOrderWithSameRowsInDifferentOrderShouldNotEqual() { + DBResult result1 = DBResult.resultInOrder("DB 1"); + result1.addColumn("name", "VARCHAR"); + result1.addRow(new Row(ImmutableList.of("test-1"))); + result1.addRow(new Row(ImmutableList.of("test-2"))); + + DBResult result2 = DBResult.resultInOrder("DB 2"); + result2.addColumn("name", "VARCHAR"); + result2.addRow(new Row(ImmutableList.of("test-2"))); + result2.addRow(new Row(ImmutableList.of("test-1"))); + + assertNotEquals(result1, result2); + } + @Test public void dbResultWithDifferentColumnTypeShouldNotEqual() { DBResult result1 = new DBResult("DB 1", Arrays.asList(new Type("age", "FLOAT")), emptyList()); @@ -89,4 +120,22 @@ public void shouldExplainDataRowsDifference() { ); } + @Test + public void shouldExplainDataRowsOrderDifference() { + DBResult result1 = DBResult.resultInOrder("DB 1"); + result1.addColumn("name", "VARCHAR"); + result1.addRow(new Row(ImmutableList.of("hello"))); + result1.addRow(new Row(ImmutableList.of("world"))); + + DBResult result2 = DBResult.resultInOrder("DB 2"); + result2.addColumn("name", "VARCHAR"); + result2.addRow(new Row(ImmutableList.of("world"))); + result2.addRow(new Row(ImmutableList.of("hello"))); + + assertEquals( + "Data row at [0] is different: this=[Row(values=[hello])], other=[Row(values=[world])]", + result1.diff(result2) + ); + } + } diff --git a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/correctness/tests/JDBCConnectionTest.java b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/correctness/tests/JDBCConnectionTest.java index 9a616ab82b..057ebdb795 100644 --- a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/correctness/tests/JDBCConnectionTest.java +++ b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/correctness/tests/JDBCConnectionTest.java @@ -26,6 +26,8 @@ import com.amazon.opendistroforelasticsearch.sql.correctness.runner.connection.JDBCConnection; import com.amazon.opendistroforelasticsearch.sql.correctness.runner.resultset.DBResult; import com.amazon.opendistroforelasticsearch.sql.correctness.runner.resultset.Type; +import com.google.common.collect.HashMultiset; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; import java.sql.Connection; @@ -123,10 +125,10 @@ public void testSelectQuery() throws SQLException { result.getSchema() ); assertEquals( - Sets.newHashSet( + HashMultiset.create(ImmutableList.of( Arrays.asList("John", 25), Arrays.asList("Hank", 30) - ), + )), result.getDataRows() ); } @@ -170,11 +172,11 @@ public void testSelectQueryWithFloatInResultSet() throws SQLException { result.getSchema() ); assertEquals( - Sets.newHashSet( + HashMultiset.create(ImmutableList.of( Arrays.asList("John", 25.13), Arrays.asList("Hank", 30.46), Arrays.asList("Allen", 15.1) - ), + )), result.getDataRows() ); } diff --git a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/JdbcTestIT.java b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/JdbcTestIT.java index 90beb2dece..26ec7f733d 100644 --- a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/JdbcTestIT.java +++ b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/JdbcTestIT.java @@ -22,6 +22,7 @@ import org.json.JSONArray; import org.json.JSONObject; import org.junit.Assert; +import org.junit.Ignore; import org.junit.Test; public class JdbcTestIT extends SQLIntegTestCase { @@ -117,6 +118,7 @@ public void stringOperatorNameCaseInsensitiveTest() { ); } + @Ignore("DATE_FORMAT function signature changed in new engine") @Test public void dateFunctionNameCaseInsensitiveTest() { assertTrue( diff --git a/integ-test/src/test/resources/correctness/bugfixes/123.txt b/integ-test/src/test/resources/correctness/bugfixes/123.txt new file mode 100644 index 0000000000..b31e9e20ad --- /dev/null +++ b/integ-test/src/test/resources/correctness/bugfixes/123.txt @@ -0,0 +1,8 @@ +SELECT Origin FROM kibana_sample_data_flights ORDER BY LOWER(Origin) +SELECT Origin FROM kibana_sample_data_flights ORDER BY LOWER(Origin) DESC +SELECT Origin FROM kibana_sample_data_flights ORDER BY SUBSTRING(Origin, 3, 3) +SELECT Origin FROM kibana_sample_data_flights ORDER BY SUBSTRING(Origin, 3, 3) DESC +SELECT Origin, FlightNum FROM kibana_sample_data_flights ORDER BY SUBSTRING(Origin, 3, 3), LOWER(FlightNum) +SELECT AVG(FlightTimeMin) FROM kibana_sample_data_flights ORDER BY SUM(FlightTimeMin) +SELECT OriginWeather, AVG(FlightTimeMin) FROM kibana_sample_data_flights GROUP BY OriginWeather ORDER BY SUM(FlightTimeMin) +SELECT FlightDelay, MIN(FlightTimeMin) FROM kibana_sample_data_flights GROUP BY FlightDelay ORDER BY MAX(FlightTimeMin) diff --git a/integ-test/src/test/resources/correctness/bugfixes/277.txt b/integ-test/src/test/resources/correctness/bugfixes/277.txt new file mode 100644 index 0000000000..bf793e2af7 --- /dev/null +++ b/integ-test/src/test/resources/correctness/bugfixes/277.txt @@ -0,0 +1,8 @@ +SELECT COUNT(FlightNum) FROM kibana_sample_data_flights GROUP BY FlightDelay ORDER BY COUNT(FlightNum) +SELECT COUNT(FlightNum) AS cnt FROM kibana_sample_data_flights GROUP BY FlightDelay ORDER BY cnt +SELECT COUNT(FlightNum) FROM kibana_sample_data_flights GROUP BY FlightDelay ORDER BY 1 +SELECT OriginWeather, AVG(FlightTimeMin) FROM kibana_sample_data_flights GROUP BY OriginWeather ORDER BY AVG(FlightTimeMin) +SELECT OriginWeather, AVG(FlightTimeMin) FROM kibana_sample_data_flights GROUP BY OriginWeather ORDER BY SUM(FlightDelayMin) +SELECT OriginWeather, AVG(FlightTimeMin), SUM(FlightDelayMin) FROM kibana_sample_data_flights GROUP BY OriginWeather ORDER BY AVG(FlightTimeMin), SUM(FlightDelayMin) +SELECT OriginWeather, AVG(FlightTimeMin), SUM(FlightDelayMin) AS s FROM kibana_sample_data_flights GROUP BY OriginWeather ORDER BY AVG(FlightTimeMin), s +SELECT OriginWeather, AVG(FlightTimeMin) AS a, SUM(FlightDelayMin) FROM kibana_sample_data_flights GROUP BY OriginWeather ORDER BY a, SUM(FlightDelayMin) diff --git a/integ-test/src/test/resources/correctness/bugfixes/674.txt b/integ-test/src/test/resources/correctness/bugfixes/674.txt new file mode 100644 index 0000000000..947e3ec76c --- /dev/null +++ b/integ-test/src/test/resources/correctness/bugfixes/674.txt @@ -0,0 +1,2 @@ +SELECT OriginCountry, OriginCityName FROM kibana_sample_data_flights GROUP BY OriginCountry, OriginCityName ORDER BY OriginCityName DESC +SELECT FlightDelay, OriginCountry, OriginCityName FROM kibana_sample_data_flights GROUP BY FlightDelay, OriginCountry, OriginCityName ORDER BY OriginCityName DESC, OriginCountry diff --git a/integ-test/src/test/resources/correctness/queries/orderby.txt b/integ-test/src/test/resources/correctness/queries/orderby.txt new file mode 100644 index 0000000000..bba7ea4a40 --- /dev/null +++ b/integ-test/src/test/resources/correctness/queries/orderby.txt @@ -0,0 +1,13 @@ +SELECT FlightNum FROM kibana_sample_data_flights ORDER BY FlightNum +SELECT FlightNum FROM kibana_sample_data_flights ORDER BY FlightNum ASC +SELECT FlightNum FROM kibana_sample_data_flights ORDER BY FlightNum DESC +SELECT FlightNum, AvgTicketPrice FROM kibana_sample_data_flights ORDER BY FlightNum, AvgTicketPrice +SELECT FlightNum, AvgTicketPrice FROM kibana_sample_data_flights ORDER BY FlightNum DESC, AvgTicketPrice +SELECT FlightNum, AvgTicketPrice FROM kibana_sample_data_flights ORDER BY FlightNum, AvgTicketPrice DESC +SELECT FlightNum, AvgTicketPrice FROM kibana_sample_data_flights ORDER BY FlightNum DESC, AvgTicketPrice DESC +SELECT OriginCountry FROM kibana_sample_data_flights GROUP BY OriginCountry ORDER BY OriginCountry +SELECT OriginCountry FROM kibana_sample_data_flights GROUP BY OriginCountry ORDER BY OriginCountry DESC +SELECT FlightDelay, OriginWeather FROM kibana_sample_data_flights GROUP BY FlightDelay, OriginWeather ORDER BY FlightDelay, OriginWeather +SELECT FlightDelay, OriginWeather FROM kibana_sample_data_flights GROUP BY FlightDelay, OriginWeather ORDER BY FlightDelay DESC, OriginWeather +SELECT FlightDelay, OriginWeather FROM kibana_sample_data_flights GROUP BY FlightDelay, OriginWeather ORDER BY FlightDelay, OriginWeather DESC +SELECT FlightDelay, OriginWeather FROM kibana_sample_data_flights GROUP BY FlightDelay, OriginWeather ORDER BY FlightDelay DESC, OriginWeather DESC diff --git a/sql/src/main/antlr/OpenDistroSQLParser.g4 b/sql/src/main/antlr/OpenDistroSQLParser.g4 index 9fe582d655..6779332a56 100644 --- a/sql/src/main/antlr/OpenDistroSQLParser.g4 +++ b/sql/src/main/antlr/OpenDistroSQLParser.g4 @@ -79,6 +79,7 @@ fromClause : FROM tableName (AS? alias)? (whereClause)? (groupByClause)? + (orderByClause)? // Place it under FROM for now but actually not necessary ex. A UNION B ORDER BY ; whereClause diff --git a/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstAggregationBuilder.java b/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstAggregationBuilder.java index c88641da38..594d3e23e2 100644 --- a/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstAggregationBuilder.java +++ b/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstAggregationBuilder.java @@ -19,11 +19,8 @@ import static java.util.Collections.emptyList; import com.amazon.opendistroforelasticsearch.sql.ast.Node; -import com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL; import com.amazon.opendistroforelasticsearch.sql.ast.expression.AggregateFunction; -import com.amazon.opendistroforelasticsearch.sql.ast.expression.DataType; -import com.amazon.opendistroforelasticsearch.sql.ast.expression.Literal; -import com.amazon.opendistroforelasticsearch.sql.ast.expression.QualifiedName; +import com.amazon.opendistroforelasticsearch.sql.ast.expression.Alias; import com.amazon.opendistroforelasticsearch.sql.ast.expression.UnresolvedExpression; import com.amazon.opendistroforelasticsearch.sql.ast.tree.Aggregation; import com.amazon.opendistroforelasticsearch.sql.ast.tree.UnresolvedPlan; @@ -35,6 +32,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.stream.Collectors; import lombok.RequiredArgsConstructor; import org.antlr.v4.runtime.tree.ParseTree; @@ -114,17 +112,11 @@ private UnresolvedPlan buildImplicitAggregation() { } private List replaceGroupByItemIfAliasOrOrdinal() { - List groupByItems = new ArrayList<>(); - for (UnresolvedExpression expr : querySpec.getGroupByItems()) { - if (isIntegerLiteral(expr)) { - groupByItems.add(getSelectItemByOrdinal(expr)); - } else if (isSelectAlias(expr)) { - groupByItems.add(getSelectItemByAlias(expr)); - } else { - groupByItems.add(AstDSL.alias(expr.toString(), expr)); - } - } - return groupByItems; + return querySpec.getGroupByItems() + .stream() + .map(querySpec::replaceIfAliasOrOrdinal) + .map(expr -> new Alias(expr.toString(), expr)) + .collect(Collectors.toList()); } private Optional findNonAggregatedItemInSelect() { @@ -148,35 +140,4 @@ private boolean isNonAggregatedExpression(UnresolvedExpression expr) { .allMatch(child -> isNonAggregatedExpression((UnresolvedExpression) child)); } - private boolean isIntegerLiteral(UnresolvedExpression expr) { - if (!(expr instanceof Literal)) { - return false; - } - - if (((Literal) expr).getType() != DataType.INTEGER) { - throw new SemanticCheckException(StringUtils.format( - "Non-integer constant [%s] found in GROUP BY clause", expr)); - } - return true; - } - - private UnresolvedExpression getSelectItemByOrdinal(UnresolvedExpression expr) { - int ordinal = (Integer) ((Literal) expr).getValue(); - if (ordinal <= 0 || ordinal > querySpec.getSelectItems().size()) { - throw new SemanticCheckException(StringUtils.format( - "Group by ordinal [%d] is out of bound of select item list", ordinal)); - } - final UnresolvedExpression groupExpr = querySpec.getSelectItems().get(ordinal - 1); - return AstDSL.alias(groupExpr.toString(), groupExpr); - } - - private boolean isSelectAlias(UnresolvedExpression expr) { - return (expr instanceof QualifiedName) - && (querySpec.getSelectItemsByAlias().containsKey(expr.toString())); - } - - private UnresolvedExpression getSelectItemByAlias(UnresolvedExpression expr) { - final UnresolvedExpression groupExpr = querySpec.getSelectItemsByAlias().get(expr.toString()); - return AstDSL.alias(groupExpr.toString(), groupExpr); - } } diff --git a/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstBuilder.java b/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstBuilder.java index 4a32e1a073..aaa2632986 100644 --- a/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstBuilder.java +++ b/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstBuilder.java @@ -115,6 +115,10 @@ public UnresolvedPlan visitFromClause(FromClauseContext ctx) { result = aggregation.attach(result); } + if (ctx.orderByClause() != null) { + AstSortBuilder sortBuilder = new AstSortBuilder(context.peek()); + result = sortBuilder.visit(ctx.orderByClause()).attach(result); + } return result; } diff --git a/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstSortBuilder.java b/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstSortBuilder.java new file mode 100644 index 0000000000..880c8a9bb8 --- /dev/null +++ b/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstSortBuilder.java @@ -0,0 +1,74 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.opendistroforelasticsearch.sql.sql.parser; + +import static com.amazon.opendistroforelasticsearch.sql.ast.expression.DataType.BOOLEAN; +import static com.amazon.opendistroforelasticsearch.sql.ast.expression.DataType.INTEGER; +import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.OrderByClauseContext; + +import com.amazon.opendistroforelasticsearch.sql.ast.expression.Argument; +import com.amazon.opendistroforelasticsearch.sql.ast.expression.Field; +import com.amazon.opendistroforelasticsearch.sql.ast.expression.Literal; +import com.amazon.opendistroforelasticsearch.sql.ast.expression.UnresolvedExpression; +import com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort; +import com.amazon.opendistroforelasticsearch.sql.ast.tree.UnresolvedPlan; +import com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParserBaseVisitor; +import com.amazon.opendistroforelasticsearch.sql.sql.parser.context.QuerySpecification; +import com.google.common.collect.ImmutableList; +import java.util.ArrayList; +import java.util.List; +import lombok.RequiredArgsConstructor; + +/** + * AST sort builder that builds Sort AST node from ORDER BY clause. During this process, the item + * in order by may be replaced by item in project list if it's an alias or ordinal. This is same as + * GROUP BY building process. + */ +@RequiredArgsConstructor +public class AstSortBuilder extends OpenDistroSQLParserBaseVisitor { + + private final QuerySpecification querySpec; + + @Override + public UnresolvedPlan visitOrderByClause(OrderByClauseContext ctx) { + return new Sort( + ImmutableList.of(new Argument("count", new Literal(0, INTEGER))), + createSortFields() + ); + } + + private List createSortFields() { + List fields = new ArrayList<>(); + List items = querySpec.getOrderByItems(); + List options = querySpec.getOrderByOptions(); + for (int i = 0; i < items.size(); i++) { + fields.add( + new Field( + querySpec.replaceIfAliasOrOrdinal(items.get(i)), + createSortArgument(options.get(i)))); + } + return fields; + } + + private List createSortArgument(String option) { + return ImmutableList.of( + new Argument( + "asc", + new Literal("ASC".equalsIgnoreCase(option), BOOLEAN))); + } + +} diff --git a/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/context/QuerySpecification.java b/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/context/QuerySpecification.java index 2b803bfb9e..fa640c25d4 100644 --- a/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/context/QuerySpecification.java +++ b/sql/src/main/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/context/QuerySpecification.java @@ -17,12 +17,17 @@ package com.amazon.opendistroforelasticsearch.sql.sql.parser.context; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.GroupByElementContext; +import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.OrderByElementContext; import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.SelectElementContext; import static com.amazon.opendistroforelasticsearch.sql.sql.parser.ParserUtils.getTextInQuery; import com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL; +import com.amazon.opendistroforelasticsearch.sql.ast.expression.DataType; +import com.amazon.opendistroforelasticsearch.sql.ast.expression.Literal; +import com.amazon.opendistroforelasticsearch.sql.ast.expression.QualifiedName; import com.amazon.opendistroforelasticsearch.sql.ast.expression.UnresolvedExpression; import com.amazon.opendistroforelasticsearch.sql.common.utils.StringUtils; +import com.amazon.opendistroforelasticsearch.sql.exception.SemanticCheckException; import com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.AggregateFunctionCallContext; import com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.QuerySpecificationContext; import com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParserBaseVisitor; @@ -36,8 +41,6 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; -import org.antlr.v4.runtime.ParserRuleContext; -import org.antlr.v4.runtime.Token; import org.antlr.v4.runtime.tree.ParseTree; /** @@ -74,10 +77,20 @@ public class QuerySpecification { private final Set aggregators = new HashSet<>(); /** - * Items in GROUP BY clause that may be simple field name or nested in scalar function call. + * Items in GROUP BY clause that may be: + * 1) Simple field name + * 2) Field nested in scalar function call + * 3) Ordinal that points to expression in SELECT + * 4) Alias that points to expression in SELECT. */ private final List groupByItems = new ArrayList<>(); + /** + * Items in ORDER BY clause that may be different forms as above and its options. + */ + private final List orderByItems = new ArrayList<>(); + private final List orderByOptions = new ArrayList<>(); + /** * Collect all query information in the parse tree excluding info in sub-query). * @param query query spec node in parse tree @@ -86,6 +99,52 @@ public void collect(QuerySpecificationContext query, String queryString) { query.accept(new QuerySpecificationCollector(queryString)); } + /** + * Replace unresolved expression if it's an alias or ordinal that represents + * an actual expression in SELECT list. + * @param expr item to be replaced + * @return select item that the given expr represents + */ + public UnresolvedExpression replaceIfAliasOrOrdinal(UnresolvedExpression expr) { + if (isIntegerLiteral(expr)) { + return getSelectItemByOrdinal(expr); + } else if (isSelectAlias(expr)) { + return getSelectItemByAlias(expr); + } else { + return expr; + } + } + + private boolean isIntegerLiteral(UnresolvedExpression expr) { + if (!(expr instanceof Literal)) { + return false; + } + + if (((Literal) expr).getType() != DataType.INTEGER) { + throw new SemanticCheckException(StringUtils.format( + "Non-integer constant [%s] found in ordinal", expr)); + } + return true; + } + + private UnresolvedExpression getSelectItemByOrdinal(UnresolvedExpression expr) { + int ordinal = (Integer) ((Literal) expr).getValue(); + if (ordinal <= 0 || ordinal > selectItems.size()) { + throw new SemanticCheckException(StringUtils.format( + "Ordinal [%d] is out of bound of select item list", ordinal)); + } + return selectItems.get(ordinal - 1); + } + + private boolean isSelectAlias(UnresolvedExpression expr) { + return (expr instanceof QualifiedName) + && (selectItemsByAlias.containsKey(expr.toString())); + } + + private UnresolvedExpression getSelectItemByAlias(UnresolvedExpression expr) { + return selectItemsByAlias.get(expr.toString()); + } + /* * Query specification collector that visits a parse tree to collect query info. * Most visit methods only collect info and returns nothing. However, one exception is @@ -124,6 +183,13 @@ public Void visitGroupByElement(GroupByElementContext ctx) { return super.visitGroupByElement(ctx); } + @Override + public Void visitOrderByElement(OrderByElementContext ctx) { + orderByItems.add(visitAstExpression(ctx.expression())); + orderByOptions.add((ctx.order == null) ? "ASC" : ctx.order.getText()); + return super.visitOrderByElement(ctx); + } + @Override public Void visitAggregateFunctionCall(AggregateFunctionCallContext ctx) { aggregators.add(AstDSL.alias(getTextInQuery(ctx, queryString), visitAstExpression(ctx))); diff --git a/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstAggregationBuilderTest.java b/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstAggregationBuilderTest.java index 21c5372ea6..627b1c70f2 100644 --- a/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstAggregationBuilderTest.java +++ b/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstAggregationBuilderTest.java @@ -139,7 +139,7 @@ void should_report_error_for_non_integer_ordinal_in_group_by() { SemanticCheckException error = assertThrows(SemanticCheckException.class, () -> buildAggregation("SELECT state AS s FROM test GROUP BY 1.5")); assertEquals( - "Non-integer constant [1.5] found in GROUP BY clause", + "Non-integer constant [1.5] found in ordinal", error.getMessage()); } @@ -182,11 +182,11 @@ void should_report_error_for_non_aggregated_item_in_select_if_no_group_by() { void should_report_error_for_group_by_ordinal_out_of_bound_of_select_list() { SemanticCheckException error1 = assertThrows(SemanticCheckException.class, () -> buildAggregation("SELECT age, AVG(balance) FROM tests GROUP BY 0")); - assertEquals("Group by ordinal [0] is out of bound of select item list", error1.getMessage()); + assertEquals("Ordinal [0] is out of bound of select item list", error1.getMessage()); SemanticCheckException error2 = assertThrows(SemanticCheckException.class, () -> buildAggregation("SELECT age, AVG(balance) FROM tests GROUP BY 3")); - assertEquals("Group by ordinal [3] is out of bound of select item list", error2.getMessage()); + assertEquals("Ordinal [3] is out of bound of select item list", error2.getMessage()); } private Matcher hasGroupByItems(UnresolvedExpression... exprs) { diff --git a/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstBuilderTest.java b/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstBuilderTest.java index fbe4658d01..b4d45379b1 100644 --- a/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstBuilderTest.java +++ b/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstBuilderTest.java @@ -19,14 +19,17 @@ import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.agg; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.aggregate; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.alias; +import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.argument; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.booleanLiteral; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.doubleLiteral; +import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.field; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.filter; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.function; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.intLiteral; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.project; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.qualifiedName; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.relation; +import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.sort; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.stringLiteral; import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.values; import static java.util.Collections.emptyList; @@ -193,7 +196,7 @@ public void can_build_where_clause() { } @Test - public void can_build_group_by_clause() { + public void can_build_group_by_field_name() { assertEquals( project( agg( @@ -208,7 +211,7 @@ public void can_build_group_by_clause() { } @Test - public void can_build_group_by_with_function() { + public void can_build_group_by_function() { assertEquals( project( agg( @@ -223,7 +226,7 @@ public void can_build_group_by_with_function() { } @Test - public void can_build_group_by_with_uppercase_function() { + public void can_build_group_by_uppercase_function() { assertEquals( project( agg( @@ -238,7 +241,7 @@ public void can_build_group_by_with_uppercase_function() { } @Test - public void can_build_group_by_with_alias() { + public void can_build_group_by_alias() { assertEquals( project( agg( @@ -253,7 +256,7 @@ public void can_build_group_by_with_alias() { } @Test - public void can_build_group_by_with_ordinal() { + public void can_build_group_by_ordinal() { assertEquals( project( agg( @@ -281,6 +284,70 @@ public void can_build_implicit_group_by_clause() { buildAST("SELECT AVG(age) FROM test")); } + @Test + public void can_build_order_by_field_name() { + assertEquals( + project( + sort( + relation("test"), + ImmutableList.of(argument("count", intLiteral(0))), + field("name", argument("asc", booleanLiteral(true)))), + alias("name", qualifiedName("name"))), + buildAST("SELECT name FROM test ORDER BY name")); + } + + @Test + public void can_build_order_by_function() { + assertEquals( + project( + sort( + relation("test"), + ImmutableList.of(argument("count", intLiteral(0))), + field( + function("ABS", qualifiedName("name")), + argument("asc", booleanLiteral(true)))), + alias("name", qualifiedName("name"))), + buildAST("SELECT name FROM test ORDER BY ABS(name)")); + } + + @Test + public void can_build_order_by_alias() { + assertEquals( + project( + sort( + relation("test"), + ImmutableList.of(argument("count", intLiteral(0))), + field("name", argument("asc", booleanLiteral(true)))), + alias("name", qualifiedName("name"), "n")), + buildAST("SELECT name AS n FROM test ORDER BY n ASC")); + } + + @Test + public void can_build_order_by_ordinal() { + assertEquals( + project( + sort( + relation("test"), + ImmutableList.of(argument("count", intLiteral(0))), + field("name", argument("asc", booleanLiteral(false)))), + alias("name", qualifiedName("name"))), + buildAST("SELECT name FROM test ORDER BY 1 DESC")); + } + + @Test + public void can_build_order_by_multiple_field_names() { + assertEquals( + project( + sort( + relation("test"), + ImmutableList.of(argument("count", intLiteral(0))), + field("name", argument("asc", booleanLiteral(true))), + field("age", argument("asc", booleanLiteral(false)))), + alias("name", qualifiedName("name")), + alias("age", qualifiedName("age"))), + buildAST("SELECT name, age FROM test ORDER BY name, age DESC")); + } + private UnresolvedPlan buildAST(String query) { ParseTree parseTree = parser.parse(query); return parseTree.accept(new AstBuilder(query)); diff --git a/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstSortBuilderTest.java b/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstSortBuilderTest.java new file mode 100644 index 0000000000..f055e62c9b --- /dev/null +++ b/sql/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/parser/AstSortBuilderTest.java @@ -0,0 +1,70 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.opendistroforelasticsearch.sql.sql.parser; + +import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.argument; +import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.booleanLiteral; +import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.field; +import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.intLiteral; +import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.qualifiedName; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.AdditionalAnswers.returnsFirstArg; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.when; + +import com.amazon.opendistroforelasticsearch.sql.ast.tree.Sort; +import com.amazon.opendistroforelasticsearch.sql.ast.tree.UnresolvedPlan; +import com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.OrderByClauseContext; +import com.amazon.opendistroforelasticsearch.sql.sql.parser.context.QuerySpecification; +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +@ExtendWith(MockitoExtension.class) +class AstSortBuilderTest { + + @Mock + private QuerySpecification querySpec; + + @Mock + private OrderByClauseContext orderByClause; + + @Mock + private UnresolvedPlan child; + + @Test + void can_build_sort_node() { + doAnswer(returnsFirstArg()).when(querySpec).replaceIfAliasOrOrdinal(any()); + when(querySpec.getOrderByItems()).thenReturn(ImmutableList.of(qualifiedName("name"))); + when(querySpec.getOrderByOptions()).thenReturn(ImmutableList.of("ASC")); + + AstSortBuilder sortBuilder = new AstSortBuilder(querySpec); + assertEquals( + new Sort( + child, // has to mock and attach child otherwise Guava ImmutableList NPE in getChild() + ImmutableList.of(argument("count", intLiteral(0))), + ImmutableList.of(field("name", argument("asc", booleanLiteral(true))))), + sortBuilder.visitOrderByClause(orderByClause).attach(child)); + } + +} \ No newline at end of file