Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ESQL: Introduce per agg filter #113735

Merged
merged 25 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/changelog/113735.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
pr: 113735
summary: "ESQL: Introduce per agg filter"
area: ES|QL
type: feature
issues: []
highlight:
title: "ESQL: Introduce per agg filter"
body: "Add support for aggregation scoped filters that work dynamically on the \n\
data in each group.\n\n```\n| STATS\n success = COUNT(*) WHERE 200 <= code\
\ AND code < 300,\n redirect = COUNT(*) WHERE 300 <= code AND code < 400,\n\
\ client_err = COUNT(*) WHERE 400 <= code AND code < 500,\n server_err = COUNT(*)\
\ WHERE 500 <= code AND code < 600,\ntotal_count = COUNT(*)\n```\n\nImplementation\
\ wise, the base AggregateFunction has been extended to \nallow a filter to be\
\ passed on. This is required to incorporate the \nfilter as part of the aggregate\
\ equality/identify which would fail with \nthe filter as an external component.\n\
alex-spies marked this conversation as resolved.
Show resolved Hide resolved
\nAs part of the process, the serialization for the existing aggregations \nhad\
\ to be fixed so AggregateFunction implementations so that it \ndelegates to their\
\ parent first."
notable: true
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ADD_DATA_STREAM_OPTIONS = def(8_754_00_0);
public static final TransportVersion CCS_REMOTE_TELEMETRY_STATS = def(8_755_00_0);
public static final TransportVersion ESQL_CCS_EXECUTION_INFO = def(8_756_00_0);
public static final TransportVersion ESQL_PER_AGGREGATE_FILTER = def(8_757_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thought: I think it's going to be important to teach users with good examples that

| STATS s = sum(field) WHERE field > 0

is fundamentally different from

| STATS s = sum(field) | WHERE s > 0

I expect users will try stuff like

| STATS s = sum(field) WHERE field > 0 AND s > 0

which is invalid.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expect users will try stuff like
| STATS s = sum(field) WHERE field > 0 AND s > 0
which is invalid.

This isn't currently verified (which is what the comment in Verifier is about). But wondering if this type of aggregation (with no group) should be invalid at all.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand correctly; but

| STATS s = sum(field) WHERE field > 0 AND s > 0

is invalid because the WHERE inside the aggregation refers to the result of the aggregation. To be correct, the second part of the predicate needs to be moved into a separate WHERE command:

| STATS s = sum(field) WHERE field > 0 | WHERE s > 0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Constant aggregations with WHERE are not tested and currently incorrect, at least for those that rely on COUNT(*), like SUM.

Reproducer:

POST /test/_doc?refresh'
{"a" : 1}

POST /test/_doc?refresh'
{"a" : 2}

POST /test/_doc?refresh'
{"a" : 3}

POST /test/_doc?refresh'
{"a" : 4}

POST /_query
{"query": "from test | stats sum(1) where a > 3"}

The query should return 1 as only 1 row satisfies a > 3, but it returns 4. That's because we don't propagate the filter into the COUNT(*) in Sum.surrogate.

Copy link
Contributor

@bpintea bpintea Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't propagate the filter into the COUNT(*) in Sum.surrogate

There might be more to it, since
from test | stats count(*) where a > 3 also fails (and same from test | stats count(a) where a > 3), but from test | stats c_gt3 = count(*) where a > 3, c_lt3 = count(*) where a < 3 works (and that's how the results for the last added test are actually correct).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's because there's no true folding of the aggregation but instead we try to rewrite them into MV_functions + case.
Which fails short when having to evaluate the filter (which might be foldable or not).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a side note, when the filters are foldable, they evaluate to either true (essentially discarded) or false meaning the agg won't run and can be folded to its initial value, 0 for count and null for the other aggs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, let's address the const case later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened #115522 to call this out explicitly as it's actually producing incorrect results.

alex-spies marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More things to test: BUCKET in the WHERE clause. Seems to work for this very simple query:

row a = 1 | stats sum(a) where bucket(a, 2) > -1 by bucket(a,2)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch however I believe this is a bug. Grouping functions should only be allowed in the BY clause - here it could work it's the same as a grouping however I tried it and we allow a bucket with different field and argument.
/cc @bpintea

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried it and we allow a bucket with different field and argument

The semantic equality for BUCKETs is currently run over the aggs in STATS against the GroupingFunctions in its BY clause. But since checkInvalidNamedExpressionUsage() isn't run on all AggregateFunction children and since the filter-on-agg is now added within the AggregateFunction itself, the condition isn't caught.
So this the check needs to be extended. But IMO it should either be a follow-up of this PR, or part of it, I don't think it's a general issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is a bug
This still works and didn't find an existing issue, so opened #115521

Original file line number Diff line number Diff line change
Expand Up @@ -2290,3 +2290,60 @@ from employees
m:integer |a:double |x:integer
74999 |48249.0 |0
;


statsWithFiltering#[skip:-8.16.0,reason:implemented in 8.16]
alex-spies marked this conversation as resolved.
Show resolved Hide resolved
from employees
| stats max = max(salary), max_f = max(salary) where salary < 50000, max_a = max(salary) where salary > 100,
min = min(salary), min_f = min(salary) where salary > 50000, min_a = min(salary) where salary > 100
;

max:integer |max_f:integer |max_a:integer | min:integer | min_f:integer | min_a:integer
74999 |49818 |74999 | 25324 | 50064 | 25324
;

statsWithEverythingFiltered#[skip:-8.16.0,reason:implemented in 8.16]
from employees
| stats max = max(salary), max_a = max(salary) where salary < 100,
alex-spies marked this conversation as resolved.
Show resolved Hide resolved
min = min(salary), min_a = min(salary) where salary > 99999
;

max:integer |max_a:integer|min:integer | min_a:integer
74999 |null |25324 | null
;
costin marked this conversation as resolved.
Show resolved Hide resolved

statsWithBasicExpressionFiltered#[skip:-8.16.0,reason:implemented in 8.16]
from employees
| stats max = max(salary), max_f = max(salary) where salary < 50000,
min = min(salary), min_f = min(salary) where salary > 50000,
exp_p = max(salary) + 10000 where salary < 50000,
exp_m = min(salary) % 10000 where salary > 50000
;

max:integer |max_f:integer|min:integer | min_f:integer|exp_p:integer | exp_m:integer
74999 |49818 |25324 | 50064 |59818 | 64
;

statsWithExpressionOverFilters#[skip:-8.16.0,reason:implemented in 8.16]
from employees
| stats max = max(salary), max_f = max(salary) where salary < 50000,
min = min(salary), min_f = min(salary) where salary > 50000,
exp_gt = max(salary) - min(salary) where salary > 50000,
exp_lt = max(salary) - min(salary) where salary < 50000
alex-spies marked this conversation as resolved.
Show resolved Hide resolved

;

max:integer |max_f:integer | min:integer | min_f:integer |exp_gt:integer | exp_lt:integer
74999 |49818 | 25324 | 50064 |24935 | 24494
;

statsWithSubstitutedExpressionOverFilters#[skip:-8.16.0,reason:implemented in 8.16]
from employees
| stats sum = sum(salary), s_l = sum(salary) where salary < 50000, s_u = sum(salary) where salary > 50000,
count = count(salary), c_l = count(salary) where salary < 50000, c_u = count(salary) where salary > 50000,
avg = round(avg(salary), 2), a_l = round(avg(salary), 2) where salary < 50000, a_u = round(avg(salary),2) where salary > 50000
;

sum:l |s_l:l | s_u:l | count:l |c_l:l |c_u:l |avg:double |a_l:double | a_u:double
4824855 |2220951 | 2603904 | 100 |58 |42 |48248.55 |38292.26 | 61997.71
;
1 change: 1 addition & 0 deletions x-pack/plugin/esql/src/main/antlr/EsqlBaseLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ PERCENT : '%';

// move it in the main section if the feature gets promoted
DEV_MATCH_OP : {this.isDevVersion()}? DEV_MATCH -> type(DEV_MATCH);
NESTED_WHERE : {this.isDevVersion()}? WHERE -> type(WHERE);

NAMED_OR_POSITIONAL_PARAM
: PARAM (LETTER | UNDERSCORE) UNQUOTED_ID_BODY*
Expand Down
20 changes: 13 additions & 7 deletions x-pack/plugin/esql/src/main/antlr/EsqlBaseParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,15 @@ fields
;

field
: booleanExpression
| qualifiedName ASSIGN booleanExpression
: (qualifiedName ASSIGN)? booleanExpression
;

fromCommand
: FROM indexPattern (COMMA indexPattern)* metadata?
;

indexPattern
: clusterString COLON indexString
| indexString
: (clusterString COLON)? indexString
Comment on lines +126 to +134
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small improvements to the grammar that don't change the parsing.

;

clusterString
Expand All @@ -154,15 +152,23 @@ deprecated_metadata
;

metricsCommand
: DEV_METRICS indexPattern (COMMA indexPattern)* aggregates=fields? (BY grouping=fields)?
: DEV_METRICS indexPattern (COMMA indexPattern)* aggregates=aggFields? (BY grouping=fields)?
;

evalCommand
: EVAL fields
;

statsCommand
: STATS stats=fields? (BY grouping=fields)?
: STATS stats=aggFields? (BY grouping=fields)?
;

aggFields
: aggField (COMMA aggField)*
;

aggField
: field {this.isDevVersion()}? (WHERE booleanExpression)?
;

qualifiedName
Expand Down Expand Up @@ -309,5 +315,5 @@ lookupCommand
;

inlinestatsCommand
: DEV_INLINESTATS stats=fields (BY grouping=fields)?
: DEV_INLINESTATS stats=aggFields (BY grouping=fields)?
;
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ private LogicalPlan resolveStats(Stats stats, List<Attribute> childrenOutput) {
newAggregates.add(agg);
}

// TODO: remove this when Stats interface is removed
stats = changed.get() ? stats.with(stats.child(), groupings, newAggregates) : stats;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.xpack.esql.core.util.Holder;
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.FilteredExpression;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate;
import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction;
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
Expand Down Expand Up @@ -301,6 +302,15 @@ private static void checkInvalidNamedExpressionUsage(
Set<Failure> failures,
int level
) {
// unwrap filtered expression
if (e instanceof FilteredExpression fe) {
e = fe.delegate();
if (e.anyMatch(AggregateFunction.class::isInstance) == false) {
Expression filter = fe.filter();
failures.add(fail(filter, "WHERE clause allowed only for aggregate functions, none found in [{}]", e.sourceText()));
}
// TODO add verification for filter clause
alex-spies marked this conversation as resolved.
Show resolved Hide resolved
}
// found an aggregate, constant or a group, bail out
if (e instanceof AggregateFunction af) {
af.field().forEachDown(AggregateFunction.class, f -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
*/
package org.elasticsearch.xpack.esql.expression.function.aggregate;

import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.expression.function.Function;
import org.elasticsearch.xpack.esql.core.tree.Source;
Expand All @@ -20,8 +22,8 @@
import java.util.List;
import java.util.Objects;

import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT;

/**
Expand Down Expand Up @@ -52,25 +54,51 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {

private final Expression field;
private final List<? extends Expression> parameters;
private final Expression filter;

protected AggregateFunction(Source source, Expression field) {
this(source, field, emptyList());
this(source, field, Literal.TRUE, emptyList());
}

protected AggregateFunction(Source source, Expression field, List<? extends Expression> parameters) {
super(source, CollectionUtils.combine(singletonList(field), parameters));
this(source, field, Literal.TRUE, parameters);
}

protected AggregateFunction(Source source, Expression field, Expression filter, List<? extends Expression> parameters) {
super(source, CollectionUtils.combine(asList(field, filter), parameters));
this.field = field;
this.filter = filter;
this.parameters = parameters;
}

protected AggregateFunction(StreamInput in) throws IOException {
this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class));
this(
Source.readFrom((PlanStreamInput) in),
in.readNamedWriteable(Expression.class),
in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER)
? in.readNamedWriteable(Expression.class)
: Literal.TRUE,
in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER)
? in.readNamedWriteableCollectionAsList(Expression.class)
: emptyList()
);
Comment on lines +75 to +84
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uniform serialization across all aggregate functions - we should had that since the beginning to avoid having each subclass add its own repetitive serialization.
/cc @nik9000

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's easier to make the filter wrap the aggregation rather than inside each one. I think also that serializing a list in case aggs use it is probably not a great choice either. It's fine, but feels like something we'd do when we need it.

}

@Override
public void writeTo(StreamOutput out) throws IOException {
public final void writeTo(StreamOutput out) throws IOException {
Source.EMPTY.writeTo(out);
out.writeNamedWriteable(field);
if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER)) {
out.writeNamedWriteable(filter);
out.writeNamedWriteableCollection(parameters);
} else {
deprecatedWriteTo(out);
}
}

@Deprecated(since = "8.16", forRemoval = true)
protected void deprecatedWriteTo(StreamOutput out) throws IOException {
costin marked this conversation as resolved.
Show resolved Hide resolved
//
}

public Expression field() {
Expand All @@ -81,6 +109,14 @@ public List<? extends Expression> parameters() {
return parameters;
}

public boolean hasFilter() {
return filter != null && filter != Literal.TRUE;
}

public Expression filter() {
return filter;
}

/**
* Returns the input expressions used in aggregation.
* Defaults to a list containing the only the input field.
Expand All @@ -94,6 +130,18 @@ protected TypeResolution resolveType() {
return TypeResolutions.isExact(field, sourceText(), DEFAULT);
}

/**
* Attach a filter to the aggregate function.
*/
public abstract AggregateFunction withFilter(Expression filter);

public AggregateFunction withParameters(List<? extends Expression> parameters) {
if (parameters == this.parameters) {
return this;
}
return (AggregateFunction) replaceChildren(CollectionUtils.combine(asList(field, filter), parameters));
}

@Override
public int hashCode() {
// NB: the hashcode is currently used for key generation so
Expand All @@ -105,7 +153,9 @@ public int hashCode() {
public boolean equals(Object obj) {
if (super.equals(obj)) {
AggregateFunction other = (AggregateFunction) obj;
return Objects.equals(other.field(), field()) && Objects.equals(other.parameters(), parameters());
return Objects.equals(other.field(), field())
&& Objects.equals(other.filter, filter)
&& Objects.equals(other.parameters(), parameters());
}
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.io.IOException;
import java.util.List;

import static java.util.Collections.emptyList;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;

Expand All @@ -47,6 +48,10 @@ public Avg(Source source, @Param(name = "number", type = { "double", "integer",
super(source, field);
}

protected Avg(Source source, Expression field, Expression filter) {
super(source, field, filter, emptyList());
}

@Override
protected Expression.TypeResolution resolveType() {
return isType(
Expand Down Expand Up @@ -74,19 +79,26 @@ public DataType dataType() {

@Override
protected NodeInfo<Avg> info() {
return NodeInfo.create(this, Avg::new, field());
return NodeInfo.create(this, Avg::new, field(), filter());
}

@Override
public Avg replaceChildren(List<Expression> newChildren) {
return new Avg(source(), newChildren.get(0));
return new Avg(source(), newChildren.get(0), newChildren.get(1));
}

@Override
public Avg withFilter(Expression filter) {
return new Avg(source(), field(), filter);
Comment on lines +83 to +93
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The side effect is all aggs had to be modified to take into account the parent filter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we wrap them with something like we do at runtime? It feels like that might work.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean. My initial approach was to wrap the agg and filter in separate class (FilteredAggregate with two children - aggregate and filter/expression ) however this messed up the optimizer since the rules picked the aggregate functions but had no idea there was a filter associated with it.
This is problematic for scenario where the same agg is defined with different filter:
STATS c = count(). c_f = count() WHERE a > 1
since the where filter is disjoined from the count, the rules replaced the second count with the first one but that is incorrect since they have different execution paths.
Furthermore at compute time the filter and actual agg gets fused anyway so the two are a pair anyway.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They do get fused anyway, but it's just a lot more convenient if every agg doesn't have to worry about being filtered. If filter is a member here we have to think about it all over the place. It's enough that the fused filter at compute time has to add all the tests for it, I'd love to avoid that bubbling all the way out. OTOH, if that's how it has to be, I'm not deep enough to object.

Copy link
Member

@fang-xing-esql fang-xing-esql Oct 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this means each aggregation evaluates its own filter? If so, from performance perspective, one of the advantages of separating filter from aggregation is that, if the same filter applies to different aggregations, it makes it easier to recognize that it is the same filter and can be evaluated once for all aggregations that share it, in stead of evaluating each filter within each aggregation separately. However I couldn’t figure out how the current infrastructure supports it yet. At first look at this, it reminds me of the (common)table expressions and union all subqueries in SQL, but we don’t have this infrastructure in ES|QL yet, and it could be a much broader scope.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach doesn't prevent optimizing the filter evaluation.
Each aggregation ends up having its own filter because count(*) is different than count(*) where a > 10. Initially they were decoupled however the rules considered them the same since the filter was sitting outside the aggregation.
When dealing with the same filter, we can assemble a different pipeline so the mask or masked block is reused across multiple aggregations.

}

@Override
public Expression surrogate() {
var s = source();
var field = field();

return field().foldable() ? new MvAvg(s, field) : new Div(s, new Sum(s, field), new Count(s, field), dataType());
return field().foldable()
? new MvAvg(s, field)
costin marked this conversation as resolved.
Show resolved Hide resolved
: new Div(s, new Sum(s, field, filter()), new Count(s, field, filter()), dataType());
}
}
Loading