Skip to content

Commit

Permalink
ESQL: Extend STATS command to support aggregate expressions (elastic#…
Browse files Browse the repository at this point in the history
…104958)

Previously only aggregate functions (max/sum/etc..) were allowed inside
 the stats command. This PR allows expressions involving one or multiple
 aggregates to be used, such as:
 stats x = avg(salary % 3) + max(emp_no),
       y = min(emp_no / 3) + 10 - median(salary)
       by z = languages % 2

Improve verifier to not allow scalar functions over grouping for now
  • Loading branch information
costin authored Feb 6, 2024
1 parent af163b2 commit ac09d75
Show file tree
Hide file tree
Showing 9 changed files with 798 additions and 142 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/104958.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 104958
summary: "ESQL: Extend STATS command to support aggregate expressions"
area: ES|QL
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -1112,3 +1112,50 @@ STATS ck = COUNT(job_positions),
ck:l | cb:l | cd:l | ci:l | c:l | csv:l
221 | 204 | 183 | 183 | 100 | 100
;

nestedAggsNoGrouping#[skip:-8.12.99,reason:supported in 8.13+]
FROM employees
| STATS x = AVG(salary) / 2 + MAX(salary), a = AVG(salary), m = MAX(salary)
;

x:d | a:d | m:i
99123.275 | 48248.55 |74999
;

nestedAggsWithGrouping#[skip:-8.12.99,reason:supported in 8.13+]
FROM employees
| STATS x = ROUND(AVG(salary % 3)) + MAX(emp_no), y = MIN(emp_no / 3) + 10 - MEDIAN(salary) by z = languages % 2
| SORT z
;

x:d | y:d | z:i
10101 | -41474.0 | 0
10098 | -45391.0 | 1
10030 | -44714.5 | null
;

nestedAggsWithScalars#[skip:-8.12.99,reason:supported in 8.13+]
FROM employees
| STATS x = CONCAT(TO_STRING(ROUND(AVG(salary % 3))), TO_STRING(MAX(emp_no))),
y = ROUND((MIN(emp_no / 3) + PI() - MEDIAN(salary))/E())
BY z = languages % 2
;

x:s | y:d | z:i
1.010029 | -16452.0 | null
1.010100 | -15260.0 | 0
1.010097 | -16701.0 | 1
;

nestedAggsOverGroupingWithAlias#[skip:-8.12.99,reason:supported in 8.13]
FROM employees
| STATS e = max(languages) + 1 by l = languages
| SORT l
| LIMIT 3
;

e:i | l:i
2 | 1
3 | 2
4 | 3
;
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ NULL
;


medianOfLong#[skip:-8.11.99,reason:ReplaceDuplicateAggWithEval breaks bwc gh-103765]
medianOfLong#[skip:-8.12.99,reason:ReplaceStatsAggExpressionWithEval breaks bwc gh-103765]
from employees | stats m = median(salary_change.long), p50 = percentile(salary_change.long, 50);

m:double | p50:double
0 | 0
;

medianOfInteger#[skip:-8.12.99,reason:ReplaceDuplicateAggWithEval breaks bwc gh-103765/Expression spaces are maintained since 8.13]
medianOfInteger#[skip:-8.12.99,reason:ReplaceStatsAggExpressionWithEval breaks bwc/Expression spaces are maintained since 8.13]
// tag::median[]
FROM employees
| STATS MEDIAN(salary), PERCENTILE(salary, 50)
Expand All @@ -90,15 +90,15 @@ MEDIAN(salary):double | PERCENTILE(salary, 50):double
// end::median-result[]
;

medianOfDouble#[skip:-8.11.99,reason:ReplaceDuplicateAggWithEval breaks bwc gh-103765]
medianOfDouble#[skip:-8.12.99,reason:ReplaceStatsAggExpressionWithEval breaks bwc gh-103765]
from employees | stats m = median(salary_change), p50 = percentile(salary_change, 50);

m:double | p50:double
0.75 | 0.75
;


medianOfLongByKeyword#[skip:-8.11.99,reason:ReplaceDuplicateAggWithEval breaks bwc gh-103765]
medianOfLongByKeyword#[skip:-8.12.99,reason:ReplaceStatsAggExpressionWithEval breaks bwc gh-103765]
from employees | stats m = median(salary_change.long), p50 = percentile(salary_change.long, 50) by job_positions | sort m desc | limit 4;

m:double | p50:double | job_positions:keyword
Expand All @@ -109,7 +109,7 @@ m:double | p50:double | job_positions:keyword
;


medianOfIntegerByKeyword#[skip:-8.11.99,reason:ReplaceDuplicateAggWithEval breaks bwc gh-103765]
medianOfIntegerByKeyword#[skip:-8.12.99,reason:ReplaceStatsAggExpressionWithEval breaks bwc gh-103765]
from employees | stats m = median(salary), p50 = percentile(salary, 50) by job_positions | sort m | limit 4;

m:double | p50:double | job_positions:keyword
Expand All @@ -120,7 +120,7 @@ m:double | p50:double | job_positions:keyword
;


medianOfDoubleByKeyword#[skip:-8.11.99,reason:ReplaceDuplicateAggWithEval breaks bwc gh-103765]
medianOfDoubleByKeyword#[skip:-8.12.99,reason:ReplaceStatsAggExpressionWithEval breaks bwc gh-103765]
from employees | stats m = median(salary_change), p50 = percentile(salary_change, 50)by job_positions | sort m desc | limit 4;

m:double | p50:double | job_positions:keyword
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import org.elasticsearch.xpack.ql.capabilities.Unresolvable;
import org.elasticsearch.xpack.ql.common.Failure;
import org.elasticsearch.xpack.ql.expression.Alias;
import org.elasticsearch.xpack.ql.expression.AttributeMap;
import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.expression.Expressions;
import org.elasticsearch.xpack.ql.expression.NamedExpression;
import org.elasticsearch.xpack.ql.expression.TypeResolutions;
import org.elasticsearch.xpack.ql.expression.UnresolvedAttribute;
Expand Down Expand Up @@ -67,6 +67,8 @@ public Verifier(Metrics metrics) {
Collection<Failure> verify(LogicalPlan plan, BitSet partialMetrics) {
assert partialMetrics != null;
Set<Failure> failures = new LinkedHashSet<>();
// alias map, collected during the first iteration for better error messages
AttributeMap<Expression> aliases = new AttributeMap<>();

// quick verification for unresolved attributes
plan.forEachUp(p -> {
Expand All @@ -80,6 +82,7 @@ Collection<Failure> verify(LogicalPlan plan, BitSet partialMetrics) {
}
// p is resolved, skip
else if (p.resolved()) {
p.forEachExpressionUp(Alias.class, a -> aliases.put(a.toAttribute(), a.child()));
return;
}
// handle aggregate first to disambiguate between missing fields or incorrect function declaration
Expand Down Expand Up @@ -128,7 +131,7 @@ else if (p.resolved()) {
return;
}
checkFilterConditionType(p, failures);
checkAggregate(p, failures);
checkAggregate(p, failures, aliases);
checkRegexExtractOnlyOnStrings(p, failures);

checkRow(p, failures);
Expand All @@ -147,38 +150,60 @@ else if (p.resolved()) {
return failures;
}

private static void checkAggregate(LogicalPlan p, Set<Failure> failures) {
private static void checkAggregate(LogicalPlan p, Set<Failure> failures, AttributeMap<Expression> aliases) {
if (p instanceof Aggregate agg) {
// check aggregates

List<Expression> nakedGroups = new ArrayList<>(agg.groupings().size());
// check grouping
// The grouping can not be an aggregate function
agg.groupings().forEach(e -> {
e.forEachUp(g -> {
if (g instanceof AggregateFunction af) {
failures.add(fail(g, "cannot use an aggregate [{}] for grouping", af));
}
});
nakedGroups.add(Alias.unwrap(e));
});

// check aggregates - accept only aggregate functions or expressions in which each naked attribute is copied as
// specified in the grouping clause
agg.aggregates().forEach(e -> {
var exp = Alias.unwrap(e);
if (exp instanceof AggregateFunction af) {
af.field().forEachDown(AggregateFunction.class, f -> {
failures.add(fail(f, "nested aggregations [{}] not allowed inside other aggregations [{}]", f, af));
});
} else {
if (Expressions.match(agg.groupings(), g -> Alias.unwrap(g).semanticEquals(exp)) == false) {
failures.add(
fail(
exp,
"expected an aggregate function or group but got ["
+ exp.sourceText()
+ "] of type ["
+ exp.nodeName()
+ "]"
)
);
}
if (exp.foldable()) {
failures.add(fail(exp, "expected an aggregate function but found [{}]", exp.sourceText()));
}
// traverse the tree to find invalid matches
checkInvalidNamedExpressionUsage(exp, nakedGroups, failures, 0);
});
}
}

// check grouping
// The grouping can not be an aggregate function
agg.groupings().forEach(e -> e.forEachUp(g -> {
if (g instanceof AggregateFunction af) {
failures.add(fail(g, "cannot use an aggregate [{}] for grouping", af));
}
}));
// traverse the expression and look either for an agg function or a grouping match
// stop either when no children are left, the leaves are literals or a reference attribute is given
private static void checkInvalidNamedExpressionUsage(Expression e, List<Expression> groups, Set<Failure> failures, int level) {
// found an aggregate, constant or a group, bail out
if (e instanceof AggregateFunction af) {
af.field().forEachDown(AggregateFunction.class, f -> {
failures.add(fail(f, "nested aggregations [{}] not allowed inside other aggregations [{}]", f, af));
});
} else if (e.foldable()) {
// don't do anything
}
// don't allow nested groupings for now stats substring(group) by group as we don't optimize yet for them
else if (groups.contains(e)) {
if (level != 0) {
failures.add(fail(e, "scalar functions over groupings [{}] not allowed yet", e.sourceText()));
}
}
// if a reference is found, mark it as an error
else if (e instanceof NamedExpression ne) {
failures.add(fail(e, "column [{}] must appear in the STATS BY clause or be used in an aggregate function", ne.name()));
}
// other keep on going
else {
for (Expression child : e.children()) {
checkInvalidNamedExpressionUsage(child, groups, failures, level + 1);
}
}
}

Expand Down
Loading

0 comments on commit ac09d75

Please sign in to comment.