diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprValueUtils.java b/core/src/main/java/org/opensearch/sql/data/model/ExprValueUtils.java index 44e0350d31..932a1f3b0c 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprValueUtils.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprValueUtils.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.data.model; import java.time.Instant; diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java index c30ca13351..a24eeca1c1 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java @@ -69,6 +69,14 @@ private static DefaultFunctionResolver avg() { new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), (functionProperties, arguments) -> new AvgAggregator(arguments, DOUBLE)) + .put(new FunctionSignature(functionName, Collections.singletonList(DATE)), + (functionProperties, arguments) -> new AvgAggregator(arguments, DATE)) + .put(new FunctionSignature(functionName, Collections.singletonList(DATETIME)), + (functionProperties, arguments) -> new AvgAggregator(arguments, DATETIME)) + .put(new FunctionSignature(functionName, Collections.singletonList(TIME)), + (functionProperties, arguments) -> new AvgAggregator(arguments, TIME)) + .put(new FunctionSignature(functionName, Collections.singletonList(TIMESTAMP)), + (functionProperties, arguments) -> new AvgAggregator(arguments, TIMESTAMP)) .build() ); } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AvgAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AvgAggregator.java index cadfdee87d..a899a6b45b 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AvgAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AvgAggregator.java @@ -6,14 +6,23 @@ package org.opensearch.sql.expression.aggregation; +import static java.time.temporal.ChronoUnit.MILLIS; import static org.opensearch.sql.utils.ExpressionUtils.format; +import java.time.Instant; +import java.time.LocalTime; import java.util.List; import java.util.Locale; +import org.opensearch.sql.data.model.ExprDateValue; +import org.opensearch.sql.data.model.ExprDatetimeValue; +import org.opensearch.sql.data.model.ExprDoubleValue; +import org.opensearch.sql.data.model.ExprIntegerValue; import org.opensearch.sql.data.model.ExprNullValue; +import org.opensearch.sql.data.model.ExprTimeValue; +import org.opensearch.sql.data.model.ExprTimestampValue; import org.opensearch.sql.data.model.ExprValue; -import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.function.BuiltinFunctionName; @@ -23,20 +32,39 @@ */ public class AvgAggregator extends Aggregator { + /** + * To process by different ways different data types, we need to store the type. + * Input data has the same type as the result. + */ + private final ExprCoreType dataType; + public AvgAggregator(List arguments, ExprCoreType returnType) { super(BuiltinFunctionName.AVG.getName(), arguments, returnType); + dataType = returnType; } @Override public AvgState create() { - return new AvgState(); + switch (dataType) { + case DATE: + return new DateAvgState(); + case DATETIME: + return new DateTimeAvgState(); + case TIMESTAMP: + return new TimestampAvgState(); + case TIME: + return new TimeAvgState(); + case DOUBLE: + return new DoubleAvgState(); + default: //unreachable code - we don't expose signatures for unsupported types + throw new IllegalArgumentException( + String.format("avg aggregation over %s type is not supported", dataType)); + } } @Override protected AvgState iterate(ExprValue value, AvgState state) { - state.count++; - state.total += ExprValueUtils.getDoubleValue(value); - return state; + return state.iterate(value); } @Override @@ -47,18 +75,117 @@ public String toString() { /** * Average State. */ - protected static class AvgState implements AggregationState { - private int count; - private double total; + protected abstract static class AvgState implements AggregationState { + protected ExprValue count; + protected ExprValue total; AvgState() { - this.count = 0; - this.total = 0d; + this.count = new ExprIntegerValue(0); + this.total = new ExprDoubleValue(0D); + } + + @Override + public abstract ExprValue result(); + + protected AvgState iterate(ExprValue value) { + count = DSL.add(DSL.literal(count), DSL.literal(1)).valueOf(); + return this; + } + } + + protected static class DoubleAvgState extends AvgState { + @Override + public ExprValue result() { + if (0 == count.integerValue()) { + return ExprNullValue.of(); + } + return DSL.divide(DSL.literal(total), DSL.literal(count)).valueOf(); + } + + @Override + protected AvgState iterate(ExprValue value) { + total = DSL.add(DSL.literal(total), DSL.literal(value)).valueOf(); + return super.iterate(value); + } + } + + protected static class DateAvgState extends AvgState { + @Override + public ExprValue result() { + if (0 == count.integerValue()) { + return ExprNullValue.of(); + } + + return new ExprDateValue( + new ExprTimestampValue(Instant.ofEpochMilli( + DSL.divide(DSL.literal(total), DSL.literal(count)).valueOf().longValue())) + .dateValue()); + } + + @Override + protected AvgState iterate(ExprValue value) { + total = DSL.add(DSL.literal(total), DSL.literal(value.timestampValue().toEpochMilli())) + .valueOf(); + return super.iterate(value); + } + } + + protected static class DateTimeAvgState extends AvgState { + @Override + public ExprValue result() { + if (0 == count.integerValue()) { + return ExprNullValue.of(); + } + + return new ExprDatetimeValue( + new ExprTimestampValue(Instant.ofEpochMilli( + DSL.divide(DSL.literal(total), DSL.literal(count)).valueOf().longValue())) + .datetimeValue()); + } + + @Override + protected AvgState iterate(ExprValue value) { + total = DSL.add(DSL.literal(total), DSL.literal(value.timestampValue().toEpochMilli())) + .valueOf(); + return super.iterate(value); + } + } + + protected static class TimestampAvgState extends AvgState { + @Override + public ExprValue result() { + if (0 == count.integerValue()) { + return ExprNullValue.of(); + } + + return new ExprTimestampValue(Instant.ofEpochMilli( + DSL.divide(DSL.literal(total), DSL.literal(count)).valueOf().longValue())); + } + + @Override + protected AvgState iterate(ExprValue value) { + total = DSL.add(DSL.literal(total), DSL.literal(value.timestampValue().toEpochMilli())) + .valueOf(); + return super.iterate(value); } + } + protected static class TimeAvgState extends AvgState { @Override public ExprValue result() { - return count == 0 ? ExprNullValue.of() : ExprValueUtils.doubleValue(total / count); + if (0 == count.integerValue()) { + return ExprNullValue.of(); + } + + return new ExprTimeValue(LocalTime.MIN.plus( + DSL.divide(DSL.literal(total), DSL.literal(count)).valueOf().longValue(), MILLIS)); + } + + @Override + protected AvgState iterate(ExprValue value) { + total = DSL.add(DSL.literal(total), + DSL.literal(MILLIS.between(LocalTime.MIN, value.timeValue()))).valueOf(); + return super.iterate(value); } } } diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/AvgAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/AvgAggregatorTest.java index 1fb7a1061c..b3b0052bc3 100644 --- a/core/src/test/java/org/opensearch/sql/expression/aggregation/AvgAggregatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/AvgAggregatorTest.java @@ -9,9 +9,18 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.opensearch.sql.data.type.ExprCoreType.DATE; +import static org.opensearch.sql.data.type.ExprCoreType.DATETIME; import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.data.type.ExprCoreType.TIME; +import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.util.List; import org.junit.jupiter.api.Test; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -62,6 +71,61 @@ public void avg_with_all_missing_or_null() { assertTrue(result.isNull()); } + @Test + public void avg_numeric_no_values() { + ExprValue result = aggregation(DSL.avg(DSL.ref("dummy", INTEGER)), List.of()); + assertTrue(result.isNull()); + } + + @Test + public void avg_date_no_values() { + ExprValue result = aggregation(DSL.avg(DSL.ref("dummy", DATE)), List.of()); + assertTrue(result.isNull()); + } + + @Test + public void avg_datetime_no_values() { + ExprValue result = aggregation(DSL.avg(DSL.ref("dummy", DATETIME)), List.of()); + assertTrue(result.isNull()); + } + + @Test + public void avg_timestamp_no_values() { + ExprValue result = aggregation(DSL.avg(DSL.ref("dummy", TIMESTAMP)), List.of()); + assertTrue(result.isNull()); + } + + @Test + public void avg_time_no_values() { + ExprValue result = aggregation(DSL.avg(DSL.ref("dummy", TIME)), List.of()); + assertTrue(result.isNull()); + } + + @Test + public void avg_date() { + ExprValue result = aggregation(DSL.avg(DSL.date(DSL.ref("date_value", STRING))), tuples); + assertEquals(LocalDate.of(2007, 7, 2), result.dateValue()); + } + + @Test + public void avg_datetime() { + var result = aggregation(DSL.avg(DSL.datetime(DSL.ref("datetime_value", STRING))), tuples); + assertEquals(LocalDateTime.of(2012, 7, 2, 3, 30), result.datetimeValue()); + } + + @Test + public void avg_time() { + ExprValue result = aggregation(DSL.avg(DSL.time(DSL.ref("time_value", STRING))), tuples); + assertEquals(LocalTime.of(9, 30), result.timeValue()); + } + + @Test + public void avg_timestamp() { + var result = aggregation(DSL.avg(DSL.timestamp(DSL.ref("timestamp_value", STRING))), tuples); + assertEquals(TIMESTAMP, result.type()); + assertEquals(LocalDateTime.of(2012, 7, 2, 3, 30), result.datetimeValue()); + } + @Test public void valueOf() { ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class, @@ -69,6 +133,14 @@ public void valueOf() { assertEquals("can't evaluate on aggregator: avg", exception.getMessage()); } + @Test + public void avg_on_unsupported_type() { + var aggregator = new AvgAggregator(List.of(DSL.ref("string", STRING)), STRING); + var exception = assertThrows(IllegalArgumentException.class, + () -> aggregator.create()); + assertEquals("avg aggregation over STRING type is not supported", exception.getMessage()); + } + @Test public void test_to_string() { Aggregator avgAggregator = DSL.avg(DSL.ref("integer_value", INTEGER)); diff --git a/docs/user/dql/aggregations.rst b/docs/user/dql/aggregations.rst index 275666e7ba..d0cbb28f62 100644 --- a/docs/user/dql/aggregations.rst +++ b/docs/user/dql/aggregations.rst @@ -163,7 +163,7 @@ SUM Description >>>>>>>>>>> -Usage: SUM(expr). Returns the sum of expr. +Usage: SUM(expr). Returns the sum of `expr`. `expr` could be of any of the numeric data types. Example:: @@ -182,7 +182,7 @@ AVG Description >>>>>>>>>>> -Usage: AVG(expr). Returns the average value of expr. +Usage: AVG(expr). Returns the average value of `expr`. `expr` can be any numeric or datetime data type. Datetime aggregation is performed with milliseconds precision. Example:: @@ -201,7 +201,7 @@ MAX Description >>>>>>>>>>> -Usage: MAX(expr). Returns the maximum value of expr. +Usage: MAX(expr). Returns the maximum value of `expr`. `expr` can be any numeric or datetime data type. Datetime aggregation is performed with milliseconds precision. Example:: @@ -219,7 +219,7 @@ MIN Description >>>>>>>>>>> -Usage: MIN(expr). Returns the minimum value of expr. +Usage: MIN(expr). Returns the minimum value of `expr`. `expr` can be any numeric or datetime data type. Datetime aggregation is performed with milliseconds precision. Example:: diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java index 95f5b5e3e4..487699cf79 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java @@ -162,7 +162,7 @@ public void testInMemoryAggregationOnAllValuesAndOnNotNullReturnsSameResult() th } @Test - public void testPushDownAggregationOnNullValuesReturnsNull() throws IOException { + public void testPushDownAggregationOnNullNumericValuesReturnsNull() throws IOException { var response = executeQuery(String.format("SELECT " + "max(int0), min(int0), avg(int0) from %s where int0 IS NULL;", TEST_INDEX_CALCS)); verifySchema(response, @@ -172,6 +172,61 @@ public void testPushDownAggregationOnNullValuesReturnsNull() throws IOException verifyDataRows(response, rows(null, null, null)); } + @Test + public void testPushDownAggregationOnNullDateTimeValuesFromTableReturnsNull() throws IOException { + var response = executeQuery(String.format("SELECT " + + "max(datetime1), min(datetime1), avg(datetime1) from %s", TEST_INDEX_CALCS)); + verifySchema(response, + schema("max(datetime1)", null, "timestamp"), + schema("min(datetime1)", null, "timestamp"), + schema("avg(datetime1)", null, "timestamp")); + verifyDataRows(response, rows(null, null, null)); + } + + @Test + public void testPushDownAggregationOnNullDateValuesReturnsNull() throws IOException { + var response = executeQuery(String.format("SELECT " + + "max(CAST(NULL AS date)), min(CAST(NULL AS date)), avg(CAST(NULL AS date)) from %s", TEST_INDEX_CALCS)); + verifySchema(response, + schema("max(CAST(NULL AS date))", null, "date"), + schema("min(CAST(NULL AS date))", null, "date"), + schema("avg(CAST(NULL AS date))", null, "date")); + verifyDataRows(response, rows(null, null, null)); + } + + @Test + public void testPushDownAggregationOnNullTimeValuesReturnsNull() throws IOException { + var response = executeQuery(String.format("SELECT " + + "max(CAST(NULL AS time)), min(CAST(NULL AS time)), avg(CAST(NULL AS time)) from %s", TEST_INDEX_CALCS)); + verifySchema(response, + schema("max(CAST(NULL AS time))", null, "time"), + schema("min(CAST(NULL AS time))", null, "time"), + schema("avg(CAST(NULL AS time))", null, "time")); + verifyDataRows(response, rows(null, null, null)); + } + + @Test + public void testPushDownAggregationOnNullTimeStampValuesReturnsNull() throws IOException { + var response = executeQuery(String.format("SELECT " + + "max(CAST(NULL AS timestamp)), min(CAST(NULL AS timestamp)), avg(CAST(NULL AS timestamp)) from %s", TEST_INDEX_CALCS)); + verifySchema(response, + schema("max(CAST(NULL AS timestamp))", null, "timestamp"), + schema("min(CAST(NULL AS timestamp))", null, "timestamp"), + schema("avg(CAST(NULL AS timestamp))", null, "timestamp")); + verifyDataRows(response, rows(null, null, null)); + } + + @Test + public void testPushDownAggregationOnNullDateTimeValuesReturnsNull() throws IOException { + var response = executeQuery(String.format("SELECT " + + "max(datetime(NULL)), min(datetime(NULL)), avg(datetime(NULL)) from %s", TEST_INDEX_CALCS)); + verifySchema(response, + schema("max(datetime(NULL))", null, "datetime"), + schema("min(datetime(NULL))", null, "datetime"), + schema("avg(datetime(NULL))", null, "datetime")); + verifyDataRows(response, rows(null, null, null)); + } + @Test public void testPushDownAggregationOnAllValuesAndOnNotNullReturnsSameResult() throws IOException { var responseNotNulls = executeQuery(String.format("SELECT " @@ -225,6 +280,303 @@ public void testPushDownAndInMemoryAggregationReturnTheSameResult() throws IOExc } } + public void testMinIntegerPushedDown() throws IOException { + var response = executeQuery(String.format("SELECT min(int2)" + + " from %s", TEST_INDEX_CALCS)); + verifySchema(response, schema("min(int2)", null, "integer")); + verifyDataRows(response, rows(-9)); + } + + @Test + public void testMaxIntegerPushedDown() throws IOException { + var response = executeQuery(String.format("SELECT max(int2)" + + " from %s", TEST_INDEX_CALCS)); + verifySchema(response, schema("max(int2)", null, "integer")); + verifyDataRows(response, rows(9)); + } + + @Test + public void testAvgIntegerPushedDown() throws IOException { + var response = executeQuery(String.format("SELECT avg(int2)" + + " from %s", TEST_INDEX_CALCS)); + verifySchema(response, schema("avg(int2)", null, "double")); + verifyDataRows(response, rows(-0.8235294117647058D)); + } + + @Test + public void testMinDoublePushedDown() throws IOException { + var response = executeQuery(String.format("SELECT min(num3)" + + " from %s", TEST_INDEX_CALCS)); + verifySchema(response, schema("min(num3)", null, "double")); + verifyDataRows(response, rows(-19.96D)); + } + + @Test + public void testMaxDoublePushedDown() throws IOException { + var response = executeQuery(String.format("SELECT max(num3)" + + " from %s", TEST_INDEX_CALCS)); + verifySchema(response, schema("max(num3)", null, "double")); + verifyDataRows(response, rows(12.93D)); + } + + @Test + public void testAvgDoublePushedDown() throws IOException { + var response = executeQuery(String.format("SELECT avg(num3)" + + " from %s", TEST_INDEX_CALCS)); + verifySchema(response, schema("avg(num3)", null, "double")); + verifyDataRows(response, rows(-6.12D)); + } + + @Test + public void testMinIntegerInMemory() throws IOException { + var response = executeQuery(String.format("SELECT min(int2)" + + " OVER(PARTITION BY datetime1) from %s", TEST_INDEX_CALCS)); + verifySchema(response, + schema("min(int2) OVER(PARTITION BY datetime1)", null, "integer")); + verifySome(response.getJSONArray("datarows"), rows(-9)); + } + + @Test + public void testMaxIntegerInMemory() throws IOException { + var response = executeQuery(String.format("SELECT max(int2)" + + " OVER(PARTITION BY datetime1) from %s", TEST_INDEX_CALCS)); + verifySchema(response, + schema("max(int2) OVER(PARTITION BY datetime1)", null, "integer")); + verifySome(response.getJSONArray("datarows"), rows(9)); + } + + @Test + public void testAvgIntegerInMemory() throws IOException { + var response = executeQuery(String.format("SELECT avg(int2)" + + " OVER(PARTITION BY datetime1) from %s", TEST_INDEX_CALCS)); + verifySchema(response, + schema("avg(int2) OVER(PARTITION BY datetime1)", null, "double")); + verifySome(response.getJSONArray("datarows"), rows(-0.8235294117647058D)); + } + + @Test + public void testMinDoubleInMemory() throws IOException { + var response = executeQuery(String.format("SELECT min(num3)" + + " OVER(PARTITION BY datetime1) from %s", TEST_INDEX_CALCS)); + verifySchema(response, + schema("min(num3) OVER(PARTITION BY datetime1)", null, "double")); + verifySome(response.getJSONArray("datarows"), rows(-19.96D)); + } + + @Test + public void testMaxDoubleInMemory() throws IOException { + var response = executeQuery(String.format("SELECT max(num3)" + + " OVER(PARTITION BY datetime1) from %s", TEST_INDEX_CALCS)); + verifySchema(response, + schema("max(num3) OVER(PARTITION BY datetime1)", null, "double")); + verifySome(response.getJSONArray("datarows"), rows(12.93D)); + } + + @Test + public void testAvgDoubleInMemory() throws IOException { + var response = executeQuery(String.format("SELECT avg(num3)" + + " OVER(PARTITION BY datetime1) from %s", TEST_INDEX_CALCS)); + verifySchema(response, + schema("avg(num3) OVER(PARTITION BY datetime1)", null, "double")); + verifySome(response.getJSONArray("datarows"), rows(-6.12D)); + } + + @Test + public void testMaxDatePushedDown() throws IOException { + var response = executeQuery(String.format("SELECT max(CAST(date0 AS date))" + + " from %s", TEST_INDEX_CALCS)); + verifySchema(response, schema("max(CAST(date0 AS date))", null, "date")); + verifyDataRows(response, rows("2004-06-19")); + } + + @Test + public void testAvgDatePushedDown() throws IOException { + var response = executeQuery(String.format("SELECT avg(CAST(date0 AS date))" + + " from %s", TEST_INDEX_CALCS)); + verifySchema(response, schema("avg(CAST(date0 AS date))", null, "date")); + verifyDataRows(response, rows("1992-04-23")); + } + + @Test + public void testMinDateTimePushedDown() throws IOException { + var response = executeQuery(String.format("SELECT min(datetime(CAST(time0 AS STRING)))" + + " from %s", TEST_INDEX_CALCS)); + verifySchema(response, schema("min(datetime(CAST(time0 AS STRING)))", null, "datetime")); + verifyDataRows(response, rows("1899-12-30 21:07:32")); + } + + @Test + public void testMaxDateTimePushedDown() throws IOException { + var response = executeQuery(String.format("SELECT max(datetime(CAST(time0 AS STRING)))" + + " from %s", TEST_INDEX_CALCS)); + verifySchema(response, schema("max(datetime(CAST(time0 AS STRING)))", null, "datetime")); + verifyDataRows(response, rows("1900-01-01 20:36:00")); + } + + @Test + public void testAvgDateTimePushedDown() throws IOException { + var response = executeQuery(String.format("SELECT avg(datetime(CAST(time0 AS STRING)))" + + " from %s", TEST_INDEX_CALCS)); + verifySchema(response, schema("avg(datetime(CAST(time0 AS STRING)))", null, "datetime")); + verifyDataRows(response, rows("1900-01-01 03:35:00.236")); + } + + @Test + public void testMinTimePushedDown() throws IOException { + var response = executeQuery(String.format("SELECT min(CAST(time1 AS time))" + + " from %s", TEST_INDEX_CALCS)); + verifySchema(response, schema("min(CAST(time1 AS time))", null, "time")); + verifyDataRows(response, rows("00:05:57")); + } + + @Test + public void testMaxTimePushedDown() throws IOException { + var response = executeQuery(String.format("SELECT max(CAST(time1 AS time))" + + " from %s", TEST_INDEX_CALCS)); + verifySchema(response, schema("max(CAST(time1 AS time))", null, "time")); + verifyDataRows(response, rows("22:50:16")); + } + + @Test + public void testAvgTimePushedDown() throws IOException { + var response = executeQuery(String.format("SELECT avg(CAST(time1 AS time))" + + " from %s", TEST_INDEX_CALCS)); + verifySchema(response, schema("avg(CAST(time1 AS time))", null, "time")); + verifyDataRows(response, rows("13:06:36.25")); + } + + @Test + public void testMinTimeStampPushedDown() throws IOException { + var response = executeQuery(String.format("SELECT min(CAST(datetime0 AS timestamp))" + + " from %s", TEST_INDEX_CALCS)); + verifySchema(response, schema("min(CAST(datetime0 AS timestamp))", null, "timestamp")); + verifyDataRows(response, rows("2004-07-04 22:49:28")); + } + + @Test + public void testMaxTimeStampPushedDown() throws IOException { + var response = executeQuery(String.format("SELECT max(CAST(datetime0 AS timestamp))" + + " from %s", TEST_INDEX_CALCS)); + verifySchema(response, schema("max(CAST(datetime0 AS timestamp))", null, "timestamp")); + verifyDataRows(response, rows("2004-08-02 07:59:23")); + } + + @Test + public void testAvgTimeStampPushedDown() throws IOException { + var response = executeQuery(String.format("SELECT avg(CAST(datetime0 AS timestamp))" + + " from %s", TEST_INDEX_CALCS)); + verifySchema(response, schema("avg(CAST(datetime0 AS timestamp))", null, "timestamp")); + verifyDataRows(response, rows("2004-07-20 10:38:09.705")); + } + + @Test + public void testMinDateInMemory() throws IOException { + var response = executeQuery(String.format("SELECT min(CAST(date0 AS date))" + + " OVER(PARTITION BY datetime1) from %s", TEST_INDEX_CALCS)); + verifySchema(response, + schema("min(CAST(date0 AS date)) OVER(PARTITION BY datetime1)", null, "date")); + verifySome(response.getJSONArray("datarows"), rows("1972-07-04")); + } + + @Test + public void testMaxDateInMemory() throws IOException { + var response = executeQuery(String.format("SELECT max(CAST(date0 AS date))" + + " OVER(PARTITION BY datetime1) from %s", TEST_INDEX_CALCS)); + verifySchema(response, + schema("max(CAST(date0 AS date)) OVER(PARTITION BY datetime1)", null, "date")); + verifySome(response.getJSONArray("datarows"), rows("2004-06-19")); + } + + @Test + public void testAvgDateInMemory() throws IOException { + var response = executeQuery(String.format("SELECT avg(CAST(date0 AS date))" + + " OVER(PARTITION BY datetime1) from %s", TEST_INDEX_CALCS)); + verifySchema(response, + schema("avg(CAST(date0 AS date)) OVER(PARTITION BY datetime1)", null, "date")); + verifySome(response.getJSONArray("datarows"), rows("1992-04-23")); + } + + @Test + public void testMinDateTimeInMemory() throws IOException { + var response = executeQuery(String.format("SELECT min(datetime(CAST(time0 AS STRING)))" + + " OVER(PARTITION BY datetime1) from %s", TEST_INDEX_CALCS)); + verifySchema(response, + schema("min(datetime(CAST(time0 AS STRING))) OVER(PARTITION BY datetime1)", null, "datetime")); + verifySome(response.getJSONArray("datarows"), rows("1899-12-30 21:07:32")); + } + + @Test + public void testMaxDateTimeInMemory() throws IOException { + var response = executeQuery(String.format("SELECT max(datetime(CAST(time0 AS STRING)))" + + " OVER(PARTITION BY datetime1) from %s", TEST_INDEX_CALCS)); + verifySchema(response, + schema("max(datetime(CAST(time0 AS STRING))) OVER(PARTITION BY datetime1)", null, "datetime")); + verifySome(response.getJSONArray("datarows"), rows("1900-01-01 20:36:00")); + } + + @Test + public void testAvgDateTimeInMemory() throws IOException { + var response = executeQuery(String.format("SELECT avg(datetime(CAST(time0 AS STRING)))" + + " OVER(PARTITION BY datetime1) from %s", TEST_INDEX_CALCS)); + verifySchema(response, + schema("avg(datetime(CAST(time0 AS STRING))) OVER(PARTITION BY datetime1)", null, "datetime")); + verifySome(response.getJSONArray("datarows"), rows("1900-01-01 03:35:00.236")); + } + + @Test + public void testMinTimeInMemory() throws IOException { + var response = executeQuery(String.format("SELECT min(CAST(time1 AS time))" + + " OVER(PARTITION BY datetime1) from %s", TEST_INDEX_CALCS)); + verifySchema(response, + schema("min(CAST(time1 AS time)) OVER(PARTITION BY datetime1)", null, "time")); + verifySome(response.getJSONArray("datarows"), rows("00:05:57")); + } + + @Test + public void testMaxTimeInMemory() throws IOException { + var response = executeQuery(String.format("SELECT max(CAST(time1 AS time))" + + " OVER(PARTITION BY datetime1) from %s", TEST_INDEX_CALCS)); + verifySchema(response, + schema("max(CAST(time1 AS time)) OVER(PARTITION BY datetime1)", null, "time")); + verifySome(response.getJSONArray("datarows"), rows("22:50:16")); + } + + @Test + public void testAvgTimeInMemory() throws IOException { + var response = executeQuery(String.format("SELECT avg(CAST(time1 AS time))" + + " OVER(PARTITION BY datetime1) from %s", TEST_INDEX_CALCS)); + verifySchema(response, + schema("avg(CAST(time1 AS time)) OVER(PARTITION BY datetime1)", null, "time")); + verifySome(response.getJSONArray("datarows"), rows("13:06:36.25")); + } + + @Test + public void testMinTimeStampInMemory() throws IOException { + var response = executeQuery(String.format("SELECT min(CAST(datetime0 AS timestamp))" + + " OVER(PARTITION BY datetime1) from %s", TEST_INDEX_CALCS)); + verifySchema(response, + schema("min(CAST(datetime0 AS timestamp)) OVER(PARTITION BY datetime1)", null, "timestamp")); + verifySome(response.getJSONArray("datarows"), rows("2004-07-04 22:49:28")); + } + + @Test + public void testMaxTimeStampInMemory() throws IOException { + var response = executeQuery(String.format("SELECT max(CAST(datetime0 AS timestamp))" + + " OVER(PARTITION BY datetime1) from %s", TEST_INDEX_CALCS)); + verifySchema(response, + schema("max(CAST(datetime0 AS timestamp)) OVER(PARTITION BY datetime1)", null, "timestamp")); + verifySome(response.getJSONArray("datarows"), rows("2004-08-02 07:59:23")); + } + + @Test + public void testAvgTimeStampInMemory() throws IOException { + var response = executeQuery(String.format("SELECT avg(CAST(datetime0 AS timestamp))" + + " OVER(PARTITION BY datetime1) from %s", TEST_INDEX_CALCS)); + verifySchema(response, + schema("avg(CAST(datetime0 AS timestamp)) OVER(PARTITION BY datetime1)", null, "timestamp")); + verifySome(response.getJSONArray("datarows"), rows("2004-07-20 10:38:09.705")); + } + protected JSONObject executeQuery(String query) throws IOException { Request request = new Request("POST", QUERY_API_ENDPOINT); request.setJsonEntity(String.format(Locale.ROOT, "{\n" + " \"query\": \"%s\"\n" + "}", query)); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/ExpressionAggregationScript.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/ExpressionAggregationScript.java index 6236d7bb32..2a371afaa3 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/ExpressionAggregationScript.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/ExpressionAggregationScript.java @@ -6,6 +6,11 @@ package org.opensearch.sql.opensearch.storage.script.aggregation; +import static java.time.temporal.ChronoUnit.MILLIS; + +import java.time.LocalTime; +import java.time.ZoneId; +import java.time.ZonedDateTime; import java.util.Map; import lombok.EqualsAndHashCode; import org.apache.lucene.index.LeafReaderContext; @@ -13,8 +18,10 @@ import org.opensearch.search.lookup.SearchLookup; import org.opensearch.sql.data.model.ExprNullValue; import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.env.Environment; +import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; import org.opensearch.sql.opensearch.storage.script.core.ExpressionScript; /** @@ -42,7 +49,21 @@ public ExpressionAggregationScript( @Override public Object execute() { - return expressionScript.execute(this::getDoc, this::evaluateExpression).value(); + var expr = expressionScript.execute(this::getDoc, this::evaluateExpression); + if (expr.type() instanceof OpenSearchDataType) { + return expr.value(); + } + switch ((ExprCoreType)expr.type()) { + case TIME: + // Can't get timestamp from `ExprTimeValue` + return MILLIS.between(LocalTime.MIN, expr.timeValue()); + case DATE: + case DATETIME: + case TIMESTAMP: + return expr.timestampValue().toEpochMilli(); + default: + return expr.value(); + } } private ExprValue evaluateExpression(Expression expression, Environment buildCompositeValuesSourceBuilder( .missingBucket(true) .missingOrder(missingOrder) .order(sortOrder); + // Time types values are converted to LONG in ExpressionAggregationScript::execute + if (List.of(TIMESTAMP, TIME, DATE, DATETIME).contains(expr.getDelegated().type())) { + sourceBuilder.userValuetypeHint(ValueType.LONG); + } return helper.build(expr.getDelegated(), sourceBuilder::field, sourceBuilder::script); } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/ExpressionAggregationScriptTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/ExpressionAggregationScriptTest.java index 79b6bba712..d8abb73140 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/ExpressionAggregationScriptTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/ExpressionAggregationScriptTest.java @@ -6,6 +6,7 @@ package org.opensearch.sql.opensearch.storage.script.aggregation; +import static java.time.temporal.ChronoUnit.MILLIS; import static java.util.Collections.emptyMap; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; @@ -16,9 +17,12 @@ import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.expression.DSL.literal; import static org.opensearch.sql.expression.DSL.ref; +import static org.opensearch.sql.opensearch.data.type.OpenSearchDataType.OPENSEARCH_TEXT; import static org.opensearch.sql.opensearch.data.type.OpenSearchDataType.OPENSEARCH_TEXT_KEYWORD; import com.google.common.collect.ImmutableMap; +import java.time.LocalDate; +import java.time.LocalTime; import java.util.Map; import lombok.RequiredArgsConstructor; import org.apache.lucene.index.LeafReaderContext; @@ -32,6 +36,10 @@ import org.opensearch.search.lookup.LeafDocLookup; import org.opensearch.search.lookup.LeafSearchLookup; import org.opensearch.search.lookup.SearchLookup; +import org.opensearch.sql.data.model.ExprDateValue; +import org.opensearch.sql.data.model.ExprDatetimeValue; +import org.opensearch.sql.data.model.ExprStringValue; +import org.opensearch.sql.data.model.ExprTimestampValue; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; @@ -100,6 +108,53 @@ void can_execute_parse_expression() { .shouldMatch("30"); } + @Test + void can_execute_expression_interpret_dates_for_aggregation() { + assertThat() + .docValues("date", "1961-04-12") + .evaluate( + DSL.date(ref("date", STRING))) + .shouldMatch(new ExprDateValue(LocalDate.of(1961, 4, 12)) + .timestampValue().toEpochMilli()); + } + + @Test + void can_execute_expression_interpret_datetimes_for_aggregation() { + assertThat() + .docValues("datetime", "1984-03-17 22:16:42") + .evaluate( + DSL.datetime(ref("datetime", STRING))) + .shouldMatch(new ExprDatetimeValue("1984-03-17 22:16:42") + .timestampValue().toEpochMilli()); + } + + @Test + void can_execute_expression_interpret_times_for_aggregation() { + assertThat() + .docValues("time", "22:13:42") + .evaluate( + DSL.time(ref("time", STRING))) + .shouldMatch(MILLIS.between(LocalTime.MIN, LocalTime.of(22, 13, 42))); + } + + @Test + void can_execute_expression_interpret_timestamps_for_aggregation() { + assertThat() + .docValues("timestamp", "1984-03-17 22:16:42") + .evaluate( + DSL.timestamp(ref("timestamp", STRING))) + .shouldMatch(new ExprTimestampValue("1984-03-17 22:16:42") + .timestampValue().toEpochMilli()); + } + + @Test + void can_execute_expression_interpret_non_core_type_for_aggregation() { + assertThat() + .docValues("text", "pewpew") + .evaluate(ref("text", OPENSEARCH_TEXT)) + .shouldMatch("pewpew"); + } + private ExprScriptAssertion assertThat() { return new ExprScriptAssertion(lookup, leafLookup, context); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/BucketAggregationBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/BucketAggregationBuilderTest.java index 25fee2047a..32ab263b8f 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/BucketAggregationBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/BucketAggregationBuilderTest.java @@ -9,8 +9,12 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.when; import static org.opensearch.common.xcontent.ToXContent.EMPTY_PARAMS; +import static org.opensearch.sql.data.type.ExprCoreType.DATE; +import static org.opensearch.sql.data.type.ExprCoreType.DATETIME; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.data.type.ExprCoreType.TIME; +import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP; import static org.opensearch.sql.expression.DSL.literal; import static org.opensearch.sql.expression.DSL.named; import static org.opensearch.sql.expression.DSL.ref; @@ -25,6 +29,9 @@ import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.common.bytes.BytesReference; @@ -34,6 +41,8 @@ import org.opensearch.search.aggregations.bucket.composite.CompositeValuesSourceBuilder; import org.opensearch.search.aggregations.bucket.missing.MissingOrder; import org.opensearch.search.sort.SortOrder; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.parse.ParseExpression; @@ -128,6 +137,24 @@ void should_build_bucket_with_parse_expression() { asc(named("name", parseExpression))))); } + @ParameterizedTest(name = "{0}") + @EnumSource(value = ExprCoreType.class, names = {"TIMESTAMP", "TIME", "DATE", "DATETIME"}) + void terms_bucket_for_datetime_types_uses_long(ExprType dataType) { + assertEquals( + "{\n" + + " \"terms\" : {\n" + + " \"field\" : \"date\",\n" + + " \"missing_bucket\" : true,\n" + + " \"value_type\" : \"long\",\n" + + " \"missing_order\" : \"first\",\n" + + " \"order\" : \"asc\"\n" + + " }\n" + + "}", + buildQuery( + Arrays.asList( + asc(named("date", ref("date", dataType)))))); + } + @SneakyThrows private String buildQuery( List> groupByExpressions) {