Skip to content

Commit

Permalink
ESQL: Allow grouping key inside stats expressions (elastic#106579)
Browse files Browse the repository at this point in the history
Similar to aggs, allow grouping keys to used inside STATS expressions by
 introducing a synthetic eval, e.g.:

STATS a = x + count(*) BY x becomes
STATS c = count(*) BY x | EVAL a = x + c | KEEP a, x

To better handle overriding aliases, introduce EsqlAggregate which keeps
 the declared structure intact during analysis and verification while
 merging the output. The deduplication happens now in the optimization
 phase.

 Fix small bug that caused replacement of expressions inside aggregations
to be skipped despite being applied
Improved Verifier to not repeat error messages in case for Aggregates
Removed verification heuristics for missing columns as functions as it
 was too broad
  • Loading branch information
costin authored Apr 4, 2024
1 parent 0cc19f3 commit 7483844
Show file tree
Hide file tree
Showing 11 changed files with 728 additions and 153 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/106579.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 106579
summary: "ESQL: Allow grouping key inside stats expressions"
area: ES|QL
type: enhancement
issues: []
141 changes: 140 additions & 1 deletion x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,7 @@ FROM employees
;

// tag::docsStatsByExpression-result[]
my_count:long |LEFT(last_name, 1):keyword
my_count:long |LEFT(last_name, 1):keyword
2 |A
11 |B
5 |C
Expand Down Expand Up @@ -1188,6 +1188,145 @@ e:i | l:i
4 | 3
;

nestedAggsOverGroupingExpressionWithoutAlias#[skip:-8.13.99,reason:supported in 8.14]
FROM employees
| STATS e = max(languages + emp_no) + 1 by languages + emp_no
| SORT e
| LIMIT 3
;

e:i | languages + emp_no:i
10004 | 10003
10007 | 10006
10008 | 10007
;

nestedAggsOverGroupingExpressionMultiGroupWithoutAlias#[skip:-8.13.99,reason:supported in 8.14]
FROM employees
| STATS e = max(languages + emp_no + 10) + 1 by languages + emp_no, f = emp_no % 3
| SORT e, f
| LIMIT 3
;

e:i | languages + emp_no:i | f:i
10014 | 10003 | 2
10017 | 10006 | 0
10018 | 10007 | 0
;

nestedAggsOverGroupingExpressionWithAlias#[skip:-8.13.99,reason:supported in 8.14]
FROM employees
| STATS e = max(languages + emp_no + 10) + 1 by languages + emp_no
| SORT e
| LIMIT 3
;

e:i | languages + emp_no:i
10014 | 10003
10017 | 10006
10018 | 10007
;

nestedAggsOverGroupingExpressionWithAlias#[skip:-8.13.99,reason:supported in 8.14]
FROM employees
| STATS e = max(a), f = min(a), g = count(a) + 1 by a = languages + emp_no
| SORT a
| LIMIT 3
;

e: i | f:i | g:l | a:i
10003 | 10003 | 2 | 10003
10006 | 10006 | 2 | 10006
10007 | 10007 | 3 | 10007
;

nestedAggsOverGroupingTwiceWithAlias#[skip:-8.13.99,reason:supported in 8.14]
FROM employees
| STATS vals = COUNT() BY x = emp_no, x = languages
| SORT x
| LIMIT 3
;

vals: l| x:i
15 | 1
19 | 2
17 | 3
;

nestedAggsOverGroupingWithAlias#[skip:-8.13.99,reason:supported in 8.14]
FROM employees
| STATS e = length(f) + 1, count(*) by f = first_name
| SORT f
| LIMIT 3
;

e:i | count(*):l | f:s
10 | 1 | Alejandro
8 | 1 | Amabile
7 | 1 | Anneke
;

nestedAggsOverGroupingWithAliasInsideExpression#[skip:-8.13.99,reason:supported in 8.14]
FROM employees
| STATS m = max(l), o = min(s) by l = languages, s = salary + 1
| SORT l, s
| LIMIT 5
;

m:i | o:i | l:i | s:i
1 | 25977 | 1 | 25977
1 | 28036 | 1 | 28036
1 | 34342 | 1 | 34342
1 | 39111 | 1 | 39111
1 | 39729 | 1 | 39729
;

nestedAggsOverGroupingWithAliasAndProjection#[skip:-8.13.99,reason:supported in 8.14]
FROM employees
| STATS e = length(f) + 1, c = count(*) by f = first_name
| KEEP e
| SORT e
| LIMIT 5
;

e:i
4
4
4
4
5
;

nestedAggsOverGroupingAndAggWithAliasAndProjection#[skip:-8.13.99,reason:supported in 8.14]
FROM employees
| STATS e = length(f) + count(*), m = max(emp_no) by f = first_name
| KEEP e
| SORT e
| LIMIT 5
;

e:l
4
4
4
4
5
;

nestedAggsOverGroupingExpAndAggWithAliasAndProjection#[skip:-8.13.99,reason:supported in 8.14]
FROM employees
| STATS e = f + count(*), m = max(emp_no) by f = length(first_name) % 2
| KEEP e
| SORT e
| LIMIT 3
;

e:l
44
47
null
;

defaultNameWithSpace
ROW a = 1 | STATS couNt(*) | SORT `couNt(*)`
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
package org.elasticsearch.xpack.esql.analysis;

import org.elasticsearch.common.logging.HeaderWarning;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.core.enrich.EnrichPolicy;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.VerificationException;
import org.elasticsearch.xpack.esql.expression.NamedExpressions;
import org.elasticsearch.xpack.esql.expression.UnresolvedNamePattern;
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
import org.elasticsearch.xpack.esql.plan.logical.Drop;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
import org.elasticsearch.xpack.esql.plan.logical.EsqlAggregate;
import org.elasticsearch.xpack.esql.plan.logical.EsqlUnresolvedRelation;
import org.elasticsearch.xpack.esql.plan.logical.Eval;
import org.elasticsearch.xpack.esql.plan.logical.Keep;
Expand All @@ -27,9 +28,11 @@
import org.elasticsearch.xpack.ql.analyzer.AnalyzerRules;
import org.elasticsearch.xpack.ql.analyzer.AnalyzerRules.BaseAnalyzerRule;
import org.elasticsearch.xpack.ql.analyzer.AnalyzerRules.ParameterizedAnalyzerRule;
import org.elasticsearch.xpack.ql.capabilities.Resolvables;
import org.elasticsearch.xpack.ql.common.Failure;
import org.elasticsearch.xpack.ql.expression.Alias;
import org.elasticsearch.xpack.ql.expression.Attribute;
import org.elasticsearch.xpack.ql.expression.AttributeMap;
import org.elasticsearch.xpack.ql.expression.EmptyAttribute;
import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.expression.Expressions;
Expand All @@ -40,6 +43,8 @@
import org.elasticsearch.xpack.ql.expression.ReferenceAttribute;
import org.elasticsearch.xpack.ql.expression.UnresolvedAttribute;
import org.elasticsearch.xpack.ql.expression.UnresolvedStar;
import org.elasticsearch.xpack.ql.expression.function.FunctionDefinition;
import org.elasticsearch.xpack.ql.expression.function.FunctionRegistry;
import org.elasticsearch.xpack.ql.expression.function.UnresolvedFunction;
import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison;
import org.elasticsearch.xpack.ql.index.EsIndex;
Expand All @@ -53,13 +58,15 @@
import org.elasticsearch.xpack.ql.rule.ParameterizedRuleExecutor;
import org.elasticsearch.xpack.ql.rule.Rule;
import org.elasticsearch.xpack.ql.rule.RuleExecutor;
import org.elasticsearch.xpack.ql.session.Configuration;
import org.elasticsearch.xpack.ql.tree.Source;
import org.elasticsearch.xpack.ql.type.DataType;
import org.elasticsearch.xpack.ql.type.DataTypes;
import org.elasticsearch.xpack.ql.type.EsField;
import org.elasticsearch.xpack.ql.type.InvalidMappedField;
import org.elasticsearch.xpack.ql.type.UnsupportedEsField;
import org.elasticsearch.xpack.ql.util.CollectionUtils;
import org.elasticsearch.xpack.ql.util.Holder;
import org.elasticsearch.xpack.ql.util.StringUtils;

import java.util.ArrayList;
Expand All @@ -70,7 +77,6 @@
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand All @@ -85,7 +91,6 @@
import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.dateTimeToLong;
import static org.elasticsearch.xpack.esql.type.EsqlDataTypes.GEO_POINT;
import static org.elasticsearch.xpack.esql.type.EsqlDataTypes.GEO_SHAPE;
import static org.elasticsearch.xpack.ql.analyzer.AnalyzerRules.resolveFunction;
import static org.elasticsearch.xpack.ql.type.DataTypes.DATETIME;
import static org.elasticsearch.xpack.ql.type.DataTypes.DOUBLE;
import static org.elasticsearch.xpack.ql.type.DataTypes.FLOAT;
Expand All @@ -105,14 +110,7 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
private static final Iterable<RuleExecutor.Batch<LogicalPlan>> rules;

static {
var resolution = new Batch<>(
"Resolution",
new ResolveTable(),
new ResolveEnrich(),
new ResolveRefs(),
new ResolveFunctions(),
new RemoveDuplicateProjections()
);
var resolution = new Batch<>("Resolution", new ResolveTable(), new ResolveEnrich(), new ResolveFunctions(), new ResolveRefs());
var finish = new Batch<>("Finish Analysis", Limiter.ONCE, new AddImplicitLimit(), new PromoteStringsInDateComparisons());
rules = List.of(resolution, finish);
}
Expand Down Expand Up @@ -313,6 +311,10 @@ protected LogicalPlan doRule(LogicalPlan plan) {
childrenOutput.addAll(output);
}

if (plan instanceof Aggregate agg) {
return resolveAggregate(agg, childrenOutput);
}

if (plan instanceof Drop d) {
return resolveDrop(d, childrenOutput);
}
Expand All @@ -337,7 +339,60 @@ protected LogicalPlan doRule(LogicalPlan plan) {
return resolveMvExpand(p, childrenOutput);
}

return plan.transformExpressionsUp(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput));
return plan.transformExpressionsOnly(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput));
}

private LogicalPlan resolveAggregate(Aggregate a, List<Attribute> childrenOutput) {
// if the grouping is resolved but the aggs are not, use the former to resolve the latter
// e.g. STATS a ... GROUP BY a = x + 1
Holder<Boolean> changed = new Holder<>(false);
List<Expression> groupings = a.groupings();
// first resolve groupings since the aggs might refer to them
// trying to globally resolve unresolved attributes will lead to some being marked as unresolvable
if (Resolvables.resolved(groupings) == false) {
List<Expression> newGroupings = new ArrayList<>(groupings.size());
for (Expression g : groupings) {
Expression resolved = g.transformUp(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput));
if (resolved != g) {
changed.set(true);
}
newGroupings.add(resolved);
}
groupings = newGroupings;
if (changed.get()) {
a = new EsqlAggregate(a.source(), a.child(), newGroupings, a.aggregates());
changed.set(false);
}
}

if (a.expressionsResolved() == false && Resolvables.resolved(groupings)) {
AttributeMap<Expression> resolved = new AttributeMap<>();
for (Expression e : groupings) {
Attribute attr = Expressions.attribute(e);
if (attr != null) {
resolved.put(attr, attr);
}
}
List<Attribute> resolvedList = NamedExpressions.mergeOutputAttributes(new ArrayList<>(resolved.keySet()), childrenOutput);
List<NamedExpression> newAggregates = new ArrayList<>();

for (NamedExpression aggregate : a.aggregates()) {
var agg = (NamedExpression) aggregate.transformUp(UnresolvedAttribute.class, ua -> {
Expression ne = ua;
Attribute maybeResolved = maybeResolveAttribute(ua, resolvedList);
if (maybeResolved != null) {
changed.set(true);
ne = maybeResolved;
}
return ne;
});
newAggregates.add(agg);
}

a = changed.get() ? new EsqlAggregate(a.source(), a.child(), groupings, newAggregates) : a;
}

return a;
}

private LogicalPlan resolveMvExpand(MvExpand p, List<Attribute> childrenOutput) {
Expand Down Expand Up @@ -664,59 +719,30 @@ private static class ResolveFunctions extends ParameterizedAnalyzerRule<LogicalP

@Override
protected LogicalPlan rule(LogicalPlan plan, AnalyzerContext context) {
return plan.transformExpressionsUp(
return plan.transformExpressionsOnly(
UnresolvedFunction.class,
uf -> resolveFunction(uf, context.configuration(), context.functionRegistry())
);
}
}

/**
* Rule that removes duplicate projects - this is done as a separate rule to allow
* full validation of the node before looking at the duplication.
* The duplication needs to be addressed to avoid ambiguity errors from commands further down
* the line.
*/
private static class RemoveDuplicateProjections extends BaseAnalyzerRule {

@Override
protected boolean skipResolved() {
return false;
}

@Override
protected LogicalPlan doRule(LogicalPlan plan) {
if (plan.resolved()) {
if (plan instanceof Aggregate agg) {
plan = removeAggDuplicates(agg);
}
}
return plan;
}

private static LogicalPlan removeAggDuplicates(Aggregate agg) {
var groupings = agg.groupings();
var newGroupings = new LinkedHashSet<>(groupings);
// reuse existing objects
groupings = newGroupings.size() == groupings.size() ? groupings : new ArrayList<>(newGroupings);

var aggregates = agg.aggregates();
var newAggregates = new ArrayList<>(aggregates);
var nameSet = Sets.newHashSetWithExpectedSize(newAggregates.size());
// remove duplicates in reverse to preserve the last one appearing
for (int i = newAggregates.size() - 1; i >= 0; i--) {
var aggregate = newAggregates.get(i);
if (nameSet.add(aggregate.name()) == false) {
newAggregates.remove(i);
public static org.elasticsearch.xpack.ql.expression.function.Function resolveFunction(
UnresolvedFunction uf,
Configuration configuration,
FunctionRegistry functionRegistry
) {
org.elasticsearch.xpack.ql.expression.function.Function f = null;
if (uf.analyzed()) {
f = uf;
} else {
String functionName = functionRegistry.resolveAlias(uf.name());
if (functionRegistry.functionExists(functionName) == false) {
f = uf.missing(functionName, functionRegistry.listFunctions());
} else {
FunctionDefinition def = functionRegistry.resolveFunction(functionName);
f = uf.buildResolved(configuration, def);
}
}
// reuse existing objects
aggregates = newAggregates.size() == aggregates.size() ? aggregates : newAggregates;
// replace aggregate if needed
agg = (groupings == agg.groupings() && newAggregates == agg.aggregates())
? agg
: new Aggregate(agg.source(), agg.child(), groupings, aggregates);
return agg;
return f;
}
}

Expand Down
Loading

0 comments on commit 7483844

Please sign in to comment.