Skip to content

Commit

Permalink
ESQL: Introduce per agg filter
Browse files Browse the repository at this point in the history
Add support for aggregation scoped filters that work dynamically on the
 data in each group.

| STATS
    success = COUNT(*) WHERE 200 <= code AND code < 300,
   redirect = COUNT(*) WHERE 300 <= code AND code < 400,
 client_err = COUNT(*) WHERE 400 <= code AND code < 500,
 server_err = COUNT(*) WHERE 500 <= code AND code < 600,
 total_count = COUNT(*)

Implementation wise, the base AggregateFunction has been extended to
 allow a filter to be passed on. This is required to incorporate the
 filter as part of the aggregate equality/identify which would fail with
 the filter as an external component.

As part of the process, the serialization for the existing aggregations
 had to be fixed so AggregateFunction implementations so that it
 delegates to their parent first.
  • Loading branch information
costin committed Oct 1, 2024
1 parent 2eb9274 commit d2db8db
Show file tree
Hide file tree
Showing 44 changed files with 2,531 additions and 1,773 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ADD_DATA_STREAM_OPTIONS = def(8_754_00_0);
public static final TransportVersion CCS_REMOTE_TELEMETRY_STATS = def(8_755_00_0);
public static final TransportVersion ESQL_CCS_EXECUTION_INFO = def(8_756_00_0);
public static final TransportVersion ESQL_PER_AGGREGATE_FILTER = def(8_757_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2290,3 +2290,60 @@ from employees
m:integer |a:double |x:integer
74999 |48249.0 |0
;


statsWithFiltering#[skip:-8.16.0,reason:implemented in 8.16]
from employees
| stats max = max(salary), max_f = max(salary) where salary < 50000, max_a = max(salary) where salary > 100,
min = min(salary), min_f = min(salary) where salary > 50000, min_a = min(salary) where salary > 100
;

max:integer |max_f:integer |max_a:integer | min:integer | min_f:integer | min_a:integer
74999 |49818 |74999 | 25324 | 50064 | 25324
;

statsWithEverythingFiltered#[skip:-8.16.0,reason:implemented in 8.16]
from employees
| stats max = max(salary), max_a = max(salary) where salary < 100,
min = min(salary), min_a = min(salary) where salary > 99999
;

max:integer |max_a:integer|min:integer | min_a:integer
74999 |null |25324 | null
;

statsWithBasicExpressionFiltered#[skip:-8.16.0,reason:implemented in 8.16]
from employees
| stats max = max(salary), max_f = max(salary) where salary < 50000,
min = min(salary), min_f = min(salary) where salary > 50000,
exp_p = max(salary) + 10000 where salary < 50000,
exp_m = min(salary) % 10000 where salary > 50000
;

max:integer |max_f:integer|min:integer | min_f:integer|exp_p:integer | exp_m:integer
74999 |49818 |25324 | 50064 |59818 | 64
;

statsWithExpressionOverFilters#[skip:-8.16.0,reason:implemented in 8.16]
from employees
| stats max = max(salary), max_f = max(salary) where salary < 50000,
min = min(salary), min_f = min(salary) where salary > 50000,
exp_gt = max(salary) - min(salary) where salary > 50000,
exp_lt = max(salary) - min(salary) where salary < 50000

;

max:integer |max_f:integer | min:integer | min_f:integer |exp_gt:integer | exp_lt:integer
74999 |49818 | 25324 | 50064 |24935 | 24494
;

statsWithSubstitutedExpressionOverFilters#[skip:-8.16.0,reason:implemented in 8.16]
from employees
| stats sum = sum(salary), s_l = sum(salary) where salary < 50000, s_u = sum(salary) where salary > 50000,
count = count(salary), c_l = count(salary) where salary < 50000, c_u = count(salary) where salary > 50000,
avg = round(avg(salary), 2), a_l = round(avg(salary), 2) where salary < 50000, a_u = round(avg(salary),2) where salary > 50000
;

sum:l |s_l:l | s_u:l | count:l |c_l:l |c_u:l |avg:double |a_l:double | a_u:double
4824855 |2220951 | 2603904 | 100 |58 |42 |48248.55 |38292.26 | 61997.71
;
1 change: 1 addition & 0 deletions x-pack/plugin/esql/src/main/antlr/EsqlBaseLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ PERCENT : '%';

// move it in the main section if the feature gets promoted
DEV_MATCH_OP : {this.isDevVersion()}? DEV_MATCH -> type(DEV_MATCH);
NESTED_WHERE : {this.isDevVersion()}? WHERE -> type(WHERE);

NAMED_OR_POSITIONAL_PARAM
: PARAM (LETTER | UNDERSCORE) UNQUOTED_ID_BODY*
Expand Down
20 changes: 13 additions & 7 deletions x-pack/plugin/esql/src/main/antlr/EsqlBaseParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,15 @@ fields
;

field
: booleanExpression
| qualifiedName ASSIGN booleanExpression
: (qualifiedName ASSIGN)? booleanExpression
;

fromCommand
: FROM indexPattern (COMMA indexPattern)* metadata?
;

indexPattern
: clusterString COLON indexString
| indexString
: (clusterString COLON)? indexString
;

clusterString
Expand All @@ -154,15 +152,23 @@ deprecated_metadata
;

metricsCommand
: DEV_METRICS indexPattern (COMMA indexPattern)* aggregates=fields? (BY grouping=fields)?
: DEV_METRICS indexPattern (COMMA indexPattern)* aggregates=aggFields? (BY grouping=fields)?
;

evalCommand
: EVAL fields
;

statsCommand
: STATS stats=fields? (BY grouping=fields)?
: STATS stats=aggFields? (BY grouping=fields)?
;

aggFields
: aggField (COMMA aggField)*
;

aggField
: field {this.isDevVersion()}? (WHERE booleanExpression)?
;

qualifiedName
Expand Down Expand Up @@ -309,5 +315,5 @@ lookupCommand
;

inlinestatsCommand
: DEV_INLINESTATS stats=fields (BY grouping=fields)?
: DEV_INLINESTATS stats=aggFields (BY grouping=fields)?
;
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ private LogicalPlan resolveStats(Stats stats, List<Attribute> childrenOutput) {
newAggregates.add(agg);
}

// TODO: remove this when Stats interface is removed
stats = changed.get() ? stats.with(stats.child(), groupings, newAggregates) : stats;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.xpack.esql.core.util.Holder;
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.FilteredExpression;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate;
import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction;
import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
Expand Down Expand Up @@ -301,6 +302,11 @@ private static void checkInvalidNamedExpressionUsage(
Set<Failure> failures,
int level
) {
// unwrap filtered expression
if (e instanceof FilteredExpression fe) {
e = fe.delegate();
// TODO add verification for filter clause
}
// found an aggregate, constant or a group, bail out
if (e instanceof AggregateFunction af) {
af.field().forEachDown(AggregateFunction.class, f -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
*/
package org.elasticsearch.xpack.esql.expression.function.aggregate;

import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.expression.function.Function;
import org.elasticsearch.xpack.esql.core.tree.Source;
Expand All @@ -20,8 +22,8 @@
import java.util.List;
import java.util.Objects;

import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT;

/**
Expand Down Expand Up @@ -52,25 +54,51 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {

private final Expression field;
private final List<? extends Expression> parameters;
private final Expression filter;

protected AggregateFunction(Source source, Expression field) {
this(source, field, emptyList());
this(source, field, Literal.TRUE, emptyList());
}

protected AggregateFunction(Source source, Expression field, List<? extends Expression> parameters) {
super(source, CollectionUtils.combine(singletonList(field), parameters));
this(source, field, Literal.TRUE, parameters);
}

protected AggregateFunction(Source source, Expression field, Expression filter, List<? extends Expression> parameters) {
super(source, CollectionUtils.combine(asList(field, filter), parameters));
this.field = field;
this.filter = filter;
this.parameters = parameters;
}

protected AggregateFunction(StreamInput in) throws IOException {
this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class));
this(
Source.readFrom((PlanStreamInput) in),
in.readNamedWriteable(Expression.class),
in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER)
? in.readNamedWriteable(Expression.class)
: Literal.TRUE,
in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER)
? in.readNamedWriteableCollectionAsList(Expression.class)
: emptyList()
);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
public final void writeTo(StreamOutput out) throws IOException {
Source.EMPTY.writeTo(out);
out.writeNamedWriteable(field);
if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER)) {
out.writeNamedWriteable(filter);
out.writeNamedWriteableCollection(parameters);
} else {
deprecatedWriteTo(out);
}
}

@Deprecated(since = "8.16", forRemoval = true)
protected void deprecatedWriteTo(StreamOutput out) throws IOException {
//
}

public Expression field() {
Expand All @@ -81,6 +109,14 @@ public List<? extends Expression> parameters() {
return parameters;
}

public boolean hasFilter() {
return filter != null && filter != Literal.TRUE;
}

public Expression filter() {
return filter;
}

/**
* Returns the input expressions used in aggregation.
* Defaults to a list containing the only the input field.
Expand All @@ -94,6 +130,18 @@ protected TypeResolution resolveType() {
return TypeResolutions.isExact(field, sourceText(), DEFAULT);
}

/**
* Attach a filter to the aggregate function.
*/
public abstract AggregateFunction withFilter(Expression filter);

public AggregateFunction withParameters(List<? extends Expression> parameters) {
if (parameters == this.parameters) {
return this;
}
return (AggregateFunction) replaceChildren(CollectionUtils.combine(asList(field, filter), parameters));
}

@Override
public int hashCode() {
// NB: the hashcode is currently used for key generation so
Expand All @@ -105,7 +153,9 @@ public int hashCode() {
public boolean equals(Object obj) {
if (super.equals(obj)) {
AggregateFunction other = (AggregateFunction) obj;
return Objects.equals(other.field(), field()) && Objects.equals(other.parameters(), parameters());
return Objects.equals(other.field(), field())
&& Objects.equals(other.filter, filter)
&& Objects.equals(other.parameters(), parameters());
}
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.io.IOException;
import java.util.List;

import static java.util.Collections.emptyList;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;

Expand All @@ -47,6 +48,10 @@ public Avg(Source source, @Param(name = "number", type = { "double", "integer",
super(source, field);
}

protected Avg(Source source, Expression field, Expression filter) {
super(source, field, filter, emptyList());
}

@Override
protected Expression.TypeResolution resolveType() {
return isType(
Expand Down Expand Up @@ -74,19 +79,26 @@ public DataType dataType() {

@Override
protected NodeInfo<Avg> info() {
return NodeInfo.create(this, Avg::new, field());
return NodeInfo.create(this, Avg::new, field(), filter());
}

@Override
public Avg replaceChildren(List<Expression> newChildren) {
return new Avg(source(), newChildren.get(0));
return new Avg(source(), newChildren.get(0), newChildren.get(1));
}

@Override
public Avg withFilter(Expression filter) {
return new Avg(source(), field(), filter);
}

@Override
public Expression surrogate() {
var s = source();
var field = field();

return field().foldable() ? new MvAvg(s, field) : new Div(s, new Sum(s, field), new Count(s, field), dataType());
return field().foldable()
? new MvAvg(s, field)
: new Div(s, new Sum(s, field, filter()), new Count(s, field, filter()), dataType());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@
import java.io.IOException;
import java.util.List;

import static java.util.Collections.emptyList;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;

public class Count extends AggregateFunction implements EnclosedAgg, ToAggregator, SurrogateExpression {
public class Count extends AggregateFunction implements ToAggregator, SurrogateExpression {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Count", Count::new);

@FunctionInfo(
Expand Down Expand Up @@ -86,6 +87,10 @@ public Count(
super(source, field);
}

protected Count(Source source, Expression field, Expression filter) {
super(source, field, filter, emptyList());
}

private Count(StreamInput in) throws IOException {
super(in);
}
Expand All @@ -97,17 +102,17 @@ public String getWriteableName() {

@Override
protected NodeInfo<Count> info() {
return NodeInfo.create(this, Count::new, field());
return NodeInfo.create(this, Count::new, field(), filter());
}

@Override
public Count replaceChildren(List<Expression> newChildren) {
return new Count(source(), newChildren.get(0));
public AggregateFunction withFilter(Expression filter) {
return new Count(source(), field(), filter);
}

@Override
public String innerName() {
return "count";
public Count replaceChildren(List<Expression> newChildren) {
return new Count(source(), newChildren.get(0), newChildren.get(1));
}

@Override
Expand Down
Loading

0 comments on commit d2db8db

Please sign in to comment.