Skip to content

Commit

Permalink
Merge branch 'main' into support-implicit-cast-from-string-to-date-re…
Browse files Browse the repository at this point in the history
…based
  • Loading branch information
dai-chen committed Jul 30, 2021
2 parents c595612 + b3dfc49 commit 7609987
Show file tree
Hide file tree
Showing 29 changed files with 428 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,9 @@ public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext
Expression arg = node.getField().accept(this, context);
Aggregator aggregator = (Aggregator) repository.compile(
builtinFunctionName.get().getName(), Collections.singletonList(arg));
if (node.getCondition() != null) {
aggregator.condition(analyze(node.getCondition(), context));
aggregator.distinct(node.getDistinct());
if (node.condition() != null) {
aggregator.condition(analyze(node.condition(), context));
}
return aggregator;
} else {
Expand Down
11 changes: 10 additions & 1 deletion core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,16 @@ public static UnresolvedExpression aggregate(

public static UnresolvedExpression filteredAggregate(
String func, UnresolvedExpression field, UnresolvedExpression condition) {
return new AggregateFunction(func, field, condition);
return new AggregateFunction(func, field).condition(condition);
}

public static UnresolvedExpression distinctAggregate(String func, UnresolvedExpression field) {
return new AggregateFunction(func, field, true);
}

public static UnresolvedExpression filteredDistinctCount(
String func, UnresolvedExpression field, UnresolvedExpression condition) {
return new AggregateFunction(func, field, true).condition(condition);
}

public static Function function(String funcName, UnresolvedExpression... funcArgs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@

import java.util.Collections;
import java.util.List;
import javax.annotation.Nullable;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.experimental.Accessors;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.common.utils.StringUtils;

Expand All @@ -45,7 +48,10 @@ public class AggregateFunction extends UnresolvedExpression {
private final String funcName;
private final UnresolvedExpression field;
private final List<UnresolvedExpression> argList;
@Setter
@Accessors(fluent = true)
private UnresolvedExpression condition;
private Boolean distinct = false;

/**
* Constructor.
Expand All @@ -62,14 +68,13 @@ public AggregateFunction(String funcName, UnresolvedExpression field) {
* Constructor.
* @param funcName function name.
* @param field {@link UnresolvedExpression}.
* @param condition condition in aggregation filter.
* @param distinct whether distinct field is specified or not.
*/
public AggregateFunction(String funcName, UnresolvedExpression field,
UnresolvedExpression condition) {
public AggregateFunction(String funcName, UnresolvedExpression field, Boolean distinct) {
this.funcName = funcName;
this.field = field;
this.argList = Collections.emptyList();
this.condition = condition;
this.distinct = distinct;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,14 @@ public static ExprValue fromObjectValue(Object o, ExprCoreType type) {
}
}

public static Byte getByteValue(ExprValue exprValue) {
return exprValue.byteValue();
}

public static Short getShortValue(ExprValue exprValue) {
return exprValue.shortValue();
}

public static Integer getIntegerValue(ExprValue exprValue) {
return exprValue.integerValue();
}
Expand Down
4 changes: 4 additions & 0 deletions core/src/main/java/org/opensearch/sql/expression/DSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,10 @@ public Aggregator count(Expression... expressions) {
return aggregate(BuiltinFunctionName.COUNT, expressions);
}

public Aggregator distinctCount(Expression... expressions) {
return count(expressions).distinct(true);
}

public Aggregator varSamp(Expression... expressions) {
return aggregate(BuiltinFunctionName.VARSAMP, expressions);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ public abstract class Aggregator<S extends AggregationState>
@Getter
@Accessors(fluent = true)
protected Expression condition;
@Setter
@Getter
@Accessors(fluent = true)
protected Boolean distinct = false;

/**
* Create an {@link AggregationState} which will be used for aggregation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@

import static org.opensearch.sql.utils.ExpressionUtils.format;

import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprCoreType;
Expand All @@ -45,33 +47,51 @@ public CountAggregator(List<Expression> arguments, ExprCoreType returnType) {

@Override
public CountAggregator.CountState create() {
return new CountState();
return distinct ? new DistinctCountState() : new CountState();
}

@Override
protected CountState iterate(ExprValue value, CountState state) {
state.count++;
state.count(value);
return state;
}

@Override
public String toString() {
return String.format(Locale.ROOT, "count(%s)", format(getArguments()));
return distinct
? String.format(Locale.ROOT, "count(distinct %s)", format(getArguments()))
: String.format(Locale.ROOT, "count(%s)", format(getArguments()));
}

/**
* Count State.
*/
protected static class CountState implements AggregationState {
private int count;
protected int count;

CountState() {
this.count = 0;
}

public void count(ExprValue value) {
count++;
}

@Override
public ExprValue result() {
return ExprValueUtils.integerValue(count);
}
}

protected static class DistinctCountState extends CountState {
private final Set<ExprValue> distinctValues = new HashSet<>();

@Override
public void count(ExprValue value) {
if (!distinctValues.contains(value)) {
distinctValues.add(value);
count++;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ public class NamedAggregator extends Aggregator<AggregationState> {

/**
* NamedAggregator.
* The aggregator properties {@link #condition} is inherited by named aggregator
* to avoid errors introduced by the property inconsistency.
* The aggregator properties {@link #condition} and {@link #distinct}
* are inherited by named aggregator to avoid errors introduced by the property inconsistency.
*
* @param name name
* @param delegated delegated
Expand All @@ -67,6 +67,7 @@ public NamedAggregator(
this.name = name;
this.delegated = delegated;
this.condition = delegated.condition;
this.distinct = delegated.distinct;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,24 @@ public void variance_mapto_varPop() {
);
}

@Test
public void distinct_count() {
assertAnalyzeEqual(
dsl.distinctCount(DSL.ref("integer_value", INTEGER)),
AstDSL.distinctAggregate("count", qualifiedName("integer_value"))
);
}

@Test
public void filtered_distinct_count() {
assertAnalyzeEqual(
dsl.distinctCount(DSL.ref("integer_value", INTEGER))
.condition(dsl.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))),
AstDSL.filteredDistinctCount("count", qualifiedName("integer_value"), function(
">", qualifiedName("integer_value"), intLiteral(1)))
);
}

protected Expression analyze(UnresolvedExpression unresolvedExpression) {
return expressionAnalyzer.analyze(unresolvedExpression, analysisContext);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ public class ExprValueUtilsTest {
Lists.newArrayList(Iterables.concat(numberValues, nonNumberValues));

private static List<Function<ExprValue, Object>> numberValueExtractor = Arrays.asList(
ExprValue::byteValue,
ExprValue::shortValue,
ExprValueUtils::getByteValue,
ExprValueUtils::getShortValue,
ExprValueUtils::getIntegerValue,
ExprValueUtils::getLongValue,
ExprValueUtils::getFloatValue,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,29 @@ public class AggregationTest extends ExpressionTestBase {
"timestamp_value",
"2040-01-01 07:00:00")));

protected static List<ExprValue> tuples_with_duplicates =
Arrays.asList(
ExprValueUtils.tupleValue(ImmutableMap.of(
"integer_value", 1,
"double_value", 4d,
"struct_value", ImmutableMap.of("str", 1),
"array_value", ImmutableList.of(1))),
ExprValueUtils.tupleValue(ImmutableMap.of(
"integer_value", 1,
"double_value", 3d,
"struct_value", ImmutableMap.of("str", 1),
"array_value", ImmutableList.of(1))),
ExprValueUtils.tupleValue(ImmutableMap.of(
"integer_value", 2,
"double_value", 2d,
"struct_value", ImmutableMap.of("str", 2),
"array_value", ImmutableList.of(2))),
ExprValueUtils.tupleValue(ImmutableMap.of(
"integer_value", 3,
"double_value", 1d,
"struct_value", ImmutableMap.of("str1", 1),
"array_value", ImmutableList.of(1, 2))));

protected static List<ExprValue> tuples_with_null_and_missing =
Arrays.asList(
ExprValueUtils.tupleValue(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,35 @@ public void filtered_count() {
assertEquals(3, result.value());
}

@Test
public void distinct_count() {
ExprValue result = aggregation(dsl.distinctCount(DSL.ref("integer_value", INTEGER)),
tuples_with_duplicates);
assertEquals(3, result.value());
}

@Test
public void filtered_distinct_count() {
ExprValue result = aggregation(dsl.distinctCount(DSL.ref("integer_value", INTEGER))
.condition(dsl.greater(DSL.ref("double_value", DOUBLE), DSL.literal(1d))),
tuples_with_duplicates);
assertEquals(2, result.value());
}

@Test
public void distinct_count_map() {
ExprValue result = aggregation(dsl.distinctCount(DSL.ref("struct_value", STRUCT)),
tuples_with_duplicates);
assertEquals(3, result.value());
}

@Test
public void distinct_count_array() {
ExprValue result = aggregation(dsl.distinctCount(DSL.ref("array_value", ARRAY)),
tuples_with_duplicates);
assertEquals(3, result.value());
}

@Test
public void count_with_missing() {
ExprValue result = aggregation(dsl.count(DSL.ref("integer_value", INTEGER)),
Expand Down Expand Up @@ -166,6 +195,9 @@ public void valueOf() {
public void test_to_string() {
Aggregator countAggregator = dsl.count(DSL.ref("integer_value", INTEGER));
assertEquals("count(integer_value)", countAggregator.toString());

countAggregator = dsl.distinctCount(DSL.ref("integer_value", INTEGER));
assertEquals("count(distinct integer_value)", countAggregator.toString());
}

@Test
Expand Down
26 changes: 26 additions & 0 deletions docs/user/dql/aggregations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,19 @@ Example::
| 2.8613807855648994 |
+--------------------+

DISTINCT COUNT Aggregation
--------------------------

To get the count of distinct values of a field, you can add a keyword ``DISTINCT`` before the field in the count aggregation. Example::

os> SELECT COUNT(DISTINCT gender), COUNT(gender) FROM accounts;
fetched rows / total rows = 1/1
+--------------------------+-----------------+
| COUNT(DISTINCT gender) | COUNT(gender) |
|--------------------------+-----------------|
| 2 | 4 |
+--------------------------+-----------------+

HAVING Clause
=============

Expand Down Expand Up @@ -456,3 +469,16 @@ The ``FILTER`` clause can be used in aggregation functions without GROUP BY as w
| 4 | 1 |
+--------------+------------+

Distinct count aggregate with FILTER
------------------------------------

The ``FILTER`` clause is also used in distinct count to do the filtering before count the distinct values of specific field. For example::

os> SELECT COUNT(DISTINCT firstname) FILTER(WHERE age > 30) AS distinct_count FROM accounts
fetched rows / total rows = 1/1
+------------------+
| distinct_count |
|------------------|
| 3 |
+------------------+

15 changes: 15 additions & 0 deletions docs/user/ppl/cmd/stats.rst
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,18 @@ PPL query::
| 36 | 32 | M |
+------------+------------+----------+

Example 7: Calculate the distinct count of a field
==================================================

To get the count of distinct values of a field, you can use ``DISTINCT_COUNT`` (or ``DC``) function instead of ``COUNT``. The example calculates both the count and the distinct count of gender field of all the accounts.

PPL query::

os> source=accounts | stats count(gender), distinct_count(gender);
fetched rows / total rows = 1/1
+-----------------+--------------------------+
| count(gender) | distinct_count(gender) |
|-----------------+--------------------------|
| 4 | 2 |
+-----------------+--------------------------+

Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,19 @@ public void testStatsCountAll() throws IOException {
verifyDataRows(response, rows(1000));
}

@Test
public void testStatsDistinctCount() throws IOException {
JSONObject response =
executeQuery(String.format("source=%s | stats distinct_count(gender)", TEST_INDEX_ACCOUNT));
verifySchema(response, schema("distinct_count(gender)", null, "integer"));
verifyDataRows(response, rows(2));

response =
executeQuery(String.format("source=%s | stats dc(age)", TEST_INDEX_ACCOUNT));
verifySchema(response, schema("dc(age)", null, "integer"));
verifyDataRows(response, rows(21));
}

@Test
public void testStatsMin() throws IOException {
JSONObject response = executeQuery(String.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,15 @@ protected void init() throws Exception {
}

@Test
void filteredAggregateWithSubquery() throws IOException {
void filteredAggregatePushedDown() throws IOException {
JSONObject response = executeQuery(
"SELECT COUNT(*) FILTER(WHERE age > 35) FROM " + TEST_INDEX_BANK);
verifySchema(response, schema("COUNT(*)", null, "integer"));
verifyDataRows(response, rows(3));
}

@Test
void filteredAggregateNotPushedDown() throws IOException {
JSONObject response = executeQuery(
"SELECT COUNT(*) FILTER(WHERE age > 35) FROM (SELECT * FROM " + TEST_INDEX_BANK
+ ") AS a");
Expand Down
Loading

0 comments on commit 7609987

Please sign in to comment.