diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/data/model/AbstractExprValue.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/data/model/AbstractExprValue.java index 5774c314b4..cc240456b9 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/data/model/AbstractExprValue.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/data/model/AbstractExprValue.java @@ -30,7 +30,7 @@ public abstract class AbstractExprValue implements ExprValue { public int compareTo(ExprValue other) { if (this.isNull() || this.isMissing() || other.isNull() || other.isMissing()) { throw new IllegalStateException( - String.format("[BUG] Unreachable, Comparing with NULL or MISSING is undefined")); + String.format("[BUG] Unreachable, Comparing with NULL or MISSING is undefined")); } if ((this.isNumber() && other.isNumber()) || this.type() == other.type()) { return compare(other); diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/data/model/ExprTimestampValue.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/data/model/ExprTimestampValue.java index 3623e838c3..4a81543457 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/data/model/ExprTimestampValue.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/data/model/ExprTimestampValue.java @@ -101,12 +101,12 @@ public String toString() { @Override public int compare(ExprValue other) { - return timestamp.compareTo(other.timestampValue()); + return timestamp.compareTo(other.timestampValue().atZone(ZONE).toInstant()); } @Override public boolean equal(ExprValue other) { - return timestamp.equals(other.timestampValue()); + return timestamp.equals(other.timestampValue().atZone(ZONE).toInstant()); } @Override diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/DSL.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/DSL.java index e193633293..495e429767 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/DSL.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/DSL.java @@ -453,6 +453,14 @@ public Aggregator count(Expression... expressions) { return aggregate(BuiltinFunctionName.COUNT, expressions); } + public Aggregator min(Expression... expressions) { + return aggregate(BuiltinFunctionName.MIN, expressions); + } + + public Aggregator max(Expression... expressions) { + return aggregate(BuiltinFunctionName.MAX, expressions); + } + private FunctionExpression function(BuiltinFunctionName functionName, Expression... expressions) { return (FunctionExpression) repository.compile( functionName.getName(), Arrays.asList(expressions)); diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AggregatorFunction.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AggregatorFunction.java index d75beae3f4..a09a2c6833 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AggregatorFunction.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AggregatorFunction.java @@ -17,12 +17,16 @@ import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.ARRAY; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.BOOLEAN; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DATE; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DATETIME; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DOUBLE; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.FLOAT; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.INTEGER; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.LONG; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.STRING; import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.STRUCT; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.TIME; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.TIMESTAMP; import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName; import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionRepository; @@ -52,6 +56,8 @@ public static void register(BuiltinFunctionRepository repository) { repository.register(avg()); repository.register(sum()); repository.register(count()); + repository.register(min()); + repository.register(max()); } private static FunctionResolver avg() { @@ -106,4 +112,57 @@ private static FunctionResolver sum() { .build() ); } + + private static FunctionResolver min() { + FunctionName functionName = BuiltinFunctionName.MIN.getName(); + return new FunctionResolver( + functionName, + new ImmutableMap.Builder() + .put(new FunctionSignature(functionName, Collections.singletonList(INTEGER)), + arguments -> new MinAggregator(arguments, INTEGER)) + .put(new FunctionSignature(functionName, Collections.singletonList(LONG)), + arguments -> new MinAggregator(arguments, LONG)) + .put(new FunctionSignature(functionName, Collections.singletonList(FLOAT)), + arguments -> new MinAggregator(arguments, FLOAT)) + .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), + arguments -> new MinAggregator(arguments, DOUBLE)) + .put(new FunctionSignature(functionName, Collections.singletonList(STRING)), + arguments -> new MinAggregator(arguments, STRING)) + .put(new FunctionSignature(functionName, Collections.singletonList(DATE)), + arguments -> new MinAggregator(arguments, DATE)) + .put(new FunctionSignature(functionName, Collections.singletonList(DATETIME)), + arguments -> new MinAggregator(arguments, DATETIME)) + .put(new FunctionSignature(functionName, Collections.singletonList(TIME)), + arguments -> new MinAggregator(arguments, TIME)) + .put(new FunctionSignature(functionName, Collections.singletonList(TIMESTAMP)), + arguments -> new MinAggregator(arguments, TIMESTAMP)) + .build()); + } + + private static FunctionResolver max() { + FunctionName functionName = BuiltinFunctionName.MAX.getName(); + return new FunctionResolver( + functionName, + new ImmutableMap.Builder() + .put(new FunctionSignature(functionName, Collections.singletonList(INTEGER)), + arguments -> new MaxAggregator(arguments, INTEGER)) + .put(new FunctionSignature(functionName, Collections.singletonList(LONG)), + arguments -> new MaxAggregator(arguments, LONG)) + .put(new FunctionSignature(functionName, Collections.singletonList(FLOAT)), + arguments -> new MaxAggregator(arguments, FLOAT)) + .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), + arguments -> new MaxAggregator(arguments, DOUBLE)) + .put(new FunctionSignature(functionName, Collections.singletonList(STRING)), + arguments -> new MaxAggregator(arguments, STRING)) + .put(new FunctionSignature(functionName, Collections.singletonList(DATE)), + arguments -> new MaxAggregator(arguments, DATE)) + .put(new FunctionSignature(functionName, Collections.singletonList(DATETIME)), + arguments -> new MaxAggregator(arguments, DATETIME)) + .put(new FunctionSignature(functionName, Collections.singletonList(TIME)), + arguments -> new MaxAggregator(arguments, TIME)) + .put(new FunctionSignature(functionName, Collections.singletonList(TIMESTAMP)), + arguments -> new MaxAggregator(arguments, TIMESTAMP)) + .build() + ); + } } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AvgAggregator.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AvgAggregator.java index e8682c9db9..917141f8aa 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AvgAggregator.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AvgAggregator.java @@ -46,9 +46,7 @@ public AvgState create() { public AvgState iterate(BindingTuple tuple, AvgState state) { Expression expression = getArguments().get(0); ExprValue value = expression.valueOf(tuple); - if (value.isNull() || value.isMissing()) { - state.isNullResult = true; - } else { + if (!(value.isNull() || value.isMissing())) { state.count++; state.total += ExprValueUtils.getDoubleValue(value); } @@ -63,19 +61,18 @@ public String toString() { /** * Average State. */ - protected class AvgState implements AggregationState { + protected static class AvgState implements AggregationState { private int count; private double total; - private boolean isNullResult = false; - public AvgState() { + AvgState() { this.count = 0; this.total = 0d; } @Override public ExprValue result() { - return isNullResult ? ExprNullValue.of() : ExprValueUtils.doubleValue(total / count); + return count == 0 ? ExprNullValue.of() : ExprValueUtils.doubleValue(total / count); } } } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/CountAggregator.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/CountAggregator.java index 95c930fbae..0e641ea8e8 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/CountAggregator.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/CountAggregator.java @@ -56,10 +56,10 @@ public String toString() { /** * Count State. */ - protected class CountState implements AggregationState { + protected static class CountState implements AggregationState { private int count; - public CountState() { + CountState() { this.count = 0; } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/MaxAggregator.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/MaxAggregator.java new file mode 100644 index 0000000000..2800b40fce --- /dev/null +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/MaxAggregator.java @@ -0,0 +1,72 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.sql.expression.aggregation; + +import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.LITERAL_NULL; +import static com.amazon.opendistroforelasticsearch.sql.utils.ExpressionUtils.format; + +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue; +import com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType; +import com.amazon.opendistroforelasticsearch.sql.expression.Expression; +import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName; +import com.amazon.opendistroforelasticsearch.sql.storage.bindingtuple.BindingTuple; +import java.util.List; + +public class MaxAggregator extends Aggregator { + + public MaxAggregator(List arguments, ExprCoreType returnType) { + super(BuiltinFunctionName.MAX.getName(), arguments, returnType); + } + + @Override + public MaxState create() { + return new MaxState(); + } + + @Override + public MaxState iterate(BindingTuple tuple, MaxState state) { + Expression expression = getArguments().get(0); + ExprValue value = expression.valueOf(tuple); + if (!(value.isNull() || value.isMissing())) { + state.max(value); + } + return state; + } + + @Override + public String toString() { + return String.format("max(%s)", format(getArguments())); + } + + protected static class MaxState implements AggregationState { + private ExprValue maxResult; + + MaxState() { + maxResult = LITERAL_NULL; + } + + public void max(ExprValue value) { + maxResult = maxResult.isNull() ? value + : (maxResult.compareTo(value) > 0) + ? maxResult : value; + } + + @Override + public ExprValue result() { + return maxResult; + } + } +} diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/MinAggregator.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/MinAggregator.java new file mode 100644 index 0000000000..7149b51eca --- /dev/null +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/MinAggregator.java @@ -0,0 +1,77 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.sql.expression.aggregation; + +import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.LITERAL_NULL; +import static com.amazon.opendistroforelasticsearch.sql.utils.ExpressionUtils.format; + +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue; +import com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType; +import com.amazon.opendistroforelasticsearch.sql.expression.Expression; +import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName; +import com.amazon.opendistroforelasticsearch.sql.storage.bindingtuple.BindingTuple; +import java.util.List; + +/** + * The minimum aggregator aggregate the value evaluated by the expression. + * If the expression evaluated result is NULL or MISSING, then the result is NULL. + */ +public class MinAggregator extends Aggregator { + + public MinAggregator(List arguments, ExprCoreType returnType) { + super(BuiltinFunctionName.MIN.getName(), arguments, returnType); + } + + + @Override + public MinState create() { + return new MinState(); + } + + @Override + public MinState iterate(BindingTuple tuple, MinState state) { + Expression expression = getArguments().get(0); + ExprValue value = expression.valueOf(tuple); + if (!(value.isNull() || value.isMissing())) { + state.min(value); + } + return state; + } + + @Override + public String toString() { + return String.format("min(%s)", format(getArguments())); + } + + protected static class MinState implements AggregationState { + private ExprValue minResult; + + MinState() { + minResult = LITERAL_NULL; + } + + public void min(ExprValue value) { + minResult = minResult.isNull() ? value + : (minResult.compareTo(value) < 0) + ? minResult : value; + } + + @Override + public ExprValue result() { + return minResult; + } + } +} diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/SumAggregator.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/SumAggregator.java index 73a93d210d..32d85ab92c 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/SumAggregator.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/SumAggregator.java @@ -56,9 +56,8 @@ public SumState create() { public SumState iterate(BindingTuple tuple, SumState state) { Expression expression = getArguments().get(0); ExprValue value = expression.valueOf(tuple); - if (value.isNull() || value.isMissing()) { - state.isNullResult = true; - } else { + if (!(value.isNull() || value.isMissing())) { + state.isEmptyCollection = false; state.add(value); } return state; @@ -72,15 +71,16 @@ public String toString() { /** * Sum State. */ - protected class SumState implements AggregationState { + protected static class SumState implements AggregationState { private final ExprCoreType type; private ExprValue sumResult; - private boolean isNullResult = false; + private boolean isEmptyCollection; - public SumState(ExprCoreType type) { + SumState(ExprCoreType type) { this.type = type; sumResult = ExprValueUtils.integerValue(0); + isEmptyCollection = true; } /** @@ -108,7 +108,7 @@ public void add(ExprValue value) { @Override public ExprValue result() { - return isNullResult ? ExprNullValue.of() : sumResult; + return isEmptyCollection ? ExprNullValue.of() : sumResult; } } } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/function/BuiltinFunctionName.java index b58bcc2a83..f693779e2f 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/function/BuiltinFunctionName.java @@ -111,6 +111,8 @@ public enum BuiltinFunctionName { AVG(FunctionName.of("avg")), SUM(FunctionName.of("sum")), COUNT(FunctionName.of("count")), + MIN(FunctionName.of("min")), + MAX(FunctionName.of("max")), /** * Text Functions. diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AggregationTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AggregationTest.java index ec42a97b2c..6488a4e8cf 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AggregationTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AggregationTest.java @@ -23,6 +23,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; public class AggregationTest extends ExpressionTestBase { @@ -38,9 +39,13 @@ public class AggregationTest extends ExpressionTestBase { .put("boolean_value", true) .put("struct_value", ImmutableMap.of("str", 1)) .put("array_value", ImmutableList.of(1)) + .put("date_value", "2000-01-01") + .put("datetime_value", "2020-01-01 12:00:00") + .put("time_value", "12:00:00") + .put("timestamp_value", "2020-01-01 12:00:00") .build()), ExprValueUtils.tupleValue( - ImmutableMap.of( + Map.of( "integer_value", 1, "long_value", @@ -50,9 +55,17 @@ public class AggregationTest extends ExpressionTestBase { "double_value", 1d, "float_value", - 1f)), + 1f, + "date_value", + "2020-01-01", + "datetime_value", + "2020-01-01 00:00:00", + "time_value", + "00:00:00", + "timestamp_value", + "2020-01-01 00:00:00")), ExprValueUtils.tupleValue( - ImmutableMap.of( + Map.of( "integer_value", 3, "long_value", @@ -62,19 +75,35 @@ public class AggregationTest extends ExpressionTestBase { "double_value", 3d, "float_value", - 3f)), + 3f, + "date_value", + "1970-01-01", + "datetime_value", + "1970-01-01 19:00:00", + "time_value", + "19:00:00", + "timestamp_value", + "1970-01-01 19:00:00")), ExprValueUtils.tupleValue( - ImmutableMap.of( + Map.of( "integer_value", 4, "long_value", 4L, "string_value", - "f", + "n", "double_value", 4d, "float_value", - 4f))); + 4f, + "date_value", + "2040-01-01", + "datetime_value", + "2040-01-01 07:00:00", + "time_value", + "07:00:00", + "timestamp_value", + "2040-01-01 07:00:00"))); protected static List tuples_with_null_and_missing = Arrays.asList( @@ -84,6 +113,12 @@ public class AggregationTest extends ExpressionTestBase { ImmutableMap.of("integer_value", 1, "string_value", "f", "double_value", 4d)), ExprValueUtils.tupleValue(Collections.singletonMap("double_value", null))); + protected static List tuples_with_all_null_or_missing = + Arrays.asList( + ExprValueUtils.tupleValue(Collections.singletonMap("integer_value", null)), + ExprValueUtils.tupleValue(Collections.singletonMap("double", null)), + ExprValueUtils.tupleValue(Collections.singletonMap("string_value", null))); + protected ExprValue aggregation(Aggregator aggregator, List tuples) { AggregationState state = aggregator.create(); for (ExprValue tuple : tuples) { diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AvgAggregatorTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AvgAggregatorTest.java index f90d061dde..5c261b267c 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AvgAggregatorTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/AvgAggregatorTest.java @@ -47,13 +47,20 @@ public void avg_arithmetic_expression() { public void avg_with_missing() { ExprValue result = aggregation(dsl.avg(DSL.ref("integer_value", INTEGER)), tuples_with_null_and_missing); - assertTrue(result.isNull()); + assertEquals(1.5, result.value()); } @Test public void avg_with_null() { ExprValue result = aggregation(dsl.avg(DSL.ref("double_value", DOUBLE)), tuples_with_null_and_missing); + assertEquals(3.5, result.value()); + } + + @Test + public void avg_with_all_missing_or_null() { + ExprValue result = + aggregation(dsl.avg(DSL.ref("integer_value", INTEGER)), tuples_with_all_null_or_missing); assertTrue(result.isNull()); } diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/MaxAggregatorTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/MaxAggregatorTest.java new file mode 100644 index 0000000000..37a8291f65 --- /dev/null +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/MaxAggregatorTest.java @@ -0,0 +1,145 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.sql.expression.aggregation; + +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DATE; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DATETIME; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DOUBLE; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.FLOAT; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.INTEGER; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.LONG; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.STRING; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.STRUCT; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.TIME; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.TIMESTAMP; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils; +import com.amazon.opendistroforelasticsearch.sql.exception.ExpressionEvaluationException; +import com.amazon.opendistroforelasticsearch.sql.expression.DSL; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.junit.jupiter.api.Test; + +public class MaxAggregatorTest extends AggregationTest { + + @Test + public void test_max_integer() { + ExprValue result = aggregation(dsl.max(DSL.ref("integer_value", INTEGER)), tuples); + assertEquals(4, result.value()); + } + + @Test + public void test_max_long() { + ExprValue result = aggregation(dsl.max(DSL.ref("long_value", LONG)), tuples); + assertEquals(4L, result.value()); + } + + @Test + public void test_max_float() { + ExprValue result = aggregation(dsl.max(DSL.ref("float_value", FLOAT)), tuples); + assertEquals(4F, result.value()); + } + + @Test + public void test_max_double() { + ExprValue result = aggregation(dsl.max(DSL.ref("double_value", DOUBLE)), tuples); + assertEquals(4D, result.value()); + } + + @Test + public void test_max_string() { + ExprValue result = aggregation(dsl.max(DSL.ref("string_value", STRING)), tuples); + assertEquals("n", result.value()); + } + + @Test + public void test_max_date() { + ExprValue result = aggregation(dsl.max(DSL.ref("date_value", DATE)), tuples); + assertEquals("2040-01-01", result.value()); + } + + @Test + public void test_max_datetime() { + ExprValue result = aggregation(dsl.max(DSL.ref("datetime_value", DATETIME)), tuples); + assertEquals("2040-01-01 07:00:00", result.value()); + } + + @Test + public void test_max_time() { + ExprValue result = aggregation(dsl.max(DSL.ref("time_value", TIME)), tuples); + assertEquals("19:00:00", result.value()); + } + + @Test + public void test_max_timestamp() { + ExprValue result = aggregation(dsl.max(DSL.ref("timestamp_value", TIMESTAMP)), tuples); + assertEquals("2040-01-01 07:00:00", result.value()); + } + + @Test + public void test_max_arithmetic_expression() { + ExprValue result = aggregation( + dsl.max(dsl.add(DSL.ref("integer_value", INTEGER), + DSL.literal(ExprValueUtils.integerValue(0)))), tuples); + assertEquals(4, result.value()); + } + + @Test + public void test_max_null() { + ExprValue result = + aggregation(dsl.max(DSL.ref("double_value", DOUBLE)), tuples_with_null_and_missing); + assertEquals(4.0, result.value()); + } + + @Test + public void test_max_missing() { + ExprValue result = + aggregation(dsl.max(DSL.ref("integer_value", INTEGER)), tuples_with_null_and_missing); + assertEquals(2, result.value()); + } + + @Test + public void test_max_all_missing_or_null() { + ExprValue result = + aggregation(dsl.max(DSL.ref("integer_value", INTEGER)), tuples_with_all_null_or_missing); + assertTrue(result.isNull()); + } + + @Test + public void test_value_of() { + ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class, + () -> dsl.max(DSL.ref("double_value", DOUBLE)).valueOf(valueEnv())); + assertEquals("can't evaluate on aggregator: max", exception.getMessage()); + } + + @Test + public void test_to_string() { + Aggregator maxAggregator = dsl.max(DSL.ref("integer_value", INTEGER)); + assertEquals("max(integer_value)", maxAggregator.toString()); + } + + @Test + public void test_nested_to_string() { + Aggregator maxAggregator = dsl.max(dsl.add(DSL.ref("integer_value", INTEGER), + DSL.literal(ExprValueUtils.integerValue(10)))); + assertEquals(String.format("max(+(%s, %d))", DSL.ref("integer_value", INTEGER), 10), + maxAggregator.toString()); + } +} diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/MinAggregatorTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/MinAggregatorTest.java new file mode 100644 index 0000000000..925d406aac --- /dev/null +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/MinAggregatorTest.java @@ -0,0 +1,145 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.sql.expression.aggregation; + +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DATE; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DATETIME; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DOUBLE; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.FLOAT; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.INTEGER; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.LONG; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.STRING; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.STRUCT; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.TIME; +import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.TIMESTAMP; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue; +import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils; +import com.amazon.opendistroforelasticsearch.sql.exception.ExpressionEvaluationException; +import com.amazon.opendistroforelasticsearch.sql.expression.DSL; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.junit.jupiter.api.Test; + +public class MinAggregatorTest extends AggregationTest { + + @Test + public void test_min_integer() { + ExprValue result = aggregation(dsl.min(DSL.ref("integer_value", INTEGER)), tuples); + assertEquals(1, result.value()); + } + + @Test + public void test_min_long() { + ExprValue result = aggregation(dsl.min(DSL.ref("long_value", LONG)), tuples); + assertEquals(1L, result.value()); + } + + @Test + public void test_min_float() { + ExprValue result = aggregation(dsl.min(DSL.ref("float_value", FLOAT)), tuples); + assertEquals(1F, result.value()); + } + + @Test + public void test_min_double() { + ExprValue result = aggregation(dsl.min(DSL.ref("double_value", DOUBLE)), tuples); + assertEquals(1D, result.value()); + } + + @Test + public void test_min_string() { + ExprValue result = aggregation(dsl.min(DSL.ref("string_value", STRING)), tuples); + assertEquals("f", result.value()); + } + + @Test + public void test_min_date() { + ExprValue result = aggregation(dsl.min(DSL.ref("date_value", DATE)), tuples); + assertEquals("1970-01-01", result.value()); + } + + @Test + public void test_min_datetime() { + ExprValue result = aggregation(dsl.min(DSL.ref("datetime_value", DATETIME)), tuples); + assertEquals("1970-01-01 19:00:00", result.value()); + } + + @Test + public void test_min_time() { + ExprValue result = aggregation(dsl.min(DSL.ref("time_value", TIME)), tuples); + assertEquals("00:00:00", result.value()); + } + + @Test + public void test_min_timestamp() { + ExprValue result = aggregation(dsl.min(DSL.ref("timestamp_value", TIMESTAMP)), tuples); + assertEquals("1970-01-01 19:00:00", result.value()); + } + + @Test + public void test_min_arithmetic_expression() { + ExprValue result = aggregation( + dsl.min(dsl.add(DSL.ref("integer_value", INTEGER), + DSL.literal(ExprValueUtils.integerValue(0)))), tuples); + assertEquals(1, result.value()); + } + + @Test + public void test_min_null() { + ExprValue result = + aggregation(dsl.min(DSL.ref("double_value", DOUBLE)), tuples_with_null_and_missing); + assertEquals(3.0, result.value()); + } + + @Test + public void test_min_missing() { + ExprValue result = + aggregation(dsl.min(DSL.ref("integer_value", INTEGER)), tuples_with_null_and_missing); + assertEquals(1, result.value()); + } + + @Test + public void test_min_all_missing_or_null() { + ExprValue result = + aggregation(dsl.min(DSL.ref("integer_value", INTEGER)), tuples_with_all_null_or_missing); + assertTrue(result.isNull()); + } + + @Test + public void test_value_of() { + ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class, + () -> dsl.min(DSL.ref("double_value", DOUBLE)).valueOf(valueEnv())); + assertEquals("can't evaluate on aggregator: min", exception.getMessage()); + } + + @Test + public void test_to_string() { + Aggregator minAggregator = dsl.min(DSL.ref("integer_value", INTEGER)); + assertEquals("min(integer_value)", minAggregator.toString()); + } + + @Test + public void test_nested_to_string() { + Aggregator minAggregator = dsl.min(dsl.add(DSL.ref("integer_value", INTEGER), + DSL.literal(ExprValueUtils.integerValue(10)))); + assertEquals(String.format("min(+(%s, %d))", DSL.ref("integer_value", INTEGER), 10), + minAggregator.toString()); + } +} diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/SumAggregatorTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/SumAggregatorTest.java index 4697a7c7ff..8ba8b01b9e 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/SumAggregatorTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/aggregation/SumAggregatorTest.java @@ -86,13 +86,20 @@ public void sum_string_field_expression() { public void sum_with_missing() { ExprValue result = aggregation(dsl.sum(DSL.ref("integer_value", INTEGER)), tuples_with_null_and_missing); - assertTrue(result.isNull()); + assertEquals(3, result.value()); } @Test public void sum_with_null() { ExprValue result = aggregation(dsl.sum(DSL.ref("double_value", DOUBLE)), tuples_with_null_and_missing); + assertEquals(7.0, result.value()); + } + + @Test + public void sum_with_all_missing_or_null() { + ExprValue result = + aggregation(dsl.sum(DSL.ref("double_value", DOUBLE)), tuples_with_all_null_or_missing); assertTrue(result.isNull()); } diff --git a/docs/experiment/ppl/cmd/stats.rst b/docs/experiment/ppl/cmd/stats.rst index 3bdf67c818..d4755c00a5 100644 --- a/docs/experiment/ppl/cmd/stats.rst +++ b/docs/experiment/ppl/cmd/stats.rst @@ -24,6 +24,10 @@ The following table catalogs the aggregation functions and also indicates how ea +----------+-------------+-------------+ | AVG | Ignore | Ignore | +----------+-------------+-------------+ +| MAX | Ignore | Ignore | ++----------+-------------+-------------+ +| MIN | Ignore | Ignore | ++----------+-------------+-------------+ Syntax @@ -84,3 +88,34 @@ PPL query:: | M | 33.666666666666664 | 101 | +----------+--------------------+------------+ +Example 4: Calculate the maximum of a field +=========================================== + +The example calculates the max age of all the accounts. + +PPL query:: + + od> source=accounts | stats max(age); + fetched rows / total rows = 1/1 + +------------+ + | max(age) | + |------------| + | 36 | + +------------+ + +Example 5: Calculate the maximum and minimum of a field by group +================================================================ + +The example calculates the max and min age values of all the accounts group by gender. + +PPL query:: + + od> source=accounts | stats max(age), min(age) by gender; + fetched rows / total rows = 2/2 + +----------+------------+------------+ + | gender | min(age) | max(age) | + |----------+------------+------------| + | F | 28 | 28 | + | M | 32 | 36 | + +----------+------------+------------+ + diff --git a/elasticsearch/src/main/java/com/amazon/opendistroforelasticsearch/sql/elasticsearch/data/value/ElasticsearchExprValueFactory.java b/elasticsearch/src/main/java/com/amazon/opendistroforelasticsearch/sql/elasticsearch/data/value/ElasticsearchExprValueFactory.java index 95cf4bc670..252b864acc 100644 --- a/elasticsearch/src/main/java/com/amazon/opendistroforelasticsearch/sql/elasticsearch/data/value/ElasticsearchExprValueFactory.java +++ b/elasticsearch/src/main/java/com/amazon/opendistroforelasticsearch/sql/elasticsearch/data/value/ElasticsearchExprValueFactory.java @@ -166,7 +166,7 @@ public ExprValue construct(String field, Object value) { return constructBoolean((Boolean) value); } else if (type.equals(TIMESTAMP)) { if (value instanceof Number) { - return constructTimestamp((Long) value); + return constructTimestamp(((Number) value).longValue()); } else if (value instanceof Instant) { return constructTimestamp((Instant) value); } else { diff --git a/elasticsearch/src/main/java/com/amazon/opendistroforelasticsearch/sql/elasticsearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/elasticsearch/src/main/java/com/amazon/opendistroforelasticsearch/sql/elasticsearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index 1481161602..508f50a062 100644 --- a/elasticsearch/src/main/java/com/amazon/opendistroforelasticsearch/sql/elasticsearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/elasticsearch/src/main/java/com/amazon/opendistroforelasticsearch/sql/elasticsearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -67,6 +67,10 @@ public AggregationBuilder visitNamedAggregator(NamedAggregator node, return make(AggregationBuilders.sum(name), expression); case "count": return make(AggregationBuilders.count(name), expression); + case "min": + return make(AggregationBuilders.min(name), expression); + case "max": + return make(AggregationBuilders.max(name), expression); default: throw new IllegalStateException( String.format("unsupported aggregator %s", node.getFunctionName().getFunctionName())); diff --git a/elasticsearch/src/test/java/com/amazon/opendistroforelasticsearch/sql/elasticsearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java b/elasticsearch/src/test/java/com/amazon/opendistroforelasticsearch/sql/elasticsearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java index 094b74235c..ced6799f43 100644 --- a/elasticsearch/src/test/java/com/amazon/opendistroforelasticsearch/sql/elasticsearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java +++ b/elasticsearch/src/test/java/com/amazon/opendistroforelasticsearch/sql/elasticsearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java @@ -27,6 +27,8 @@ import com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.serialization.ExpressionSerializer; import com.amazon.opendistroforelasticsearch.sql.expression.aggregation.AvgAggregator; import com.amazon.opendistroforelasticsearch.sql.expression.aggregation.CountAggregator; +import com.amazon.opendistroforelasticsearch.sql.expression.aggregation.MaxAggregator; +import com.amazon.opendistroforelasticsearch.sql.expression.aggregation.MinAggregator; import com.amazon.opendistroforelasticsearch.sql.expression.aggregation.NamedAggregator; import com.amazon.opendistroforelasticsearch.sql.expression.aggregation.SumAggregator; import com.amazon.opendistroforelasticsearch.sql.expression.function.FunctionName; @@ -107,15 +109,47 @@ void should_build_count_aggregation() { new CountAggregator(Arrays.asList(ref("age", INTEGER)), INTEGER))))); } + @Test + void should_build_min_aggregation() { + assertEquals( + "{\n" + + " \"min(age)\" : {\n" + + " \"min\" : {\n" + + " \"field\" : \"age\"\n" + + " }\n" + + " }\n" + + "}", + buildQuery( + Arrays.asList( + named("min(age)", + new MinAggregator(Arrays.asList(ref("age", INTEGER)), INTEGER))))); + } + + @Test + void should_build_max_aggregation() { + assertEquals( + "{\n" + + " \"max(age)\" : {\n" + + " \"max\" : {\n" + + " \"field\" : \"age\"\n" + + " }\n" + + " }\n" + + "}", + buildQuery( + Arrays.asList( + named("max(age)", + new MaxAggregator(Arrays.asList(ref("age", INTEGER)), INTEGER))))); + } + @Test void should_throw_exception_for_unsupported_aggregator() { - when(aggregator.getFunctionName()).thenReturn(new FunctionName("max")); + when(aggregator.getFunctionName()).thenReturn(new FunctionName("unsupported_agg")); when(aggregator.getArguments()).thenReturn(Arrays.asList(ref("age", INTEGER))); IllegalStateException exception = - assertThrows(IllegalStateException.class, () -> buildQuery(Arrays.asList(named("count(age)", - aggregator)))); - assertEquals("unsupported aggregator max", exception.getMessage()); + assertThrows(IllegalStateException.class, + () -> buildQuery(Arrays.asList(named("unsupported_agg(age)", aggregator)))); + assertEquals("unsupported aggregator unsupported_agg", exception.getMessage()); } @Test diff --git a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/ppl/StatsCommandIT.java b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/ppl/StatsCommandIT.java index 3d9f5e9f4d..530c8b21db 100644 --- a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/ppl/StatsCommandIT.java +++ b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/ppl/StatsCommandIT.java @@ -58,7 +58,23 @@ public void testStatsCount() throws IOException { verifyDataRows(response, rows(1000)); } - // TODO: each stats aggregate function should be tested here when implemented + @Test + public void testStatsMin() throws IOException { + JSONObject response = executeQuery(String.format( + "source=%s | stats min(age)", + TEST_INDEX_ACCOUNT)); + verifySchema(response, schema("min(age)", null, "long")); + verifyDataRows(response, rows(20)); + } + + @Test + public void testStatsMax() throws IOException { + JSONObject response = executeQuery(String.format( + "source=%s | stats max(age)", + TEST_INDEX_ACCOUNT)); + verifySchema(response, schema("max(age)", null, "long")); + verifyDataRows(response, rows(40)); + } @Test public void testStatsNested() throws IOException { @@ -96,4 +112,24 @@ public void testGroupByNullValue() throws IOException { rows(null, 36) ); } + + + @Test + public void testStatsWithNull() throws IOException { + JSONObject response = + executeQuery(String.format( + "source=%s | stats avg(age)", + TEST_INDEX_BANK_WITH_NULL_VALUES)); + verifySchema(response, schema("avg(age)", null, "double")); + verifyDataRows(response, rows(33.166666666666664)); + } + + @Test + public void testStatsWithMissing() throws IOException { + JSONObject response = executeQuery(String.format( + "source=%s | stats avg(balance)", + TEST_INDEX_BANK_WITH_NULL_VALUES)); + verifySchema(response, schema("avg(balance)", null, "double")); + verifyDataRows(response, rows(31082.25)); + } } diff --git a/integ-test/src/test/resources/correctness/queries/aggregation.txt b/integ-test/src/test/resources/correctness/queries/aggregation.txt new file mode 100644 index 0000000000..3a2081d9a8 --- /dev/null +++ b/integ-test/src/test/resources/correctness/queries/aggregation.txt @@ -0,0 +1,7 @@ +SELECT COUNT(AvgTicketPrice) FROM kibana_sample_data_flights +SELECT AVG(AvgTicketPrice) FROM kibana_sample_data_flights +SELECT SUM(AvgTicketPrice) FROM kibana_sample_data_flights +SELECT MAX(AvgTicketPrice) FROM kibana_sample_data_flights +SELECT MAX(timestamp) FROM kibana_sample_data_flights +SELECT MIN(AvgTicketPrice) FROM kibana_sample_data_flights +SELECT MIN(timestamp) FROM kibana_sample_data_flights \ No newline at end of file diff --git a/ppl/src/main/antlr/OpenDistroPPLParser.g4 b/ppl/src/main/antlr/OpenDistroPPLParser.g4 index 70846fb1d3..0fb63c6923 100644 --- a/ppl/src/main/antlr/OpenDistroPPLParser.g4 +++ b/ppl/src/main/antlr/OpenDistroPPLParser.g4 @@ -129,7 +129,7 @@ statsFunction ; statsFunctionName - : AVG | COUNT | SUM + : AVG | COUNT | SUM | MIN | MAX ; percentileAggFunction diff --git a/sql/src/main/antlr/OpenDistroSQLParser.g4 b/sql/src/main/antlr/OpenDistroSQLParser.g4 index 95cecea393..3dd86ad11b 100644 --- a/sql/src/main/antlr/OpenDistroSQLParser.g4 +++ b/sql/src/main/antlr/OpenDistroSQLParser.g4 @@ -217,10 +217,14 @@ scalarFunctionName ; aggregateFunction - : functionName=(AVG | SUM) LR_BRACKET functionArg RR_BRACKET + : functionName=aggregationFunctionName LR_BRACKET functionArg RR_BRACKET /*| COUNT LR_BRACKET (STAR | functionArg) RR_BRACKET */ ; +aggregationFunctionName + : AVG | COUNT | SUM | MIN | MAX + ; + mathematicalFunctionName : ABS | CEIL | CEILING | CONV | CRC32 | E | EXP | FLOOR | LN | LOG | LOG10 | LOG2 | MOD | PI | POW | POWER | RAND | ROUND | SIGN | SQRT | TRUNCATE