Skip to content

Commit

Permalink
ESQL: Introduce per agg filter (elastic#113735)
Browse files Browse the repository at this point in the history
Add support for aggregation scoped filters that work dynamically on the
 data in each group.

| STATS
    success = COUNT(*) WHERE 200 <= code AND code < 300,
   redirect = COUNT(*) WHERE 300 <= code AND code < 400,
 client_err = COUNT(*) WHERE 400 <= code AND code < 500,
 server_err = COUNT(*) WHERE 500 <= code AND code < 600,
 total_count = COUNT(*)

Implementation wise, the base AggregateFunction has been extended to
 allow a filter to be passed on. This is required to incorporate the
 filter as part of the aggregate equality/identify which would fail with
 the filter as an external component.

As part of the process, the serialization for the existing aggregations
 had to be fixed so AggregateFunction implementations so that it
 delegates to their parent first.
  • Loading branch information
costin authored and georgewallace committed Oct 25, 2024
1 parent 2961658 commit bb593b5
Show file tree
Hide file tree
Showing 57 changed files with 3,181 additions and 2,113 deletions.
28 changes: 28 additions & 0 deletions docs/changelog/113735.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
pr: 113735
summary: "ESQL: Introduce per agg filter"
area: ES|QL
type: feature
issues: []
highlight:
title: "ESQL: Introduce per agg filter"
body: |-
Add support for aggregation scoped filters that work dynamically on the
data in each group.
[source,esql]
----
| STATS success = COUNT(*) WHERE 200 <= code AND code < 300,
redirect = COUNT(*) WHERE 300 <= code AND code < 400,
client_err = COUNT(*) WHERE 400 <= code AND code < 500,
server_err = COUNT(*) WHERE 500 <= code AND code < 600,
total_count = COUNT(*)
----
Implementation wise, the base AggregateFunction has been extended to
allow a filter to be passed on. This is required to incorporate the
filter as part of the aggregate equality/identity which would fail with
the filter as an external component.
As part of the process, the serialization for the existing aggregations
had to be fixed so AggregateFunction implementations so that it
delegates to their parent first.
notable: true
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ static TransportVersion def(int id) {
public static final TransportVersion CHUNK_SENTENCE_OVERLAP_SETTING_ADDED = def(8_767_00_0);
public static final TransportVersion OPT_IN_ESQL_CCS_EXECUTION_INFO = def(8_768_00_0);
public static final TransportVersion QUERY_RULE_TEST_API = def(8_769_00_0);
public static final TransportVersion ESQL_PER_AGGREGATE_FILTER = def(8_770_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,19 @@ public static int mapSize(int size) {
}
return (int) (size / 0.75f + 1f);
}

@SafeVarargs
@SuppressWarnings("varargs")
public static <T> List<T> nullSafeList(T... entries) {
if (entries == null || entries.length == 0) {
return emptyList();
}
List<T> list = new ArrayList<>(entries.length);
for (T entry : entries) {
if (entry != null) {
list.add(entry);
}
}
return list;
}
}
183 changes: 183 additions & 0 deletions x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec
Original file line number Diff line number Diff line change
Expand Up @@ -2290,3 +2290,186 @@ from employees
m:integer |a:double |x:integer
74999 |48249.0 |0
;


statsWithFiltering
required_capability: per_agg_filtering
from employees
| stats max = max(salary), max_f = max(salary) where salary < 50000, max_a = max(salary) where salary > 100,
min = min(salary), min_f = min(salary) where salary > 50000, min_a = min(salary) where salary > 100
;

max:integer |max_f:integer |max_a:integer | min:integer | min_f:integer | min_a:integer
74999 |49818 |74999 | 25324 | 50064 | 25324
;

statsWithEverythingFiltered
required_capability: per_agg_filtering
from employees
| stats max = max(salary), max_a = max(salary) where salary < 100,
min = min(salary), min_a = min(salary) where salary > 99999
;

max:integer |max_a:integer|min:integer | min_a:integer
74999 |null |25324 | null
;

statsWithNullFilter
required_capability: per_agg_filtering
from employees
| stats max = max(salary), max_a = max(salary) where null,
min = min(salary), min_a = min(salary) where to_string(null) == "abc"
;

max:integer |max_a:integer|min:integer | min_a:integer
74999 |null |25324 | null
;

statsWithBasicExpressionFiltered
required_capability: per_agg_filtering
from employees
| stats max = max(salary), max_f = max(salary) where salary < 50000,
min = min(salary), min_f = min(salary) where salary > 50000,
exp_p = max(salary) + 10000 where salary < 50000,
exp_m = min(salary) % 10000 where salary > 50000
;

max:integer |max_f:integer|min:integer | min_f:integer|exp_p:integer | exp_m:integer
74999 |49818 |25324 | 50064 |59818 | 64
;

statsWithExpressionOverFilters
required_capability: per_agg_filtering
from employees
| stats max = max(salary), max_f = max(salary) where salary < 50000,
min = min(salary), min_f = min(salary) where salary > 50000,
exp_gt = max(salary) - min(salary) where salary > 50000,
exp_lt = max(salary) - min(salary) where salary < 50000

;

max:integer |max_f:integer | min:integer | min_f:integer |exp_gt:integer | exp_lt:integer
74999 |49818 | 25324 | 50064 |24935 | 24494
;


statsWithExpressionOfExpressionsOverFilters
required_capability: per_agg_filtering
from employees
| stats max = max(salary + 1), max_f = max(salary + 2) where salary < 50000,
min = min(salary - 1), min_f = min(salary - 2) where salary > 50000,
exp_gt = max(salary + 3) - min(salary - 3) where salary > 50000,
exp_lt = max(salary + 4) - min(salary - 4) where salary < 50000

;

max:integer |max_f:integer | min:integer | min_f:integer |exp_gt:integer | exp_lt:integer
75000 |49820 | 25323 | 50062 |24941 | 24502
;

statsWithSubstitutedExpressionOverFilters
required_capability: per_agg_filtering
from employees
| stats sum = sum(salary), s_l = sum(salary) where salary < 50000, s_u = sum(salary) where salary > 50000,
count = count(salary), c_l = count(salary) where salary < 50000, c_u = count(salary) where salary > 50000,
avg = round(avg(salary), 2), a_l = round(avg(salary), 2) where salary < 50000, a_u = round(avg(salary),2) where salary > 50000
;

sum:l |s_l:l | s_u:l | count:l |c_l:l |c_u:l |avg:double |a_l:double | a_u:double
4824855 |2220951 | 2603904 | 100 |58 |42 |48248.55 |38292.26 | 61997.71
;


statsWithFilterAndGroupBy
required_capability: per_agg_filtering
from employees
| stats m = max(height),
m_f = max(height + 1) where gender == "M" OR is_rehired is null
BY gender, is_rehired
| sort gender, is_rehired
;

m:d |m_f:d |gender:s|is_rehired:bool
2.1 |null |F |false
2.1 |null |F |true
1.85|2.85 |F |null
2.1 |3.1 |M |false
2.1 |3.1 |M |true
2.01|3.01 |M |null
2.06|null |null |false
1.97|null |null |true
1.99|2.99 |null |null
;

statsWithFilterOnGroupBy
required_capability: per_agg_filtering
from employees
| stats m_f = max(height) where gender == "M" BY gender
| sort gender
;

m_f:d |gender:s
null |F
2.1 |M
null |null
;

statsWithGroupByLiteral
required_capability: per_agg_filtering
from employees
| stats m = max(languages) by salary = 2
;

m:i |salary:i
5 |2
;


statsWithFilterOnSameColumn
required_capability: per_agg_filtering
from employees
| stats m = max(languages), m_f = max(languages) where salary > 50000 by salary = 2
| sort salary
;

m:i |m_f:i |salary:i
5 |null |2
;

# the query is reused below in a multi-stats
statsWithFilteringAndGrouping
required_capability: per_agg_filtering
from employees
| stats c = count(), c_f = count(languages) where l > 1,
m_f = max(height) where salary > 50000
by l = languages
| sort c
;

c:l |c_f:l |m_f:d |l:i
10 |0 |2.08 |null
15 |0 |2.06 |1
17 |17 |2.1 |3
18 |18 |1.83 |4
19 |19 |2.03 |2
21 |21 |2.1 |5
;

multiStatsWithFiltering
required_capability: per_agg_filtering
from employees
| stats c = count(), c_f = count(languages) where l > 1,
m_f = max(height) where salary > 50000
by l = languages
| stats c2 = count(), c2_f = count() where m_f > 2.06 , m2 = max(l), m2_f = max(l) where l > 1 by c
| sort c
;

c2:l |c2_f:l |m2:i |m2_f:i |c:l
1 |1 |null |null |10
1 |0 |1 |null |15
1 |1 |3 |3 |17
1 |0 |4 |4 |18
1 |0 |2 |2 |19
1 |1 |5 |5 |21
;
1 change: 1 addition & 0 deletions x-pack/plugin/esql/src/main/antlr/EsqlBaseLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ SLASH : '/';
PERCENT : '%';

MATCH : 'match';
NESTED_WHERE : {this.isDevVersion()}? WHERE -> type(WHERE);

NAMED_OR_POSITIONAL_PARAM
: PARAM (LETTER | UNDERSCORE) UNQUOTED_ID_BODY*
Expand Down
20 changes: 13 additions & 7 deletions x-pack/plugin/esql/src/main/antlr/EsqlBaseParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -123,17 +123,15 @@ fields
;

field
: booleanExpression
| qualifiedName ASSIGN booleanExpression
: (qualifiedName ASSIGN)? booleanExpression
;

fromCommand
: FROM indexPattern (COMMA indexPattern)* metadata?
;

indexPattern
: clusterString COLON indexString
| indexString
: (clusterString COLON)? indexString
;

clusterString
Expand All @@ -159,15 +157,23 @@ deprecated_metadata
;

metricsCommand
: DEV_METRICS indexPattern (COMMA indexPattern)* aggregates=fields? (BY grouping=fields)?
: DEV_METRICS indexPattern (COMMA indexPattern)* aggregates=aggFields? (BY grouping=fields)?
;

evalCommand
: EVAL fields
;

statsCommand
: STATS stats=fields? (BY grouping=fields)?
: STATS stats=aggFields? (BY grouping=fields)?
;

aggFields
: aggField (COMMA aggField)*
;

aggField
: field {this.isDevVersion()}? (WHERE booleanExpression)?
;

qualifiedName
Expand Down Expand Up @@ -316,5 +322,5 @@ lookupCommand
;

inlinestatsCommand
: DEV_INLINESTATS stats=fields (BY grouping=fields)?
: DEV_INLINESTATS stats=aggFields (BY grouping=fields)?
;
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,12 @@ public enum Cap {
/**
* Fix sorting not allowed on _source and counters.
*/
SORTING_ON_SOURCE_AND_COUNTERS_FORBIDDEN;
SORTING_ON_SOURCE_AND_COUNTERS_FORBIDDEN,

/**
* Allow filter per individual aggregation.
*/
PER_AGG_FILTERING;

private final boolean snapshotOnly;
private final FeatureFlag featureFlag;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ private LogicalPlan resolveStats(Stats stats, List<Attribute> childrenOutput) {
newAggregates.add(agg);
}

// TODO: remove this when Stats interface is removed
stats = changed.get() ? stats.with(stats.child(), groupings, newAggregates) : stats;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.elasticsearch.xpack.esql.core.util.Holder;
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.FilteredExpression;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate;
import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction;
import org.elasticsearch.xpack.esql.expression.function.fulltext.Match;
Expand Down Expand Up @@ -308,6 +309,29 @@ private static void checkInvalidNamedExpressionUsage(
Set<Failure> failures,
int level
) {
// unwrap filtered expression
if (e instanceof FilteredExpression fe) {
e = fe.delegate();
// make sure they work on aggregate functions
if (e.anyMatch(AggregateFunction.class::isInstance) == false) {
Expression filter = fe.filter();
failures.add(fail(filter, "WHERE clause allowed only for aggregate functions, none found in [{}]", fe.sourceText()));
}
// but that the filter doesn't use grouping or aggregate functions
fe.filter().forEachDown(c -> {
if (c instanceof AggregateFunction af) {
failures.add(
fail(af, "cannot use aggregate function [{}] in aggregate WHERE clause [{}]", af.sourceText(), fe.sourceText())
);
}
// check the bucketing function against the group
else if (c instanceof GroupingFunction gf) {
if (Expressions.anyMatch(groups, ex -> ex instanceof Alias a && a.child().semanticEquals(gf)) == false) {
failures.add(fail(gf, "can only use grouping function [{}] part of the BY clause", gf.sourceText()));
}
}
});
}
// found an aggregate, constant or a group, bail out
if (e instanceof AggregateFunction af) {
af.field().forEachDown(AggregateFunction.class, f -> {
Expand All @@ -319,7 +343,7 @@ private static void checkInvalidNamedExpressionUsage(
} else if (e instanceof GroupingFunction gf) {
// optimizer will later unroll expressions with aggs and non-aggs with a grouping function into an EVAL, but that will no longer
// be verified (by check above in checkAggregate()), so do it explicitly here
if (groups.stream().anyMatch(ex -> ex instanceof Alias a && a.child().semanticEquals(gf)) == false) {
if (Expressions.anyMatch(groups, ex -> ex instanceof Alias a && a.child().semanticEquals(gf)) == false) {
failures.add(fail(gf, "can only use grouping function [{}] part of the BY clause", gf.sourceText()));
} else if (level == 0) {
addFailureOnGroupingUsedNakedInAggs(failures, gf, "function");
Expand Down
Loading

0 comments on commit bb593b5

Please sign in to comment.