Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Commit

Permalink
[PPL] Support Count aggregator and OR operator (#493)
Browse files Browse the repository at this point in the history
* [PPL] Support Count aggregator and OR operator

* address comments
  • Loading branch information
penghuo authored May 28, 2020
1 parent 5ae2203 commit 1c16f2f
Show file tree
Hide file tree
Showing 12 changed files with 302 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Field;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Function;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Literal;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Or;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.UnresolvedAttribute;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.UnresolvedExpression;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils;
Expand Down Expand Up @@ -87,6 +88,14 @@ public Expression visitAnd(And node, AnalysisContext context) {
return dsl.and(context.peek(), left, right);
}

@Override
public Expression visitOr(Or node, AnalysisContext context) {
Expression left = node.getLeft().accept(this, context);
Expression right = node.getRight().accept(this, context);

return dsl.or(context.peek(), left, right);
}

@Override
public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext context) {
Optional<BuiltinFunctionName> builtinFunctionName = BuiltinFunctionName.of(node.getFuncName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ public static Literal nullLiteral() {
return literal(null, DataType.NULL);
}

public static UnresolvedExpression map(String origin, String target) {
public static Map map(String origin, String target) {
return new Map(new Field(origin), new Field(target));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
import com.google.common.collect.ImmutableList;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;

import java.util.List;

@ToString
@EqualsAndHashCode(callSuper = false)
@Getter
@RequiredArgsConstructor
public class Rename extends UnresolvedPlan {
private final List<Map> renameList;
private UnresolvedPlan child;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@
import com.amazon.opendistroforelasticsearch.sql.expression.env.Environment;
import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName;
import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionRepository;
import lombok.RequiredArgsConstructor;

import java.util.Arrays;
import lombok.RequiredArgsConstructor;

@RequiredArgsConstructor
public class DSL {
Expand Down Expand Up @@ -123,4 +122,9 @@ public Aggregator sum(Environment<Expression, ExprType> env, Expression... expre
return (Aggregator)
repository.compile(BuiltinFunctionName.SUM.getName(), Arrays.asList(expressions), env);
}

public Aggregator count(Environment<Expression, ExprType> env, Expression... expressions) {
return (Aggregator)
repository.compile(BuiltinFunctionName.COUNT.getName(), Arrays.asList(expressions), env);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public class AggregatorFunction {
public static void register(BuiltinFunctionRepository repository) {
repository.register(avg());
repository.register(sum());
repository.register(count());
}

private static FunctionResolver avg() {
Expand All @@ -54,6 +55,31 @@ private static FunctionResolver avg() {
);
}

private static FunctionResolver count() {
FunctionName functionName = BuiltinFunctionName.COUNT.getName();
return new FunctionResolver(
functionName,
new ImmutableMap.Builder<FunctionSignature, FunctionBuilder>()
.put(new FunctionSignature(functionName, Collections.singletonList(ExprType.INTEGER)),
arguments -> new CountAggregator(arguments, ExprType.INTEGER))
.put(new FunctionSignature(functionName, Collections.singletonList(ExprType.LONG)),
arguments -> new CountAggregator(arguments, ExprType.INTEGER))
.put(new FunctionSignature(functionName, Collections.singletonList(ExprType.FLOAT)),
arguments -> new CountAggregator(arguments, ExprType.INTEGER))
.put(new FunctionSignature(functionName, Collections.singletonList(ExprType.DOUBLE)),
arguments -> new CountAggregator(arguments, ExprType.INTEGER))
.put(new FunctionSignature(functionName, Collections.singletonList(ExprType.STRING)),
arguments -> new CountAggregator(arguments, ExprType.INTEGER))
.put(new FunctionSignature(functionName, Collections.singletonList(ExprType.STRUCT)),
arguments -> new CountAggregator(arguments, ExprType.INTEGER))
.put(new FunctionSignature(functionName, Collections.singletonList(ExprType.ARRAY)),
arguments -> new CountAggregator(arguments, ExprType.INTEGER))
.put(new FunctionSignature(functionName, Collections.singletonList(ExprType.BOOLEAN)),
arguments -> new CountAggregator(arguments, ExprType.INTEGER))
.build()
);
}

private static FunctionResolver sum() {
FunctionName functionName = BuiltinFunctionName.SUM.getName();
return new FunctionResolver(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package com.amazon.opendistroforelasticsearch.sql.expression.aggregation;

import static com.amazon.opendistroforelasticsearch.sql.utils.ExpressionUtils.format;

import com.amazon.opendistroforelasticsearch.sql.data.model.ExprType;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils;
import com.amazon.opendistroforelasticsearch.sql.expression.Expression;
import com.amazon.opendistroforelasticsearch.sql.expression.aggregation.CountAggregator.CountState;
import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName;
import com.amazon.opendistroforelasticsearch.sql.storage.bindingtuple.BindingTuple;
import java.util.List;
import java.util.Locale;

public class CountAggregator extends Aggregator<CountState> {

public CountAggregator(List<Expression> arguments, ExprType returnType) {
super(BuiltinFunctionName.COUNT.getName(), arguments, returnType);
}

@Override
public CountAggregator.CountState create() {
return new CountState();
}

@Override
public CountState iterate(BindingTuple tuple, CountState state) {
Expression expression = getArguments().get(0);
ExprValue value = expression.valueOf(tuple);
if (!(value.isNull() || value.isMissing())) {
state.count++;
}
return state;
}

@Override
public String toString() {
return String.format(Locale.ROOT, "count(%s)", format(getArguments()));
}

/** Count State. */
protected class CountState implements AggregationState {
private int count;

public CountState() {
this.count = 0;
}

@Override
public ExprValue result() {
return ExprValueUtils.integerValue(count);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ public enum BuiltinFunctionName {

/** Aggregation Function. */
AVG(FunctionName.of("avg")),
SUM(FunctionName.of("sum"));
SUM(FunctionName.of("sum")),
COUNT(FunctionName.of("count"));

private final FunctionName name;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ public void and() {
);
}

@Test
public void or() {
assertAnalyzeEqual(
dsl.or(typeEnv, DSL.ref("boolean_value"), DSL.literal(LITERAL_TRUE)),
AstDSL.or(AstDSL.unresolvedAttr("boolean_value"), AstDSL.booleanLiteral(true))
);
}

@Test
public void undefined_var_semantic_check_failed() {
SemanticCheckException exception = assertThrows(SemanticCheckException.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,52 +15,84 @@

package com.amazon.opendistroforelasticsearch.sql.expression.aggregation;

import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.booleanValue;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.collectionValue;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.tupleValue;

import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils;
import com.amazon.opendistroforelasticsearch.sql.expression.ExpressionTestBase;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

public class AggregationTest extends ExpressionTestBase {

protected static List<ExprValue> tuples = Arrays.asList(
ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 2,
"long_value", 2L,
"string_value", "m",
"double_value", 2d,
"float_value", 2f)),
ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1,
"long_value", 1L,
"string_value", "f",
"double_value", 1d,
"float_value", 1f)),
ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 3,
"long_value", 3L,
"string_value", "m",
"double_value", 3d,
"float_value", 3f)),
ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 4,
"long_value", 4L,
"string_value", "f",
"double_value", 4d,
"float_value", 4f)));
protected static List<ExprValue> tuples =
Arrays.asList(
ExprValueUtils.tupleValue(
new ImmutableMap.Builder<String, Object>()
.put("integer_value", 2)
.put("long_value", 2L)
.put("string_value", "m")
.put("double_value", 2d)
.put("float_value", 2f)
.put("boolean_value", true)
.put("struct_value", ImmutableMap.of("str", 1))
.put("array_value", ImmutableList.of(1))
.build()),
ExprValueUtils.tupleValue(
ImmutableMap.of(
"integer_value",
1,
"long_value",
1L,
"string_value",
"f",
"double_value",
1d,
"float_value",
1f)),
ExprValueUtils.tupleValue(
ImmutableMap.of(
"integer_value",
3,
"long_value",
3L,
"string_value",
"m",
"double_value",
3d,
"float_value",
3f)),
ExprValueUtils.tupleValue(
ImmutableMap.of(
"integer_value",
4,
"long_value",
4L,
"string_value",
"f",
"double_value",
4d,
"float_value",
4f)));

protected static List<ExprValue> tuples_with_null_and_missing = Arrays.asList(
ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 2,
"string_value", "m",
"double_value", 3d)),
ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1,
"string_value", "f",
"double_value", 4d)),
ExprValueUtils.tupleValue(Collections.singletonMap("double_value", null)));
protected static List<ExprValue> tuples_with_null_and_missing =
Arrays.asList(
ExprValueUtils.tupleValue(
ImmutableMap.of("integer_value", 2, "string_value", "m", "double_value", 3d)),
ExprValueUtils.tupleValue(
ImmutableMap.of("integer_value", 1, "string_value", "f", "double_value", 4d)),
ExprValueUtils.tupleValue(Collections.singletonMap("double_value", null)));

protected ExprValue aggregation(Aggregator aggregator, List<ExprValue> tuples) {
AggregationState state = aggregator.create();
for (ExprValue tuple : tuples) {
aggregator.iterate(tuple.bindingTuples(), state);
}
return state.result();
protected ExprValue aggregation(Aggregator aggregator, List<ExprValue> tuples) {
AggregationState state = aggregator.create();
for (ExprValue tuple : tuples) {
aggregator.iterate(tuple.bindingTuples(), state);
}
return state.result();
}
}
Loading

0 comments on commit 1c16f2f

Please sign in to comment.