Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

[PPL] Support Count aggregator and OR operator #493

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Field;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Function;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Literal;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.Or;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.UnresolvedAttribute;
import com.amazon.opendistroforelasticsearch.sql.ast.expression.UnresolvedExpression;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils;
Expand Down Expand Up @@ -87,6 +88,14 @@ public Expression visitAnd(And node, AnalysisContext context) {
return dsl.and(context.peek(), left, right);
}

@Override
public Expression visitOr(Or node, AnalysisContext context) {
Expression left = node.getLeft().accept(this, context);
Expression right = node.getRight().accept(this, context);

return dsl.or(context.peek(), left, right);
}

@Override
public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext context) {
Optional<BuiltinFunctionName> builtinFunctionName = BuiltinFunctionName.of(node.getFuncName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ public static Literal nullLiteral() {
return literal(null, DataType.NULL);
}

public static UnresolvedExpression map(String origin, String target) {
public static Map map(String origin, String target) {
return new Map(new Field(origin), new Field(target));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
import com.google.common.collect.ImmutableList;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;

import java.util.List;

@ToString
@EqualsAndHashCode(callSuper = false)
@Getter
@RequiredArgsConstructor
public class Rename extends UnresolvedPlan {
private final List<Map> renameList;
private UnresolvedPlan child;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@
import com.amazon.opendistroforelasticsearch.sql.expression.env.Environment;
import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName;
import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionRepository;
import lombok.RequiredArgsConstructor;

import java.util.Arrays;
import lombok.RequiredArgsConstructor;

@RequiredArgsConstructor
public class DSL {
Expand Down Expand Up @@ -123,4 +122,9 @@ public Aggregator sum(Environment<Expression, ExprType> env, Expression... expre
return (Aggregator)
repository.compile(BuiltinFunctionName.SUM.getName(), Arrays.asList(expressions), env);
}

public Aggregator count(Environment<Expression, ExprType> env, Expression... expressions) {
return (Aggregator)
repository.compile(BuiltinFunctionName.COUNT.getName(), Arrays.asList(expressions), env);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public class AggregatorFunction {
public static void register(BuiltinFunctionRepository repository) {
repository.register(avg());
repository.register(sum());
repository.register(count());
}

private static FunctionResolver avg() {
Expand All @@ -54,6 +55,31 @@ private static FunctionResolver avg() {
);
}

private static FunctionResolver count() {
FunctionName functionName = BuiltinFunctionName.COUNT.getName();
return new FunctionResolver(
functionName,
new ImmutableMap.Builder<FunctionSignature, FunctionBuilder>()
.put(new FunctionSignature(functionName, Collections.singletonList(ExprType.INTEGER)),
dai-chen marked this conversation as resolved.
Show resolved Hide resolved
arguments -> new CountAggregator(arguments, ExprType.INTEGER))
.put(new FunctionSignature(functionName, Collections.singletonList(ExprType.LONG)),
arguments -> new CountAggregator(arguments, ExprType.INTEGER))
.put(new FunctionSignature(functionName, Collections.singletonList(ExprType.FLOAT)),
arguments -> new CountAggregator(arguments, ExprType.INTEGER))
.put(new FunctionSignature(functionName, Collections.singletonList(ExprType.DOUBLE)),
arguments -> new CountAggregator(arguments, ExprType.INTEGER))
.put(new FunctionSignature(functionName, Collections.singletonList(ExprType.STRING)),
arguments -> new CountAggregator(arguments, ExprType.INTEGER))
.put(new FunctionSignature(functionName, Collections.singletonList(ExprType.STRUCT)),
arguments -> new CountAggregator(arguments, ExprType.INTEGER))
.put(new FunctionSignature(functionName, Collections.singletonList(ExprType.ARRAY)),
arguments -> new CountAggregator(arguments, ExprType.INTEGER))
.put(new FunctionSignature(functionName, Collections.singletonList(ExprType.BOOLEAN)),
arguments -> new CountAggregator(arguments, ExprType.INTEGER))
.build()
);
}

private static FunctionResolver sum() {
FunctionName functionName = BuiltinFunctionName.SUM.getName();
return new FunctionResolver(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import static com.amazon.opendistroforelasticsearch.sql.utils.ExpressionUtils.format;

/**
* The average aggregator aggregate the value evaluated by the expression.
* The count aggregator aggregate the value evaluated by the expression.
dai-chen marked this conversation as resolved.
Show resolved Hide resolved
* If the expression evaluated result is NULL or MISSING, then the result is NULL.
*/
public class AvgAggregator extends Aggregator<AvgAggregator.AvgState> {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package com.amazon.opendistroforelasticsearch.sql.expression.aggregation;

import static com.amazon.opendistroforelasticsearch.sql.utils.ExpressionUtils.format;

import com.amazon.opendistroforelasticsearch.sql.data.model.ExprType;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils;
import com.amazon.opendistroforelasticsearch.sql.expression.Expression;
import com.amazon.opendistroforelasticsearch.sql.expression.aggregation.CountAggregator.CountState;
import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName;
import com.amazon.opendistroforelasticsearch.sql.storage.bindingtuple.BindingTuple;
import java.util.List;
import java.util.Locale;

public class CountAggregator extends Aggregator<CountState> {

public CountAggregator(List<Expression> arguments, ExprType returnType) {
super(BuiltinFunctionName.COUNT.getName(), arguments, returnType);
}

@Override
public CountAggregator.CountState create() {
return new CountState();
}

@Override
public CountState iterate(BindingTuple tuple, CountState state) {
Expression expression = getArguments().get(0);
ExprValue value = expression.valueOf(tuple);
if (!(value.isNull() || value.isMissing())) {
state.count++;
}
return state;
}

@Override
public String toString() {
return String.format(Locale.ROOT, "count(%s)", format(getArguments()));
}

/** Count State. */
protected class CountState implements AggregationState {
private int count;

public CountState() {
this.count = 0;
}

@Override
public ExprValue result() {
return ExprValueUtils.integerValue(count);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ public enum BuiltinFunctionName {

/** Aggregation Function. */
AVG(FunctionName.of("avg")),
SUM(FunctionName.of("sum"));
SUM(FunctionName.of("sum")),
COUNT(FunctionName.of("count"));

private final FunctionName name;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ public void and() {
);
}

@Test
public void or() {
assertAnalyzeEqual(
dsl.or(typeEnv, DSL.ref("boolean_value"), DSL.literal(LITERAL_TRUE)),
AstDSL.or(AstDSL.unresolvedAttr("boolean_value"), AstDSL.booleanLiteral(true))
);
}

@Test
public void undefined_var_semantic_check_failed() {
SemanticCheckException exception = assertThrows(SemanticCheckException.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,52 +15,84 @@

package com.amazon.opendistroforelasticsearch.sql.expression.aggregation;

import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.booleanValue;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.collectionValue;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.tupleValue;

import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils;
import com.amazon.opendistroforelasticsearch.sql.expression.ExpressionTestBase;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

public class AggregationTest extends ExpressionTestBase {

protected static List<ExprValue> tuples = Arrays.asList(
ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 2,
"long_value", 2L,
"string_value", "m",
"double_value", 2d,
"float_value", 2f)),
ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1,
"long_value", 1L,
"string_value", "f",
"double_value", 1d,
"float_value", 1f)),
ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 3,
"long_value", 3L,
"string_value", "m",
"double_value", 3d,
"float_value", 3f)),
ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 4,
"long_value", 4L,
"string_value", "f",
"double_value", 4d,
"float_value", 4f)));
protected static List<ExprValue> tuples =
Arrays.asList(
ExprValueUtils.tupleValue(
new ImmutableMap.Builder<String, Object>()
.put("integer_value", 2)
.put("long_value", 2L)
.put("string_value", "m")
.put("double_value", 2d)
.put("float_value", 2f)
.put("boolean_value", true)
.put("struct_value", ImmutableMap.of("str", 1))
.put("array_value", ImmutableList.of(1))
.build()),
ExprValueUtils.tupleValue(
ImmutableMap.of(
"integer_value",
1,
"long_value",
1L,
"string_value",
"f",
"double_value",
1d,
"float_value",
1f)),
ExprValueUtils.tupleValue(
ImmutableMap.of(
"integer_value",
3,
"long_value",
3L,
"string_value",
"m",
"double_value",
3d,
"float_value",
3f)),
ExprValueUtils.tupleValue(
ImmutableMap.of(
"integer_value",
4,
"long_value",
4L,
"string_value",
"f",
"double_value",
4d,
"float_value",
4f)));

protected static List<ExprValue> tuples_with_null_and_missing = Arrays.asList(
ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 2,
"string_value", "m",
"double_value", 3d)),
ExprValueUtils.tupleValue(ImmutableMap.of("integer_value", 1,
"string_value", "f",
"double_value", 4d)),
ExprValueUtils.tupleValue(Collections.singletonMap("double_value", null)));
protected static List<ExprValue> tuples_with_null_and_missing =
Arrays.asList(
ExprValueUtils.tupleValue(
ImmutableMap.of("integer_value", 2, "string_value", "m", "double_value", 3d)),
ExprValueUtils.tupleValue(
ImmutableMap.of("integer_value", 1, "string_value", "f", "double_value", 4d)),
ExprValueUtils.tupleValue(Collections.singletonMap("double_value", null)));

protected ExprValue aggregation(Aggregator aggregator, List<ExprValue> tuples) {
AggregationState state = aggregator.create();
for (ExprValue tuple : tuples) {
aggregator.iterate(tuple.bindingTuples(), state);
}
return state.result();
protected ExprValue aggregation(Aggregator aggregator, List<ExprValue> tuples) {
AggregationState state = aggregator.create();
for (ExprValue tuple : tuples) {
aggregator.iterate(tuple.bindingTuples(), state);
}
return state.result();
}
}
Loading