Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Commit

Permalink
Support COUNT star and literal in new engine (#802)
Browse files Browse the repository at this point in the history
* Support count star and literal

* Add UT

* Add UT

* Add UT

* Add comparison IT and doctest

* Ignore failed IT

* Change javadoc

* Add UT for in-memory count aggregator
  • Loading branch information
dai-chen authored Oct 30, 2020
1 parent d1b8a7b commit a718f58
Show file tree
Hide file tree
Showing 11 changed files with 130 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.amazon.opendistroforelasticsearch.sql.analysis.symbol.Symbol;
import com.amazon.opendistroforelasticsearch.sql.ast.AbstractNodeVisitor;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.AggregateFunction;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.AllFields;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.And;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Compare;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.EqualTo;
Expand Down Expand Up @@ -168,6 +169,14 @@ public Expression visitField(Field node, AnalysisContext context) {
return visitIdentifier(attr, context);
}

@Override
public Expression visitAllFields(AllFields node, AnalysisContext context) {
// Convert to string literal for argument in COUNT(*), because there is no difference between
// COUNT(*) and COUNT(literal). For SELECT *, its select expression analyzer will expand * to
// the right field name list by itself.
return DSL.literal("*");
}

@Override
public Expression visitQualifiedName(QualifiedName node, AnalysisContext context) {
QualifierAnalyzer qualifierAnalyzer = new QualifierAnalyzer(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import com.amazon.opendistroforelasticsearch.sql.analysis.symbol.Namespace;
import com.amazon.opendistroforelasticsearch.sql.analysis.symbol.Symbol;
import com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.AllFields;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.DataType;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.UnresolvedExpression;
import com.amazon.opendistroforelasticsearch.sql.common.antlr.SyntaxCheckException;
Expand Down Expand Up @@ -121,6 +122,13 @@ public void interval() {
AstDSL.intervalLiteral(1L, DataType.LONG, "DAY"));
}

@Test
public void all_fields() {
assertAnalyzeEqual(
DSL.literal("*"),
AllFields.of());
}

@Test
public void skip_struct_data_type() {
SyntaxCheckException exception =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ public void count_with_null() {
assertEquals(2, result.value());
}

@Test
public void count_star_with_null_and_missing() {
ExprValue result = aggregation(dsl.count(DSL.literal("*")), tuples_with_null_and_missing);
assertEquals(3, result.value());
}

@Test
public void count_literal_with_null_and_missing() {
ExprValue result = aggregation(dsl.count(DSL.literal(1)), tuples_with_null_and_missing);
assertEquals(3, result.value());
}

@Test
public void valueOf() {
ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class,
Expand Down
8 changes: 8 additions & 0 deletions docs/user/dql/aggregations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,14 @@ The aggregation could has expression as arguments::
| M | 202 |
+----------+--------+

COUNT Aggregations
------------------

Besides regular identifiers, ``COUNT`` aggregate function also accepts arguments such as ``*`` or literals like ``1``. The meaning of these different forms are as follows:

1. ``COUNT(field)`` will count only if given field (or expression) is not null or missing in the input rows.
2. ``COUNT(*)`` will count the number of all its input rows.
3. ``COUNT(1)`` is same as ``COUNT(*)`` because any non-null literal will count.

HAVING Clause
=============
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@

package com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.script.aggregation.dsl;

import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.INTEGER;

import com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.serialization.ExpressionSerializer;
import com.amazon.opendistroforelasticsearch.sql.expression.Expression;
import com.amazon.opendistroforelasticsearch.sql.expression.ExpressionNodeVisitor;
import com.amazon.opendistroforelasticsearch.sql.expression.LiteralExpression;
import com.amazon.opendistroforelasticsearch.sql.expression.ReferenceExpression;
import com.amazon.opendistroforelasticsearch.sql.expression.aggregation.NamedAggregator;
import java.util.List;
import org.elasticsearch.search.aggregations.AggregationBuilder;
Expand Down Expand Up @@ -66,7 +70,7 @@ public AggregationBuilder visitNamedAggregator(NamedAggregator node,
case "sum":
return make(AggregationBuilders.sum(name), expression);
case "count":
return make(AggregationBuilders.count(name), expression);
return make(AggregationBuilders.count(name), replaceStarOrLiteral(expression));
case "min":
return make(AggregationBuilders.min(name), expression);
case "max":
Expand All @@ -81,4 +85,21 @@ private ValuesSourceAggregationBuilder<?> make(ValuesSourceAggregationBuilder<?>
Expression expression) {
return helper.build(expression, builder::field, builder::script);
}

/**
* Replace star or literal with Elasticsearch metadata field "_index". Because:
* 1) Analyzer already converts * to string literal, literal check here can handle
* both COUNT(*) and COUNT(1).
* 2) Value count aggregation on _index counts all docs (after filter), therefore
* it has same semantics as COUNT(*) or COUNT(1).
* @param countArg count function argument
* @return Reference to _index if literal, otherwise return original argument expression
*/
private Expression replaceStarOrLiteral(Expression countArg) {
if (countArg instanceof LiteralExpression) {
return new ReferenceExpression("_index", INTEGER);
}
return countArg;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.script.aggregation.dsl;

import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.INTEGER;
import static com.amazon.opendistroforelasticsearch.sql.expression.DSL.literal;
import static com.amazon.opendistroforelasticsearch.sql.expression.DSL.named;
import static com.amazon.opendistroforelasticsearch.sql.expression.DSL.ref;
import static org.junit.jupiter.api.Assertions.assertEquals;
Expand Down Expand Up @@ -109,6 +110,38 @@ void should_build_count_aggregation() {
new CountAggregator(Arrays.asList(ref("age", INTEGER)), INTEGER)))));
}

@Test
void should_build_count_star_aggregation() {
assertEquals(
"{\n"
+ " \"count(*)\" : {\n"
+ " \"value_count\" : {\n"
+ " \"field\" : \"_index\"\n"
+ " }\n"
+ " }\n"
+ "}",
buildQuery(
Arrays.asList(
named("count(*)",
new CountAggregator(Arrays.asList(literal("*")), INTEGER)))));
}

@Test
void should_build_count_other_literal_aggregation() {
assertEquals(
"{\n"
+ " \"count(1)\" : {\n"
+ " \"value_count\" : {\n"
+ " \"field\" : \"_index\"\n"
+ " }\n"
+ " }\n"
+ "}",
buildQuery(
Arrays.asList(
named("count(1)",
new CountAggregator(Arrays.asList(literal(1)), INTEGER)))));
}

@Test
void should_build_min_aggregation() {
assertEquals(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ public void aggregationFunctionInSelect() throws IOException {
}
}

@Ignore("In MySQL and our new engine, the original text in SELECT is used as final column name")
@Test
public void aggregationFunctionInSelectCaseCheck() throws IOException {
JSONObject response = executeQuery(
Expand Down
8 changes: 8 additions & 0 deletions integ-test/src/test/resources/correctness/queries/groupby.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
SELECT COUNT(*) FROM kibana_sample_data_flights
SELECT COUNT(1) FROM kibana_sample_data_flights
SELECT COUNT('hello') FROM kibana_sample_data_flights
SELECT SUM(FlightDelayMin) FROM kibana_sample_data_flights
SELECT AVG(FlightDelayMin) FROM kibana_sample_data_flights
SELECT MIN(FlightDelayMin) FROM kibana_sample_data_flights
SELECT MAX(FlightDelayMin) FROM kibana_sample_data_flights
SELECT count(*), Avg(FlightDelayMin), sUm(FlightDelayMin) FROM kibana_sample_data_flights
SELECT COUNT(*) AS cnt, AVG(FlightDelayMin) AS a, SUM(FlightDelayMin) AS s FROM kibana_sample_data_flights
SELECT COUNT(*) FROM kibana_sample_data_flights WHERE FlightTimeMin > 0
SELECT COUNT(1) FROM kibana_sample_data_flights WHERE FlightTimeMin > 0
SELECT SUM(FlightDelayMin) FROM kibana_sample_data_flights WHERE FlightTimeMin > 0
SELECT AVG(FlightDelayMin) FROM kibana_sample_data_flights WHERE FlightTimeMin > 0
SELECT COUNT(*) FROM kibana_sample_data_flights GROUP BY OriginCountry
SELECT COUNT(FlightNum) FROM kibana_sample_data_flights GROUP BY OriginCountry
SELECT COUNT(FlightDelay) FROM kibana_sample_data_flights GROUP BY OriginCountry
Expand Down
4 changes: 2 additions & 2 deletions sql/src/main/antlr/OpenDistroSQLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ scalarFunctionName
;

aggregateFunction
: functionName=aggregationFunctionName LR_BRACKET functionArg RR_BRACKET
/*| COUNT LR_BRACKET (STAR | functionArg) RR_BRACKET */
: functionName=aggregationFunctionName LR_BRACKET functionArg RR_BRACKET #regularAggregateFunctionCall
| COUNT LR_BRACKET STAR RR_BRACKET #countStarFunctionCall
;

aggregationFunctionName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import static com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName.LIKE;
import static com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName.NOT_LIKE;
import static com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName.REGEXP;
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.AggregateFunctionContext;
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.BinaryComparisonPredicateContext;
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.BooleanContext;
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.CountStarFunctionCallContext;
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.DateLiteralContext;
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.IdentsAsQualifiedNameContext;
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.IsNullPredicateContext;
Expand All @@ -36,6 +36,7 @@
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.OverClauseContext;
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.RankingWindowFunctionContext;
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.RegexpPredicateContext;
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.RegularAggregateFunctionCallContext;
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.ScalarFunctionCallContext;
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.SignedDecimalContext;
import static com.amazon.opendistroforelasticsearch.sql.sql.antlr.parser.OpenDistroSQLParser.SignedRealContext;
Expand All @@ -46,6 +47,7 @@

import com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.AggregateFunction;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.AllFields;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.And;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Function;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Interval;
Expand Down Expand Up @@ -160,12 +162,18 @@ public UnresolvedExpression visitRankingWindowFunction(RankingWindowFunctionCont
}

@Override
public UnresolvedExpression visitAggregateFunction(AggregateFunctionContext ctx) {
public UnresolvedExpression visitRegularAggregateFunctionCall(
RegularAggregateFunctionCallContext ctx) {
return new AggregateFunction(
ctx.functionName.getText(),
visitFunctionArg(ctx.functionArg()));
}

@Override
public UnresolvedExpression visitCountStarFunctionCall(CountStarFunctionCallContext ctx) {
return new AggregateFunction("COUNT", AllFields.of());
}

@Override
public UnresolvedExpression visitIsNullPredicate(IsNullPredicateContext ctx) {
return new Function(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,23 @@ public void can_build_where_clause() {
);
}

@Test
public void can_build_count_star_and_count_literal() {
assertEquals(
project(
agg(
relation("test"),
ImmutableList.of(
alias("COUNT(*)", aggregate("COUNT", AllFields.of())),
alias("COUNT(1)", aggregate("COUNT", intLiteral(1)))),
emptyList(),
emptyList(),
emptyList()),
alias("COUNT(*)", aggregate("COUNT", AllFields.of())),
alias("COUNT(1)", aggregate("COUNT", intLiteral(1)))),
buildAST("SELECT COUNT(*), COUNT(1) FROM test"));
}

@Test
public void can_build_group_by_field_name() {
assertEquals(
Expand Down

0 comments on commit a718f58

Please sign in to comment.