From bf5d29cf38ca37f85af9bcf9a6e8f5f6280fa5af Mon Sep 17 00:00:00 2001 From: Andras Palinkas Date: Wed, 9 Dec 2020 12:13:27 -0500 Subject: [PATCH] SQL: Fix SUM(all zeroes) to return 0 instead of NULL (#65796) Previously the SUM(all zeroes) was `NULL`, but after this change the SUM SQL function call is automatically upgraded into a `stats` aggregation instead of a `sum` aggregation. The `stats` aggregation only results in `NULL` if the there were no rows, no values (all nulls) to aggregate, which is the expected behaviour across different SQL implementations. This is a workaround for the issue #45251 . Once the results of the `sum` aggregation can differentiate between `SUM(all nulls)` and `SUM(all zeroes`) the optimizer rule introduced in this commit needs to be removed. (cherry-picked from b74792a) --- .../qa/server/src/main/resources/agg.csv-spec | 417 ++++++++++++++++++ .../sql/qa/server/src/main/resources/logs.csv | 1 + .../server/src/main/resources/pivot.csv-spec | 13 +- .../xpack/sql/optimizer/Optimizer.java | 34 ++ .../xpack/sql/optimizer/OptimizerTests.java | 95 ++-- .../sql/planner/QueryTranslatorTests.java | 16 + 6 files changed, 547 insertions(+), 29 deletions(-) diff --git a/x-pack/plugin/sql/qa/server/src/main/resources/agg.csv-spec b/x-pack/plugin/sql/qa/server/src/main/resources/agg.csv-spec index 084af17a9d82..30d2a6f88d4b 100644 --- a/x-pack/plugin/sql/qa/server/src/main/resources/agg.csv-spec +++ b/x-pack/plugin/sql/qa/server/src/main/resources/agg.csv-spec @@ -1325,3 +1325,420 @@ F |1964-10-18T00:00:00.000Z|1952-04-19T00:00:00.000Z M |1965-01-03T00:00:00.000Z|1952-02-27T00:00:00.000Z ; + +// +// Aggregations on NULLs and Zeros +// + +allZerosWithFirst +schema::FIRST_AllZeros:i +SELECT FIRST(bytes_in) as "FIRST_AllZeros" FROM logs WHERE bytes_in = 0; + +FIRST_AllZeros +--------------- +0 +; + + +allNullsWithFirst +schema::FIRST_AllNulls:i +SELECT FIRST(bytes_out) as "FIRST_AllNulls" FROM logs WHERE bytes_out IS NULL; + +FIRST_AllNulls +--------------- +null +; + + +allZerosWithLast +schema::LAST_AllZeros:i +SELECT LAST(bytes_in) as "LAST_AllZeros" FROM logs WHERE bytes_in = 0; + + LAST_AllZeros +--------------- +0 +; + + +allNullsWithLast +schema::LAST_AllNulls:i +SELECT LAST(bytes_out) as "LAST_AllNulls" FROM logs WHERE bytes_out IS NULL; + + LAST_AllNulls +--------------- +null +; + + +allZerosWithCount +schema::COUNT_AllZeros:l +SELECT COUNT(bytes_in) as "COUNT_AllZeros" FROM logs WHERE bytes_in = 0; + +COUNT_AllZeros +--------------- +2 +; + + +allNullsWithCount +schema::COUNT_AllNulls:l +SELECT COUNT(bytes_out) as "COUNT_AllNulls" FROM logs WHERE bytes_out IS NULL; + +COUNT_AllNulls +--------------- +0 +; + + + +allZerosWithAvg +schema::AVG_AllZeros:d +SELECT AVG(bytes_in) as "AVG_AllZeros" FROM logs WHERE bytes_in = 0; + + AVG_AllZeros +--------------- +0.0 +; + + +allNullsWithAvg +schema::AVG_AllNulls:d +SELECT AVG(bytes_out) as "AVG_AllNulls" FROM logs WHERE bytes_out IS NULL; + + AVG_AllNulls +--------------- +null +; + + +allZerosWithMin +schema::MIN_AllZeros:i +SELECT MIN(bytes_in) as "MIN_AllZeros" FROM logs WHERE bytes_in = 0; + + MIN_AllZeros +--------------- +0 +; + + +allNullsWithMin +schema::MIN_AllNulls:i +SELECT MIN(bytes_out) as "MIN_AllNulls" FROM logs WHERE bytes_out IS NULL; + + MIN_AllNulls +--------------- +null +; + + +allZerosWithMax +schema::MAX_AllZeros:i +SELECT MAX(bytes_in) as "MAX_AllZeros" FROM logs WHERE bytes_in = 0; + + MAX_AllZeros +--------------- +0 +; + + +allNullsWithMax +schema::MAX_AllNulls:i +SELECT MAX(bytes_out) as "MAX_AllNulls" FROM logs WHERE bytes_out IS NULL; + + MAX_AllNulls +--------------- +null +; + + +allZerosWithSum +schema::SUM_AllZeros:i +SELECT SUM(bytes_in) as "SUM_AllZeros" FROM logs WHERE bytes_in = 0; + + SUM_AllZeros +--------------- +0 +; + + +allNullsWithSum +schema::SUM_AllNulls:i +SELECT SUM(bytes_out) as "SUM_AllNulls" FROM logs WHERE bytes_out IS NULL; + + SUM_AllNulls +--------------- +null +; + + +allZerosWithPercentile +schema::PERCENTILE_AllZeros:d +SELECT PERCENTILE(bytes_in, 0) as "PERCENTILE_AllZeros" FROM logs WHERE bytes_in = 0; + +PERCENTILE_AllZeros +------------------- +0.0 +; + + +allNullsWithPercentile +schema::PERCENTILE_AllNulls:d +SELECT PERCENTILE(bytes_out, 0) as "PERCENTILE_AllNulls" FROM logs WHERE bytes_out IS NULL; + +PERCENTILE_AllNulls +------------------- +null +; + + +allZerosWithPercentileRank +schema::PERCENTILE_RANK_AllZeros:d +SELECT PERCENTILE_RANK(bytes_in, 0) as "PERCENTILE_RANK_AllZeros" FROM logs WHERE bytes_in = 0; + +PERCENTILE_RANK_AllZeros +------------------------ +100.0 +; + + +allNullsWithPercentileRank +schema::PERCENTILE_RANK_AllNulls:d +SELECT PERCENTILE_RANK(bytes_out, 0) as "PERCENTILE_RANK_AllNulls" FROM logs WHERE bytes_out IS NULL; + +PERCENTILE_RANK_AllNulls +------------------------ +null +; + + +allZerosWithSumOfSquares +schema::SUM_OF_SQUARES_AllZeros:d +SELECT SUM_OF_SQUARES(bytes_in) as "SUM_OF_SQUARES_AllZeros" FROM logs WHERE bytes_in = 0; + +SUM_OF_SQUARES_AllZeros +----------------------- +0.0 +; + + +allNullsWithSumOfSquares +schema::SUM_OF_SQUARES_AllNulls:d +SELECT SUM_OF_SQUARES(bytes_out) as "SUM_OF_SQUARES_AllNulls" FROM logs WHERE bytes_out IS NULL; + +SUM_OF_SQUARES_AllNulls +----------------------- +null +; + + +allZerosWithStddevPop +schema::STDDEV_POP_AllZeros:d +SELECT STDDEV_POP(bytes_in) as "STDDEV_POP_AllZeros" FROM logs WHERE bytes_in = 0; + +STDDEV_POP_AllZeros +------------------- +0.0 +; + + +allNullsWithStddevPop +schema::STDDEV_POP_AllNulls:d +SELECT STDDEV_POP(bytes_out) as "STDDEV_POP_AllNulls" FROM logs WHERE bytes_out IS NULL; + +STDDEV_POP_AllNulls +------------------- +null +; + + +allZerosWithStddevSamp +schema::STDDEV_SAMP_AllZeros:d +SELECT STDDEV_SAMP(bytes_in) as "STDDEV_SAMP_AllZeros" FROM logs WHERE bytes_in = 0; + +STDDEV_SAMP_AllZeros +-------------------- +0.0 +; + + +allNullsWithStddevSamp +schema::STDDEV_SAMP_AllNulls:d +SELECT STDDEV_SAMP(bytes_out) as "STDDEV_SAMP_AllNulls" FROM logs WHERE bytes_out IS NULL; + +STDDEV_SAMP_AllNulls +-------------------- +null +; + + +allZerosWithVarSamp +schema::VAR_SAMP_AllZeros:d +SELECT VAR_SAMP(bytes_in) as "VAR_SAMP_AllZeros" FROM logs WHERE bytes_in = 0; + +VAR_SAMP_AllZeros +----------------- +0.0 +; + + +allNullsWithVarSamp +schema::VAR_SAMP_AllNulls:d +SELECT VAR_SAMP(bytes_out) as "VAR_SAMP_AllNulls" FROM logs WHERE bytes_out IS NULL; + +VAR_SAMP_AllNulls +----------------- +null +; + + +allZerosWithVarPop +schema::VAR_POP_AllZeros:d +SELECT VAR_POP(bytes_in) as "VAR_POP_AllZeros" FROM logs WHERE bytes_in = 0; + +VAR_POP_AllZeros +---------------- +0.0 +; + + +allNullsWithVarPop +schema::VAR_POP_AllNulls:d +SELECT VAR_POP(bytes_out) as "VAR_POP_AllNulls" FROM logs WHERE bytes_out IS NULL; + +VAR_POP_AllNulls +---------------- +null +; + + +allZerosWithSkewness +schema::SKEWNESS_AllZeros:d +SELECT SKEWNESS(bytes_in) as "SKEWNESS_AllZeros" FROM logs WHERE bytes_in = 0; + +SKEWNESS_AllZeros +----------------- +NaN +; + + +allNullsWithSkewness +schema::SKEWNESS_AllNulls:d +SELECT SKEWNESS(bytes_out) as "SKEWNESS_AllNulls" FROM logs WHERE bytes_out IS NULL; + +SKEWNESS_AllNulls +----------------- +null +; + + +allZerosWithMad +schema::MAD_AllZeros:d +SELECT MAD(bytes_in) as "MAD_AllZeros" FROM logs WHERE bytes_in = 0; + + MAD_AllZeros +--------------- +0.0 +; + + +allNullsWithMad +schema::MAD_AllNulls:d +SELECT MAD(bytes_out) as "MAD_AllNulls" FROM logs WHERE bytes_out IS NULL; + + MAD_AllNulls +--------------- +NaN +; + + +allZerosWithKurtosis +schema::KURTOSIS_AllZeros:d +SELECT KURTOSIS(bytes_in) as "KURTOSIS_AllZeros" FROM logs WHERE bytes_in = 0; + +KURTOSIS_AllZeros +----------------- +NaN +; + + +allNullsWithKurtosis +schema::KURTOSIS_AllNulls:d +SELECT KURTOSIS(bytes_out) as "KURTOSIS_AllNulls" FROM logs WHERE bytes_out IS NULL; + +KURTOSIS_AllNulls +----------------- +null +; + +nullsAndZerosCombined +schema::COUNT(*):l|COUNT_AllZeros:l|COUNT_AllNulls:l|FIRST_AllZeros:i|FIRST_AllNulls:i|SUM_AllZeros:i|SUM_AllNulls:i +SELECT + COUNT(*), + COUNT(bytes_in) AS "COUNT_AllZeros", + COUNT(bytes_out) AS "COUNT_AllNulls", + FIRST(bytes_in) AS "FIRST_AllZeros", + FIRST(bytes_out) AS "FIRST_AllNulls", + SUM(bytes_in) AS "SUM_AllZeros", + SUM(bytes_out) AS "SUM_AllNulls" +FROM logs +WHERE bytes_in = 0 AND bytes_out IS NULL; + + COUNT(*) |COUNT(bytes_in)|COUNT(bytes_out)|FIRST_AllZeros |FIRST_AllNulls | SUM_AllZeros | SUM_AllNulls +---------------+---------------+----------------+---------------+---------------+---------------+--------------- +1 |1 |0 |0 |null |0 |null +; + + +groupedByNullsAndZeros +schema::bytes_in:i|COUNT(*):l|SUM(bytes_in):i|MIN(bytes_in):i|MAX(bytes_in):i|AVG(bytes_in):d +SELECT + bytes_in, + COUNT(*), + SUM(bytes_in), + MIN(bytes_in), + MAX(bytes_in), + AVG(bytes_in) +FROM logs +WHERE NVL(bytes_in, 0) = 0 +GROUP BY bytes_in +ORDER BY bytes_in DESC NULLS LAST; + + bytes_in | COUNT(*) | SUM(bytes_in) | MIN(bytes_in) | MAX(bytes_in) | AVG(bytes_in) +---------------+---------------+---------------+---------------+---------------+--------------- +0 |2 |0 |0 |0 |0.0 +null |1 |null |null |null |null +; + +groupedByMultipleSumsWithNullsAndZeros +schema::SUM(bytes_in):i|SUM(bytes_out):i|client_ip:s|c:l +SELECT + SUM(bytes_in), + SUM(bytes_out), + client_ip, + COUNT(*) AS c +FROM logs +WHERE client_ip = '10.0.0.0/16' AND NVL(bytes_out, 0) = 0 +GROUP BY client_ip +ORDER BY c DESC, SUM(bytes_in) ASC NULLS FIRST; + + SUM(bytes_in) |SUM(bytes_out) | client_ip | c +---------------+---------------+---------------+--------------- +232 |null |10.0.1.199 |10 +124 |null |10.0.1.166 |7 +336 |null |10.0.1.122 |7 +8 |null |10.0.1.205 |2 +16 |null |10.0.1.201 |2 +16 |null |10.0.1.203 |2 +28 |null |10.0.1.207 |2 +40 |null |10.0.1.222 |2 +56 |null |10.0.0.130 |2 +null |null |10.0.2.129 |1 +8 |null |10.0.1.202 |1 +8 |null |10.0.1.206 |1 +8 |null |10.0.1.208 |1 +16 |null |10.0.1.13 |1 +28 |null |10.0.0.107 |1 +30 |null |10.0.0.147 |1 +32 |null |10.0.1.177 |1 +48 |null |10.0.0.109 |1 +; \ No newline at end of file diff --git a/x-pack/plugin/sql/qa/server/src/main/resources/logs.csv b/x-pack/plugin/sql/qa/server/src/main/resources/logs.csv index 7103f578b80b..aacb772e6e34 100644 --- a/x-pack/plugin/sql/qa/server/src/main/resources/logs.csv +++ b/x-pack/plugin/sql/qa/server/src/main/resources/logs.csv @@ -99,3 +99,4 @@ id,@timestamp,bytes_in,bytes_out,client_ip,client_port,dest_ip,status 98,2017-11-10T21:12:24Z,74,90,10.0.0.134,57203,172.20.10.1,OK 99,2017-11-10T21:17:37Z,39,512,10.0.0.128,29333,,OK 100,2017-11-10T03:21:36Z,64,183,10.0.0.129,4541,172.16.1.1,OK +101,2017-11-10T23:22:36Z,,,10.0.2.129,4541,172.20.11.1,OK diff --git a/x-pack/plugin/sql/qa/server/src/main/resources/pivot.csv-spec b/x-pack/plugin/sql/qa/server/src/main/resources/pivot.csv-spec index f8b580a7bc06..3265c1038a47 100644 --- a/x-pack/plugin/sql/qa/server/src/main/resources/pivot.csv-spec +++ b/x-pack/plugin/sql/qa/server/src/main/resources/pivot.csv-spec @@ -197,10 +197,21 @@ null |10043 |Yishay |M |1990-10-20 null |10044 |Mingsen |F |1994-05-21 00:00:00.0|Casley |39728 |null 1952-04-19 00:00:00.0|10009 |Sumant |F |1985-02-18 00:00:00.0|Peac |66174 |null 1953-01-07 00:00:00.0|10067 |Claudi |M |1987-03-04 00:00:00.0|Stavenow |null |52044 - // end::sumWithoutSubquery ; +sumWithZeros +SELECT * +FROM (SELECT client_ip, status, bytes_in FROM logs WHERE NVL(bytes_in, 0) = 0) +PIVOT (SUM(bytes_in) FOR status IN ('OK','Error')); + + client_ip | 'OK' | 'Error' +---------------+---------------+--------------- +10.0.1.199 |0 |null +10.0.1.205 |0 |null +10.0.2.129 |null |null +; + sumWithInnerAggregateSumOfSquares schema::birth_date:ts|emp_no:i|first_name:s|gender:s|hire_date:ts|last_name:s|1:d|2:d SELECT * FROM test_emp PIVOT (SUM_OF_SQUARES(salary) FOR languages IN (1, 2)) LIMIT 5; diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java index a0fdf5370f05..83cc60d80448 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java @@ -164,6 +164,7 @@ protected Iterable.Batch> batches() { new ReplaceAggsWithMatrixStats(), new ReplaceAggsWithExtendedStats(), new ReplaceAggsWithStats(), + new ReplaceSumWithStats(), new PromoteStatsToExtendedStats(), new ReplaceAggsWithPercentiles(), new ReplaceAggsWithPercentileRanks() @@ -983,6 +984,39 @@ public LogicalPlan apply(LogicalPlan p) { } } + // This class is a workaround for the SUM(all zeros) = NULL issue raised in https://github.com/elastic/elasticsearch/issues/45251 and + // should be removed as soon as root cause is fixed and the sum aggregation results can differentiate between SUM(all zeroes) + // and SUM(all nulls) + // NOTE: this rule should always be applied AFTER the ReplaceAggsWithStats rule + static class ReplaceSumWithStats extends OptimizerBasicRule { + + @Override + public LogicalPlan apply(LogicalPlan plan) { + final Map statsPerField = new LinkedHashMap<>(); + + plan.forEachExpressionsUp(e -> { + if (e instanceof Sum) { + statsPerField.computeIfAbsent(((Sum) e).field(), field -> { + Source source = new Source(field.sourceLocation(), "STATS(" + field.sourceText() + ")"); + return new Stats(source, field); + }); + } + }); + + if (statsPerField.isEmpty() == false) { + plan = plan.transformExpressionsUp(e -> { + if (e instanceof Sum) { + Sum sum = (Sum) e; + return new InnerAggregate(sum, statsPerField.get(sum.field())); + } + return e; + }); + } + + return plan; + } + } + static class PromoteStatsToExtendedStats extends OptimizerBasicRule { @Override diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java index 5c1248d2b768..1d46ccdf343b 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java @@ -119,7 +119,6 @@ import org.elasticsearch.xpack.sql.session.EmptyExecutable; import java.lang.reflect.Constructor; -import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -219,7 +218,7 @@ public void testReplaceFoldableAttributes() { // WHERE a < 10 LogicalPlan p = new Filter(EMPTY, FROM(), lessThanOf(a, L(10))); // SELECT - p = new Project(EMPTY, p, Arrays.asList(a, b)); + p = new Project(EMPTY, p, asList(a, b)); // ORDER BY p = new OrderBy(EMPTY, p, singletonList(new Order(EMPTY, b, OrderDirection.ASC, null))); @@ -269,14 +268,14 @@ public void testConstantFoldingDatetime() { public void testConstantFoldingIn() { In in = new In(EMPTY, ONE, - Arrays.asList(ONE, TWO, ONE, THREE, new Sub(EMPTY, THREE, ONE), ONE, FOUR, new Abs(EMPTY, new Sub(EMPTY, TWO, FIVE)))); + asList(ONE, TWO, ONE, THREE, new Sub(EMPTY, THREE, ONE), ONE, FOUR, new Abs(EMPTY, new Sub(EMPTY, TWO, FIVE)))); Literal result= (Literal) new ConstantFolding().rule(in); assertEquals(true, result.value()); } public void testConstantFoldingIn_LeftValueNotFoldable() { In in = new In(EMPTY, getFieldAttribute(), - Arrays.asList(ONE, TWO, ONE, THREE, new Sub(EMPTY, THREE, ONE), ONE, FOUR, new Abs(EMPTY, new Sub(EMPTY, TWO, FIVE)))); + asList(ONE, TWO, ONE, THREE, new Sub(EMPTY, THREE, ONE), ONE, FOUR, new Abs(EMPTY, new Sub(EMPTY, TWO, FIVE)))); Alias as = new Alias(in.source(), in.sourceText(), in); Project p = new Project(EMPTY, FROM(), Collections.singletonList(as)); p = (Project) new ConstantFolding().apply(p); @@ -287,13 +286,13 @@ public void testConstantFoldingIn_LeftValueNotFoldable() { } public void testConstantFoldingIn_RightValueIsNull() { - In in = new In(EMPTY, getFieldAttribute(), Arrays.asList(NULL, NULL)); + In in = new In(EMPTY, getFieldAttribute(), asList(NULL, NULL)); Literal result= (Literal) new ConstantFolding().rule(in); assertNull(result.value()); } public void testConstantFoldingIn_LeftValueIsNull() { - In in = new In(EMPTY, NULL, Arrays.asList(ONE, TWO, THREE)); + In in = new In(EMPTY, NULL, asList(ONE, TWO, THREE)); Literal result= (Literal) new ConstantFolding().rule(in); assertNull(result.value()); } @@ -426,9 +425,9 @@ public void testNullFoldingDoesNotApplyOnArbitraryConditionals() throws Exceptio Class clazz = (Class) randomFrom(Coalesce.class, Greatest.class, Least.class); Constructor ctor = clazz.getConstructor(Source.class, List.class); - ArbitraryConditionalFunction conditionalFunction = ctor.newInstance(EMPTY, Arrays.asList(NULL, ONE, TWO)); + ArbitraryConditionalFunction conditionalFunction = ctor.newInstance(EMPTY, asList(NULL, ONE, TWO)); assertEquals(conditionalFunction, rule.rule(conditionalFunction)); - conditionalFunction = ctor.newInstance(EMPTY, Arrays.asList(NULL, NULL, NULL)); + conditionalFunction = ctor.newInstance(EMPTY, asList(NULL, NULL, NULL)); assertEquals(conditionalFunction, rule.rule(conditionalFunction)); } @@ -461,7 +460,7 @@ private List randomListOfNulls() { public void testSimplifyCoalesceFirstLiteral() { Expression e = new SimplifyConditional() .rule(new Coalesce(EMPTY, - Arrays.asList(NULL, TRUE, FALSE, new Abs(EMPTY, getFieldAttribute())))); + asList(NULL, TRUE, FALSE, new Abs(EMPTY, getFieldAttribute())))); assertEquals(Coalesce.class, e.getClass()); assertEquals(1, e.children().size()); assertEquals(TRUE, e.children().get(0)); @@ -585,7 +584,7 @@ public void testSimplifyCaseConditionsFoldWhenFalse() { // ELSE 'default' // END - Case c = new Case(EMPTY, Arrays.asList( + Case c = new Case(EMPTY, asList( new IfConditional(EMPTY, equalsOf(getFieldAttribute(), ONE), literal("foo1")), new IfConditional(EMPTY, equalsOf(ONE, TWO), literal("bar1")), new IfConditional(EMPTY, equalsOf(TWO, ONE), literal("bar2")), @@ -611,7 +610,7 @@ public void testSimplifyCaseConditionsFoldCompletely_FoldableElse() { // // 'foo2' - Case c = new Case(EMPTY, Arrays.asList( + Case c = new Case(EMPTY, asList( new IfConditional(EMPTY, equalsOf(ONE, TWO), literal("foo1")), new IfConditional(EMPTY, equalsOf(ONE, ONE), literal("foo2")), literal("default"))); assertFalse(c.foldable()); @@ -636,7 +635,7 @@ public void testSimplifyCaseConditionsFoldCompletely_NonFoldableElse() { // // myField (non-foldable) - Case c = new Case(EMPTY, Arrays.asList( + Case c = new Case(EMPTY, asList( new IfConditional(EMPTY, equalsOf(ONE, TWO), literal("foo1")), getFieldAttribute("myField"))); assertFalse(c.foldable()); @@ -794,8 +793,8 @@ public void testTranslateMinToFirst() { Min min2 = new Min(EMPTY, getFieldAttribute()); OrderBy plan = new OrderBy(EMPTY, new Aggregate(EMPTY, FROM(), emptyList(), - Arrays.asList(a("min1", min1), a("min2", min2))), - Arrays.asList( + asList(a("min1", min1), a("min2", min2))), + asList( new Order(EMPTY, min1, OrderDirection.ASC, Order.NullsPosition.LAST), new Order(EMPTY, min2, OrderDirection.ASC, Order.NullsPosition.LAST))); LogicalPlan result = new ReplaceMinMaxWithTopHits().apply(plan); @@ -819,8 +818,8 @@ public void testTranslateMaxToLast() { Max max1 = new Max(EMPTY, new FieldAttribute(EMPTY, "str", new EsField("str", KEYWORD, emptyMap(), true))); Max max2 = new Max(EMPTY, getFieldAttribute()); - OrderBy plan = new OrderBy(EMPTY, new Aggregate(EMPTY, FROM(), emptyList(), Arrays.asList(a("max1", max1), a("max2", max2))), - Arrays.asList( + OrderBy plan = new OrderBy(EMPTY, new Aggregate(EMPTY, FROM(), emptyList(), asList(a("max1", max1), a("max2", max2))), + asList( new Order(EMPTY, max1, OrderDirection.ASC, Order.NullsPosition.LAST), new Order(EMPTY, max2, OrderDirection.ASC, Order.NullsPosition.LAST))); LogicalPlan result = new ReplaceMinMaxWithTopHits().apply(plan); @@ -849,8 +848,8 @@ public void testSortAggregateOnOrderByWithTwoFields() { Order secondOrderBy = new Order(EMPTY, secondField, OrderDirection.ASC, Order.NullsPosition.LAST); OrderBy orderByPlan = new OrderBy(EMPTY, - new Aggregate(EMPTY, FROM(), Arrays.asList(secondField, firstField), Arrays.asList(secondAlias, firstAlias)), - Arrays.asList(firstOrderBy, secondOrderBy)); + new Aggregate(EMPTY, FROM(), asList(secondField, firstField), asList(secondAlias, firstAlias)), + asList(firstOrderBy, secondOrderBy)); LogicalPlan result = new SortAggregateOnOrderBy().apply(orderByPlan); assertTrue(result instanceof OrderBy); @@ -881,8 +880,8 @@ public void testSortAggregateOnOrderByOnlyAliases() { Order secondOrderBy = new Order(EMPTY, secondAlias, OrderDirection.ASC, Order.NullsPosition.LAST); OrderBy orderByPlan = new OrderBy(EMPTY, - new Aggregate(EMPTY, FROM(), Arrays.asList(secondAlias, firstAlias), Arrays.asList(secondAlias, firstAlias)), - Arrays.asList(firstOrderBy, secondOrderBy)); + new Aggregate(EMPTY, FROM(), asList(secondAlias, firstAlias), asList(secondAlias, firstAlias)), + asList(firstOrderBy, secondOrderBy)); LogicalPlan result = new SortAggregateOnOrderBy().apply(orderByPlan); assertTrue(result instanceof OrderBy); @@ -906,8 +905,8 @@ public void testSortAggregateOnOrderByOnlyAliases() { public void testPivotRewrite() { FieldAttribute column = getFieldAttribute("pivot"); FieldAttribute number = getFieldAttribute("number"); - List values = Arrays.asList(new Alias(EMPTY, "ONE", L(1)), new Alias(EMPTY, "TWO", L(2))); - List aggs = Arrays.asList(new Alias(EMPTY, "AVG", new Avg(EMPTY, number))); + List values = asList(new Alias(EMPTY, "ONE", L(1)), new Alias(EMPTY, "TWO", L(2))); + List aggs = asList(new Alias(EMPTY, "AVG", new Avg(EMPTY, number))); Pivot pivot = new Pivot(EMPTY, new EsRelation(EMPTY, new EsIndex("table", emptyMap()), false), column, values, aggs); LogicalPlan result = new RewritePivot().apply(pivot); @@ -919,7 +918,7 @@ public void testPivotRewrite() { assertEquals(In.class, f.condition().getClass()); In in = (In) f.condition(); assertEquals(column, in.value()); - assertEquals(Arrays.asList(L(1), L(2)), in.list()); + assertEquals(asList(L(1), L(2)), in.list()); } /** @@ -933,7 +932,7 @@ public void testAggregatesPromoteToStats_WithFullTextPredicatesConditions() { FullTextPredicate matchPredicate = new MatchQueryPredicate(EMPTY, matchField, "A", StringUtils.EMPTY); FullTextPredicate multiMatchPredicate = new MultiMatchQueryPredicate(EMPTY, "match_field", "A", StringUtils.EMPTY); FullTextPredicate stringQueryPredicate = new StringQueryPredicate(EMPTY, "match_field:A", StringUtils.EMPTY); - List predicates = Arrays.asList(matchPredicate, multiMatchPredicate, stringQueryPredicate); + List predicates = asList(matchPredicate, multiMatchPredicate, stringQueryPredicate); FullTextPredicate left = randomFrom(predicates); FullTextPredicate right = randomFrom(predicates); @@ -946,15 +945,15 @@ public void testAggregatesPromoteToStats_WithFullTextPredicatesConditions() { List aggregates; boolean isSimpleStats = randomBoolean(); if (isSimpleStats) { - aggregates = Arrays.asList(new Avg(EMPTY, aggField), new Sum(EMPTY, aggField), new Min(EMPTY, aggField), + aggregates = asList(new Avg(EMPTY, aggField), new Sum(EMPTY, aggField), new Min(EMPTY, aggField), new Max(EMPTY, aggField)); } else { - aggregates = Arrays.asList(new StddevPop(EMPTY, aggField), new SumOfSquares(EMPTY, aggField), new VarPop(EMPTY, aggField)); + aggregates = asList(new StddevPop(EMPTY, aggField), new SumOfSquares(EMPTY, aggField), new VarPop(EMPTY, aggField)); } AggregateFunction firstAggregate = randomFrom(aggregates); AggregateFunction secondAggregate = randomValueOtherThan(firstAggregate, () -> randomFrom(aggregates)); Aggregate aggregatePlan = new Aggregate(EMPTY, filter, singletonList(matchField), - Arrays.asList(new Alias(EMPTY, "first", firstAggregate), new Alias(EMPTY, "second", secondAggregate))); + asList(new Alias(EMPTY, "first", firstAggregate), new Alias(EMPTY, "second", secondAggregate))); LogicalPlan result; if (isSimpleStats) { result = new ReplaceAggsWithStats().apply(aggregatePlan); @@ -1001,7 +1000,7 @@ public void testReplaceAttributesWithTarget() { Alias aAlias = new Alias(EMPTY, "aAlias", a); Alias bAlias = new Alias(EMPTY, "bAlias", b); - Project p = new Project(EMPTY, FROM(), Arrays.asList(aAlias, bAlias)); + Project p = new Project(EMPTY, FROM(), asList(aAlias, bAlias)); Filter f = new Filter(EMPTY, p, new And(EMPTY, greaterThanOf(aAlias.toAttribute(), L(1)), greaterThanOf(bAlias.toAttribute(), L(2)))); @@ -1023,4 +1022,44 @@ public void testReplaceAttributesWithTarget() { gt = (GreaterThan) and.left(); assertEquals(a, gt.left()); } + + // + // ReplaceSumWithStats rule + // + public void testSumIsReplacedWithStats() { + FieldAttribute fa = getFieldAttribute(); + Sum sum = new Sum(EMPTY, fa); + + Alias sumAlias = new Alias(EMPTY, "sum", sum); + + Aggregate aggregate = new Aggregate(EMPTY, FROM(), emptyList(), asList(sumAlias)); + LogicalPlan optimizedPlan = new Optimizer().optimize(aggregate); + assertTrue(optimizedPlan instanceof Aggregate); + Aggregate p = (Aggregate) optimizedPlan; + assertEquals(1, p.aggregates().size()); + assertTrue(p.aggregates().get(0) instanceof Alias); + Alias alias = (Alias) p.aggregates().get(0); + assertTrue(alias.child() instanceof InnerAggregate); + assertEquals(sum, ((InnerAggregate) alias.child()).inner()); + } + + /** + * Once the root cause of https://github.com/elastic/elasticsearch/issues/45251 is fixed in the sum ES aggregation + * (can differentiate between SUM(all zeroes) and SUM(all nulls)), + * remove the {@link OptimizerTests#testSumIsReplacedWithStats()}, and re-enable the following test. + */ + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/45251") + public void testSumIsNotReplacedWithStats() { + FieldAttribute fa = getFieldAttribute(); + Sum sum = new Sum(EMPTY, fa); + + Alias sumAlias = new Alias(EMPTY, "sum", sum); + + Aggregate aggregate = new Aggregate(EMPTY, FROM(), emptyList(), asList(sumAlias)); + LogicalPlan optimizedPlan = new Optimizer().optimize(aggregate); + assertTrue(optimizedPlan instanceof Aggregate); + Aggregate p = (Aggregate) optimizedPlan; + assertEquals(1, p.aggregates().size()); + assertEquals(sumAlias, p.aggregates().get(0)); + } } diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java index 8546c513022f..f8c31dd1d865 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java @@ -2405,4 +2405,20 @@ public void testPercentileOptimization() { test.accept("PERCENTILE", p -> ((PercentilesAggregationBuilder)p).percentiles()); test.accept("PERCENTILE_RANK", p -> ((PercentileRanksAggregationBuilder)p).values()); } + + // Tests the workaround for the SUM(all zeros) = NULL issue raised in https://github.com/elastic/elasticsearch/issues/45251 and + // should be removed as soon as root cause is fixed and the sum aggregation results can differentiate between SUM(all zeroes) + // and SUM(all nulls) + public void testReplaceSumWithStats() { + List testCases = Arrays.asList( + "SELECT keyword, SUM(int) FROM test GROUP BY keyword", + "SELECT SUM(int) FROM test", + "SELECT * FROM (SELECT some.string, keyword, int FROM test) PIVOT (SUM(int) FOR keyword IN ('a', 'b'))"); + for (String testCase : testCases) { + PhysicalPlan physicalPlan = optimizeAndPlan(testCase); + assertEquals(EsQueryExec.class, physicalPlan.getClass()); + EsQueryExec eqe = (EsQueryExec) physicalPlan; + assertThat(eqe.queryContainer().toString().replaceAll("\\s+", ""), containsString("{\"stats\":{\"field\":\"int\"}}")); + } + } }