diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionAnalyzer.java index f23be3efa7..bc0146fc22 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionAnalyzer.java @@ -23,6 +23,7 @@ import com.amazon.opendistroforelasticsearch.sql.ast.expression.Field; import com.amazon.opendistroforelasticsearch.sql.ast.expression.Function; import com.amazon.opendistroforelasticsearch.sql.ast.expression.Literal; +import com.amazon.opendistroforelasticsearch.sql.ast.expression.Or; import com.amazon.opendistroforelasticsearch.sql.ast.expression.UnresolvedAttribute; import com.amazon.opendistroforelasticsearch.sql.ast.expression.UnresolvedExpression; import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils; @@ -87,6 +88,14 @@ public Expression visitAnd(And node, AnalysisContext context) { return dsl.and(context.peek(), left, right); } + @Override + public Expression visitOr(Or node, AnalysisContext context) { + Expression left = node.getLeft().accept(this, context); + Expression right = node.getRight().accept(this, context); + + return dsl.or(context.peek(), left, right); + } + @Override public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext context) { Optional builtinFunctionName = BuiltinFunctionName.of(node.getFuncName()); diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/dsl/AstDSL.java index 906b469d5a..76c6990460 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/dsl/AstDSL.java @@ -122,7 +122,7 @@ public static Literal nullLiteral() { return literal(null, DataType.NULL); } - public static UnresolvedExpression map(String origin, String target) { + public static Map map(String origin, String target) { return new Map(new Field(origin), new Field(target)); } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/tree/Rename.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/tree/Rename.java index b6745b2459..9b7d0d5cf6 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/tree/Rename.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/ast/tree/Rename.java @@ -20,6 +20,7 @@ import com.google.common.collect.ImmutableList; import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.RequiredArgsConstructor; import lombok.ToString; import java.util.List; @@ -27,6 +28,7 @@ @ToString @EqualsAndHashCode(callSuper = false) @Getter +@RequiredArgsConstructor public class Rename extends UnresolvedPlan { private final List renameList; private UnresolvedPlan child; diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/DSL.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/DSL.java index c61cdd0805..29f37eb92f 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/DSL.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/DSL.java @@ -22,9 +22,8 @@ import com.amazon.opendistroforelasticsearch.sql.expression.env.Environment; import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName; import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionRepository; -import lombok.RequiredArgsConstructor; - import java.util.Arrays; +import lombok.RequiredArgsConstructor; @RequiredArgsConstructor public class DSL { @@ -123,4 +122,9 @@ public Aggregator sum(Environment env, Expression... expre return (Aggregator) repository.compile(BuiltinFunctionName.SUM.getName(), Arrays.asList(expressions), env); } + + public Aggregator count(Environment env, Expression... expressions) { + return (Aggregator) + repository.compile(BuiltinFunctionName.COUNT.getName(), Arrays.asList(expressions), env); + } } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AggregatorFunction.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AggregatorFunction.java index ffc06c5a59..60ff968fde 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AggregatorFunction.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AggregatorFunction.java @@ -41,6 +41,7 @@ public class AggregatorFunction { public static void register(BuiltinFunctionRepository repository) { repository.register(avg()); repository.register(sum()); + repository.register(count()); } private static FunctionResolver avg() { @@ -54,6 +55,31 @@ private static FunctionResolver avg() { ); } + private static FunctionResolver count() { + FunctionName functionName = BuiltinFunctionName.COUNT.getName(); + return new FunctionResolver( + functionName, + new ImmutableMap.Builder() + .put(new FunctionSignature(functionName, Collections.singletonList(ExprType.INTEGER)), + arguments -> new CountAggregator(arguments, ExprType.INTEGER)) + .put(new FunctionSignature(functionName, Collections.singletonList(ExprType.LONG)), + arguments -> new CountAggregator(arguments, ExprType.INTEGER)) + .put(new FunctionSignature(functionName, Collections.singletonList(ExprType.FLOAT)), + arguments -> new CountAggregator(arguments, ExprType.INTEGER)) + .put(new FunctionSignature(functionName, Collections.singletonList(ExprType.DOUBLE)), + arguments -> new CountAggregator(arguments, ExprType.INTEGER)) + .put(new FunctionSignature(functionName, Collections.singletonList(ExprType.STRING)), + arguments -> new CountAggregator(arguments, ExprType.INTEGER)) + .put(new FunctionSignature(functionName, Collections.singletonList(ExprType.STRUCT)), + arguments -> new CountAggregator(arguments, ExprType.INTEGER)) + .put(new FunctionSignature(functionName, Collections.singletonList(ExprType.ARRAY)), + arguments -> new CountAggregator(arguments, ExprType.INTEGER)) + .put(new FunctionSignature(functionName, Collections.singletonList(ExprType.BOOLEAN)), + arguments -> new CountAggregator(arguments, ExprType.INTEGER)) + .build() + ); + } + private static FunctionResolver sum() { FunctionName functionName = BuiltinFunctionName.SUM.getName(); return new FunctionResolver( diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/CountAggregator.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/CountAggregator.java new file mode 100644 index 0000000000..eeae8d1a20 --- /dev/null +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/CountAggregator.java @@ -0,0 +1,69 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.sql.expression.aggregation; + +import static com.amazon.opendistroforelasticsearch.sql.utils.ExpressionUtils.format; + +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprType; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils; +import com.amazon.opendistroforelasticsearch.sql.expression.Expression; +import com.amazon.opendistroforelasticsearch.sql.expression.aggregation.CountAggregator.CountState; +import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName; +import com.amazon.opendistroforelasticsearch.sql.storage.bindingtuple.BindingTuple; +import java.util.List; +import java.util.Locale; + +public class CountAggregator extends Aggregator { + + public CountAggregator(List arguments, ExprType returnType) { + super(BuiltinFunctionName.COUNT.getName(), arguments, returnType); + } + + @Override + public CountAggregator.CountState create() { + return new CountState(); + } + + @Override + public CountState iterate(BindingTuple tuple, CountState state) { + Expression expression = getArguments().get(0); + ExprValue value = expression.valueOf(tuple); + if (!(value.isNull() || value.isMissing())) { + state.count++; + } + return state; + } + + @Override + public String toString() { + return String.format(Locale.ROOT, "count(%s)", format(getArguments())); + } + + /** Count State. */ + protected class CountState implements AggregationState { + private int count; + + public CountState() { + this.count = 0; + } + + @Override + public ExprValue result() { + return ExprValueUtils.integerValue(count); + } + } +} diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/function/BuiltinFunctionName.java index 46d5d8437b..29ecb26fd6 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/function/BuiltinFunctionName.java @@ -33,7 +33,8 @@ public enum BuiltinFunctionName { /** Aggregation Function. */ AVG(FunctionName.of("avg")), - SUM(FunctionName.of("sum")); + SUM(FunctionName.of("sum")), + COUNT(FunctionName.of("count")); private final FunctionName name; diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionAnalyzerTest.java index 6f1831a94c..84986b8097 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionAnalyzerTest.java @@ -47,6 +47,14 @@ public void and() { ); } + @Test + public void or() { + assertAnalyzeEqual( + dsl.or(typeEnv, DSL.ref("boolean_value"), DSL.literal(LITERAL_TRUE)), + AstDSL.or(AstDSL.unresolvedAttr("boolean_value"), AstDSL.booleanLiteral(true)) + ); + } + @Test public void undefined_var_semantic_check_failed() { SemanticCheckException exception = assertThrows(SemanticCheckException.class, diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AggregationTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AggregationTest.java index 5621c5d705..5ddf8db89d 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AggregationTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AggregationTest.java @@ -15,9 +15,14 @@ package com.amazon.opendistroforelasticsearch.sql.expression.aggregation; +import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.booleanValue; +import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.collectionValue; +import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.tupleValue; + import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue; import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils; import com.amazon.opendistroforelasticsearch.sql.expression.ExpressionTestBase; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.Arrays; import java.util.Collections; @@ -25,42 +30,69 @@ public class AggregationTest extends ExpressionTestBase { - protected static List tuples = Arrays.asList( - ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 2, - "long_value", 2L, - "string_value", "m", - "double_value", 2d, - "float_value", 2f)), - ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1, - "long_value", 1L, - "string_value", "f", - "double_value", 1d, - "float_value", 1f)), - ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 3, - "long_value", 3L, - "string_value", "m", - "double_value", 3d, - "float_value", 3f)), - ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 4, - "long_value", 4L, - "string_value", "f", - "double_value", 4d, - "float_value", 4f))); + protected static List tuples = + Arrays.asList( + ExprValueUtils.tupleValue( + new ImmutableMap.Builder() + .put("integer_value", 2) + .put("long_value", 2L) + .put("string_value", "m") + .put("double_value", 2d) + .put("float_value", 2f) + .put("boolean_value", true) + .put("struct_value", ImmutableMap.of("str", 1)) + .put("array_value", ImmutableList.of(1)) + .build()), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "integer_value", + 1, + "long_value", + 1L, + "string_value", + "f", + "double_value", + 1d, + "float_value", + 1f)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "integer_value", + 3, + "long_value", + 3L, + "string_value", + "m", + "double_value", + 3d, + "float_value", + 3f)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "integer_value", + 4, + "long_value", + 4L, + "string_value", + "f", + "double_value", + 4d, + "float_value", + 4f))); - protected static List tuples_with_null_and_missing = Arrays.asList( - ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 2, - "string_value", "m", - "double_value", 3d)), - ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1, - "string_value", "f", - "double_value", 4d)), - ExprValueUtils.tupleValue(Collections.singletonMap("double_value", null))); + protected static List tuples_with_null_and_missing = + Arrays.asList( + ExprValueUtils.tupleValue( + ImmutableMap.of("integer_value", 2, "string_value", "m", "double_value", 3d)), + ExprValueUtils.tupleValue( + ImmutableMap.of("integer_value", 1, "string_value", "f", "double_value", 4d)), + ExprValueUtils.tupleValue(Collections.singletonMap("double_value", null))); - protected ExprValue aggregation(Aggregator aggregator, List tuples) { - AggregationState state = aggregator.create(); - for (ExprValue tuple : tuples) { - aggregator.iterate(tuple.bindingTuples(), state); - } - return state.result(); + protected ExprValue aggregation(Aggregator aggregator, List tuples) { + AggregationState state = aggregator.create(); + for (ExprValue tuple : tuples) { + aggregator.iterate(tuple.bindingTuples(), state); } + return state.result(); + } } diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/CountAggregatorTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/CountAggregatorTest.java new file mode 100644 index 0000000000..b98b4ef88e --- /dev/null +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/CountAggregatorTest.java @@ -0,0 +1,110 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.sql.expression.aggregation; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils; +import com.amazon.opendistroforelasticsearch.sql.exception.ExpressionEvaluationException; +import com.amazon.opendistroforelasticsearch.sql.expression.DSL; +import org.junit.jupiter.api.Test; + +class CountAggregatorTest extends AggregationTest { + + @Test + public void count_integer_field_expression() { + ExprValue result = aggregation(dsl.count(typeEnv, DSL.ref("integer_value")), tuples); + assertEquals(4, result.value()); + } + + @Test + public void count_long_field_expression() { + ExprValue result = aggregation(dsl.count(typeEnv, DSL.ref("long_value")), tuples); + assertEquals(4, result.value()); + } + + @Test + public void count_float_field_expression() { + ExprValue result = aggregation(dsl.count(typeEnv, DSL.ref("float_value")), tuples); + assertEquals(4, result.value()); + } + + @Test + public void count_double_field_expression() { + ExprValue result = aggregation(dsl.count(typeEnv, DSL.ref("double_value")), tuples); + assertEquals(4, result.value()); + } + + @Test + public void count_arithmetic_expression() { + ExprValue result = aggregation(dsl.count(typeEnv, + dsl.multiply(typeEnv, DSL.ref("integer_value"), DSL.literal(ExprValueUtils.integerValue(10)))), tuples); + assertEquals(4, result.value()); + } + + @Test + public void count_string_field_expression() { + ExprValue result = aggregation(dsl.count(typeEnv, DSL.ref("string_value")), tuples); + assertEquals(4, result.value()); + } + + @Test + public void count_boolean_field_expression() { + ExprValue result = aggregation(dsl.count(typeEnv, DSL.ref("boolean_value")), tuples); + assertEquals(1, result.value()); + } + + @Test + public void count_struct_field_expression() { + ExprValue result = aggregation(dsl.count(typeEnv, DSL.ref("struct_value")), tuples); + assertEquals(1, result.value()); + } + + @Test + public void count_array_field_expression() { + ExprValue result = aggregation(dsl.count(typeEnv, DSL.ref("array_value")), tuples); + assertEquals(1, result.value()); + } + + @Test + public void count_with_missing() { + ExprValue result = aggregation(dsl.count(typeEnv, DSL.ref("integer_value")), + tuples_with_null_and_missing); + assertEquals(2, result.value()); + } + + @Test + public void count_with_null() { + ExprValue result = aggregation(dsl.count(typeEnv, DSL.ref("double_value")), + tuples_with_null_and_missing); + assertEquals(2, result.value()); + } + + @Test + public void valueOf() { + ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class, + () -> dsl.count(typeEnv, DSL.ref("double_value")).valueOf(valueEnv())); + assertEquals("can't evaluate on aggregator: count", exception.getMessage()); + } + + @Test + public void test_to_string() { + Aggregator countAggregator = dsl.count(typeEnv, DSL.ref("integer_value")); + assertEquals("count(integer_value)", countAggregator.toString()); + } +} \ No newline at end of file diff --git a/ppl/src/main/java/com/amazon/opendistroforelasticsearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/com/amazon/opendistroforelasticsearch/sql/ppl/parser/AstBuilder.java index 08d5696cb4..8838ea7204 100644 --- a/ppl/src/main/java/com/amazon/opendistroforelasticsearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/com/amazon/opendistroforelasticsearch/sql/ppl/parser/AstBuilder.java @@ -108,7 +108,7 @@ public UnresolvedPlan visitFieldsCommand(FieldsCommandContext ctx) { /** Rename command */ @Override public UnresolvedPlan visitRenameCommand(RenameCommandContext ctx) { - return new Project( + return new Rename( new ArrayList<>( Collections.singletonList( new Map( diff --git a/ppl/src/test/java/com/amazon/opendistroforelasticsearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/com/amazon/opendistroforelasticsearch/sql/ppl/parser/AstBuilderTest.java index 733cfac064..c97aed6522 100644 --- a/ppl/src/test/java/com/amazon/opendistroforelasticsearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/com/amazon/opendistroforelasticsearch/sql/ppl/parser/AstBuilderTest.java @@ -134,7 +134,7 @@ public void testFieldsCommandWithExcludeArguments() { @Test public void testRenameCommand() { assertEqual("source=t | rename f as g", - project( + rename( relation("t"), map("f", "g") ));