From ac09d75078efc08610f016779dc57299eeb768d0 Mon Sep 17 00:00:00 2001 From: Costin Leau Date: Tue, 6 Feb 2024 05:08:10 +0200 Subject: [PATCH] ESQL: Extend STATS command to support aggregate expressions (#104958) Previously only aggregate functions (max/sum/etc..) were allowed inside the stats command. This PR allows expressions involving one or multiple aggregates to be used, such as: stats x = avg(salary % 3) + max(emp_no), y = min(emp_no / 3) + 10 - median(salary) by z = languages % 2 Improve verifier to not allow scalar functions over grouping for now --- docs/changelog/104958.yaml | 5 + .../src/main/resources/stats.csv-spec | 47 +++ .../main/resources/stats_percentile.csv-spec | 12 +- .../xpack/esql/analysis/Verifier.java | 81 +++-- .../esql/optimizer/LogicalPlanOptimizer.java | 328 ++++++++++++------ .../xpack/esql/parser/ExpressionBuilder.java | 2 +- .../xpack/esql/analysis/AnalyzerTests.java | 112 ++++++ .../xpack/esql/analysis/VerifierTests.java | 58 +++- .../optimizer/LogicalPlanOptimizerTests.java | 295 +++++++++++++++- 9 files changed, 798 insertions(+), 142 deletions(-) create mode 100644 docs/changelog/104958.yaml diff --git a/docs/changelog/104958.yaml b/docs/changelog/104958.yaml new file mode 100644 index 0000000000000..936342db03b45 --- /dev/null +++ b/docs/changelog/104958.yaml @@ -0,0 +1,5 @@ +pr: 104958 +summary: "ESQL: Extend STATS command to support aggregate expressions" +area: ES|QL +type: enhancement +issues: [] diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec index 65b01aae461e5..fbb38df87ed75 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec @@ -1112,3 +1112,50 @@ STATS ck = COUNT(job_positions), ck:l | cb:l | cd:l | ci:l | c:l | csv:l 221 | 204 | 183 | 183 | 100 | 100 ; + +nestedAggsNoGrouping#[skip:-8.12.99,reason:supported in 8.13+] +FROM employees +| STATS x = AVG(salary) / 2 + MAX(salary), a = AVG(salary), m = MAX(salary) +; + +x:d | a:d | m:i +99123.275 | 48248.55 |74999 +; + +nestedAggsWithGrouping#[skip:-8.12.99,reason:supported in 8.13+] +FROM employees +| STATS x = ROUND(AVG(salary % 3)) + MAX(emp_no), y = MIN(emp_no / 3) + 10 - MEDIAN(salary) by z = languages % 2 +| SORT z +; + +x:d | y:d | z:i +10101 | -41474.0 | 0 +10098 | -45391.0 | 1 +10030 | -44714.5 | null +; + +nestedAggsWithScalars#[skip:-8.12.99,reason:supported in 8.13+] +FROM employees +| STATS x = CONCAT(TO_STRING(ROUND(AVG(salary % 3))), TO_STRING(MAX(emp_no))), + y = ROUND((MIN(emp_no / 3) + PI() - MEDIAN(salary))/E()) + BY z = languages % 2 +; + +x:s | y:d | z:i +1.010029 | -16452.0 | null +1.010100 | -15260.0 | 0 +1.010097 | -16701.0 | 1 +; + +nestedAggsOverGroupingWithAlias#[skip:-8.12.99,reason:supported in 8.13] +FROM employees +| STATS e = max(languages) + 1 by l = languages +| SORT l +| LIMIT 3 +; + +e:i | l:i +2 | 1 +3 | 2 +4 | 3 +; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_percentile.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_percentile.csv-spec index 8ac93dc5455bd..db386e877b9c3 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_percentile.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_percentile.csv-spec @@ -70,14 +70,14 @@ NULL ; -medianOfLong#[skip:-8.11.99,reason:ReplaceDuplicateAggWithEval breaks bwc gh-103765] +medianOfLong#[skip:-8.12.99,reason:ReplaceStatsAggExpressionWithEval breaks bwc gh-103765] from employees | stats m = median(salary_change.long), p50 = percentile(salary_change.long, 50); m:double | p50:double 0 | 0 ; -medianOfInteger#[skip:-8.12.99,reason:ReplaceDuplicateAggWithEval breaks bwc gh-103765/Expression spaces are maintained since 8.13] +medianOfInteger#[skip:-8.12.99,reason:ReplaceStatsAggExpressionWithEval breaks bwc/Expression spaces are maintained since 8.13] // tag::median[] FROM employees | STATS MEDIAN(salary), PERCENTILE(salary, 50) @@ -90,7 +90,7 @@ MEDIAN(salary):double | PERCENTILE(salary, 50):double // end::median-result[] ; -medianOfDouble#[skip:-8.11.99,reason:ReplaceDuplicateAggWithEval breaks bwc gh-103765] +medianOfDouble#[skip:-8.12.99,reason:ReplaceStatsAggExpressionWithEval breaks bwc gh-103765] from employees | stats m = median(salary_change), p50 = percentile(salary_change, 50); m:double | p50:double @@ -98,7 +98,7 @@ m:double | p50:double ; -medianOfLongByKeyword#[skip:-8.11.99,reason:ReplaceDuplicateAggWithEval breaks bwc gh-103765] +medianOfLongByKeyword#[skip:-8.12.99,reason:ReplaceStatsAggExpressionWithEval breaks bwc gh-103765] from employees | stats m = median(salary_change.long), p50 = percentile(salary_change.long, 50) by job_positions | sort m desc | limit 4; m:double | p50:double | job_positions:keyword @@ -109,7 +109,7 @@ m:double | p50:double | job_positions:keyword ; -medianOfIntegerByKeyword#[skip:-8.11.99,reason:ReplaceDuplicateAggWithEval breaks bwc gh-103765] +medianOfIntegerByKeyword#[skip:-8.12.99,reason:ReplaceStatsAggExpressionWithEval breaks bwc gh-103765] from employees | stats m = median(salary), p50 = percentile(salary, 50) by job_positions | sort m | limit 4; m:double | p50:double | job_positions:keyword @@ -120,7 +120,7 @@ m:double | p50:double | job_positions:keyword ; -medianOfDoubleByKeyword#[skip:-8.11.99,reason:ReplaceDuplicateAggWithEval breaks bwc gh-103765] +medianOfDoubleByKeyword#[skip:-8.12.99,reason:ReplaceStatsAggExpressionWithEval breaks bwc gh-103765] from employees | stats m = median(salary_change), p50 = percentile(salary_change, 50)by job_positions | sort m desc | limit 4; m:double | p50:double | job_positions:keyword diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java index d0d1a4f4ef573..903c0f948f2e1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java @@ -21,8 +21,8 @@ import org.elasticsearch.xpack.ql.capabilities.Unresolvable; import org.elasticsearch.xpack.ql.common.Failure; import org.elasticsearch.xpack.ql.expression.Alias; +import org.elasticsearch.xpack.ql.expression.AttributeMap; import org.elasticsearch.xpack.ql.expression.Expression; -import org.elasticsearch.xpack.ql.expression.Expressions; import org.elasticsearch.xpack.ql.expression.NamedExpression; import org.elasticsearch.xpack.ql.expression.TypeResolutions; import org.elasticsearch.xpack.ql.expression.UnresolvedAttribute; @@ -67,6 +67,8 @@ public Verifier(Metrics metrics) { Collection verify(LogicalPlan plan, BitSet partialMetrics) { assert partialMetrics != null; Set failures = new LinkedHashSet<>(); + // alias map, collected during the first iteration for better error messages + AttributeMap aliases = new AttributeMap<>(); // quick verification for unresolved attributes plan.forEachUp(p -> { @@ -80,6 +82,7 @@ Collection verify(LogicalPlan plan, BitSet partialMetrics) { } // p is resolved, skip else if (p.resolved()) { + p.forEachExpressionUp(Alias.class, a -> aliases.put(a.toAttribute(), a.child())); return; } // handle aggregate first to disambiguate between missing fields or incorrect function declaration @@ -128,7 +131,7 @@ else if (p.resolved()) { return; } checkFilterConditionType(p, failures); - checkAggregate(p, failures); + checkAggregate(p, failures, aliases); checkRegexExtractOnlyOnStrings(p, failures); checkRow(p, failures); @@ -147,38 +150,60 @@ else if (p.resolved()) { return failures; } - private static void checkAggregate(LogicalPlan p, Set failures) { + private static void checkAggregate(LogicalPlan p, Set failures, AttributeMap aliases) { if (p instanceof Aggregate agg) { - // check aggregates + + List nakedGroups = new ArrayList<>(agg.groupings().size()); + // check grouping + // The grouping can not be an aggregate function + agg.groupings().forEach(e -> { + e.forEachUp(g -> { + if (g instanceof AggregateFunction af) { + failures.add(fail(g, "cannot use an aggregate [{}] for grouping", af)); + } + }); + nakedGroups.add(Alias.unwrap(e)); + }); + + // check aggregates - accept only aggregate functions or expressions in which each naked attribute is copied as + // specified in the grouping clause agg.aggregates().forEach(e -> { var exp = Alias.unwrap(e); - if (exp instanceof AggregateFunction af) { - af.field().forEachDown(AggregateFunction.class, f -> { - failures.add(fail(f, "nested aggregations [{}] not allowed inside other aggregations [{}]", f, af)); - }); - } else { - if (Expressions.match(agg.groupings(), g -> Alias.unwrap(g).semanticEquals(exp)) == false) { - failures.add( - fail( - exp, - "expected an aggregate function or group but got [" - + exp.sourceText() - + "] of type [" - + exp.nodeName() - + "]" - ) - ); - } + if (exp.foldable()) { + failures.add(fail(exp, "expected an aggregate function but found [{}]", exp.sourceText())); } + // traverse the tree to find invalid matches + checkInvalidNamedExpressionUsage(exp, nakedGroups, failures, 0); }); + } + } - // check grouping - // The grouping can not be an aggregate function - agg.groupings().forEach(e -> e.forEachUp(g -> { - if (g instanceof AggregateFunction af) { - failures.add(fail(g, "cannot use an aggregate [{}] for grouping", af)); - } - })); + // traverse the expression and look either for an agg function or a grouping match + // stop either when no children are left, the leaves are literals or a reference attribute is given + private static void checkInvalidNamedExpressionUsage(Expression e, List groups, Set failures, int level) { + // found an aggregate, constant or a group, bail out + if (e instanceof AggregateFunction af) { + af.field().forEachDown(AggregateFunction.class, f -> { + failures.add(fail(f, "nested aggregations [{}] not allowed inside other aggregations [{}]", f, af)); + }); + } else if (e.foldable()) { + // don't do anything + } + // don't allow nested groupings for now stats substring(group) by group as we don't optimize yet for them + else if (groups.contains(e)) { + if (level != 0) { + failures.add(fail(e, "scalar functions over groupings [{}] not allowed yet", e.sourceText())); + } + } + // if a reference is found, mark it as an error + else if (e instanceof NamedExpression ne) { + failures.add(fail(e, "column [{}] must appear in the STATS BY clause or be used in an aggregate function", ne.name())); + } + // other keep on going + else { + for (Expression child : e.children()) { + checkInvalidNamedExpressionUsage(child, groups, failures, level + 1); + } } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java index 81f712ae0408a..71595b074afc7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java @@ -22,7 +22,6 @@ import org.elasticsearch.xpack.esql.plan.logical.MvExpand; import org.elasticsearch.xpack.esql.plan.logical.RegexExtract; import org.elasticsearch.xpack.esql.plan.logical.TopN; -import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject; import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation; import org.elasticsearch.xpack.esql.plan.logical.local.LocalSupplier; import org.elasticsearch.xpack.esql.planner.PlannerUtils; @@ -77,6 +76,7 @@ import java.util.function.Predicate; import static java.util.Arrays.asList; +import static java.util.Collections.singleton; import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputExpressions; import static org.elasticsearch.xpack.ql.expression.Expressions.asAttributes; import static org.elasticsearch.xpack.ql.optimizer.OptimizerRules.FoldNull; @@ -138,22 +138,20 @@ protected static Batch operators() { } protected static Batch cleanup() { - return new Batch<>( - "Clean Up", - new ReplaceDuplicateAggWithEval(), - // pushing down limits again, because ReplaceDuplicateAggWithEval could create new Project nodes that can still be optimized - new PushDownAndCombineLimits(), - new ReplaceLimitAndSortAsTopN() - ); + return new Batch<>("Clean Up", new ReplaceLimitAndSortAsTopN()); } protected static List> rules() { var substitutions = new Batch<>( "Substitutions", Limiter.ONCE, + // first extract nested aggs top-level - this simplifies the rest of the rules + new ReplaceStatsAggExpressionWithEval(), + // second extract nested aggs inside of them + new ReplaceStatsNestedExpressionWithEval(), + // lastly replace surrogate functions new SubstituteSurrogates(), new ReplaceRegexMatch(), - new ReplaceNestedExpressionWithEval(), new ReplaceAliasingEvalWithProject() // new NormalizeAggregate(), - waits on https://github.com/elastic/elasticsearch/issues/100634 ); @@ -189,6 +187,7 @@ protected LogicalPlan rule(Aggregate aggregate) { } } + int[] counter = new int[] { 0 }; // 0. check list of surrogate expressions for (NamedExpression agg : aggs) { Expression e = Alias.unwrap(agg); @@ -205,7 +204,7 @@ protected LogicalPlan rule(Aggregate aggregate) { var attr = aggFuncToAttr.get(af); // the agg doesn't exist in the Aggregate, create an alias for it and save its attribute if (attr == null) { - var temporaryName = temporaryName(agg, af); + var temporaryName = temporaryName(af, agg, counter[0]++); // create a synthetic alias (so it doesn't clash with a user defined name) var newAlias = new Alias(agg.source(), temporaryName, null, af, null, true); attr = newAlias.toAttribute(); @@ -239,15 +238,31 @@ protected LogicalPlan rule(Aggregate aggregate) { // project away transient fields and re-enforce the original order using references (not copies) to the original aggs // this works since the replaced aliases have their nameId copied to avoid having to update all references (which has // a cascading effect) - plan = new EsqlProject(source, plan, Expressions.asAttributes(aggs)); + plan = new Project(source, plan, Expressions.asAttributes(aggs)); } } return plan; } - static String temporaryName(NamedExpression agg, AggregateFunction af) { - return "__" + agg.name() + "_" + af.functionName() + "@" + Integer.toHexString(af.hashCode()); + static String temporaryName(Expression inner, Expression outer, int suffix) { + String in = toString(inner); + String out = toString(outer); + return "$$" + in + "$" + out + "$" + suffix; + } + + static int TO_STRING_LIMIT = 16; + + static String toString(Expression ex) { + return ex instanceof AggregateFunction af ? af.functionName() : extractString(ex); + } + + static String extractString(Expression ex) { + return ex instanceof NamedExpression ne ? ne.name() : limitToString(ex.sourceText()).replace(' ', '_'); + } + + static String limitToString(String string) { + return string.length() > 16 ? string.substring(0, TO_STRING_LIMIT - 1) + ">" : string; } } @@ -259,17 +274,23 @@ static class ConvertStringToByteRef extends OptimizerRules.OptimizerExpressionRu @Override protected Expression rule(Literal lit) { - if (lit.value() == null) { + Object value = lit.value(); + + if (value == null) { return lit; } - if (lit.value() instanceof String s) { + if (value instanceof String s) { return Literal.of(lit, new BytesRef(s)); } - if (lit.value() instanceof List l) { + if (value instanceof List l) { if (l.isEmpty() || false == l.get(0) instanceof String) { return lit; } - return Literal.of(lit, l.stream().map(v -> new BytesRef((String) v)).toList()); + List byteRefs = new ArrayList<>(l.size()); + for (Object v : l) { + byteRefs.add(new BytesRef(v.toString())); + } + return Literal.of(lit, byteRefs); } return lit; } @@ -288,39 +309,80 @@ protected LogicalPlan rule(UnaryPlan plan) { if (plan instanceof Project project) { if (child instanceof Project p) { // eliminate lower project but first replace the aliases in the upper one - return p.withProjections(combineProjections(project.projections(), p.projections())); - } else if (child instanceof Aggregate a) { + project = p.withProjections(combineProjections(project.projections(), p.projections())); + child = project.child(); + plan = project; + // don't return the plan since the grandchild (now child) might be an aggregate that could not be folded on the way up + // e.g. stats c = count(x) | project c, c as x | project x + // try to apply the rule again opportunistically as another node might be pushed in (a limit might be pushed in) + } + // check if the projection eliminates certain aggregates + // but be mindful of aliases to existing aggregates that we don't want to duplicate to avoid redundant work + if (child instanceof Aggregate a) { var aggs = a.aggregates(); - var newAggs = combineProjections(project.projections(), aggs); - var newGroups = replacePrunedAliasesUsedInGroupBy(a.groupings(), aggs, newAggs); - return new Aggregate(a.source(), a.child(), newGroups, newAggs); + var newAggs = projectAggregations(project.projections(), aggs); + // project can be fully removed + if (newAggs != null) { + var newGroups = replacePrunedAliasesUsedInGroupBy(a.groupings(), aggs, newAggs); + plan = new Aggregate(a.source(), a.child(), newGroups, newAggs); + } } + return plan; } // Agg with underlying Project (group by on sub-queries) if (plan instanceof Aggregate a) { if (child instanceof Project p) { - return new Aggregate(a.source(), p.child(), a.groupings(), combineProjections(a.aggregates(), p.projections())); + plan = new Aggregate(a.source(), p.child(), a.groupings(), combineProjections(a.aggregates(), p.projections())); } } return plan; } + // variant of #combineProjections specialized for project followed by agg due to the rewrite rules applied on aggregations + // this method tries to combine the projections by paying attention to: + // - aggregations that are projected away - remove them + // - aliases in the project that point to aggregates - keep them in place (to avoid duplicating the aggs) + private static List projectAggregations( + List upperProjection, + List lowerAggregations + ) { + AttributeMap lowerAliases = new AttributeMap<>(); + for (NamedExpression ne : lowerAggregations) { + lowerAliases.put(ne.toAttribute(), Alias.unwrap(ne)); + } + + AttributeSet seen = new AttributeSet(); + for (NamedExpression upper : upperProjection) { + Expression unwrapped = Alias.unwrap(upper); + // projection contains an inner alias (point to an existing fields inside the projection) + if (seen.contains(unwrapped)) { + return null; + } + seen.add(Expressions.attribute(unwrapped)); + } + + lowerAggregations = combineProjections(upperProjection, lowerAggregations); + + return lowerAggregations; + } + // normally only the upper projections should survive but since the lower list might have aliases definitions // that might be reused by the upper one, these need to be replaced. // for example an alias defined in the lower list might be referred in the upper - without replacing it the alias becomes invalid - private List combineProjections(List upper, List lower) { + private static List combineProjections( + List upper, + List lower + ) { // collect aliases in the lower list - AttributeMap.Builder aliasesBuilder = AttributeMap.builder(); + AttributeMap aliases = new AttributeMap<>(); for (NamedExpression ne : lower) { if ((ne instanceof Attribute) == false) { - aliasesBuilder.put(ne.toAttribute(), ne); + aliases.put(ne.toAttribute(), ne); } } - - AttributeMap aliases = aliasesBuilder.build(); List replaced = new ArrayList<>(); // replace any matching attribute with a lower alias (if there's a match) @@ -366,10 +428,7 @@ private List replacePrunedAliasesUsedInGroupBy( } public static Expression trimNonTopLevelAliases(Expression e) { - if (e instanceof Alias a) { - return new Alias(a.source(), a.name(), a.qualifier(), trimAliases(a.child()), a.id()); - } - return trimAliases(e); + return e instanceof Alias a ? a.replaceChild(trimAliases(a.child())) : trimAliases(e); } private static Expression trimAliases(Expression e) { @@ -1071,7 +1130,7 @@ protected Expression regexToEquals(RegexMatch regexMatch, Literal literal) { * becomes * eval `a + 1` = a + 1, `x % 2` = x % 2 | stats sum(`a+1`_ref) by `x % 2`_ref */ - static class ReplaceNestedExpressionWithEval extends OptimizerRules.OptimizerRule { + static class ReplaceStatsNestedExpressionWithEval extends OptimizerRules.OptimizerRule { @Override protected LogicalPlan rule(Aggregate aggregate) { @@ -1103,12 +1162,11 @@ protected LogicalPlan rule(Aggregate aggregate) { expToAttribute.put(a.child().canonical(), a.toAttribute()); } + int[] counter = new int[] { 0 }; // for the aggs make sure to unwrap the agg function and check the existing groupings - for (int i = 0, s = aggs.size(); i < s; i++) { - NamedExpression agg = aggs.get(i); - + for (NamedExpression agg : aggs) { NamedExpression a = (NamedExpression) agg.transformDown(Alias.class, as -> { - // if the child a nested expression + // if the child is a nested expression Expression child = as.child(); // shortcut for common scenario @@ -1123,9 +1181,6 @@ protected LogicalPlan rule(Aggregate aggregate) { return ref; } - // TODO: break expression into aggregate functions (sum(x + 1) / max(y + 2)) - // List afs = a.collectFirstChildren(AggregateFunction.class::isInstance); - // 1. look for the aggregate function var replaced = child.transformUp(AggregateFunction.class, af -> { Expression result = af; @@ -1135,7 +1190,7 @@ protected LogicalPlan rule(Aggregate aggregate) { if (field instanceof Attribute == false && field.foldable() == false) { // 3. create a new alias if one doesn't exist yet no reference Attribute attr = expToAttribute.computeIfAbsent(field.canonical(), k -> { - Alias newAlias = new Alias(k.source(), temporaryName(agg, af), null, k, null, true); + Alias newAlias = new Alias(k.source(), syntheticName(k, af, counter[0]++), null, k, null, true); evals.add(newAlias); aggsChanged.set(true); return newAlias.toAttribute(); @@ -1165,8 +1220,141 @@ protected LogicalPlan rule(Aggregate aggregate) { return aggregate; } - static String temporaryName(NamedExpression agg, AggregateFunction af) { - return SubstituteSurrogates.temporaryName(agg, af); + static String syntheticName(Expression expression, AggregateFunction af, int counter) { + return SubstituteSurrogates.temporaryName(expression, af, counter); + } + } + + /** + * Replace nested expressions over aggregates with synthetic eval post the aggregation + * stats a = sum(a) + min(b) by x + * becomes + * stats a1 = sum(a), a2 = min(b) by x | eval a = a1 + a2 | keep a, x + * + * Since the logic is very similar, this rule also handles duplicate aggregate functions to avoid duplicate compute + * stats a = min(x), b = min(x), c = count(*), d = count() by g + * becomes + * stats a = min(x), c = count(*) by g | eval b = a, d = c | keep a, b, c, d, g + */ + static class ReplaceStatsAggExpressionWithEval extends OptimizerRules.OptimizerRule { + ReplaceStatsAggExpressionWithEval() { + super(TransformDirection.UP); + } + + @Override + protected LogicalPlan rule(Aggregate aggregate) { + // build alias map + AttributeMap aliases = new AttributeMap<>(); + aggregate.forEachExpressionUp(Alias.class, a -> aliases.put(a.toAttribute(), a.child())); + + // break down each aggregate into AggregateFunction + // preserve the projection at the end + List aggs = aggregate.aggregates(); + + // root/naked aggs + Map rootAggs = Maps.newLinkedHashMapWithExpectedSize(aggs.size()); + // evals (original expression relying on multiple aggs) + List newEvals = new ArrayList<>(); + List newProjections = new ArrayList<>(); + // track the aggregate aggs (including grouping which is not an AggregateFunction) + List newAggs = new ArrayList<>(); + + Holder changed = new Holder<>(false); + int[] counter = new int[] { 0 }; + + for (NamedExpression agg : aggs) { + if (agg instanceof Alias as) { + // if the child a nested expression + Expression child = as.child(); + + // common case - handle duplicates + if (child instanceof AggregateFunction af) { + AggregateFunction canonical = (AggregateFunction) af.canonical(); + Expression field = canonical.field().transformUp(e -> aliases.resolve(e, e)); + canonical = (AggregateFunction) canonical.replaceChildren( + CollectionUtils.combine(singleton(field), canonical.parameters()) + ); + + Alias found = rootAggs.get(canonical); + // aggregate is new + if (found == null) { + rootAggs.put(canonical, as); + newAggs.add(as); + newProjections.add(as.toAttribute()); + } + // agg already exists - preserve the current alias but point it to the existing agg + // thus don't add it to the list of aggs as we don't want duplicated compute + else { + changed.set(true); + newProjections.add(as.replaceChild(found.toAttribute())); + } + } + // nested expression over aggregate function - replace them with reference and move the expression into a + // follow-up eval + else { + Holder transformed = new Holder<>(false); + Expression aggExpression = child.transformUp(AggregateFunction.class, af -> { + transformed.set(true); + changed.set(true); + + AggregateFunction canonical = (AggregateFunction) af.canonical(); + Alias alias = rootAggs.get(canonical); + if (alias == null) { + // create synthetic alias ove the found agg function + alias = new Alias( + af.source(), + syntheticName(canonical, child, counter[0]++), + as.qualifier(), + canonical, + null, + true + ); + // and remember it to remove duplicates + rootAggs.put(canonical, alias); + // add it to the list of aggregates and continue + newAggs.add(alias); + } + // (even when found) return a reference to it + return alias.toAttribute(); + }); + + Alias alias = as; + if (transformed.get()) { + // if at least a change occurred, update the alias and add it to the eval + alias = as.replaceChild(aggExpression); + newEvals.add(alias); + } + // aliased grouping + else { + newAggs.add(alias); + } + + newProjections.add(alias.toAttribute()); + } + } + // not an alias (e.g. grouping field) + else { + newAggs.add(agg); + newProjections.add(agg.toAttribute()); + } + } + + LogicalPlan plan = aggregate; + if (changed.get()) { + Source source = aggregate.source(); + plan = new Aggregate(source, aggregate.child(), aggregate.groupings(), newAggs); + if (newEvals.size() > 0) { + plan = new Eval(source, plan, newEvals); + } + // preserve initial projection + plan = new Project(source, plan, newProjections); + } + + return plan; + } + + static String syntheticName(Expression expression, Expression af, int counter) { + return SubstituteSurrogates.temporaryName(expression, af, counter); } } @@ -1330,58 +1518,4 @@ private static LogicalPlan normalize(Aggregate aggregate, AttributeMap { - - ReplaceDuplicateAggWithEval() { - super(TransformDirection.UP); - } - - @Override - protected LogicalPlan rule(Aggregate aggregate) { - LogicalPlan plan = aggregate; - - boolean foundDuplicate = false; - var aggs = aggregate.aggregates(); - Map seenAggs = Maps.newMapWithExpectedSize(aggs.size()); - List projections = new ArrayList<>(); - List keptAggs = new ArrayList<>(aggs.size()); - - for (NamedExpression agg : aggs) { - var attr = agg.toAttribute(); - if (agg instanceof Alias as && as.child() instanceof AggregateFunction af) { - var seen = seenAggs.putIfAbsent(af, attr); - if (seen != null) { - foundDuplicate = true; - projections.add(as.replaceChild(seen)); - } - // otherwise keep the agg in place - else { - keptAggs.add(agg); - projections.add(attr); - } - } else { - keptAggs.add(agg); - projections.add(attr); - } - } - - // at least one duplicate found - add the projection (to keep the output in place) - if (foundDuplicate) { - var source = aggregate.source(); - var newAggregate = new Aggregate(source, aggregate.child(), aggregate.groupings(), keptAggs); - plan = new Project(source, newAggregate, projections); - } - - return plan; - } - } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java index d4cfb6b95176b..e07494c19e1ff 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java @@ -423,7 +423,7 @@ public List visitFields(EsqlBaseParser.FieldsContext ctx) { } /** - * Similar to {@link #visitFields(EsqlBaseParser.FieldsContext)} however avoids wrapping the exception + * Similar to {@link #visitFields(EsqlBaseParser.FieldsContext)} however avoids wrapping the expression * into an Alias. */ public List visitGrouping(EsqlBaseParser.FieldsContext ctx) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index ee77ff93b7687..a1f1aae7e6e25 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.io.Streams; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.index.analysis.IndexAnalyzers; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; @@ -1554,6 +1555,112 @@ public void testUnresolvedMvExpand() { assertThat(e.getMessage(), containsString("Unknown column [bar]")); } + public void testRegularStats() { + var plan = analyze(""" + from tests + | stats by salary + """); + + var limit = as(plan, Limit.class); + } + + public void testLiteralInAggregateNoGrouping() { + var e = expectThrows(VerificationException.class, () -> analyze(""" + from test + |stats 1 + """)); + + assertThat(e.getMessage(), containsString("expected an aggregate function but found [1]")); + } + + public void testLiteralBehindEvalInAggregateNoGrouping() { + var e = expectThrows(VerificationException.class, () -> analyze(""" + from test + |eval x = 1 + |stats x + """)); + + assertThat(e.getMessage(), containsString("column [x] must appear in the STATS BY clause or be used in an aggregate function")); + } + + public void testLiteralsInAggregateNoGrouping() { + var e = expectThrows(VerificationException.class, () -> analyze(""" + from test + |stats 1 + 2 + """)); + + assertThat(e.getMessage(), containsString("expected an aggregate function but found [1 + 2]")); + } + + public void testLiteralsBehindEvalInAggregateNoGrouping() { + var e = expectThrows(VerificationException.class, () -> analyze(""" + from test + |eval x = 1 + 2 + |stats x + """)); + + assertThat(e.getMessage(), containsString("column [x] must appear in the STATS BY clause or be used in an aggregate function")); + } + + public void testFoldableInAggregateWithGrouping() { + var e = expectThrows(VerificationException.class, () -> analyze(""" + from test + |stats 1 + 2 by languages + """)); + + assertThat(e.getMessage(), containsString("expected an aggregate function but found [1 + 2]")); + } + + public void testLiteralsInAggregateWithGrouping() { + var e = expectThrows(VerificationException.class, () -> analyze(""" + from test + |stats "a" by languages + """)); + + assertThat(e.getMessage(), containsString("expected an aggregate function but found [\"a\"]")); + } + + public void testFoldableBehindEvalInAggregateWithGrouping() { + var e = expectThrows(VerificationException.class, () -> analyze(""" + from test + |eval x = 1 + 2 + |stats x by languages + """)); + + assertThat(e.getMessage(), containsString("column [x] must appear in the STATS BY clause or be used in an aggregate function")); + } + + public void testFoldableInGrouping() { + var e = expectThrows(VerificationException.class, () -> analyze(""" + from test + |stats x by 1 + """)); + + assertThat(e.getMessage(), containsString("[x] is not an aggregate function")); + } + + public void testScalarFunctionsInStats() { + var e = expectThrows(VerificationException.class, () -> analyze(""" + from test + |stats salary % 3 by languages + """)); + + assertThat( + e.getMessage(), + containsString("column [salary] must appear in the STATS BY clause or be used in an aggregate function") + ); + } + + public void testDeferredGroupingInStats() { + var e = expectThrows(VerificationException.class, () -> analyze(""" + from test + |eval x = first_name + |stats x by first_name + """)); + + assertThat(e.getMessage(), containsString("column [x] must appear in the STATS BY clause or be used in an aggregate function")); + } + public void testUnsupportedTypesInStats() { verifyUnsupported( """ @@ -1654,4 +1761,9 @@ private void assertEmptyEsRelation(LogicalPlan plan) { assertThat(esRelation.output(), equalTo(NO_FIELDS)); assertTrue(esRelation.index().mapping().isEmpty()); } + + @Override + protected IndexAnalyzers createDefaultIndexAnalyzers() { + return super.createDefaultIndexAnalyzers(); + } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index 06d20de70bce3..1257cc5ee8bd6 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -62,13 +62,17 @@ public void testRoundFunctionInvalidInputs() { public void testAggsExpressionsInStatsAggs() { assertEquals( - "1:44: expected an aggregate function or group but got [salary] of type [FieldAttribute]", + "1:44: column [salary] must appear in the STATS BY clause or be used in an aggregate function", error("from test | eval z = 2 | stats x = avg(z), salary by emp_no") ); assertEquals( - "1:19: expected an aggregate function or group but got [length(first_name)] of type [Length]", + "1:26: scalar functions over groupings [first_name] not allowed yet", error("from test | stats length(first_name), count(1) by first_name") ); + assertEquals( + "1:36: scalar functions over groupings [languages] not allowed yet", + error("from test | stats max(languages) + languages by l = languages") + ); assertEquals( "1:23: nested aggregations [max(salary)] not allowed inside other aggregations [max(max(salary))]", error("from test | stats max(max(salary)) by first_name") @@ -77,10 +81,6 @@ public void testAggsExpressionsInStatsAggs() { "1:25: argument of [avg(first_name)] must be [numeric except unsigned_long], found value [first_name] type [keyword]", error("from test | stats count(avg(first_name)) by first_name") ); - assertEquals( - "1:23: expected an aggregate function or group but got [emp_no + avg(emp_no)] of type [Add]", - error("from test | stats x = emp_no + avg(emp_no) by emp_no") - ); assertEquals( "1:23: second argument of [percentile(languages, languages)] must be a constant, received [languages]", error("from test | stats x = percentile(languages, languages) by emp_no") @@ -89,6 +89,7 @@ public void testAggsExpressionsInStatsAggs() { "1:23: second argument of [count_distinct(languages, languages)] must be a constant, received [languages]", error("from test | stats x = count_distinct(languages, languages) by emp_no") ); + } public void testAggsInsideGrouping() { @@ -98,10 +99,55 @@ public void testAggsInsideGrouping() { ); } + public void testAggsWithInvalidGrouping() { + assertEquals( + "1:35: column [languages] must appear in the STATS BY clause or be used in an aggregate function", + error("from test| stats max(languages) + languages by l = languages % 3") + ); + } + + public void testAggsIgnoreCanonicalGrouping() { + // the grouping column should appear verbatim - ignore canonical representation as they complicate things significantly + // for no real benefit (1+languages != languages + 1) + assertEquals( + "1:39: column [languages] must appear in the STATS BY clause or be used in an aggregate function", + error("from test| stats max(languages) + 1 + languages by l = languages + 1") + ); + } + + public void testAggsWithoutAgg() { + // should work + assertEquals( + "1:35: column [salary] must appear in the STATS BY clause or be used in an aggregate function", + error("from test| stats max(languages) + salary by l = languages + 1") + ); + } + public void testAggsInsideEval() throws Exception { assertEquals("1:29: aggregate function [max(b)] not allowed outside STATS command", error("row a = 1, b = 2 | eval x = max(b)")); } + public void testAggsWithExpressionOverAggs() { + assertEquals( + "1:44: scalar functions over groupings [languages] not allowed yet", + error("from test | stats max(languages + 1) , m = languages + min(salary + 1) by l = languages, s = salary") + ); + } + + public void testAggScalarOverGroupingColumn() { + assertEquals( + "1:26: scalar functions over groupings [first_name] not allowed yet", + error("from test | stats length(first_name), count(1) by first_name") + ); + } + + public void testGroupingInAggs() { + assertEquals("2:12: column [salary] must appear in the STATS BY clause or be used in an aggregate function", error(""" + from test + |stats e = salary + max(salary) by languages + """)); + } + public void testDoubleRenamingField() { assertEquals( "1:44: Column [emp_no] renamed to [r1] and is no longer available [emp_no as r3]", diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index ed3df60ecf13b..06b81d9c4608e 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -28,15 +28,19 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Min; import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile; import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; +import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToString; import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateFormat; import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateParse; import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateTrunc; import org.elasticsearch.xpack.esql.expression.function.scalar.math.Pow; import org.elasticsearch.xpack.esql.expression.function.scalar.math.Round; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat; import org.elasticsearch.xpack.esql.expression.function.scalar.string.Substring; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mod; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Sub; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In; import org.elasticsearch.xpack.esql.parser.EsqlParser; import org.elasticsearch.xpack.esql.plan.logical.Dissect; @@ -230,6 +234,28 @@ public void testCombineProjectionWithAggregation() { assertThat(Expressions.names(agg.groupings()), contains("last_name", "first_name")); } + /** + * Project[[s{r}#4 AS d, s{r}#4, last_name{f}#21, first_name{f}#18]] + * \_Limit[500[INTEGER]] + * \_Aggregate[[last_name{f}#21, first_name{f}#18],[SUM(salary{f}#22) AS s, last_name{f}#21, first_name{f}#18]] + * \_EsRelation[test][_meta_field{f}#23, emp_no{f}#17, first_name{f}#18, ..] + */ + public void testCombineProjectionWithDuplicateAggregation() { + var plan = plan(""" + from test + | stats s = sum(salary), d = sum(salary), c = sum(salary) by last_name, first_name + | keep d, s, last_name, first_name + """); + + var project = as(plan, Project.class); + assertThat(Expressions.names(project.projections()), contains("d", "s", "last_name", "first_name")); + var limit = as(project.child(), Limit.class); + var agg = as(limit.child(), Aggregate.class); + assertThat(Expressions.names(agg.aggregates()), contains("s", "last_name", "first_name")); + assertThat(Alias.unwrap(agg.aggregates().get(0)), instanceOf(Sum.class)); + assertThat(Expressions.names(agg.groupings()), contains("last_name", "first_name")); + } + public void testQlComparisonOptimizationsApply() { var plan = plan(""" from test @@ -1774,11 +1800,11 @@ public void testSimpleAvgReplacement() { var agg = as(limit.child(), Aggregate.class); var aggs = agg.aggregates(); var a = as(aggs.get(0), Alias.class); - assertThat(a.name(), startsWith("__a_SUM@")); + assertThat(a.name(), startsWith("$$SUM$a$")); var sum = as(a.child(), Sum.class); a = as(aggs.get(1), Alias.class); - assertThat(a.name(), startsWith("__a_COUNT@")); + assertThat(a.name(), startsWith("$$COUNT$a$")); var count = as(a.child(), Count.class); assertThat(Expressions.names(agg.groupings()), contains("last_name")); @@ -1799,7 +1825,7 @@ public void testClashingAggAvgReplacement() { """); assertThat(Expressions.names(plan.output()), contains("a", "c", "s", "last_name")); - var project = as(plan, EsqlProject.class); + var project = as(plan, Project.class); var eval = as(project.child(), Eval.class); var f = eval.fields(); assertThat(f, hasSize(1)); @@ -1835,7 +1861,7 @@ public void testSemiClashingAvgReplacement() { var agg = as(limit.child(), Aggregate.class); var aggs = agg.aggregates(); var a = as(aggs.get(0), Alias.class); - assertThat(a.name(), startsWith("__a_COUNT@")); + assertThat(a.name(), startsWith("$$COUNT$a$0")); var sum = as(a.child(), Count.class); a = as(aggs.get(1), Alias.class); @@ -2895,6 +2921,267 @@ public void testNestedMultiExpressionsInGroupingAndAggs() { assertThat(Expressions.names(agg.output()), contains("count(salary + 1)", "max(salary + 23)", "languages + 1", "emp_no % 3")); } + /** + * Expects + * Project[[x{r}#5]] + * \_Eval[[____x_AVG@9efc3cf3_SUM@daf9f221{r}#18 / ____x_AVG@9efc3cf3_COUNT@53cd08ed{r}#19 AS __x_AVG@9efc3cf3, __x_AVG@ + * 9efc3cf3{r}#16 / 2[INTEGER] + __x_MAX@475d0e4d{r}#17 AS x]] + * \_Limit[500[INTEGER]] + * \_Aggregate[[],[SUM(salary{f}#11) AS ____x_AVG@9efc3cf3_SUM@daf9f221, COUNT(salary{f}#11) AS ____x_AVG@9efc3cf3_COUNT@53cd0 + * 8ed, MAX(salary{f}#11) AS __x_MAX@475d0e4d]] + * \_EsRelation[test][_meta_field{f}#12, emp_no{f}#6, first_name{f}#7, ge..] + */ + public void testStatsExpOverAggs() { + var plan = optimizedPlan(""" + from test + | stats x = avg(salary) /2 + max(salary) + """); + + var project = as(plan, Project.class); + assertThat(Expressions.names(project.projections()), contains("x")); + var eval = as(project.child(), Eval.class); + var fields = eval.fields(); + assertThat(Expressions.name(fields.get(1)), is("x")); + // sum/count to compute avg + var div = as(fields.get(0).child(), Div.class); + // avg + max + var add = as(fields.get(1).child(), Add.class); + var limit = as(eval.child(), Limit.class); + var agg = as(limit.child(), Aggregate.class); + var aggs = agg.aggregates(); + assertThat(aggs, hasSize(3)); + var sum = as(Alias.unwrap(aggs.get(0)), Sum.class); + assertThat(Expressions.name(sum.field()), is("salary")); + var count = as(Alias.unwrap(aggs.get(1)), Count.class); + assertThat(Expressions.name(count.field()), is("salary")); + var max = as(Alias.unwrap(aggs.get(2)), Max.class); + assertThat(Expressions.name(max.field()), is("salary")); + } + + /** + * Expects + * Project[[x{r}#5, y{r}#9, z{r}#12]] + * \_Eval[[$$SUM$$$AVG$avg(salary_%_3)>$0$0{r}#29 / $$COUNT$$$AVG$avg(salary_%_3)>$0$1{r}#30 AS $$AVG$avg(salary_%_3)>$0, + * $$AVG$avg(salary_%_3)>$0{r}#23 + $$MAX$avg(salary_%_3)>$1{r}#24 AS x, + * $$MIN$min(emp_no_/_3)>$2{r}#25 + 10[INTEGER] - $$MEDIAN$min(emp_no_/_3)>$3{r}#26 AS y]] + * \_Limit[500[INTEGER]] + * \_Aggregate[[z{r}#12],[SUM($$salary_%_3$AVG$0{r}#27) AS $$SUM$$$AVG$avg(salary_%_3)>$0$0, + * COUNT($$salary_%_3$AVG$0{r}#27) AS $$COUNT$$$AVG$avg(salary_%_3)>$0$1, + * MAX(emp_no{f}#13) AS $$MAX$avg(salary_%_3)>$1, + * MIN($$emp_no_/_3$MIN$1{r}#28) AS $$MIN$min(emp_no_/_3)>$2, + * PERCENTILE(salary{f}#18,50[INTEGER]) AS $$MEDIAN$min(emp_no_/_3)>$3, z{r}#12]] + * \_Eval[[languages{f}#16 % 2[INTEGER] AS z, + * salary{f}#18 % 3[INTEGER] AS $$salary_%_3$AVG$0, + * emp_no{f}#13 / 3[INTEGER] AS $$emp_no_/_3$MIN$1]] + * \_EsRelation[test][_meta_field{f}#19, emp_no{f}#13, first_name{f}#14, ..] + */ + public void testStatsExpOverAggsMulti() { + var plan = optimizedPlan(""" + from test + | stats x = avg(salary % 3) + max(emp_no), y = min(emp_no / 3) + 10 - median(salary) by z = languages % 2 + """); + + var project = as(plan, Project.class); + assertThat(Expressions.names(project.projections()), contains("x", "y", "z")); + var eval = as(project.child(), Eval.class); + var fields = eval.fields(); + // avg = Sum/Count + assertThat(Expressions.name(fields.get(0)), containsString("AVG")); + assertThat(Alias.unwrap(fields.get(0)), instanceOf(Div.class)); + // avg + max + assertThat(Expressions.name(fields.get(1)), containsString("x")); + assertThat(Alias.unwrap(fields.get(1)), instanceOf(Add.class)); + // min + 10 - median + assertThat(Expressions.name(fields.get(2)), containsString("y")); + assertThat(Alias.unwrap(fields.get(2)), instanceOf(Sub.class)); + + var limit = as(eval.child(), Limit.class); + + var agg = as(limit.child(), Aggregate.class); + var aggs = agg.aggregates(); + var sum = as(Alias.unwrap(aggs.get(0)), Sum.class); + var count = as(Alias.unwrap(aggs.get(1)), Count.class); + var max = as(Alias.unwrap(aggs.get(2)), Max.class); + var min = as(Alias.unwrap(aggs.get(3)), Min.class); + var percentile = as(Alias.unwrap(aggs.get(4)), Percentile.class); + + eval = as(agg.child(), Eval.class); + fields = eval.fields(); + assertThat(Expressions.name(fields.get(0)), is("z")); + assertThat(Expressions.name(fields.get(1)), containsString("AVG")); + assertThat(Expressions.name(Alias.unwrap(fields.get(1))), containsString("salary")); + assertThat(Expressions.name(fields.get(2)), containsString("MIN")); + assertThat(Expressions.name(Alias.unwrap(fields.get(2))), containsString("emp_no")); + } + + /** + * Expects + * Project[[x{r}#5, y{r}#9, z{r}#12]] + * \_Eval[[$$SUM$$$AVG$CONCAT(TO_STRIN>$0$0{r}#29 / $$COUNT$$$AVG$CONCAT(TO_STRIN>$0$1{r}#30 AS $$AVG$CONCAT(TO_STRIN>$0, + * CONCAT(TOSTRING($$AVG$CONCAT(TO_STRIN>$0{r}#23),TOSTRING($$MAX$CONCAT(TO_STRIN>$1{r}#24)) AS x, + * $$MIN$(MIN(emp_no_/_3>$2{r}#25 + 3.141592653589793[DOUBLE] - $$MEDIAN$(MIN(emp_no_/_3>$3{r}#26 / 2.718281828459045[DOUBLE] + * AS y]] + * \_Limit[500[INTEGER]] + * \_Aggregate[[z{r}#12],[SUM($$salary_%_3$AVG$0{r}#27) AS $$SUM$$$AVG$CONCAT(TO_STRIN>$0$0, + * COUNT($$salary_%_3$AVG$0{r}#27) AS $$COUNT$$$AVG$CONCAT(TO_STRIN>$0$1, + * MAX(emp_no{f}#13) AS $$MAX$CONCAT(TO_STRIN>$1, + * MIN($$emp_no_/_3$MIN$1{r}#28) AS $$MIN$(MIN(emp_no_/_3>$2, + * PERCENTILE(salary{f}#18,50[INTEGER]) AS $$MEDIAN$(MIN(emp_no_/_3>$3, z{r}#12]] + * \_Eval[[languages{f}#16 % 2[INTEGER] AS z, + * salary{f}#18 % 3[INTEGER] AS $$salary_%_3$AVG$0, + * emp_no{f}#13 / 3[INTEGER] AS $$emp_no_/_3$MIN$1]] + * \_EsRelation[test][_meta_field{f}#19, emp_no{f}#13, first_name{f}#14, ..] + */ + public void testStatsExpOverAggsWithScalars() { + var plan = optimizedPlan(""" + from test + | stats x = CONCAT(TO_STRING(AVG(salary % 3)), TO_STRING(MAX(emp_no))), + y = (MIN(emp_no / 3) + PI() - MEDIAN(salary))/E() + by z = languages % 2 + """); + + var project = as(plan, Project.class); + assertThat(Expressions.names(project.projections()), contains("x", "y", "z")); + var eval = as(project.child(), Eval.class); + var fields = eval.fields(); + // avg = Sum/Count + assertThat(Expressions.name(fields.get(0)), containsString("AVG")); + assertThat(Alias.unwrap(fields.get(0)), instanceOf(Div.class)); + // concat(to_string(avg) + assertThat(Expressions.name(fields.get(1)), containsString("x")); + var concat = as(Alias.unwrap(fields.get(1)), Concat.class); + var toString = as(concat.children().get(0), ToString.class); + toString = as(concat.children().get(1), ToString.class); + // min + 10 - median/e + assertThat(Expressions.name(fields.get(2)), containsString("y")); + assertThat(Alias.unwrap(fields.get(2)), instanceOf(Div.class)); + + var limit = as(eval.child(), Limit.class); + + var agg = as(limit.child(), Aggregate.class); + var aggs = agg.aggregates(); + var sum = as(Alias.unwrap(aggs.get(0)), Sum.class); + var count = as(Alias.unwrap(aggs.get(1)), Count.class); + var max = as(Alias.unwrap(aggs.get(2)), Max.class); + var min = as(Alias.unwrap(aggs.get(3)), Min.class); + var percentile = as(Alias.unwrap(aggs.get(4)), Percentile.class); + assertThat(Expressions.name(aggs.get(5)), is("z")); + + eval = as(agg.child(), Eval.class); + fields = eval.fields(); + assertThat(Expressions.name(fields.get(0)), is("z")); + assertThat(Expressions.name(fields.get(1)), containsString("AVG")); + assertThat(Expressions.name(Alias.unwrap(fields.get(1))), containsString("salary")); + assertThat(Expressions.name(fields.get(2)), containsString("MIN")); + assertThat(Expressions.name(Alias.unwrap(fields.get(2))), containsString("emp_no")); + } + + /** + * Expects + * Project[[a{r}#5, b{r}#9, $$max(salary)_+_3>$COUNT$2{r}#46 AS d, $$count(salary)_->$MIN$3{r}#47 AS e, $$avg(salary)_+_m + * >$MAX$1{r}#45 AS g]] + * \_Eval[[$$$$avg(salary)_+_m>$AVG$0$SUM$0{r}#48 / $$max(salary)_+_3>$COUNT$2{r}#46 AS $$avg(salary)_+_m>$AVG$0, $$avg( + * salary)_+_m>$AVG$0{r}#44 + $$avg(salary)_+_m>$MAX$1{r}#45 AS a, $$avg(salary)_+_m>$MAX$1{r}#45 + 3[INTEGER] + + * 3.141592653589793[DOUBLE] + $$max(salary)_+_3>$COUNT$2{r}#46 AS b]] + * \_Limit[500[INTEGER]] + * \_Aggregate[[w{r}#28],[SUM(salary{f}#39) AS $$$$avg(salary)_+_m>$AVG$0$SUM$0, MAX(salary{f}#39) AS $$avg(salary)_+_m>$MAX$1 + * , COUNT(salary{f}#39) AS $$max(salary)_+_3>$COUNT$2, MIN(salary{f}#39) AS $$count(salary)_->$MIN$3]] + * \_Eval[[languages{f}#37 % 2[INTEGER] AS w]] + * \_EsRelation[test][_meta_field{f}#40, emp_no{f}#34, first_name{f}#35, ..] + */ + public void testStatsExpOverAggsWithScalarAndDuplicateAggs() { + var plan = optimizedPlan(""" + from test + | stats a = avg(salary) + max(salary), + b = max(salary) + 3 + PI() + count(salary), + c = count(salary) - min(salary), + d = count(salary), + e = min(salary), + f = max(salary), + g = max(salary) + by w = languages % 2 + | keep a, b, d, e, g + """); + + var project = as(plan, Project.class); + var projections = project.projections(); + assertThat(Expressions.names(projections), contains("a", "b", "d", "e", "g")); + var refA = Alias.unwrap(projections.get(0)); + var refB = Alias.unwrap(projections.get(1)); + var refD = Alias.unwrap(projections.get(2)); + var refE = Alias.unwrap(projections.get(3)); + var refG = Alias.unwrap(projections.get(4)); + + var eval = as(project.child(), Eval.class); + var fields = eval.fields(); + // avg = Sum/Count + assertThat(Expressions.name(fields.get(0)), containsString("AVG")); + assertThat(Alias.unwrap(fields.get(0)), instanceOf(Div.class)); + // avg + max + assertThat(Expressions.name(fields.get(1)), is("a")); + var add = as(Alias.unwrap(fields.get(1)), Add.class); + var max_salary = add.right(); + assertThat(Expressions.attribute(fields.get(1)), is(Expressions.attribute(refA))); + + assertThat(Expressions.name(fields.get(2)), is("b")); + assertThat(Expressions.attribute(fields.get(2)), is(Expressions.attribute(refB))); + + add = as(Alias.unwrap(fields.get(2)), Add.class); + add = as(add.left(), Add.class); + add = as(add.left(), Add.class); + assertThat(Expressions.attribute(max_salary), is(Expressions.attribute(add.left()))); + + var limit = as(eval.child(), Limit.class); + + var agg = as(limit.child(), Aggregate.class); + var aggs = agg.aggregates(); + var sum = as(Alias.unwrap(aggs.get(0)), Sum.class); + + assertThat(Expressions.attribute(aggs.get(1)), is(Expressions.attribute(max_salary))); + var max = as(Alias.unwrap(aggs.get(1)), Max.class); + var count = as(Alias.unwrap(aggs.get(2)), Count.class); + var min = as(Alias.unwrap(aggs.get(3)), Min.class); + + eval = as(agg.child(), Eval.class); + fields = eval.fields(); + assertThat(Expressions.name(fields.get(0)), is("w")); + } + + /** + * Expects + * Project[[a{r}#5, a{r}#5 AS b, w{r}#12]] + * \_Limit[500[INTEGER]] + * \_Aggregate[[w{r}#12],[SUM($$salary_/_2_+_la>$SUM$0{r}#26) AS a, w{r}#12]] + * \_Eval[[emp_no{f}#16 % 2[INTEGER] AS w, salary{f}#21 / 2[INTEGER] + languages{f}#19 AS $$salary_/_2_+_la>$SUM$0]] + * \_EsRelation[test][_meta_field{f}#22, emp_no{f}#16, first_name{f}#17, ..] + */ + public void testStatsWithCanonicalAggregate() throws Exception { + var plan = optimizedPlan(""" + from test + | stats a = sum(salary / 2 + languages), + b = sum(languages + salary / 2) + by w = emp_no % 2 + | keep a, b, w + """); + + var project = as(plan, Project.class); + assertThat(Expressions.names(project.projections()), contains("a", "b", "w")); + assertThat(Expressions.name(Alias.unwrap(project.projections().get(1))), is("a")); + var limit = as(project.child(), Limit.class); + var aggregate = as(limit.child(), Aggregate.class); + var aggregates = aggregate.aggregates(); + assertThat(Expressions.names(aggregates), contains("a", "w")); + var unwrapped = Alias.unwrap(aggregates.get(0)); + var sum = as(unwrapped, Sum.class); + var sum_argument = sum.field(); + var grouping = aggregates.get(1); + + var eval = as(aggregate.child(), Eval.class); + var fields = eval.fields(); + assertThat(Expressions.attribute(fields.get(0)), is(Expressions.attribute(grouping))); + assertThat(Expressions.attribute(fields.get(1)), is(Expressions.attribute(sum_argument))); + } + private LogicalPlan optimizedPlan(String query) { return plan(query); }