Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Support filter clause in aggregations #960

Merged
merged 8 commits into from
Jan 6, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,12 @@ public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext
Optional<BuiltinFunctionName> builtinFunctionName = BuiltinFunctionName.of(node.getFuncName());
if (builtinFunctionName.isPresent()) {
Expression arg = node.getField().accept(this, context);
return (Aggregator)
repository.compile(
Aggregator aggregator = (Aggregator) repository.compile(
builtinFunctionName.get().getName(), Collections.singletonList(arg));
if (node.getCondition() != null) {
aggregator.condition(analyze(node.getCondition(), context));
}
return aggregator;
} else {
throw new SemanticCheckException("Unsupported aggregation function " + node.getFuncName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ public static UnresolvedExpression aggregate(
return new AggregateFunction(func, field, Arrays.asList(args));
}

public static UnresolvedExpression filteredAggregate(
String func, UnresolvedExpression field, UnresolvedExpression condition) {
return new AggregateFunction(func, field, condition);
}

public static Function function(String funcName, UnresolvedExpression... funcArgs) {
return new Function(funcName, Arrays.asList(funcArgs));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public class AggregateFunction extends UnresolvedExpression {
private final String funcName;
private final UnresolvedExpression field;
private final List<UnresolvedExpression> argList;
private UnresolvedExpression condition;

/**
* Constructor.
Expand All @@ -46,6 +47,20 @@ public AggregateFunction(String funcName, UnresolvedExpression field) {
this.argList = Collections.emptyList();
}

/**
* Constructor.
* @param funcName function name.
* @param field {@link UnresolvedExpression}.
* @param condition condition in aggregation filter.
*/
public AggregateFunction(String funcName, UnresolvedExpression field,
UnresolvedExpression condition) {
this.funcName = funcName;
this.field = field;
this.argList = Collections.emptyList();
this.condition = condition;
}

@Override
public List<UnresolvedExpression> getChild() {
return Collections.singletonList(field);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import com.amazon.opendistroforelasticsearch.sql.analysis.ExpressionAnalyzer;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils;
import com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType;
import com.amazon.opendistroforelasticsearch.sql.data.type.ExprType;
import com.amazon.opendistroforelasticsearch.sql.exception.ExpressionEvaluationException;
Expand All @@ -30,6 +31,8 @@
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.experimental.Accessors;

/**
* Aggregator which will iterate on the {@link BindingTuple}s to aggregate the result.
Expand All @@ -46,6 +49,10 @@ public abstract class Aggregator<S extends AggregationState>
@Getter
private final List<Expression> arguments;
protected final ExprCoreType returnType;
@Setter
@Getter
@Accessors(fluent = true)
protected Expression condition;

/**
* Create an {@link AggregationState} which will be used for aggregation.
Expand Down Expand Up @@ -77,4 +84,14 @@ public <T, C> T accept(ExpressionNodeVisitor<T, C> visitor, C context) {
return visitor.visitAggregator(this, context);
}

/**
* Util method to get value of condition in aggregation filter.
*/
public boolean conditionValue(BindingTuple tuple) {
if (condition == null) {
return true;
}
return ExprValueUtils.getBooleanValue(condition.valueOf(tuple));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public AvgState create() {
public AvgState iterate(BindingTuple tuple, AvgState state) {
Expression expression = getArguments().get(0);
ExprValue value = expression.valueOf(tuple);
if (!(value.isNull() || value.isMissing())) {
if (!(value.isNull() || value.isMissing()) && conditionValue(tuple)) {
chloe-zh marked this conversation as resolved.
Show resolved Hide resolved
state.count++;
state.total += ExprValueUtils.getDoubleValue(value);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public CountAggregator.CountState create() {
public CountState iterate(BindingTuple tuple, CountState state) {
Expression expression = getArguments().get(0);
ExprValue value = expression.valueOf(tuple);
if (!(value.isNull() || value.isMissing())) {
if (!(value.isNull() || value.isMissing()) && conditionValue(tuple)) {
state.count++;
}
return state;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public MaxState create() {
public MaxState iterate(BindingTuple tuple, MaxState state) {
Expression expression = getArguments().get(0);
ExprValue value = expression.valueOf(tuple);
if (!(value.isNull() || value.isMissing())) {
if (!(value.isNull() || value.isMissing()) && conditionValue(tuple)) {
state.max(value);
}
return state;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public MinState create() {
public MinState iterate(BindingTuple tuple, MinState state) {
Expression expression = getArguments().get(0);
ExprValue value = expression.valueOf(tuple);
if (!(value.isNull() || value.isMissing())) {
if (!(value.isNull() || value.isMissing()) && conditionValue(tuple)) {
state.min(value);
}
return state;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public SumState create() {
public SumState iterate(BindingTuple tuple, SumState state) {
Expression expression = getArguments().get(0);
ExprValue value = expression.valueOf(tuple);
if (!(value.isNull() || value.isMissing())) {
if (!(value.isNull() || value.isMissing()) && conditionValue(tuple)) {
state.isEmptyCollection = false;
state.add(value);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
package com.amazon.opendistroforelasticsearch.sql.analysis;

import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.field;
import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.filteredAggregate;
import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.function;
import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.intLiteral;
import static com.amazon.opendistroforelasticsearch.sql.ast.dsl.AstDSL.qualifiedName;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.LITERAL_TRUE;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.integerValue;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.BOOLEAN;
Expand Down Expand Up @@ -256,6 +260,16 @@ public void undefined_aggregation_function() {
assertEquals("Unsupported aggregation function ESTDC_ERROR", exception.getMessage());
}

@Test
public void aggregation_filter() {
assertAnalyzeEqual(
dsl.avg(DSL.ref("integer_value", INTEGER))
.condition(dsl.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))),
AstDSL.filteredAggregate("avg", qualifiedName("integer_value"),
function(">", qualifiedName("integer_value"), intLiteral(1)))
);
}

protected Expression analyze(UnresolvedExpression unresolvedExpression) {
return expressionAnalyzer.analyze(unresolvedExpression, analysisContext);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ public void avg_arithmetic_expression() {
assertEquals(25.0, result.value());
}

@Test
public void filtered_avg() {
ExprValue result = aggregation(dsl.avg(DSL.ref("integer_value", INTEGER))
.condition(dsl.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))), tuples);
assertEquals(3.0, result.value());
}

@Test
public void avg_with_missing() {
ExprValue result =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ public void count_array_field_expression() {
assertEquals(1, result.value());
}

@Test
public void filtered_count() {
ExprValue result = aggregation(dsl.count(DSL.ref("integer_value", INTEGER))
.condition(dsl.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))), tuples);
assertEquals(3, result.value());
}

@Test
public void count_with_missing() {
ExprValue result = aggregation(dsl.count(DSL.ref("integer_value", INTEGER)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,13 @@ public void test_max_arithmetic_expression() {
assertEquals(4, result.value());
}

@Test
public void filtered_max() {
ExprValue result = aggregation(dsl.max(DSL.ref("integer_value", INTEGER))
.condition(dsl.less(DSL.ref("integer_value", INTEGER), DSL.literal(4))), tuples);
assertEquals(3, result.value());
}

@Test
public void test_max_null() {
ExprValue result =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,13 @@ public void test_min_arithmetic_expression() {
assertEquals(1, result.value());
}

@Test
public void filtered_min() {
ExprValue result = aggregation(dsl.min(DSL.ref("integer_value", INTEGER))
.condition(dsl.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))), tuples);
assertEquals(2, result.value());
}

@Test
public void test_min_null() {
ExprValue result =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ public void sum_string_field_expression() {
assertEquals("unexpected type [STRING] in sum aggregation", exception.getMessage());
}

@Test
public void filtered_sum() {
ExprValue result = aggregation(dsl.sum(DSL.ref("integer_value", INTEGER))
.condition(dsl.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))), tuples);
assertEquals(9, result.value());
}

@Test
public void sum_with_missing() {
ExprValue result =
Expand Down
39 changes: 39 additions & 0 deletions docs/user/dql/aggregations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,42 @@ Additionally, a ``HAVING`` clause can work without ``GROUP BY`` clause. This is
| Total of age > 100 |
+------------------------+


FILTER Clause
=============

Description
-----------

A ``FILTER`` clause can set specific condition for the current aggregation bucket, following the syntax ``aggregation_function(expr) FILTER(WHERE condition_expr)``. If a filter is specified, then only the input rows for which the condition in the filter clause evaluates to true are fed to the aggregate function; other rows are discarded. The aggregation with filter clause can be use in ``SELECT`` clause only.

FILTER with GROUP BY
--------------------

The group by aggregation with ``FILTER`` clause can set different conditions for each aggregation bucket. Here is an example to use ``FILTER`` in group by aggregation::

od> SELECT avg(age) FILTER(WHERE balance > 10000) AS filtered, gender FROM accounts GROUP BY gender
fetched rows / total rows = 2/2
+------------+----------+
| filtered | gender |
|------------+----------|
| 28.0 | F |
| 32.0 | M |
+------------+----------+

FILTER without GROUP BY
-----------------------

The ``FILTER`` clause can be used in aggregation functions without GROUP BY as well. For example::

od> SELECT
... count(*) AS unfiltered,
... count(*) FILTER(WHERE age > 34) AS filtered
... FROM accounts
fetched rows / total rows = 1/1
+--------------+------------+
| unfiltered | filtered |
|--------------+------------|
| 4 | 1 |
+--------------+------------+

Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation;
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;

/**
Expand Down Expand Up @@ -82,6 +83,13 @@ private static Map<String, Object> parseInternal(Aggregation aggregation) {
resultMap.put(
aggregation.getName(),
handleNanValue(((NumericMetricsAggregation.SingleValue) aggregation).value()));
} else if (aggregation instanceof Filter) {
// parse sub-aggregations for FilterAggregation response
List<Aggregation> aggList = ((Filter) aggregation).getAggregations().asList();
aggList.forEach(internalAgg -> {
Map<String, Object> intermediateMap = parseInternal(internalAgg);
resultMap.put(internalAgg.getName(), intermediateMap.get(internalAgg.getName()));
});
} else {
throw new IllegalStateException("unsupported aggregation type " + aggregation.getType());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.INTEGER;

import com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.script.filter.FilterQueryBuilder;
import com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.serialization.ExpressionSerializer;
import com.amazon.opendistroforelasticsearch.sql.expression.Expression;
import com.amazon.opendistroforelasticsearch.sql.expression.ExpressionNodeVisitor;
Expand All @@ -29,6 +30,7 @@
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.bucket.filter.FilterAggregationBuilder;
import org.elasticsearch.search.aggregations.support.ValuesSourceAggregationBuilder;

/**
Expand All @@ -38,10 +40,12 @@ public class MetricAggregationBuilder
extends ExpressionNodeVisitor<AggregationBuilder, Object> {

private final AggregationBuilderHelper<ValuesSourceAggregationBuilder<?>> helper;
private final FilterQueryBuilder filterBuilder;

public MetricAggregationBuilder(
ExpressionSerializer serializer) {
this.helper = new AggregationBuilderHelper<>(serializer);
this.filterBuilder = new FilterQueryBuilder(serializer);
}

/**
Expand All @@ -62,28 +66,35 @@ public AggregatorFactories.Builder build(List<NamedAggregator> aggregatorList) {
public AggregationBuilder visitNamedAggregator(NamedAggregator node,
Object context) {
Expression expression = node.getArguments().get(0);
Expression condition = node.getDelegated().condition();
String name = node.getName();

switch (node.getFunctionName().getFunctionName()) {
case "avg":
return make(AggregationBuilders.avg(name), expression);
return make(AggregationBuilders.avg(name), expression, condition, name);
case "sum":
return make(AggregationBuilders.sum(name), expression);
return make(AggregationBuilders.sum(name), expression, condition, name);
case "count":
return make(AggregationBuilders.count(name), replaceStarOrLiteral(expression));
return make(
AggregationBuilders.count(name), replaceStarOrLiteral(expression), condition, name);
case "min":
return make(AggregationBuilders.min(name), expression);
return make(AggregationBuilders.min(name), expression, condition, name);
case "max":
return make(AggregationBuilders.max(name), expression);
return make(AggregationBuilders.max(name), expression, condition, name);
default:
throw new IllegalStateException(
String.format("unsupported aggregator %s", node.getFunctionName().getFunctionName()));
}
}

private ValuesSourceAggregationBuilder<?> make(ValuesSourceAggregationBuilder<?> builder,
Expression expression) {
return helper.build(expression, builder::field, builder::script);
private AggregationBuilder make(ValuesSourceAggregationBuilder<?> builder,
Expression expression, Expression condition, String name) {
ValuesSourceAggregationBuilder aggregationBuilder =
helper.build(expression, builder::field, builder::script);
if (condition != null) {
return makeFilterAggregation(aggregationBuilder, condition, name);
}
return aggregationBuilder;
}

/**
Expand All @@ -102,4 +113,18 @@ private Expression replaceStarOrLiteral(Expression countArg) {
return countArg;
}

/**
* Make builder to build FilterAggregation for aggregations with filter in the bucket.
* @param subAggBuilder AggregationBuilder instance which the filter is applied to.
* @param condition Condition expression in the filter.
* @param name Name of the FilterAggregation instance to build.
* @return {@link FilterAggregationBuilder}.
*/
private FilterAggregationBuilder makeFilterAggregation(AggregationBuilder subAggBuilder,
Expression condition, String name) {
return AggregationBuilders
.filter(name, filterBuilder.build(condition))
.subAggregation(subAggBuilder);
}

}
Loading