diff --git a/docs/changelog/113735.yaml b/docs/changelog/113735.yaml
new file mode 100644
index 0000000000000..4f6579c7cb9e0
--- /dev/null
+++ b/docs/changelog/113735.yaml
@@ -0,0 +1,28 @@
+pr: 113735
+summary: "ESQL: Introduce per agg filter"
+area: ES|QL
+type: feature
+issues: []
+highlight:
+ title: "ESQL: Introduce per agg filter"
+ body: |-
+ Add support for aggregation scoped filters that work dynamically on the
+ data in each group.
+
+ [source,esql]
+ ----
+ | STATS success = COUNT(*) WHERE 200 <= code AND code < 300,
+ redirect = COUNT(*) WHERE 300 <= code AND code < 400,
+ client_err = COUNT(*) WHERE 400 <= code AND code < 500,
+ server_err = COUNT(*) WHERE 500 <= code AND code < 600,
+ total_count = COUNT(*)
+ ----
+
+ Implementation wise, the base AggregateFunction has been extended to
+ allow a filter to be passed on. This is required to incorporate the
+ filter as part of the aggregate equality/identity which would fail with
+ the filter as an external component.
+ As part of the process, the serialization for the existing aggregations
+ had to be fixed so AggregateFunction implementations so that it
+ delegates to their parent first.
+ notable: true
diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java
index e4cceb6977adb..dec6974aedb6b 100644
--- a/server/src/main/java/org/elasticsearch/TransportVersions.java
+++ b/server/src/main/java/org/elasticsearch/TransportVersions.java
@@ -243,6 +243,7 @@ static TransportVersion def(int id) {
public static final TransportVersion CHUNK_SENTENCE_OVERLAP_SETTING_ADDED = def(8_767_00_0);
public static final TransportVersion OPT_IN_ESQL_CCS_EXECUTION_INFO = def(8_768_00_0);
public static final TransportVersion QUERY_RULE_TEST_API = def(8_769_00_0);
+ public static final TransportVersion ESQL_PER_AGGREGATE_FILTER = def(8_770_00_0);
/*
* STOP! READ THIS FIRST! No, really,
diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/CollectionUtils.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/CollectionUtils.java
index 48b5fd1605edf..8bfcf4ca5c405 100644
--- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/CollectionUtils.java
+++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/CollectionUtils.java
@@ -79,4 +79,19 @@ public static int mapSize(int size) {
}
return (int) (size / 0.75f + 1f);
}
+
+ @SafeVarargs
+ @SuppressWarnings("varargs")
+ public static List nullSafeList(T... entries) {
+ if (entries == null || entries.length == 0) {
+ return emptyList();
+ }
+ List list = new ArrayList<>(entries.length);
+ for (T entry : entries) {
+ if (entry != null) {
+ list.add(entry);
+ }
+ }
+ return list;
+ }
}
diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec
index 16c19083f78be..93b769e629ef9 100644
--- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec
+++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec
@@ -2290,3 +2290,186 @@ from employees
m:integer |a:double |x:integer
74999 |48249.0 |0
;
+
+
+statsWithFiltering
+required_capability: per_agg_filtering
+from employees
+| stats max = max(salary), max_f = max(salary) where salary < 50000, max_a = max(salary) where salary > 100,
+ min = min(salary), min_f = min(salary) where salary > 50000, min_a = min(salary) where salary > 100
+;
+
+max:integer |max_f:integer |max_a:integer | min:integer | min_f:integer | min_a:integer
+74999 |49818 |74999 | 25324 | 50064 | 25324
+;
+
+statsWithEverythingFiltered
+required_capability: per_agg_filtering
+from employees
+| stats max = max(salary), max_a = max(salary) where salary < 100,
+ min = min(salary), min_a = min(salary) where salary > 99999
+;
+
+max:integer |max_a:integer|min:integer | min_a:integer
+74999 |null |25324 | null
+;
+
+statsWithNullFilter
+required_capability: per_agg_filtering
+from employees
+| stats max = max(salary), max_a = max(salary) where null,
+ min = min(salary), min_a = min(salary) where to_string(null) == "abc"
+;
+
+max:integer |max_a:integer|min:integer | min_a:integer
+74999 |null |25324 | null
+;
+
+statsWithBasicExpressionFiltered
+required_capability: per_agg_filtering
+from employees
+| stats max = max(salary), max_f = max(salary) where salary < 50000,
+ min = min(salary), min_f = min(salary) where salary > 50000,
+ exp_p = max(salary) + 10000 where salary < 50000,
+ exp_m = min(salary) % 10000 where salary > 50000
+;
+
+max:integer |max_f:integer|min:integer | min_f:integer|exp_p:integer | exp_m:integer
+74999 |49818 |25324 | 50064 |59818 | 64
+;
+
+statsWithExpressionOverFilters
+required_capability: per_agg_filtering
+from employees
+| stats max = max(salary), max_f = max(salary) where salary < 50000,
+ min = min(salary), min_f = min(salary) where salary > 50000,
+ exp_gt = max(salary) - min(salary) where salary > 50000,
+ exp_lt = max(salary) - min(salary) where salary < 50000
+
+;
+
+max:integer |max_f:integer | min:integer | min_f:integer |exp_gt:integer | exp_lt:integer
+74999 |49818 | 25324 | 50064 |24935 | 24494
+;
+
+
+statsWithExpressionOfExpressionsOverFilters
+required_capability: per_agg_filtering
+from employees
+| stats max = max(salary + 1), max_f = max(salary + 2) where salary < 50000,
+ min = min(salary - 1), min_f = min(salary - 2) where salary > 50000,
+ exp_gt = max(salary + 3) - min(salary - 3) where salary > 50000,
+ exp_lt = max(salary + 4) - min(salary - 4) where salary < 50000
+
+;
+
+max:integer |max_f:integer | min:integer | min_f:integer |exp_gt:integer | exp_lt:integer
+75000 |49820 | 25323 | 50062 |24941 | 24502
+;
+
+statsWithSubstitutedExpressionOverFilters
+required_capability: per_agg_filtering
+from employees
+| stats sum = sum(salary), s_l = sum(salary) where salary < 50000, s_u = sum(salary) where salary > 50000,
+ count = count(salary), c_l = count(salary) where salary < 50000, c_u = count(salary) where salary > 50000,
+ avg = round(avg(salary), 2), a_l = round(avg(salary), 2) where salary < 50000, a_u = round(avg(salary),2) where salary > 50000
+;
+
+sum:l |s_l:l | s_u:l | count:l |c_l:l |c_u:l |avg:double |a_l:double | a_u:double
+4824855 |2220951 | 2603904 | 100 |58 |42 |48248.55 |38292.26 | 61997.71
+;
+
+
+statsWithFilterAndGroupBy
+required_capability: per_agg_filtering
+from employees
+| stats m = max(height),
+ m_f = max(height + 1) where gender == "M" OR is_rehired is null
+ BY gender, is_rehired
+| sort gender, is_rehired
+;
+
+m:d |m_f:d |gender:s|is_rehired:bool
+2.1 |null |F |false
+2.1 |null |F |true
+1.85|2.85 |F |null
+2.1 |3.1 |M |false
+2.1 |3.1 |M |true
+2.01|3.01 |M |null
+2.06|null |null |false
+1.97|null |null |true
+1.99|2.99 |null |null
+;
+
+statsWithFilterOnGroupBy
+required_capability: per_agg_filtering
+from employees
+| stats m_f = max(height) where gender == "M" BY gender
+| sort gender
+;
+
+m_f:d |gender:s
+null |F
+2.1 |M
+null |null
+;
+
+statsWithGroupByLiteral
+required_capability: per_agg_filtering
+from employees
+| stats m = max(languages) by salary = 2
+;
+
+m:i |salary:i
+5 |2
+;
+
+
+statsWithFilterOnSameColumn
+required_capability: per_agg_filtering
+from employees
+| stats m = max(languages), m_f = max(languages) where salary > 50000 by salary = 2
+| sort salary
+;
+
+m:i |m_f:i |salary:i
+5 |null |2
+;
+
+# the query is reused below in a multi-stats
+statsWithFilteringAndGrouping
+required_capability: per_agg_filtering
+from employees
+| stats c = count(), c_f = count(languages) where l > 1,
+ m_f = max(height) where salary > 50000
+ by l = languages
+| sort c
+;
+
+c:l |c_f:l |m_f:d |l:i
+10 |0 |2.08 |null
+15 |0 |2.06 |1
+17 |17 |2.1 |3
+18 |18 |1.83 |4
+19 |19 |2.03 |2
+21 |21 |2.1 |5
+;
+
+multiStatsWithFiltering
+required_capability: per_agg_filtering
+from employees
+| stats c = count(), c_f = count(languages) where l > 1,
+ m_f = max(height) where salary > 50000
+ by l = languages
+| stats c2 = count(), c2_f = count() where m_f > 2.06 , m2 = max(l), m2_f = max(l) where l > 1 by c
+| sort c
+;
+
+c2:l |c2_f:l |m2:i |m2_f:i |c:l
+1 |1 |null |null |10
+1 |0 |1 |null |15
+1 |1 |3 |3 |17
+1 |0 |4 |4 |18
+1 |0 |2 |2 |19
+1 |1 |5 |5 |21
+;
diff --git a/x-pack/plugin/esql/src/main/antlr/EsqlBaseLexer.g4 b/x-pack/plugin/esql/src/main/antlr/EsqlBaseLexer.g4
index d6d45097a1d07..b13606befd2a4 100644
--- a/x-pack/plugin/esql/src/main/antlr/EsqlBaseLexer.g4
+++ b/x-pack/plugin/esql/src/main/antlr/EsqlBaseLexer.g4
@@ -209,6 +209,7 @@ SLASH : '/';
PERCENT : '%';
MATCH : 'match';
+NESTED_WHERE : {this.isDevVersion()}? WHERE -> type(WHERE);
NAMED_OR_POSITIONAL_PARAM
: PARAM (LETTER | UNDERSCORE) UNQUOTED_ID_BODY*
diff --git a/x-pack/plugin/esql/src/main/antlr/EsqlBaseParser.g4 b/x-pack/plugin/esql/src/main/antlr/EsqlBaseParser.g4
index 77568d5527cd1..9a95e0e6726ba 100644
--- a/x-pack/plugin/esql/src/main/antlr/EsqlBaseParser.g4
+++ b/x-pack/plugin/esql/src/main/antlr/EsqlBaseParser.g4
@@ -123,8 +123,7 @@ fields
;
field
- : booleanExpression
- | qualifiedName ASSIGN booleanExpression
+ : (qualifiedName ASSIGN)? booleanExpression
;
fromCommand
@@ -132,8 +131,7 @@ fromCommand
;
indexPattern
- : clusterString COLON indexString
- | indexString
+ : (clusterString COLON)? indexString
;
clusterString
@@ -159,7 +157,7 @@ deprecated_metadata
;
metricsCommand
- : DEV_METRICS indexPattern (COMMA indexPattern)* aggregates=fields? (BY grouping=fields)?
+ : DEV_METRICS indexPattern (COMMA indexPattern)* aggregates=aggFields? (BY grouping=fields)?
;
evalCommand
@@ -167,7 +165,15 @@ evalCommand
;
statsCommand
- : STATS stats=fields? (BY grouping=fields)?
+ : STATS stats=aggFields? (BY grouping=fields)?
+ ;
+
+aggFields
+ : aggField (COMMA aggField)*
+ ;
+
+aggField
+ : field {this.isDevVersion()}? (WHERE booleanExpression)?
;
qualifiedName
@@ -316,5 +322,5 @@ lookupCommand
;
inlinestatsCommand
- : DEV_INLINESTATS stats=fields (BY grouping=fields)?
+ : DEV_INLINESTATS stats=aggFields (BY grouping=fields)?
;
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
index 9dc17b020e426..f5baaef4f579d 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
@@ -370,7 +370,12 @@ public enum Cap {
/**
* Fix sorting not allowed on _source and counters.
*/
- SORTING_ON_SOURCE_AND_COUNTERS_FORBIDDEN;
+ SORTING_ON_SOURCE_AND_COUNTERS_FORBIDDEN,
+
+ /**
+ * Allow filter per individual aggregation.
+ */
+ PER_AGG_FILTERING;
private final boolean snapshotOnly;
private final FeatureFlag featureFlag;
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
index 90957f55141b9..fe7b945a9b3c1 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
@@ -488,6 +488,7 @@ private LogicalPlan resolveStats(Stats stats, List childrenOutput) {
newAggregates.add(agg);
}
+ // TODO: remove this when Stats interface is removed
stats = changed.get() ? stats.with(stats.child(), groupings, newAggregates) : stats;
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java
index dd2b72b4d35d9..ef39220d7ffcc 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java
@@ -30,6 +30,7 @@
import org.elasticsearch.xpack.esql.core.util.Holder;
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
+import org.elasticsearch.xpack.esql.expression.function.aggregate.FilteredExpression;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate;
import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction;
import org.elasticsearch.xpack.esql.expression.function.fulltext.Match;
@@ -308,6 +309,29 @@ private static void checkInvalidNamedExpressionUsage(
Set failures,
int level
) {
+ // unwrap filtered expression
+ if (e instanceof FilteredExpression fe) {
+ e = fe.delegate();
+ // make sure they work on aggregate functions
+ if (e.anyMatch(AggregateFunction.class::isInstance) == false) {
+ Expression filter = fe.filter();
+ failures.add(fail(filter, "WHERE clause allowed only for aggregate functions, none found in [{}]", fe.sourceText()));
+ }
+ // but that the filter doesn't use grouping or aggregate functions
+ fe.filter().forEachDown(c -> {
+ if (c instanceof AggregateFunction af) {
+ failures.add(
+ fail(af, "cannot use aggregate function [{}] in aggregate WHERE clause [{}]", af.sourceText(), fe.sourceText())
+ );
+ }
+ // check the bucketing function against the group
+ else if (c instanceof GroupingFunction gf) {
+ if (Expressions.anyMatch(groups, ex -> ex instanceof Alias a && a.child().semanticEquals(gf)) == false) {
+ failures.add(fail(gf, "can only use grouping function [{}] part of the BY clause", gf.sourceText()));
+ }
+ }
+ });
+ }
// found an aggregate, constant or a group, bail out
if (e instanceof AggregateFunction af) {
af.field().forEachDown(AggregateFunction.class, f -> {
@@ -319,7 +343,7 @@ private static void checkInvalidNamedExpressionUsage(
} else if (e instanceof GroupingFunction gf) {
// optimizer will later unroll expressions with aggs and non-aggs with a grouping function into an EVAL, but that will no longer
// be verified (by check above in checkAggregate()), so do it explicitly here
- if (groups.stream().anyMatch(ex -> ex instanceof Alias a && a.child().semanticEquals(gf)) == false) {
+ if (Expressions.anyMatch(groups, ex -> ex instanceof Alias a && a.child().semanticEquals(gf)) == false) {
failures.add(fail(gf, "can only use grouping function [{}] part of the BY clause", gf.sourceText()));
} else if (level == 0) {
addFailureOnGroupingUsedNakedInAggs(failures, gf, "function");
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
index faf99d6bd65bc..66151275fc2e8 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
@@ -259,19 +259,21 @@ private FunctionDefinition[][] functions() {
// grouping functions
new FunctionDefinition[] { def(Bucket.class, Bucket::new, "bucket", "bin"), },
// aggregate functions
+ // since they declare two public constructors - one with filter (for nested where) and one without
+ // use casting to disambiguate between the two
new FunctionDefinition[] {
- def(Avg.class, Avg::new, "avg"),
- def(Count.class, Count::new, "count"),
- def(CountDistinct.class, CountDistinct::new, "count_distinct"),
- def(Max.class, Max::new, "max"),
- def(Median.class, Median::new, "median"),
- def(MedianAbsoluteDeviation.class, MedianAbsoluteDeviation::new, "median_absolute_deviation"),
- def(Min.class, Min::new, "min"),
- def(Percentile.class, Percentile::new, "percentile"),
- def(Sum.class, Sum::new, "sum"),
- def(Top.class, Top::new, "top"),
- def(Values.class, Values::new, "values"),
- def(WeightedAvg.class, WeightedAvg::new, "weighted_avg") },
+ def(Avg.class, uni(Avg::new), "avg"),
+ def(Count.class, uni(Count::new), "count"),
+ def(CountDistinct.class, bi(CountDistinct::new), "count_distinct"),
+ def(Max.class, uni(Max::new), "max"),
+ def(Median.class, uni(Median::new), "median"),
+ def(MedianAbsoluteDeviation.class, uni(MedianAbsoluteDeviation::new), "median_absolute_deviation"),
+ def(Min.class, uni(Min::new), "min"),
+ def(Percentile.class, bi(Percentile::new), "percentile"),
+ def(Sum.class, uni(Sum::new), "sum"),
+ def(Top.class, tri(Top::new), "top"),
+ def(Values.class, uni(Values::new), "values"),
+ def(WeightedAvg.class, bi(WeightedAvg::new), "weighted_avg") },
// math
new FunctionDefinition[] {
def(Abs.class, Abs::new, "abs"),
@@ -482,11 +484,10 @@ public static DataType getTargetType(String[] names) {
}
public static FunctionDescription description(FunctionDefinition def) {
- var constructors = def.clazz().getConstructors();
- if (constructors.length == 0) {
+ Constructor> constructor = constructorFor(def.clazz());
+ if (constructor == null) {
return new FunctionDescription(def.name(), List.of(), null, null, false, false);
}
- Constructor> constructor = constructors[0];
FunctionInfo functionInfo = functionInfo(def);
String functionDescription = functionInfo == null ? "" : functionInfo.description().replace('\n', ' ');
String[] returnType = functionInfo == null ? new String[] { "?" } : removeUnderConstruction(functionInfo.returnType());
@@ -523,14 +524,29 @@ private static String[] removeUnderConstruction(String[] types) {
}
public static FunctionInfo functionInfo(FunctionDefinition def) {
- var constructors = def.clazz().getConstructors();
- if (constructors.length == 0) {
+ Constructor> constructor = constructorFor(def.clazz());
+ if (constructor == null) {
return null;
}
- Constructor> constructor = constructors[0];
return constructor.getAnnotation(FunctionInfo.class);
}
+ private static Constructor> constructorFor(Class extends Function> clazz) {
+ Constructor>[] constructors = clazz.getConstructors();
+ if (constructors.length == 0) {
+ return null;
+ }
+ // when dealing with multiple, pick the constructor exposing the FunctionInfo annotation
+ if (constructors.length > 1) {
+ for (Constructor> constructor : constructors) {
+ if (constructor.getAnnotation(FunctionInfo.class) != null) {
+ return constructor;
+ }
+ }
+ }
+ return constructors[0];
+ }
+
private void buildDataTypesForStringLiteralConversion(FunctionDefinition[]... groupFunctions) {
for (FunctionDefinition[] group : groupFunctions) {
for (FunctionDefinition def : group) {
@@ -913,15 +929,19 @@ protected interface TernaryConfigurationAwareBuilder {
}
//
- // Utility method for extra argument extraction.
+ // Utility functions to help disambiguate the method handle passed in.
+ // They work by providing additional method information to help the compiler know which method to pick.
//
- protected static Boolean asBool(Object[] extras) {
- if (CollectionUtils.isEmpty(extras)) {
- return null;
- }
- if (extras.length != 1 || (extras[0] instanceof Boolean) == false) {
- throw new QlIllegalArgumentException("Invalid number and types of arguments given to function definition");
- }
- return (Boolean) extras[0];
+ private static BiFunction
*/
@Override public T visitStatsCommand(EsqlBaseParser.StatsCommandContext ctx) { return visitChildren(ctx); }
+ /**
+ * {@inheritDoc}
+ *
+ * The default implementation returns the result of calling
+ * {@link #visitChildren} on {@code ctx}.
+ */
+ @Override public T visitAggFields(EsqlBaseParser.AggFieldsContext ctx) { return visitChildren(ctx); }
+ /**
+ * {@inheritDoc}
+ *
+ * The default implementation returns the result of calling
+ * {@link #visitChildren} on {@code ctx}.
+ */
+ @Override public T visitAggField(EsqlBaseParser.AggFieldContext ctx) { return visitChildren(ctx); }
/**
* {@inheritDoc}
*
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java
index c6dcaca736e1f..cf658c4a73141 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java
@@ -465,6 +465,26 @@ public interface EsqlBaseParserListener extends ParseTreeListener {
* @param ctx the parse tree
*/
void exitStatsCommand(EsqlBaseParser.StatsCommandContext ctx);
+ /**
+ * Enter a parse tree produced by {@link EsqlBaseParser#aggFields}.
+ * @param ctx the parse tree
+ */
+ void enterAggFields(EsqlBaseParser.AggFieldsContext ctx);
+ /**
+ * Exit a parse tree produced by {@link EsqlBaseParser#aggFields}.
+ * @param ctx the parse tree
+ */
+ void exitAggFields(EsqlBaseParser.AggFieldsContext ctx);
+ /**
+ * Enter a parse tree produced by {@link EsqlBaseParser#aggField}.
+ * @param ctx the parse tree
+ */
+ void enterAggField(EsqlBaseParser.AggFieldContext ctx);
+ /**
+ * Exit a parse tree produced by {@link EsqlBaseParser#aggField}.
+ * @param ctx the parse tree
+ */
+ void exitAggField(EsqlBaseParser.AggFieldContext ctx);
/**
* Enter a parse tree produced by {@link EsqlBaseParser#qualifiedName}.
* @param ctx the parse tree
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java
index 310d3dc76dd6d..86c1d1aafc33a 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java
@@ -284,6 +284,18 @@ public interface EsqlBaseParserVisitor extends ParseTreeVisitor {
* @return the visitor result
*/
T visitStatsCommand(EsqlBaseParser.StatsCommandContext ctx);
+ /**
+ * Visit a parse tree produced by {@link EsqlBaseParser#aggFields}.
+ * @param ctx the parse tree
+ * @return the visitor result
+ */
+ T visitAggFields(EsqlBaseParser.AggFieldsContext ctx);
+ /**
+ * Visit a parse tree produced by {@link EsqlBaseParser#aggField}.
+ * @param ctx the parse tree
+ * @return the visitor result
+ */
+ T visitAggField(EsqlBaseParser.AggFieldContext ctx);
/**
* Visit a parse tree produced by {@link EsqlBaseParser#qualifiedName}.
* @param ctx the parse tree
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java
index 42a1ad6de5224..cda118adb19e6 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java
@@ -26,6 +26,7 @@
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedStar;
+import org.elasticsearch.xpack.esql.core.expression.function.Function;
import org.elasticsearch.xpack.esql.core.expression.predicate.fulltext.MatchQueryPredicate;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Not;
@@ -44,6 +45,7 @@
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
import org.elasticsearch.xpack.esql.expression.function.FunctionResolutionStrategy;
import org.elasticsearch.xpack.esql.expression.function.UnresolvedFunction;
+import org.elasticsearch.xpack.esql.expression.function.aggregate.FilteredExpression;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.RLike;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.WildcardLike;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add;
@@ -742,9 +744,12 @@ private NamedExpression enrichFieldName(EsqlBaseParser.QualifiedNamePatternConte
@Override
public Alias visitField(EsqlBaseParser.FieldContext ctx) {
+ return visitField(ctx, source(ctx));
+ }
+
+ private Alias visitField(EsqlBaseParser.FieldContext ctx, Source source) {
UnresolvedAttribute id = visitQualifiedName(ctx.qualifiedName());
Expression value = expression(ctx.booleanExpression());
- var source = source(ctx);
String name = id == null ? source.text() : id.name();
return new Alias(source, name, value);
}
@@ -754,6 +759,36 @@ public List visitFields(EsqlBaseParser.FieldsContext ctx) {
return ctx != null ? visitList(this, ctx.field(), Alias.class) : new ArrayList<>();
}
+ @Override
+ public NamedExpression visitAggField(EsqlBaseParser.AggFieldContext ctx) {
+ Source source = source(ctx);
+ Alias field = visitField(ctx.field(), source);
+ var filterExpression = ctx.booleanExpression();
+
+ if (filterExpression != null) {
+ Expression condition = expression(filterExpression);
+ Expression child = field.child();
+ // basic check as the filter can be specified only on a function (should be an aggregate but we can't determine that yet)
+ if (field.child().anyMatch(Function.class::isInstance)) {
+ field = field.replaceChild(new FilteredExpression(field.source(), child, condition));
+ }
+ // allow condition only per aggregated function
+ else {
+ throw new ParsingException(
+ condition.source(),
+ "WHERE clause allowed only for aggregate functions [{}]",
+ field.sourceText()
+ );
+ }
+ }
+ return field;
+ }
+
+ @Override
+ public List visitAggFields(EsqlBaseParser.AggFieldsContext ctx) {
+ return ctx != null ? visitList(this, ctx.aggField(), Alias.class) : new ArrayList<>();
+ }
+
/**
* Similar to {@link #visitFields(EsqlBaseParser.FieldsContext)} however avoids wrapping the expression
* into an Alias.
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java
index c90c3cba4ef24..dc913cd2f14f4 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java
@@ -298,13 +298,12 @@ public PlanFactory visitStatsCommand(EsqlBaseParser.StatsCommandContext ctx) {
return input -> new Aggregate(source(ctx), input, Aggregate.AggregateType.STANDARD, stats.groupings, stats.aggregates);
}
- private record Stats(List groupings, List extends NamedExpression> aggregates) {
+ private record Stats(List groupings, List extends NamedExpression> aggregates) {}
- }
-
- private Stats stats(Source source, EsqlBaseParser.FieldsContext groupingsCtx, EsqlBaseParser.FieldsContext aggregatesCtx) {
+ private Stats stats(Source source, EsqlBaseParser.FieldsContext groupingsCtx, EsqlBaseParser.AggFieldsContext aggregatesCtx) {
List groupings = visitGrouping(groupingsCtx);
- List aggregates = new ArrayList<>(visitFields(aggregatesCtx));
+ List aggregates = new ArrayList<>(visitAggFields(aggregatesCtx));
+
if (aggregates.isEmpty() && groupings.isEmpty()) {
throw new ParsingException(source, "At least one aggregation or grouping expression required in [{}]", source.text());
}
@@ -341,9 +340,11 @@ public PlanFactory visitInlinestatsCommand(EsqlBaseParser.InlinestatsCommandCont
if (false == EsqlPlugin.INLINESTATS_FEATURE_FLAG.isEnabled()) {
throw new ParsingException(source(ctx), "INLINESTATS command currently requires a snapshot build");
}
- List aggregates = new ArrayList<>(visitFields(ctx.stats));
+ List aggFields = visitAggFields(ctx.stats);
+ List aggregates = new ArrayList<>(aggFields);
List groupings = visitGrouping(ctx.grouping);
aggregates.addAll(groupings);
+ // TODO: add support for filters
return input -> new InlineStats(source(ctx), input, new ArrayList<>(groupings), aggregates);
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Aggregate.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Aggregate.java
index 8445c8236c45a..3b7240dcd693b 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Aggregate.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Aggregate.java
@@ -59,6 +59,7 @@ static AggregateType readType(StreamInput in) throws IOException {
private final AggregateType aggregateType;
private final List groupings;
private final List extends NamedExpression> aggregates;
+
private List lazyOutput;
public Aggregate(
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java
index 0e71963e29270..94a9246a56f83 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java
@@ -10,10 +10,12 @@
import org.elasticsearch.compute.aggregation.Aggregator;
import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.AggregatorMode;
+import org.elasticsearch.compute.aggregation.FilteredAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.GroupingAggregator;
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.operator.AggregationOperator;
+import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.compute.operator.HashAggregationOperator.HashAggregationOperatorFactory;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
@@ -24,6 +26,7 @@
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.NameId;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
+import org.elasticsearch.xpack.esql.evaluator.EvalMapper;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
@@ -231,11 +234,14 @@ private void aggregatesToFactory(
boolean grouping,
Consumer consumer
) {
+ // extract filtering channels - and wrap the aggregation with the new evaluator expression only during the init phase
for (NamedExpression ne : aggregates) {
+ // a filter can only appear on aggregate function, not on the grouping columns
+
if (ne instanceof Alias alias) {
var child = alias.child();
if (child instanceof AggregateFunction aggregateFunction) {
- List extends NamedExpression> sourceAttr;
+ List sourceAttr = new ArrayList<>();
if (mode == AggregatorMode.INITIAL) {
// TODO: this needs to be made more reliable - use casting to blow up when dealing with expressions (e+1)
@@ -251,19 +257,22 @@ private void aggregatesToFactory(
);
}
} else {
- sourceAttr = aggregateFunction.inputExpressions().stream().map(e -> {
- Attribute attr = Expressions.attribute(e);
+ // extra dependencies like TS ones (that require a timestamp)
+ for (Expression input : aggregateFunction.references()) {
+ Attribute attr = Expressions.attribute(input);
if (attr == null) {
throw new EsqlIllegalArgumentException(
"Cannot work with target field [{}] for agg [{}]",
- e.sourceText(),
+ input.sourceText(),
aggregateFunction.sourceText()
);
}
- return attr;
- }).toList();
+ sourceAttr.add(attr);
+ }
}
- } else if (mode == AggregatorMode.FINAL || mode == AggregatorMode.INTERMEDIATE) {
+ }
+ // coordinator/exchange phase
+ else if (mode == AggregatorMode.FINAL || mode == AggregatorMode.INTERMEDIATE) {
if (grouping) {
sourceAttr = aggregateMapper.mapGrouping(aggregateFunction);
} else {
@@ -274,16 +283,27 @@ private void aggregatesToFactory(
}
List inputChannels = sourceAttr.stream().map(attr -> layout.get(attr.id()).channel()).toList();
assert inputChannels.stream().allMatch(i -> i >= 0) : inputChannels;
- if (aggregateFunction instanceof ToAggregator agg) {
- consumer.accept(new AggFunctionSupplierContext(agg.supplier(inputChannels), mode));
- } else {
- throw new EsqlIllegalArgumentException("aggregate functions must extend ToAggregator");
+
+ AggregatorFunctionSupplier aggSupplier = supplier(aggregateFunction, inputChannels);
+
+ // apply the filter only in the initial phase - as the rest of the data is already filtered
+ if (aggregateFunction.hasFilter() && mode.isInputPartial() == false) {
+ EvalOperator.ExpressionEvaluator.Factory evalFactory = EvalMapper.toEvaluator(aggregateFunction.filter(), layout);
+ aggSupplier = new FilteredAggregatorFunctionSupplier(aggSupplier, evalFactory);
}
+ consumer.accept(new AggFunctionSupplierContext(aggSupplier, mode));
}
}
}
}
+ private static AggregatorFunctionSupplier supplier(AggregateFunction aggregateFunction, List inputChannels) {
+ if (aggregateFunction instanceof ToAggregator delegate) {
+ return delegate.supplier(inputChannels);
+ }
+ throw new EsqlIllegalArgumentException("aggregate functions must extend ToAggregator");
+ }
+
private record GroupSpec(Integer channel, Attribute attribute) {
BlockHash.GroupSpec toHashGroupSpec() {
if (channel == null) {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java
index 13ce9ba77cc71..c322135198262 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java
@@ -98,39 +98,39 @@ private record AggDef(Class> aggClazz, String type, String extra, boolean grou
.collect(Collectors.toUnmodifiableMap(aggDef -> aggDef, AggregateMapper::lookupIntermediateState));
/** Cache of aggregates to intermediate expressions. */
- private final HashMap> cache;
+ private final HashMap> cache;
AggregateMapper() {
cache = new HashMap<>();
}
- public List extends NamedExpression> mapNonGrouping(List extends Expression> aggregates) {
+ public List mapNonGrouping(List extends Expression> aggregates) {
return doMapping(aggregates, false);
}
- public List extends NamedExpression> mapNonGrouping(Expression aggregate) {
+ public List mapNonGrouping(Expression aggregate) {
return map(aggregate, false).toList();
}
- public List extends NamedExpression> mapGrouping(List extends Expression> aggregates) {
+ public List mapGrouping(List extends Expression> aggregates) {
return doMapping(aggregates, true);
}
- private List extends NamedExpression> doMapping(List extends Expression> aggregates, boolean grouping) {
+ private List doMapping(List extends Expression> aggregates, boolean grouping) {
AttributeMap attrToExpressions = new AttributeMap<>();
aggregates.stream().flatMap(agg -> map(agg, grouping)).forEach(ne -> attrToExpressions.put(ne.toAttribute(), ne));
return attrToExpressions.values().stream().toList();
}
- public List extends NamedExpression> mapGrouping(Expression aggregate) {
+ public List mapGrouping(Expression aggregate) {
return map(aggregate, true).toList();
}
- private Stream extends NamedExpression> map(Expression aggregate, boolean grouping) {
+ private Stream map(Expression aggregate, boolean grouping) {
return cache.computeIfAbsent(Alias.unwrap(aggregate), aggKey -> computeEntryForAgg(aggKey, grouping)).stream();
}
- private static List extends NamedExpression> computeEntryForAgg(Expression aggregate, boolean grouping) {
+ private static List computeEntryForAgg(Expression aggregate, boolean grouping) {
var aggDef = aggDefOrNull(aggregate, grouping);
if (aggDef != null) {
var is = getNonNull(aggDef);
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
index f881c0e1a9bba..ce072e7b0a438 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
@@ -67,10 +67,7 @@
import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer;
-import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerContext;
-import org.elasticsearch.xpack.esql.optimizer.PhysicalPlanOptimizer;
import org.elasticsearch.xpack.esql.optimizer.TestLocalPhysicalPlanOptimizer;
-import org.elasticsearch.xpack.esql.optimizer.TestPhysicalPlanOptimizer;
import org.elasticsearch.xpack.esql.parser.EsqlParser;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
@@ -167,7 +164,6 @@ public class CsvTests extends ESTestCase {
private final EsqlFunctionRegistry functionRegistry = new EsqlFunctionRegistry();
private final EsqlParser parser = new EsqlParser();
private final Mapper mapper = new Mapper(functionRegistry);
- private final PhysicalPlanOptimizer physicalPlanOptimizer = new TestPhysicalPlanOptimizer(new PhysicalOptimizerContext(configuration));
private ThreadPool threadPool;
private Executor executor;
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
index 6b13420f4ca67..11c21ad46d3a8 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
@@ -1226,7 +1226,7 @@ public void testAggsOverGroupingKey() throws Exception {
assertThat(output, hasSize(2));
var aggs = agg.aggregates();
var min = as(Alias.unwrap(aggs.get(0)), Min.class);
- assertThat(min.arguments(), hasSize(1));
+ assertThat(min.arguments(), hasSize(2)); // field + filter
var group = Alias.unwrap(agg.groupings().get(0));
assertEquals(min.arguments().get(0), group);
}
@@ -1248,7 +1248,7 @@ public void testAggsOverGroupingKeyWithAlias() throws Exception {
assertThat(output, hasSize(2));
var aggs = agg.aggregates();
var min = as(Alias.unwrap(aggs.get(0)), Min.class);
- assertThat(min.arguments(), hasSize(1));
+ assertThat(min.arguments(), hasSize(2)); // field + filter
assertEquals(Expressions.attribute(min.arguments().get(0)), Expressions.attribute(agg.groupings().get(0)));
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
index ecf012718eaf8..63f7629f3c720 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
@@ -360,6 +360,40 @@ public void testAggsInsideGrouping() {
);
}
+ public void testAggFilterOnNonAggregates() {
+ assertEquals(
+ "1:36: WHERE clause allowed only for aggregate functions, none found in [emp_no + 1 where languages > 1]",
+ error("from test | stats emp_no + 1 where languages > 1 by emp_no")
+ );
+ assertEquals(
+ "1:53: WHERE clause allowed only for aggregate functions, none found in [abs(emp_no + languages) % 2 WHERE languages > 1]",
+ error("from test | stats abs(emp_no + languages) % 2 WHERE languages > 1 by emp_no, languages")
+ );
+ }
+
+ public void testAggFilterOnBucketingOrAggFunctions() {
+ // query passes when the bucket function is part of the BY clause
+ query("from test | stats max(languages) WHERE bucket(salary, 10) > 1 by bucket(salary, 10)");
+
+ // but fails if it's different
+ assertEquals(
+ "1:40: can only use grouping function [bucket(salary, 10)] part of the BY clause",
+ error("from test | stats max(languages) WHERE bucket(salary, 10) > 1 by emp_no")
+ );
+
+ assertEquals(
+ "1:40: cannot use aggregate function [max(salary)] in aggregate WHERE clause [max(languages) WHERE max(salary) > 1]",
+ error("from test | stats max(languages) WHERE max(salary) > 1 by emp_no")
+ );
+
+ assertEquals(
+ "1:40: cannot use aggregate function [max(salary)] in aggregate WHERE clause [max(languages) WHERE max(salary) + 2 > 1]",
+ error("from test | stats max(languages) WHERE max(salary) + 2 > 1 by emp_no")
+ );
+
+ assertEquals("1:60: Unknown column [m]", error("from test | stats m = max(languages), min(languages) WHERE m + 2 > 1 by emp_no"));
+ }
+
public void testGroupingInsideAggsAsAgg() {
assertEquals(
"1:18: can only use grouping function [bucket(emp_no, 5.)] part of the BY clause",
@@ -1507,6 +1541,10 @@ public void testToDatePeriodToTimeDurationWithInvalidType() {
);
}
+ private void query(String query) {
+ defaultAnalyzer.analyze(parser.createStatement(query));
+ }
+
private String error(String query) {
return error(query, defaultAnalyzer);
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/RateSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/RateSerializationTests.java
index 94b2a81b308d7..ea7c480817317 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/RateSerializationTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/RateSerializationTests.java
@@ -36,4 +36,9 @@ protected Rate mutateInstance(Rate instance) throws IOException {
}
return new Rate(source, field, timestamp, unit);
}
+
+ @Override
+ protected boolean alwaysEmptySource() {
+ return true;
+ }
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopSerializationTests.java
index 82bf57d1a194e..e74b26c87c84f 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopSerializationTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopSerializationTests.java
@@ -36,4 +36,9 @@ protected Top mutateInstance(Top instance) throws IOException {
}
return new Top(source, field, limit, order);
}
+
+ @Override
+ protected boolean alwaysEmptySource() {
+ return true;
+ }
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java
index c05b5dd165485..8d7c1997f78e3 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java
@@ -537,6 +537,24 @@ public void testCombineProjectionWithDuplicateAggregation() {
assertThat(Expressions.names(agg.groupings()), contains("last_name", "first_name"));
}
+ /**
+ * Limit[1000[INTEGER]]
+ * \_Aggregate[STANDARD,[],[SUM(salary{f}#12,true[BOOLEAN]) AS sum(salary), SUM(salary{f}#12,last_name{f}#11 == [44 6f 65][KEYW
+ * ORD]) AS sum(salary) WheRe last_name == "Doe"]]
+ * \_EsRelation[test][_meta_field{f}#13, emp_no{f}#7, first_name{f}#8, ge..]
+ */
+ public void testStatsWithFilteringDefaultAliasing() {
+ var plan = plan("""
+ from test
+ | stats sum(salary), sum(salary) WheRe last_name == "Doe"
+ """);
+
+ var limit = as(plan, Limit.class);
+ var agg = as(limit.child(), Aggregate.class);
+ assertThat(agg.aggregates(), hasSize(2));
+ assertThat(Expressions.names(agg.aggregates()), contains("sum(salary)", "sum(salary) WheRe last_name == \"Doe\""));
+ }
+
public void testQlComparisonOptimizationsApply() {
var plan = plan("""
from test
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/ExpressionTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/ExpressionTests.java
index 80a2d49d0d94a..67b4dd71260aa 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/ExpressionTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/ExpressionTests.java
@@ -208,7 +208,7 @@ public void testParenthesizedExpression() {
}
public void testCommandNamesAsIdentifiers() {
- Expression expr = whereExpression("from and where");
+ Expression expr = whereExpression("from and limit");
assertThat(expr, instanceOf(And.class));
And and = (And) expr;
@@ -216,7 +216,7 @@ public void testCommandNamesAsIdentifiers() {
assertThat(((UnresolvedAttribute) and.left()).name(), equalTo("from"));
assertThat(and.right(), instanceOf(UnresolvedAttribute.class));
- assertThat(((UnresolvedAttribute) and.right()).name(), equalTo("where"));
+ assertThat(((UnresolvedAttribute) and.right()).name(), equalTo("limit"));
}
public void testIdentifiersCaseSensitive() {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java
index 53621a79aedac..c797f426d2ae5 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java
@@ -20,14 +20,18 @@
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Not;
+import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or;
import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.Order;
import org.elasticsearch.xpack.esql.expression.UnresolvedNamePattern;
import org.elasticsearch.xpack.esql.expression.function.UnresolvedFunction;
+import org.elasticsearch.xpack.esql.expression.function.aggregate.FilteredExpression;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.RLike;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.WildcardLike;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add;
+import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
+import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mod;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual;
@@ -321,6 +325,61 @@ public void testAggsWithGroupKeyAsAgg() throws Exception {
}
}
+ public void testStatsWithGroupKeyAndAggFilter() throws Exception {
+ var a = attribute("a");
+ var f = new UnresolvedFunction(EMPTY, "min", DEFAULT, List.of(a));
+ var filter = new Alias(EMPTY, "min(a) where a > 1", new FilteredExpression(EMPTY, f, new GreaterThan(EMPTY, a, integer(1))));
+ assertEquals(
+ new Aggregate(EMPTY, PROCESSING_CMD_INPUT, Aggregate.AggregateType.STANDARD, List.of(a), List.of(filter, a)),
+ processingCommand("stats min(a) where a > 1 by a")
+ );
+ }
+
+ public void testStatsWithGroupKeyAndMixedAggAndFilter() throws Exception {
+ var a = attribute("a");
+ var min = new UnresolvedFunction(EMPTY, "min", DEFAULT, List.of(a));
+ var max = new UnresolvedFunction(EMPTY, "max", DEFAULT, List.of(a));
+ var avg = new UnresolvedFunction(EMPTY, "avg", DEFAULT, List.of(a));
+ var min_alias = new Alias(EMPTY, "min", min);
+
+ var max_filter_ex = new Or(
+ EMPTY,
+ new GreaterThan(EMPTY, new Mod(EMPTY, a, integer(3)), integer(10)),
+ new GreaterThan(EMPTY, new Div(EMPTY, a, integer(2)), integer(100))
+ );
+ var max_filter = new Alias(EMPTY, "max", new FilteredExpression(EMPTY, max, max_filter_ex));
+
+ var avg_filter_ex = new GreaterThan(EMPTY, new Div(EMPTY, a, integer(2)), integer(100));
+ var avg_filter = new Alias(EMPTY, "avg", new FilteredExpression(EMPTY, avg, avg_filter_ex));
+
+ assertEquals(
+ new Aggregate(
+ EMPTY,
+ PROCESSING_CMD_INPUT,
+ Aggregate.AggregateType.STANDARD,
+ List.of(a),
+ List.of(min_alias, max_filter, avg_filter, a)
+ ),
+ processingCommand("""
+ stats
+ min = min(a),
+ max = max(a) WHERE (a % 3 > 10 OR a / 2 > 100),
+ avg = avg(a) WHERE a / 2 > 100
+ BY a
+ """)
+ );
+ }
+
+ public void testStatsWithoutGroupKeyMixedAggAndFilter() throws Exception {
+ var a = attribute("a");
+ var f = new UnresolvedFunction(EMPTY, "min", DEFAULT, List.of(a));
+ var filter = new Alias(EMPTY, "min(a) where a > 1", new FilteredExpression(EMPTY, f, new GreaterThan(EMPTY, a, integer(1))));
+ assertEquals(
+ new Aggregate(EMPTY, PROCESSING_CMD_INPUT, Aggregate.AggregateType.STANDARD, List.of(), List.of(filter)),
+ processingCommand("stats min(a) where a > 1")
+ );
+ }
+
public void testInlineStatsWithGroups() {
var query = "inlinestats b = min(a) by c, d.e";
if (Build.current().isSnapshot() == false) {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java
index d186b4c199d77..7075c9fe58d63 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java
@@ -21,6 +21,8 @@
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.Literal;
+import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
+import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttributeTests;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedNamedExpression;
@@ -164,6 +166,16 @@ public void testInfoParameters() throws Exception {
* in the parameters and not included.
*/
expectedCount -= 1;
+
+ // special exceptions with private constructors
+ if (MetadataAttribute.class.equals(subclass) || ReferenceAttribute.class.equals(subclass)) {
+ expectedCount++;
+ }
+
+ if (FieldAttribute.class.equals(subclass)) {
+ expectedCount += 2;
+ }
+
assertEquals(expectedCount, info(node).properties().size());
}
@@ -174,6 +186,9 @@ public void testInfoParameters() throws Exception {
* implementations in the process.
*/
public void testTransform() throws Exception {
+ if (FieldAttribute.class.equals(subclass)) {
+ assumeTrue("FieldAttribute private constructor", false);
+ }
Constructor ctor = longestCtor(subclass);
Object[] nodeCtorArgs = ctorArgs(ctor);
T node = ctor.newInstance(nodeCtorArgs);