Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Support filter clause in aggregations #960

Merged
merged 8 commits into from
Jan 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,12 @@ public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext
Optional<BuiltinFunctionName> builtinFunctionName = BuiltinFunctionName.of(node.getFuncName());
if (builtinFunctionName.isPresent()) {
Expression arg = node.getField().accept(this, context);
return (Aggregator)
repository.compile(
Aggregator aggregator = (Aggregator) repository.compile(
builtinFunctionName.get().getName(), Collections.singletonList(arg));
if (node.getCondition() != null) {
aggregator.condition(analyze(node.getCondition(), context));
}
return aggregator;
} else {
throw new SemanticCheckException("Unsupported aggregation function " + node.getFuncName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ public static UnresolvedExpression aggregate(
return new AggregateFunction(func, field, Arrays.asList(args));
}

public static UnresolvedExpression filteredAggregate(
String func, UnresolvedExpression field, UnresolvedExpression condition) {
return new AggregateFunction(func, field, condition);
}

public static Function function(String funcName, UnresolvedExpression... funcArgs) {
return new Function(funcName, Arrays.asList(funcArgs));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public class AggregateFunction extends UnresolvedExpression {
private final String funcName;
private final UnresolvedExpression field;
private final List<UnresolvedExpression> argList;
private UnresolvedExpression condition;

/**
* Constructor.
Expand All @@ -46,6 +47,20 @@ public AggregateFunction(String funcName, UnresolvedExpression field) {
this.argList = Collections.emptyList();
}

/**
* Constructor.
* @param funcName function name.
* @param field {@link UnresolvedExpression}.
* @param condition condition in aggregation filter.
*/
public AggregateFunction(String funcName, UnresolvedExpression field,
UnresolvedExpression condition) {
this.funcName = funcName;
this.field = field;
this.argList = Collections.emptyList();
this.condition = condition;
}

@Override
public List<UnresolvedExpression> getChild() {
return Collections.singletonList(field);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import com.amazon.opendistroforelasticsearch.sql.analysis.ExpressionAnalyzer;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils;
import com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType;
import com.amazon.opendistroforelasticsearch.sql.data.type.ExprType;
import com.amazon.opendistroforelasticsearch.sql.exception.ExpressionEvaluationException;
Expand All @@ -30,6 +31,8 @@
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.experimental.Accessors;

/**
* Aggregator which will iterate on the {@link BindingTuple}s to aggregate the result.
Expand All @@ -46,20 +49,40 @@ public abstract class Aggregator<S extends AggregationState>
@Getter
private final List<Expression> arguments;
protected final ExprCoreType returnType;
@Setter
@Getter
@Accessors(fluent = true)
protected Expression condition;

/**
* Create an {@link AggregationState} which will be used for aggregation.
*/
public abstract S create();

/**
* Iterate on the {@link BindingTuple}.
* Iterate on {@link ExprValue}.
* @param value {@link ExprValue}
* @param state {@link AggregationState}
* @return {@link AggregationState}
*/
protected abstract S iterate(ExprValue value, S state);

/**
* Let the aggregator iterate on the {@link BindingTuple}
* To filter out ExprValues that are missing, null or cannot satisfy {@link #condition}
* Before the specific aggregator iterating ExprValue in the tuple.
*
* @param tuple {@link BindingTuple}
* @param state {@link AggregationState}
* @return {@link AggregationState}
*/
public abstract S iterate(BindingTuple tuple, S state);
public S iterate(BindingTuple tuple, S state) {
ExprValue value = getArguments().get(0).valueOf(tuple);
if (value.isNull() || value.isMissing() || !conditionValue(tuple)) {
return state;
}
return iterate(value, state);
}

@Override
public ExprValue valueOf(Environment<Expression, ExprValue> valueEnv) {
Expand All @@ -77,4 +100,14 @@ public <T, C> T accept(ExpressionNodeVisitor<T, C> visitor, C context) {
return visitor.visitAggregator(this, context);
}

/**
* Util method to get value of condition in aggregation filter.
*/
public boolean conditionValue(BindingTuple tuple) {
if (condition == null) {
return true;
}
return ExprValueUtils.getBooleanValue(condition.valueOf(tuple));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,9 @@ public AvgState create() {
}

@Override
public AvgState iterate(BindingTuple tuple, AvgState state) {
Expression expression = getArguments().get(0);
ExprValue value = expression.valueOf(tuple);
if (!(value.isNull() || value.isMissing())) {
state.count++;
state.total += ExprValueUtils.getDoubleValue(value);
}
protected AvgState iterate(ExprValue value, AvgState state) {
state.count++;
state.total += ExprValueUtils.getDoubleValue(value);
return state;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,8 @@ public CountAggregator.CountState create() {
}

@Override
public CountState iterate(BindingTuple tuple, CountState state) {
Expression expression = getArguments().get(0);
ExprValue value = expression.valueOf(tuple);
if (!(value.isNull() || value.isMissing())) {
state.count++;
}
protected CountState iterate(ExprValue value, CountState state) {
state.count++;
return state;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,8 @@ public MaxState create() {
}

@Override
public MaxState iterate(BindingTuple tuple, MaxState state) {
Expression expression = getArguments().get(0);
ExprValue value = expression.valueOf(tuple);
if (!(value.isNull() || value.isMissing())) {
state.max(value);
}
protected MaxState iterate(ExprValue value, MaxState state) {
state.max(value);
return state;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,8 @@ public MinState create() {
}

@Override
public MinState iterate(BindingTuple tuple, MinState state) {
Expression expression = getArguments().get(0);
ExprValue value = expression.valueOf(tuple);
if (!(value.isNull() || value.isMissing())) {
state.min(value);
}
protected MinState iterate(ExprValue value, MinState state) {
state.min(value);
return state;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package com.amazon.opendistroforelasticsearch.sql.expression.aggregation;

import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
import com.amazon.opendistroforelasticsearch.sql.expression.ExpressionNodeVisitor;
import com.amazon.opendistroforelasticsearch.sql.storage.bindingtuple.BindingTuple;
import com.google.common.base.Strings;
Expand Down Expand Up @@ -63,8 +64,8 @@ public AggregationState create() {
}

@Override
public AggregationState iterate(BindingTuple tuple, AggregationState state) {
return delegated.iterate(tuple, state);
protected AggregationState iterate(ExprValue value, AggregationState state) {
return delegated.iterate(value, state);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,9 @@ public SumState create() {
}

@Override
public SumState iterate(BindingTuple tuple, SumState state) {
Expression expression = getArguments().get(0);
ExprValue value = expression.valueOf(tuple);
if (!(value.isNull() || value.isMissing())) {
state.isEmptyCollection = false;
state.add(value);
}
protected SumState iterate(ExprValue value, SumState state) {
state.isEmptyCollection = false;
state.add(value);
return state;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
package com.amazon.opendistroforelasticsearch.sql.analysis;

import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.field;
import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.filteredAggregate;
import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.function;
import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.intLiteral;
import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.qualifiedName;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.LITERAL_TRUE;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.integerValue;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.BOOLEAN;
Expand Down Expand Up @@ -256,6 +260,16 @@ public void undefined_aggregation_function() {
assertEquals("Unsupported aggregation function ESTDC_ERROR", exception.getMessage());
}

@Test
public void aggregation_filter() {
assertAnalyzeEqual(
dsl.avg(DSL.ref("integer_value", INTEGER))
.condition(dsl.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))),
AstDSL.filteredAggregate("avg", qualifiedName("integer_value"),
function(">", qualifiedName("integer_value"), intLiteral(1)))
);
}

protected Expression analyze(UnresolvedExpression unresolvedExpression) {
return expressionAnalyzer.analyze(unresolvedExpression, analysisContext);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ public void avg_arithmetic_expression() {
assertEquals(25.0, result.value());
}

@Test
public void filtered_avg() {
ExprValue result = aggregation(dsl.avg(DSL.ref("integer_value", INTEGER))
.condition(dsl.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))), tuples);
assertEquals(3.0, result.value());
}

@Test
public void avg_with_missing() {
ExprValue result =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ public void count_array_field_expression() {
assertEquals(1, result.value());
}

@Test
public void filtered_count() {
ExprValue result = aggregation(dsl.count(DSL.ref("integer_value", INTEGER))
.condition(dsl.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))), tuples);
assertEquals(3, result.value());
}

@Test
public void count_with_missing() {
ExprValue result = aggregation(dsl.count(DSL.ref("integer_value", INTEGER)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,13 @@ public void test_max_arithmetic_expression() {
assertEquals(4, result.value());
}

@Test
public void filtered_max() {
ExprValue result = aggregation(dsl.max(DSL.ref("integer_value", INTEGER))
.condition(dsl.less(DSL.ref("integer_value", INTEGER), DSL.literal(4))), tuples);
assertEquals(3, result.value());
}

@Test
public void test_max_null() {
ExprValue result =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,13 @@ public void test_min_arithmetic_expression() {
assertEquals(1, result.value());
}

@Test
public void filtered_min() {
ExprValue result = aggregation(dsl.min(DSL.ref("integer_value", INTEGER))
.condition(dsl.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))), tuples);
assertEquals(2, result.value());
}

@Test
public void test_min_null() {
ExprValue result =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ public void sum_string_field_expression() {
assertEquals("unexpected type [STRING] in sum aggregation", exception.getMessage());
}

@Test
public void filtered_sum() {
ExprValue result = aggregation(dsl.sum(DSL.ref("integer_value", INTEGER))
.condition(dsl.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))), tuples);
assertEquals(9, result.value());
}

@Test
public void sum_with_missing() {
ExprValue result =
Expand Down
39 changes: 39 additions & 0 deletions docs/user/dql/aggregations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,42 @@ Additionally, a ``HAVING`` clause can work without ``GROUP BY`` clause. This is
| Total of age > 100 |
+------------------------+


FILTER Clause
=============

Description
-----------

A ``FILTER`` clause can set specific condition for the current aggregation bucket, following the syntax ``aggregation_function(expr) FILTER(WHERE condition_expr)``. If a filter is specified, then only the input rows for which the condition in the filter clause evaluates to true are fed to the aggregate function; other rows are discarded. The aggregation with filter clause can be use in ``SELECT`` clause only.

FILTER with GROUP BY
--------------------

The group by aggregation with ``FILTER`` clause can set different conditions for each aggregation bucket. Here is an example to use ``FILTER`` in group by aggregation::

od> SELECT avg(age) FILTER(WHERE balance > 10000) AS filtered, gender FROM accounts GROUP BY gender
fetched rows / total rows = 2/2
+------------+----------+
| filtered | gender |
|------------+----------|
| 28.0 | F |
| 32.0 | M |
+------------+----------+

FILTER without GROUP BY
-----------------------

The ``FILTER`` clause can be used in aggregation functions without GROUP BY as well. For example::

od> SELECT
... count(*) AS unfiltered,
... count(*) FILTER(WHERE age > 34) AS filtered
... FROM accounts
fetched rows / total rows = 1/1
+--------------+------------+
| unfiltered | filtered |
|--------------+------------|
| 4 | 1 |
+--------------+------------+

Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation;
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;

/**
Expand Down Expand Up @@ -82,6 +83,13 @@ private static Map<String, Object> parseInternal(Aggregation aggregation) {
resultMap.put(
aggregation.getName(),
handleNanValue(((NumericMetricsAggregation.SingleValue) aggregation).value()));
} else if (aggregation instanceof Filter) {
// parse sub-aggregations for FilterAggregation response
List<Aggregation> aggList = ((Filter) aggregation).getAggregations().asList();
aggList.forEach(internalAgg -> {
Map<String, Object> intermediateMap = parseInternal(internalAgg);
resultMap.put(internalAgg.getName(), intermediateMap.get(internalAgg.getName()));
});
} else {
throw new IllegalStateException("unsupported aggregation type " + aggregation.getType());
}
Expand Down
Loading