From ce61f3c3d20f2ef41a21bf6d7a81b43db371ffe8 Mon Sep 17 00:00:00 2001 From: penghuo Date: Mon, 7 Jun 2021 11:33:03 -0700 Subject: [PATCH 01/11] impl variance frontend and backend --- core/build.gradle | 1 + .../sql/analysis/ExpressionAnalyzer.java | 3 +- .../org/opensearch/sql/expression/DSL.java | 8 + .../aggregation/AggregatorFunction.java | 24 +++ .../aggregation/VarianceAggregator.java | 93 +++++++++ .../function/BuiltinFunctionName.java | 22 ++ .../sql/analysis/ExpressionAnalyzerTest.java | 8 + .../aggregation/VarianceAggregatorTest.java | 190 ++++++++++++++++++ sql/src/main/antlr/OpenSearchSQLLexer.g4 | 3 + sql/src/main/antlr/OpenSearchSQLParser.g4 | 2 +- .../sql/parser/AstExpressionBuilderTest.java | 21 ++ 11 files changed, 373 insertions(+), 2 deletions(-) create mode 100644 core/src/main/java/org/opensearch/sql/expression/aggregation/VarianceAggregator.java create mode 100644 core/src/test/java/org/opensearch/sql/expression/aggregation/VarianceAggregatorTest.java diff --git a/core/build.gradle b/core/build.gradle index 69acf5cef3..1c6c0c0481 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -51,6 +51,7 @@ dependencies { compile group: 'org.springframework', name: 'spring-beans', version: '5.2.5.RELEASE' compile group: 'org.apache.commons', name: 'commons-lang3', version: '3.10' compile group: 'com.facebook.presto', name: 'presto-matching', version: '0.240' + compile group: 'org.apache.commons', name: 'commons-math3', version: '3.6.1' compile project(':common') testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java index 0f207c0374..d5c1538b77 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java @@ -155,7 +155,8 @@ public Expression visitNot(Not node, AnalysisContext context) { @Override public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext context) { - Optional builtinFunctionName = BuiltinFunctionName.of(node.getFuncName()); + Optional builtinFunctionName = + BuiltinFunctionName.ofAggregation(node.getFuncName()); if (builtinFunctionName.isPresent()) { Expression arg = node.getField().accept(this, context); Aggregator aggregator = (Aggregator) repository.compile( diff --git a/core/src/main/java/org/opensearch/sql/expression/DSL.java b/core/src/main/java/org/opensearch/sql/expression/DSL.java index 31050afc87..6af2b19742 100644 --- a/core/src/main/java/org/opensearch/sql/expression/DSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/DSL.java @@ -500,6 +500,14 @@ public Aggregator count(Expression... expressions) { return aggregate(BuiltinFunctionName.COUNT, expressions); } + public Aggregator varSamp(Expression... expressions) { + return aggregate(BuiltinFunctionName.VARSAMP, expressions); + } + + public Aggregator varPop(Expression... expressions) { + return aggregate(BuiltinFunctionName.VARPOP, expressions); + } + public RankingWindowFunction rowNumber() { return (RankingWindowFunction) repository.compile( BuiltinFunctionName.ROW_NUMBER.getName(), Collections.emptyList()); diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java index a6be7378f7..cdbb9855f3 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java @@ -68,6 +68,8 @@ public static void register(BuiltinFunctionRepository repository) { repository.register(count()); repository.register(min()); repository.register(max()); + repository.register(varSamp()); + repository.register(varPop()); } private static FunctionResolver avg() { @@ -159,4 +161,26 @@ private static FunctionResolver max() { .build() ); } + + private static FunctionResolver varSamp() { + FunctionName functionName = BuiltinFunctionName.VARSAMP.getName(); + return new FunctionResolver( + functionName, + new ImmutableMap.Builder() + .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), + arguments -> new VarianceAggregator(true, arguments, DOUBLE)) + .build() + ); + } + + private static FunctionResolver varPop() { + FunctionName functionName = BuiltinFunctionName.VARPOP.getName(); + return new FunctionResolver( + functionName, + new ImmutableMap.Builder() + .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), + arguments -> new VarianceAggregator(false, arguments, DOUBLE)) + .build() + ); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/VarianceAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/VarianceAggregator.java new file mode 100644 index 0000000000..7abfdcb987 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/VarianceAggregator.java @@ -0,0 +1,93 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.expression.aggregation; + +import static org.opensearch.sql.data.model.ExprValueUtils.doubleValue; +import static org.opensearch.sql.utils.ExpressionUtils.format; + +import java.util.ArrayList; +import java.util.List; +import org.apache.commons.math3.stat.descriptive.moment.Variance; +import org.opensearch.sql.common.utils.StringUtils; +import org.opensearch.sql.data.model.ExprNullValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.function.BuiltinFunctionName; + +/** + * Variance Aggregator. + */ +public class VarianceAggregator extends Aggregator { + + private final boolean isSampleVariance; + + /** + * VarianceAggregator constructor. + * + * @param isSampleVariance true for sample variance aggregator, false for population variance + * aggregator. + * @param arguments aggregator arguments. + * @param returnType aggregator return types. + */ + public VarianceAggregator( + Boolean isSampleVariance, List arguments, ExprCoreType returnType) { + super( + isSampleVariance + ? BuiltinFunctionName.VARSAMP.getName() + : BuiltinFunctionName.VARPOP.getName(), + arguments, + returnType); + this.isSampleVariance = isSampleVariance; + } + + @Override + public VarianceState create() { + return new VarianceState(isSampleVariance); + } + + @Override + protected VarianceState iterate(ExprValue value, VarianceState state) { + state.evaluate(value); + return state; + } + + @Override + public String toString() { + return StringUtils.format( + "%s(%s)", isSampleVariance ? "var_samp" : "var_pop", format(getArguments())); + } + + protected static class VarianceState implements AggregationState { + + private final Variance variance; + + private final List values = new ArrayList<>(); + + public VarianceState(boolean isSampleVariance) { + this.variance = new Variance(isSampleVariance); + } + + public void evaluate(ExprValue value) { + values.add(value.doubleValue()); + } + + @Override + public ExprValue result() { + return values.size() == 0 + ? ExprNullValue.of() + : doubleValue(variance.evaluate(values.stream().mapToDouble(d -> d).toArray())); + } + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 0210161abe..9c541bbe7d 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -12,6 +12,7 @@ package org.opensearch.sql.expression.function; import com.google.common.collect.ImmutableMap; +import java.util.Locale; import java.util.Map; import java.util.Optional; import lombok.Getter; @@ -126,6 +127,10 @@ public enum BuiltinFunctionName { COUNT(FunctionName.of("count")), MIN(FunctionName.of("min")), MAX(FunctionName.of("max")), + // sample variance + VARSAMP(FunctionName.of("var_samp")), + // population standard variance + VARPOP(FunctionName.of("var_pop")), /** * Text Functions. @@ -189,7 +194,24 @@ public enum BuiltinFunctionName { ALL_NATIVE_FUNCTIONS = builder.build(); } + private static final Map AGGREGATION_FUNC_MAPPING = + new ImmutableMap.Builder() + .put("max", BuiltinFunctionName.MAX) + .put("min", BuiltinFunctionName.MIN) + .put("avg", BuiltinFunctionName.AVG) + .put("count", BuiltinFunctionName.COUNT) + .put("sum", BuiltinFunctionName.SUM) + .put("var_pop", BuiltinFunctionName.VARPOP) + .put("var_samp", BuiltinFunctionName.VARSAMP) + .put("variance", BuiltinFunctionName.VARPOP) + .build(); + public static Optional of(String str) { return Optional.ofNullable(ALL_NATIVE_FUNCTIONS.getOrDefault(FunctionName.of(str), null)); } + + public static Optional ofAggregation(String functionName) { + return Optional.ofNullable( + AGGREGATION_FUNC_MAPPING.getOrDefault(functionName.toLowerCase(Locale.ROOT), null)); + } } diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java index aa8d2b12de..8cb7288273 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java @@ -292,6 +292,14 @@ public void aggregation_filter() { ); } + @Test + public void variance_mapto_varPop() { + assertAnalyzeEqual( + dsl.varPop(DSL.ref("integer_value", INTEGER)), + AstDSL.aggregate("variance", qualifiedName("integer_value")) + ); + } + protected Expression analyze(UnresolvedExpression unresolvedExpression) { return expressionAnalyzer.analyze(unresolvedExpression, analysisContext); } diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/VarianceAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/VarianceAggregatorTest.java new file mode 100644 index 0000000000..09fb8b8012 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/VarianceAggregatorTest.java @@ -0,0 +1,190 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.expression.aggregation; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.data.model.ExprValueUtils.doubleValue; +import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; +import static org.opensearch.sql.data.model.ExprValueUtils.missingValue; +import static org.opensearch.sql.data.model.ExprValueUtils.nullValue; +import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.expression.DSL.ref; + +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.storage.bindingtuple.BindingTuple; + +@ExtendWith(MockitoExtension.class) +public class VarianceAggregatorTest extends AggregationTest { + + @Mock Expression expression; + + @Mock ExprValue tupleValue; + + @Mock BindingTuple tuple; + + @Test + public void variance_sample_field_expression() { + ExprValue result = + varianceSample(integerValue(1), integerValue(2), integerValue(3), integerValue(4)); + assertEquals(1.6666666666666667, result.value()); + } + + @Test + public void variance_population_field_expression() { + ExprValue result = + variancePop(integerValue(1), integerValue(2), integerValue(3), integerValue(4)); + assertEquals(1.25, result.value()); + } + + @Test + public void variance_sample_arithmetic_expression() { + ExprValue result = + aggregation( + dsl.varSamp(dsl.multiply(ref("integer_value", INTEGER), DSL.literal(10))), tuples); + assertEquals(166.66666666666666, result.value()); + } + + @Test + public void variance_pop_arithmetic_expression() { + ExprValue result = + aggregation( + dsl.varPop(dsl.multiply(ref("integer_value", INTEGER), DSL.literal(10))), tuples); + assertEquals(125d, result.value()); + } + + @Test + public void filtered_variance_sample() { + ExprValue result = + aggregation( + dsl.varSamp(ref("integer_value", INTEGER)) + .condition(dsl.greater(ref("integer_value", INTEGER), DSL.literal(1))), + tuples); + assertEquals(1.0, result.value()); + } + + @Test + public void filtered_variance_pop() { + ExprValue result = + aggregation( + dsl.varPop(ref("integer_value", INTEGER)) + .condition(dsl.greater(ref("integer_value", INTEGER), DSL.literal(1))), + tuples); + assertEquals(0.6666666666666666, result.value()); + } + + @Test + public void variance_sample_with_missing() { + ExprValue result = varianceSample(integerValue(2), integerValue(1), missingValue()); + assertEquals(0.5, result.value()); + } + + @Test + public void variance_population_with_missing() { + ExprValue result = variancePop(integerValue(2), integerValue(1), missingValue()); + assertEquals(0.25, result.value()); + } + + @Test + public void variance_sample_with_null() { + ExprValue result = varianceSample(doubleValue(3d), doubleValue(4d), nullValue()); + assertEquals(0.5, result.value()); + } + + @Test + public void variance_pop_with_null() { + ExprValue result = variancePop(doubleValue(3d), doubleValue(4d), nullValue()); + assertEquals(0.25, result.value()); + } + + @Test + public void variance_sample_with_all_missing_or_null() { + ExprValue result = varianceSample(missingValue(), nullValue()); + assertTrue(result.isNull()); + } + + @Test + public void variance_pop_with_all_missing_or_null() { + ExprValue result = variancePop(missingValue(), nullValue()); + assertTrue(result.isNull()); + } + + @Test + public void valueOf() { + ExpressionEvaluationException exception = + assertThrows( + ExpressionEvaluationException.class, + () -> dsl.avg(ref("double_value", DOUBLE)).valueOf(valueEnv())); + assertEquals("can't evaluate on aggregator: avg", exception.getMessage()); + } + + @Test + public void variance_sample_to_string() { + Aggregator avgAggregator = dsl.varSamp(ref("integer_value", INTEGER)); + assertEquals("var_samp(integer_value)", avgAggregator.toString()); + } + + @Test + public void variance_pop_to_string() { + Aggregator avgAggregator = dsl.varPop(ref("integer_value", INTEGER)); + assertEquals("var_pop(integer_value)", avgAggregator.toString()); + } + + @Test + public void variance_sample_nested_to_string() { + Aggregator avgAggregator = + dsl.varSamp( + dsl.multiply( + ref("integer_value", INTEGER), DSL.literal(ExprValueUtils.integerValue(10)))); + assertEquals( + String.format("var_samp(*(%s, %d))", ref("integer_value", INTEGER), 10), + avgAggregator.toString()); + } + + private ExprValue varianceSample(ExprValue value, ExprValue... values) { + when(expression.valueOf(any())).thenReturn(value, values); + when(expression.type()).thenReturn(DOUBLE); + return aggregation(dsl.varSamp(expression), mockTuples(value, values)); + } + + private ExprValue variancePop(ExprValue value, ExprValue... values) { + when(expression.valueOf(any())).thenReturn(value, values); + when(expression.type()).thenReturn(DOUBLE); + return aggregation(dsl.varPop(expression), mockTuples(value, values)); + } + + private List mockTuples(ExprValue value, ExprValue... values) { + List mockTuples = new ArrayList<>(); + when(tupleValue.bindingTuples()).thenReturn(tuple); + mockTuples.add(tupleValue); + for (ExprValue exprValue : values) { + mockTuples.add(tupleValue); + } + return mockTuples; + } +} diff --git a/sql/src/main/antlr/OpenSearchSQLLexer.g4 b/sql/src/main/antlr/OpenSearchSQLLexer.g4 index 94f8e7c87a..828f9709ca 100644 --- a/sql/src/main/antlr/OpenSearchSQLLexer.g4 +++ b/sql/src/main/antlr/OpenSearchSQLLexer.g4 @@ -126,6 +126,9 @@ COUNT: 'COUNT'; MAX: 'MAX'; MIN: 'MIN'; SUM: 'SUM'; +VAR_POP: 'VAR_POP'; +VAR_SAMP: 'VAR_SAMP'; +VARIANCE: 'VARIANCE'; // Common function Keywords diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index 0ad08781bf..92144abb54 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -345,7 +345,7 @@ filterClause ; aggregationFunctionName - : AVG | COUNT | SUM | MIN | MAX + : AVG | COUNT | SUM | MIN | MAX | VAR_POP | VAR_SAMP | VARIANCE ; mathematicalFunctionName diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java index a3c8494e7a..e4e8028f05 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java @@ -410,6 +410,27 @@ public void filteredAggregation() { ); } + @Test + public void canBuildVarSamp() { + assertEquals( + aggregate("var_samp", qualifiedName("age")), + buildExprAst("var_samp(age)")); + } + + @Test + public void canBuildVarPop() { + assertEquals( + aggregate("var_pop", qualifiedName("age")), + buildExprAst("var_pop(age)")); + } + + @Test + public void canBuildVariance() { + assertEquals( + aggregate("variance", qualifiedName("age")), + buildExprAst("variance(age)")); + } + private Node buildExprAst(String expr) { OpenSearchSQLLexer lexer = new OpenSearchSQLLexer(new CaseInsensitiveCharStream(expr)); OpenSearchSQLParser parser = new OpenSearchSQLParser(new CommonTokenStream(lexer)); From fae8138e27939a045be7955c7795c7928303ac4b Mon Sep 17 00:00:00 2001 From: penghuo Date: Fri, 4 Jun 2021 15:43:49 -0700 Subject: [PATCH 02/11] Support construct AggregationResponseParser during Aggregator build stage --- .../value/OpenSearchExprValueFactory.java | 16 +- .../OpenSearchAggregationResponseParser.java | 114 ----------- .../response/OpenSearchResponse.java | 2 +- .../agg/CompositeAggregationParser.java | 51 +++++ .../opensearch/response/agg/FilterParser.java | 38 ++++ .../opensearch/response/agg/MetricParser.java | 36 ++++ .../response/agg/MetricParserHelper.java | 56 +++++ .../agg/NoBucketAggregationParser.java | 41 ++++ .../OpenSearchAggregationResponseParser.java | 31 +++ .../response/agg/SingleValueParser.java | 39 ++++ .../opensearch/response/agg/StatsParser.java | 41 ++++ .../sql/opensearch/response/agg/Utils.java | 27 +++ .../opensearch/storage/OpenSearchIndex.java | 4 +- .../storage/OpenSearchIndexScan.java | 10 +- .../aggregation/AggregationQueryBuilder.java | 47 +++-- .../dsl/MetricAggregationBuilder.java | 93 ++++++--- .../response/AggregationResponseUtils.java | 4 + ...enSearchAggregationResponseParserTest.java | 192 ++++++++++++------ .../response/OpenSearchResponseTest.java | 42 ++-- .../AggregationQueryBuilderTest.java | 17 +- .../dsl/MetricAggregationBuilderTest.java | 2 +- 21 files changed, 650 insertions(+), 253 deletions(-) delete mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchAggregationResponseParser.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/CompositeAggregationParser.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/FilterParser.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParser.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/NoBucketAggregationParser.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/OpenSearchAggregationResponseParser.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/SingleValueParser.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/StatsParser.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/Utils.java diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java index 313347aec1..001363b476 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java @@ -63,7 +63,7 @@ import java.util.List; import java.util.Map; import java.util.function.Function; -import lombok.AllArgsConstructor; +import lombok.Getter; import lombok.Setter; import org.opensearch.common.time.DateFormatters; import org.opensearch.sql.data.model.ExprBooleanValue; @@ -86,11 +86,11 @@ import org.opensearch.sql.opensearch.data.utils.Content; import org.opensearch.sql.opensearch.data.utils.ObjectContent; import org.opensearch.sql.opensearch.data.utils.OpenSearchJsonContent; +import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; /** * Construct ExprValue from OpenSearch response. */ -@AllArgsConstructor public class OpenSearchExprValueFactory { /** * The Mapping of Field and ExprType. @@ -98,6 +98,10 @@ public class OpenSearchExprValueFactory { @Setter private Map typeMapping; + @Getter + @Setter + private OpenSearchAggregationResponseParser parser; + private static final DateTimeFormatter DATE_TIME_FORMATTER = new DateTimeFormatterBuilder() .appendOptional(SQL_LITERAL_DATE_TIME_FORMAT) @@ -131,6 +135,14 @@ public class OpenSearchExprValueFactory { .put(OPENSEARCH_BINARY, c -> new OpenSearchExprBinaryValue(c.stringValue())) .build(); + /** + * Constructor of OpenSearchExprValueFactory. + */ + public OpenSearchExprValueFactory( + Map typeMapping) { + this.typeMapping = typeMapping; + } + /** * The struct construction has the following assumption. 1. The field has OpenSearch Object * data type. https://www.elastic.co/guide/en/elasticsearch/reference/current/object.html 2. The diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchAggregationResponseParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchAggregationResponseParser.java deleted file mode 100644 index bb029cddb0..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchAggregationResponseParser.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -/* - * - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - * - */ - -package org.opensearch.sql.opensearch.response; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import lombok.experimental.UtilityClass; -import org.opensearch.search.aggregations.Aggregation; -import org.opensearch.search.aggregations.Aggregations; -import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; -import org.opensearch.search.aggregations.bucket.filter.Filter; -import org.opensearch.search.aggregations.metrics.NumericMetricsAggregation; - -/** - * AggregationResponseParser. - */ -@UtilityClass -public class OpenSearchAggregationResponseParser { - - /** - * Parse Aggregations as a list of field and value map. - * - * @param aggregations aggregations - * @return a list of field and value map - */ - public static List> parse(Aggregations aggregations) { - List aggregationList = aggregations.asList(); - ImmutableList.Builder> builder = new ImmutableList.Builder<>(); - Map noBucketMap = new HashMap<>(); - - for (Aggregation aggregation : aggregationList) { - if (aggregation instanceof CompositeAggregation) { - for (CompositeAggregation.Bucket bucket : - ((CompositeAggregation) aggregation).getBuckets()) { - builder.add(parse(bucket)); - } - } else { - noBucketMap.putAll(parseInternal(aggregation)); - } - - } - // Todo, there is no better way to difference the with/without bucket from aggregations result. - return noBucketMap.isEmpty() ? builder.build() : Collections.singletonList(noBucketMap); - } - - private static Map parse(CompositeAggregation.Bucket bucket) { - Map resultMap = new HashMap<>(); - // The NodeClient return InternalComposite - - // build pair - resultMap.putAll(bucket.getKey()); - - // build pair - for (Aggregation aggregation : bucket.getAggregations()) { - resultMap.putAll(parseInternal(aggregation)); - } - - return resultMap; - } - - private static Map parseInternal(Aggregation aggregation) { - Map resultMap = new HashMap<>(); - if (aggregation instanceof NumericMetricsAggregation.SingleValue) { - resultMap.put( - aggregation.getName(), - handleNanValue(((NumericMetricsAggregation.SingleValue) aggregation).value())); - } else if (aggregation instanceof Filter) { - // parse sub-aggregations for FilterAggregation response - List aggList = ((Filter) aggregation).getAggregations().asList(); - aggList.forEach(internalAgg -> { - Map intermediateMap = parseInternal(internalAgg); - resultMap.put(internalAgg.getName(), intermediateMap.get(internalAgg.getName())); - }); - } else { - throw new IllegalStateException("unsupported aggregation type " + aggregation.getType()); - } - return resultMap; - } - - @VisibleForTesting - protected static Object handleNanValue(double value) { - return Double.isNaN(value) ? null : value; - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java index fc7421aec3..156490d93a 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java @@ -103,7 +103,7 @@ public boolean isAggregationResponse() { */ public Iterator iterator() { if (isAggregationResponse()) { - return OpenSearchAggregationResponseParser.parse(aggregations).stream().map(entry -> { + return exprValueFactory.getParser().parse(aggregations).stream().map(entry -> { ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); for (Map.Entry value : entry.entrySet()) { builder.put(value.getKey(), exprValueFactory.construct(value.getKey(), value.getValue())); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/CompositeAggregationParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/CompositeAggregationParser.java new file mode 100644 index 0000000000..00e8a5154c --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/CompositeAggregationParser.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.opensearch.response.agg; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; + +/** + * Composite Aggregation Parser which include composite aggregation and metric parsers. + */ +public class CompositeAggregationParser implements OpenSearchAggregationResponseParser { + + private final MetricParserHelper metricsParser; + + public CompositeAggregationParser(MetricParser... metricParserList) { + metricsParser = new MetricParserHelper(Arrays.asList(metricParserList)); + } + + public CompositeAggregationParser(List metricParserList) { + metricsParser = new MetricParserHelper(metricParserList); + } + + @Override + public List> parse(Aggregations aggregations) { + return ((CompositeAggregation) aggregations.asList().get(0)) + .getBuckets().stream().map(this::parse).collect(Collectors.toList()); + } + + private Map parse(CompositeAggregation.Bucket bucket) { + Map resultMap = new HashMap<>(); + resultMap.putAll(bucket.getKey()); + resultMap.putAll(metricsParser.parse(bucket.getAggregations())); + return resultMap; + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/FilterParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/FilterParser.java new file mode 100644 index 0000000000..cfcba82c18 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/FilterParser.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.opensearch.response.agg; + +import java.util.Map; +import lombok.Builder; +import lombok.Getter; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.bucket.filter.Filter; + +/** + * {@link Filter} Parser. + * The current use case is filter aggregation, e.g. avg(age) filter(balance>0). The filter parser + * do nothing and return the result from metricsParser. + */ +@Builder +public class FilterParser implements MetricParser { + + private final MetricParser metricsParser; + + @Getter private final String name; + + @Override + public Map parse(Aggregation aggregations) { + return metricsParser.parse(((Filter) aggregations).getAggregations().asList().get(0)); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParser.java new file mode 100644 index 0000000000..15f05e5b05 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParser.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.opensearch.response.agg; + +import java.util.Map; +import org.opensearch.search.aggregations.Aggregation; + +/** + * Metric Aggregation Parser. + */ +public interface MetricParser { + + /** + * Get the name of metric parser. + */ + String getName(); + + /** + * Parse the {@link Aggregation}. + * + * @param aggregation {@link Aggregation} + * @return the map between metric name and metric value. + */ + Map parse(Aggregation aggregation); +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java new file mode 100644 index 0000000000..54b9305f49 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.opensearch.response.agg; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import lombok.RequiredArgsConstructor; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.sql.common.utils.StringUtils; + +/** + * Parse multiple metrics in one bucket. + */ +@RequiredArgsConstructor +public class MetricParserHelper { + + private final Map metricParserMap; + + public MetricParserHelper(List metricParserList) { + metricParserMap = + metricParserList.stream().collect(Collectors.toMap(MetricParser::getName, m -> m)); + } + + /** + * Parse {@link Aggregations}. + * + * @param aggregations {@link Aggregations} + * @return the map between metric name and metric value. + */ + public Map parse(Aggregations aggregations) { + Map resultMap = new HashMap<>(); + for (Aggregation aggregation : aggregations) { + if (metricParserMap.containsKey(aggregation.getName())) { + resultMap.putAll(metricParserMap.get(aggregation.getName()).parse(aggregation)); + } else { + throw new RuntimeException(StringUtils.format("couldn't parse field %s in aggregation " + + "response", aggregation.getName())); + } + } + return resultMap; + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/NoBucketAggregationParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/NoBucketAggregationParser.java new file mode 100644 index 0000000000..5756003523 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/NoBucketAggregationParser.java @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.opensearch.response.agg; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.opensearch.search.aggregations.Aggregations; + +/** + * No Bucket Aggregation Parser which include only metric parsers. + */ +public class NoBucketAggregationParser implements OpenSearchAggregationResponseParser { + + private final MetricParserHelper metricsParser; + + public NoBucketAggregationParser(MetricParser... metricParserList) { + metricsParser = new MetricParserHelper(Arrays.asList(metricParserList)); + } + + public NoBucketAggregationParser(List metricParserList) { + metricsParser = new MetricParserHelper(metricParserList); + } + + @Override + public List> parse(Aggregations aggregations) { + return Collections.singletonList(metricsParser.parse(aggregations)); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/OpenSearchAggregationResponseParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/OpenSearchAggregationResponseParser.java new file mode 100644 index 0000000000..3a19747ef3 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/OpenSearchAggregationResponseParser.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.opensearch.response.agg; + +import java.util.List; +import java.util.Map; +import org.opensearch.search.aggregations.Aggregations; + +/** + * OpenSearch Aggregation Response Parser. + */ +public interface OpenSearchAggregationResponseParser { + + /** + * Parse the OpenSearch Aggregation Response. + * @param aggregations Aggregations. + * @return aggregation result. + */ + List> parse(Aggregations aggregations); +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/SingleValueParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/SingleValueParser.java new file mode 100644 index 0000000000..7536a24661 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/SingleValueParser.java @@ -0,0 +1,39 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.opensearch.response.agg; + +import static org.opensearch.sql.opensearch.response.agg.Utils.handleNanValue; + +import java.util.Collections; +import java.util.Map; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.metrics.NumericMetricsAggregation; + +/** + * {@link NumericMetricsAggregation.SingleValue} metric parser. + */ +@RequiredArgsConstructor +public class SingleValueParser implements MetricParser { + + @Getter private final String name; + + @Override + public Map parse(Aggregation agg) { + return Collections.singletonMap( + agg.getName(), + handleNanValue(((NumericMetricsAggregation.SingleValue) agg).value())); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/StatsParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/StatsParser.java new file mode 100644 index 0000000000..6cac2fbdc9 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/StatsParser.java @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.opensearch.response.agg; + +import static org.opensearch.sql.opensearch.response.agg.Utils.handleNanValue; + +import java.util.Collections; +import java.util.Map; +import java.util.function.Function; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.metrics.ExtendedStats; + +/** + * {@link ExtendedStats} metric parser. + */ +@RequiredArgsConstructor +public class StatsParser implements MetricParser { + + private final Function valueExtractor; + + @Getter private final String name; + + @Override + public Map parse(Aggregation agg) { + return Collections.singletonMap( + agg.getName(), handleNanValue(valueExtractor.apply((ExtendedStats) agg))); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/Utils.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/Utils.java new file mode 100644 index 0000000000..28b9d41e83 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/Utils.java @@ -0,0 +1,27 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.opensearch.response.agg; + +import lombok.experimental.UtilityClass; + +@UtilityClass +public class Utils { + /** + * Utils to handle Nan Value. + * @return null if is Nan. + */ + public static Object handleNanValue(double value) { + return Double.isNaN(value) ? null : value; + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java index 74e966637f..0198abe7a1 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java @@ -32,6 +32,7 @@ import java.util.Map; import java.util.stream.Collectors; import lombok.RequiredArgsConstructor; +import org.apache.commons.lang3.tuple.Pair; import org.opensearch.index.query.QueryBuilder; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.sql.common.setting.Settings; @@ -43,6 +44,7 @@ import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalPlanOptimizerFactory; import org.opensearch.sql.opensearch.request.system.OpenSearchDescribeIndexRequest; +import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; import org.opensearch.sql.opensearch.storage.script.aggregation.AggregationQueryBuilder; import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder; import org.opensearch.sql.opensearch.storage.script.sort.SortQueryBuilder; @@ -163,7 +165,7 @@ public PhysicalPlan visitIndexAggregation(OpenSearchLogicalIndexAgg node, } AggregationQueryBuilder builder = new AggregationQueryBuilder(new DefaultExpressionSerializer()); - List aggregationBuilder = + Pair, OpenSearchAggregationResponseParser> aggregationBuilder = builder.buildAggregationBuilder(node.getAggregatorList(), node.getGroupByList(), node.getSortList()); context.pushDownAggregation(aggregationBuilder); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScan.java index 99b11c21a4..57980f23b9 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScan.java @@ -40,6 +40,7 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; +import org.apache.commons.lang3.tuple.Pair; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; @@ -55,6 +56,7 @@ import org.opensearch.sql.opensearch.request.OpenSearchQueryRequest; import org.opensearch.sql.opensearch.request.OpenSearchRequest; import org.opensearch.sql.opensearch.response.OpenSearchResponse; +import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; import org.opensearch.sql.storage.TableScanOperator; /** @@ -138,12 +140,14 @@ public void pushDown(QueryBuilder query) { /** * Push down aggregation to DSL request. - * @param aggregationBuilderList aggregation query. + * @param aggregationBuilder pair of aggregation query and aggregation parser. */ - public void pushDownAggregation(List aggregationBuilderList) { + public void pushDownAggregation( + Pair, OpenSearchAggregationResponseParser> aggregationBuilder) { SearchSourceBuilder source = request.getSourceBuilder(); - aggregationBuilderList.forEach(aggregationBuilder -> source.aggregation(aggregationBuilder)); + aggregationBuilder.getLeft().forEach(builder -> source.aggregation(builder)); source.size(0); + request.getExprValueFactory().setParser(aggregationBuilder.getRight()); } /** diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilder.java index a89ba042ee..403f99e593 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilder.java @@ -42,6 +42,7 @@ import org.apache.commons.lang3.tuple.Pair; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.AggregatorFactories; import org.opensearch.search.sort.SortOrder; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.data.type.ExprType; @@ -50,6 +51,10 @@ import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.aggregation.NamedAggregator; +import org.opensearch.sql.opensearch.response.agg.CompositeAggregationParser; +import org.opensearch.sql.opensearch.response.agg.MetricParser; +import org.opensearch.sql.opensearch.response.agg.NoBucketAggregationParser; +import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; import org.opensearch.sql.opensearch.storage.script.aggregation.dsl.BucketAggregationBuilder; import org.opensearch.sql.opensearch.storage.script.aggregation.dsl.MetricAggregationBuilder; import org.opensearch.sql.opensearch.storage.serialization.ExpressionSerializer; @@ -82,25 +87,35 @@ public AggregationQueryBuilder( this.metricBuilder = new MetricAggregationBuilder(serializer); } - /** - * Build AggregationBuilder. - */ - public List buildAggregationBuilder( - List namedAggregatorList, - List groupByList, - List> sortList) { + /** Build AggregationBuilder. */ + public Pair, OpenSearchAggregationResponseParser> + buildAggregationBuilder( + List namedAggregatorList, + List groupByList, + List> sortList) { + + final Pair> metrics = + metricBuilder.build(namedAggregatorList); + if (groupByList.isEmpty()) { // no bucket - return ImmutableList - .copyOf(metricBuilder.build(namedAggregatorList).getAggregatorFactories()); + return Pair.of( + ImmutableList.copyOf(metrics.getLeft().getAggregatorFactories()), + new NoBucketAggregationParser(metrics.getRight())); } else { - final GroupSortOrder groupSortOrder = new GroupSortOrder(sortList); - return Collections.singletonList(AggregationBuilders.composite("composite_buckets", - bucketBuilder - .build(groupByList.stream().sorted(groupSortOrder).map(expr -> Pair.of(expr, - groupSortOrder.apply(expr))).collect(Collectors.toList()))) - .subAggregations(metricBuilder.build(namedAggregatorList)) - .size(AGGREGATION_BUCKET_SIZE)); + GroupSortOrder groupSortOrder = new GroupSortOrder(sortList); + return Pair.of( + Collections.singletonList( + AggregationBuilders.composite( + "composite_buckets", + bucketBuilder.build( + groupByList.stream() + .sorted(groupSortOrder) + .map(expr -> Pair.of(expr, groupSortOrder.apply(expr))) + .collect(Collectors.toList()))) + .subAggregations(metrics.getLeft()) + .size(AGGREGATION_BUCKET_SIZE)), + new CompositeAggregationParser(metrics.getRight())); } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index f3807ae662..0dbfec02c1 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -30,7 +30,9 @@ import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import java.util.ArrayList; import java.util.List; +import org.apache.commons.lang3.tuple.Pair; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.AggregatorFactories; @@ -41,20 +43,22 @@ import org.opensearch.sql.expression.LiteralExpression; import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.aggregation.NamedAggregator; +import org.opensearch.sql.opensearch.response.agg.FilterParser; +import org.opensearch.sql.opensearch.response.agg.MetricParser; +import org.opensearch.sql.opensearch.response.agg.SingleValueParser; import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder; import org.opensearch.sql.opensearch.storage.serialization.ExpressionSerializer; /** - * Build the Metric Aggregation from {@link NamedAggregator}. + * Build the Metric Aggregation and List of {@link MetricParser} from {@link NamedAggregator}. */ public class MetricAggregationBuilder - extends ExpressionNodeVisitor { + extends ExpressionNodeVisitor, Object> { private final AggregationBuilderHelper> helper; private final FilterQueryBuilder filterBuilder; - public MetricAggregationBuilder( - ExpressionSerializer serializer) { + public MetricAggregationBuilder(ExpressionSerializer serializer) { this.helper = new AggregationBuilderHelper<>(serializer); this.filterBuilder = new FilterQueryBuilder(serializer); } @@ -65,55 +69,89 @@ public MetricAggregationBuilder( * @param aggregatorList aggregator list * @return AggregatorFactories.Builder */ - public AggregatorFactories.Builder build(List aggregatorList) { + public Pair> build( + List aggregatorList) { AggregatorFactories.Builder builder = new AggregatorFactories.Builder(); + List metricParserList = new ArrayList<>(); for (NamedAggregator aggregator : aggregatorList) { - builder.addAggregator(aggregator.accept(this, null)); + Pair pair = aggregator.accept(this, null); + builder.addAggregator(pair.getLeft()); + metricParserList.add(pair.getRight()); } - return builder; + return Pair.of(builder, metricParserList); } @Override - public AggregationBuilder visitNamedAggregator(NamedAggregator node, - Object context) { + public Pair visitNamedAggregator( + NamedAggregator node, Object context) { Expression expression = node.getArguments().get(0); Expression condition = node.getDelegated().condition(); String name = node.getName(); switch (node.getFunctionName().getFunctionName()) { case "avg": - return make(AggregationBuilders.avg(name), expression, condition, name); + return make( + AggregationBuilders.avg(name), + expression, + condition, + name, + new SingleValueParser(name)); case "sum": - return make(AggregationBuilders.sum(name), expression, condition, name); + return make( + AggregationBuilders.sum(name), + expression, + condition, + name, + new SingleValueParser(name)); case "count": return make( - AggregationBuilders.count(name), replaceStarOrLiteral(expression), condition, name); + AggregationBuilders.count(name), + replaceStarOrLiteral(expression), + condition, + name, + new SingleValueParser(name)); case "min": - return make(AggregationBuilders.min(name), expression, condition, name); + return make( + AggregationBuilders.min(name), + expression, + condition, + name, + new SingleValueParser(name)); case "max": - return make(AggregationBuilders.max(name), expression, condition, name); + return make( + AggregationBuilders.max(name), + expression, + condition, + name, + new SingleValueParser(name)); default: throw new IllegalStateException( String.format("unsupported aggregator %s", node.getFunctionName().getFunctionName())); } } - private AggregationBuilder make(ValuesSourceAggregationBuilder builder, - Expression expression, Expression condition, String name) { + private Pair make( + ValuesSourceAggregationBuilder builder, + Expression expression, + Expression condition, + String name, + MetricParser parser) { ValuesSourceAggregationBuilder aggregationBuilder = helper.build(expression, builder::field, builder::script); if (condition != null) { - return makeFilterAggregation(aggregationBuilder, condition, name); + return Pair.of( + makeFilterAggregation(aggregationBuilder, condition, name), + FilterParser.builder().name(name).metricsParser(parser).build()); } - return aggregationBuilder; + return Pair.of(aggregationBuilder, parser); } /** - * Replace star or literal with OpenSearch metadata field "_index". Because: - * 1) Analyzer already converts * to string literal, literal check here can handle - * both COUNT(*) and COUNT(1). - * 2) Value count aggregation on _index counts all docs (after filter), therefore - * it has same semantics as COUNT(*) or COUNT(1). + * Replace star or literal with OpenSearch metadata field "_index". Because: 1) Analyzer already + * converts * to string literal, literal check here can handle both COUNT(*) and COUNT(1). 2) + * Value count aggregation on _index counts all docs (after filter), therefore it has same + * semantics as COUNT(*) or COUNT(1). + * * @param countArg count function argument * @return Reference to _index if literal, otherwise return original argument expression */ @@ -126,16 +164,15 @@ private Expression replaceStarOrLiteral(Expression countArg) { /** * Make builder to build FilterAggregation for aggregations with filter in the bucket. + * * @param subAggBuilder AggregationBuilder instance which the filter is applied to. * @param condition Condition expression in the filter. * @param name Name of the FilterAggregation instance to build. * @return {@link FilterAggregationBuilder}. */ - private FilterAggregationBuilder makeFilterAggregation(AggregationBuilder subAggBuilder, - Expression condition, String name) { - return AggregationBuilders - .filter(name, filterBuilder.build(condition)) + private FilterAggregationBuilder makeFilterAggregation( + AggregationBuilder subAggBuilder, Expression condition, String name) { + return AggregationBuilders.filter(name, filterBuilder.build(condition)) .subAggregation(subAggBuilder); } - } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/AggregationResponseUtils.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/AggregationResponseUtils.java index c8ef830635..173b33575c 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/AggregationResponseUtils.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/AggregationResponseUtils.java @@ -55,9 +55,11 @@ import org.opensearch.search.aggregations.bucket.terms.ParsedStringTerms; import org.opensearch.search.aggregations.bucket.terms.StringTerms; import org.opensearch.search.aggregations.metrics.AvgAggregationBuilder; +import org.opensearch.search.aggregations.metrics.ExtendedStatsAggregationBuilder; import org.opensearch.search.aggregations.metrics.MaxAggregationBuilder; import org.opensearch.search.aggregations.metrics.MinAggregationBuilder; import org.opensearch.search.aggregations.metrics.ParsedAvg; +import org.opensearch.search.aggregations.metrics.ParsedExtendedStats; import org.opensearch.search.aggregations.metrics.ParsedMax; import org.opensearch.search.aggregations.metrics.ParsedMin; import org.opensearch.search.aggregations.metrics.ParsedSum; @@ -74,6 +76,8 @@ public class AggregationResponseUtils { .put(MaxAggregationBuilder.NAME, (p, c) -> ParsedMax.fromXContent(p, (String) c)) .put(SumAggregationBuilder.NAME, (p, c) -> ParsedSum.fromXContent(p, (String) c)) .put(AvgAggregationBuilder.NAME, (p, c) -> ParsedAvg.fromXContent(p, (String) c)) + .put(ExtendedStatsAggregationBuilder.NAME, + (p, c) -> ParsedExtendedStats.fromXContent(p, (String) c)) .put(StringTerms.NAME, (p, c) -> ParsedStringTerms.fromXContent(p, (String) c)) .put(LongTerms.NAME, (p, c) -> ParsedLongTerms.fromXContent(p, (String) c)) .put(DoubleTerms.NAME, (p, c) -> ParsedDoubleTerms.fromXContent(p, (String) c)) diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchAggregationResponseParserTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchAggregationResponseParserTest.java index b49bec4d44..120d48b601 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchAggregationResponseParserTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchAggregationResponseParserTest.java @@ -34,6 +34,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.opensearch.sql.opensearch.response.AggregationResponseUtils.fromJson; +import static org.opensearch.sql.opensearch.response.agg.Utils.handleNanValue; import com.google.common.collect.ImmutableMap; import java.util.List; @@ -41,6 +43,13 @@ import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; +import org.opensearch.search.aggregations.metrics.ExtendedStats; +import org.opensearch.sql.opensearch.response.agg.CompositeAggregationParser; +import org.opensearch.sql.opensearch.response.agg.FilterParser; +import org.opensearch.sql.opensearch.response.agg.NoBucketAggregationParser; +import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; +import org.opensearch.sql.opensearch.response.agg.SingleValueParser; +import org.opensearch.sql.opensearch.response.agg.StatsParser; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class OpenSearchAggregationResponseParserTest { @@ -55,7 +64,10 @@ void no_bucket_one_metric_should_pass() { + " \"value\": 40\n" + " }\n" + "}"; - assertThat(parse(response), contains(entry("max", 40d))); + NoBucketAggregationParser parser = new NoBucketAggregationParser( + new SingleValueParser("max") + ); + assertThat(parse(parser, response), contains(entry("max", 40d))); } /** @@ -71,7 +83,11 @@ void no_bucket_two_metric_should_pass() { + " \"value\": 20\n" + " }\n" + "}"; - assertThat(parse(response), + NoBucketAggregationParser parser = new NoBucketAggregationParser( + new SingleValueParser("max"), + new SingleValueParser("min") + ); + assertThat(parse(parser, response), contains(entry("max", 40d,"min", 20d))); } @@ -104,7 +120,10 @@ void one_bucket_one_metric_should_pass() { + " ]\n" + " }\n" + "}"; - assertThat(parse(response), + + OpenSearchAggregationResponseParser parser = new CompositeAggregationParser( + new SingleValueParser("avg")); + assertThat(parse(parser, response), containsInAnyOrder(ImmutableMap.of("type", "cost", "avg", 20d), ImmutableMap.of("type", "sale", "avg", 105d))); } @@ -139,7 +158,9 @@ void two_bucket_one_metric_should_pass() { + " ]\n" + " }\n" + "}"; - assertThat(parse(response), + OpenSearchAggregationResponseParser parser = new CompositeAggregationParser( + new SingleValueParser("avg")); + assertThat(parse(parser, response), containsInAnyOrder(ImmutableMap.of("type", "cost", "region", "us", "avg", 20d), ImmutableMap.of("type", "sale", "region", "uk", "avg", 130d))); } @@ -147,81 +168,132 @@ void two_bucket_one_metric_should_pass() { @Test void unsupported_aggregation_should_fail() { String response = "{\n" - + " \"date_histogram#max\": {\n" + + " \"date_histogram#date_histogram\": {\n" + " \"value\": 40\n" + " }\n" + "}"; - IllegalStateException exception = - assertThrows(IllegalStateException.class, () -> parse(response)); - assertEquals("unsupported aggregation type date_histogram", exception.getMessage()); + NoBucketAggregationParser parser = new NoBucketAggregationParser( + new SingleValueParser("max") + ); + RuntimeException exception = + assertThrows(RuntimeException.class, () -> parse(parser, response)); + assertEquals( + "couldn't parse field date_histogram in aggregation response", exception.getMessage()); } @Test void nan_value_should_return_null() { - assertNull(OpenSearchAggregationResponseParser.handleNanValue(Double.NaN)); + assertNull(handleNanValue(Double.NaN)); } - /** - * SELECT AVG(age) FILTER(WHERE age > 37) as filtered FROM accounts. - */ @Test void filter_aggregation_should_pass() { - String response = "{\n" - + " \"filter#filtered\" : {\n" - + " \"doc_count\" : 3,\n" - + " \"avg#filtered\" : {\n" - + " \"value\" : 37.0\n" - + " }\n" - + " }\n" - + " }"; - assertThat(parse(response), contains(entry("filtered", 37.0))); + String response = "{\n" + + " \"filter#filtered\" : {\n" + + " \"doc_count\" : 3,\n" + + " \"avg#filtered\" : {\n" + + " \"value\" : 37.0\n" + + " }\n" + + " }\n" + + " }"; + OpenSearchAggregationResponseParser parser = + new NoBucketAggregationParser( + FilterParser.builder() + .name("filtered") + .metricsParser(new SingleValueParser("filtered")) + .build()); + assertThat(parse(parser, response), contains(entry("filtered", 37.0))); } - /** - * SELECT AVG(age) FILTER(WHERE age > 37) as filtered FROM accounts GROUP BY gender. - */ @Test void filter_aggregation_group_by_should_pass() { - String response = "{\n" - + " \"composite#composite_buckets\":{\n" - + " \"after_key\":{\n" - + " \"gender\":\"m\"\n" - + " },\n" - + " \"buckets\":[\n" - + " {\n" - + " \"key\":{\n" - + " \"gender\":\"f\"\n" - + " },\n" - + " \"doc_count\":3,\n" - + " \"filter#filter\":{\n" - + " \"doc_count\":1,\n" - + " \"avg#avg\":{\n" - + " \"value\":39.0\n" - + " }\n" - + " }\n" - + " },\n" - + " {\n" - + " \"key\":{\n" - + " \"gender\":\"m\"\n" - + " },\n" - + " \"doc_count\":4,\n" - + " \"filter#filter\":{\n" - + " \"doc_count\":2,\n" - + " \"avg#avg\":{\n" - + " \"value\":36.0\n" - + " }\n" - + " }\n" - + " }\n" - + " ]\n" - + " }\n" - + "}"; - assertThat(parse(response), containsInAnyOrder( + String response = "{\n" + + " \"composite#composite_buckets\":{\n" + + " \"after_key\":{\n" + + " \"gender\":\"m\"\n" + + " },\n" + + " \"buckets\":[\n" + + " {\n" + + " \"key\":{\n" + + " \"gender\":\"f\"\n" + + " },\n" + + " \"doc_count\":3,\n" + + " \"filter#filter\":{\n" + + " \"doc_count\":1,\n" + + " \"avg#avg\":{\n" + + " \"value\":39.0\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"key\":{\n" + + " \"gender\":\"m\"\n" + + " },\n" + + " \"doc_count\":4,\n" + + " \"filter#filter\":{\n" + + " \"doc_count\":2,\n" + + " \"avg#avg\":{\n" + + " \"value\":36.0\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + "}"; + OpenSearchAggregationResponseParser parser = new CompositeAggregationParser( + FilterParser.builder() + .name("filter") + .metricsParser(new SingleValueParser("avg")) + .build() + ); + assertThat(parse(parser, response), containsInAnyOrder( entry("gender", "f", "avg", 39.0), entry("gender", "m", "avg", 36.0))); } - public List> parse(String json) { - return OpenSearchAggregationResponseParser.parse(AggregationResponseUtils.fromJson(json)); + /** + * SELECT MAX(age) as max, STDDEV(age) as min FROM accounts. + */ + @Test + void no_bucket_max_and_extended_stats() { + String response = "{\n" + + " \"extended_stats#esField\": {\n" + + " \"count\": 2033,\n" + + " \"min\": 0,\n" + + " \"max\": 360,\n" + + " \"avg\": 45.47958681751107,\n" + + " \"sum\": 92460,\n" + + " \"sum_of_squares\": 22059450,\n" + + " \"variance\": 8782.295820390027,\n" + + " \"variance_population\": 8782.295820390027,\n" + + " \"variance_sampling\": 8786.61781636463,\n" + + " \"std_deviation\": 93.71390409320287,\n" + + " \"std_deviation_population\": 93.71390409320287,\n" + + " \"std_deviation_sampling\": 93.73696078049805,\n" + + " \"std_deviation_bounds\": {\n" + + " \"upper\": 232.9073950039168,\n" + + " \"lower\": -141.94822136889468,\n" + + " \"upper_population\": 232.9073950039168,\n" + + " \"lower_population\": -141.94822136889468,\n" + + " \"upper_sampling\": 232.95350837850717,\n" + + " \"lower_sampling\": -141.99433474348504\n" + + " }\n" + + " },\n" + + " \"max#maxField\": {\n" + + " \"value\": 360\n" + + " }\n" + + "}"; + + NoBucketAggregationParser parser = new NoBucketAggregationParser( + new SingleValueParser("maxField"), + new StatsParser(ExtendedStats::getStdDeviation, "esField") + ); + assertThat(parse(parser, response), + contains(entry("esField", 93.71390409320287, "maxField", 360D))); + } + + public List> parse(OpenSearchAggregationResponseParser parser, String json) { + return parser.parse(fromJson(json)); } public Map entry(String name, Object value) { diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java index 184312afa1..c9cde4f634 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java @@ -42,8 +42,6 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; -import org.mockito.MockedStatic; -import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.search.SearchResponse; import org.opensearch.search.SearchHit; @@ -53,6 +51,7 @@ import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; +import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; @ExtendWith(MockitoExtension.class) class OpenSearchResponseTest { @@ -72,6 +71,9 @@ class OpenSearchResponseTest { @Mock private Aggregations aggregations; + @Mock + private OpenSearchAggregationResponseParser parser; + private ExprTupleValue exprTupleValue1 = ExprTupleValue.fromExprValueMap(ImmutableMap.of("id1", new ExprIntegerValue(1))); @@ -147,26 +149,24 @@ void response_isnot_aggregation_when_aggregation_is_empty() { @Test void aggregation_iterator() { - try ( - MockedStatic mockedStatic = Mockito - .mockStatic(OpenSearchAggregationResponseParser.class)) { - when(OpenSearchAggregationResponseParser.parse(any())) - .thenReturn(Arrays.asList(ImmutableMap.of("id1", 1), ImmutableMap.of("id2", 2))); - when(searchResponse.getAggregations()).thenReturn(aggregations); - when(factory.construct(anyString(), any())).thenReturn(new ExprIntegerValue(1)) - .thenReturn(new ExprIntegerValue(2)); - - int i = 0; - for (ExprValue hit : new OpenSearchResponse(searchResponse, factory)) { - if (i == 0) { - assertEquals(exprTupleValue1, hit); - } else if (i == 1) { - assertEquals(exprTupleValue2, hit); - } else { - fail("More search hits returned than expected"); - } - i++; + when(parser.parse(any())) + .thenReturn(Arrays.asList(ImmutableMap.of("id1", 1), ImmutableMap.of("id2", 2))); + when(searchResponse.getAggregations()).thenReturn(aggregations); + when(factory.getParser()).thenReturn(parser); + when(factory.construct(anyString(), any())) + .thenReturn(new ExprIntegerValue(1)) + .thenReturn(new ExprIntegerValue(2)); + + int i = 0; + for (ExprValue hit : new OpenSearchResponse(searchResponse, factory)) { + if (i == 0) { + assertEquals(exprTupleValue1, hit); + } else if (i == 1) { + assertEquals(exprTupleValue2, hit); + } else { + fail("More search hits returned than expected"); } + i++; } } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilderTest.java index 2242298bed..62643baad2 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilderTest.java @@ -423,13 +423,18 @@ private String buildQuery(List namedAggregatorList, } @SneakyThrows - private String buildQuery(List namedAggregatorList, - List groupByList, - List> sortList) { + private String buildQuery( + List namedAggregatorList, + List groupByList, + List> sortList) { ObjectMapper objectMapper = new ObjectMapper(); - return objectMapper.readTree( - queryBuilder.buildAggregationBuilder(namedAggregatorList, groupByList, sortList).get(0) - .toString()) + return objectMapper + .readTree( + queryBuilder + .buildAggregationBuilder(namedAggregatorList, groupByList, sortList) + .getLeft() + .get(0) + .toString()) .toPrettyString(); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java index b956a2f5a0..85b3bd5a65 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java @@ -211,7 +211,7 @@ void should_throw_exception_for_unsupported_exception() { private String buildQuery(List namedAggregatorList) { ObjectMapper objectMapper = new ObjectMapper(); return objectMapper.readTree( - aggregationBuilder.build(namedAggregatorList).toString()) + aggregationBuilder.build(namedAggregatorList).getLeft().toString()) .toPrettyString(); } } From a19df4bb1f7e32cb4aeb4e6345c4594f073fa9e0 Mon Sep 17 00:00:00 2001 From: penghuo Date: Mon, 7 Jun 2021 17:10:03 -0700 Subject: [PATCH 03/11] add var and varp for PPL Signed-off-by: penghuo --- .../function/BuiltinFunctionName.java | 2 + docs/user/dql/aggregations.rst | 150 ++++++++++++++++++ docs/user/ppl/cmd/stats.rst | 132 +++++++++++++++ .../correctness/queries/aggregation.txt | 4 +- .../dsl/MetricAggregationBuilder.java | 16 ++ ppl/src/main/antlr/OpenSearchPPLParser.g4 | 2 +- .../ppl/parser/AstExpressionBuilderTest.java | 42 +++++ 7 files changed, 346 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 9c541bbe7d..f531ee4bbd 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -204,6 +204,8 @@ public enum BuiltinFunctionName { .put("var_pop", BuiltinFunctionName.VARPOP) .put("var_samp", BuiltinFunctionName.VARSAMP) .put("variance", BuiltinFunctionName.VARPOP) + .put("var", BuiltinFunctionName.VARSAMP) + .put("varp", BuiltinFunctionName.VARPOP) .build(); public static Optional of(String str) { diff --git a/docs/user/dql/aggregations.rst b/docs/user/dql/aggregations.rst index 98b565e1ec..5309bc8233 100644 --- a/docs/user/dql/aggregations.rst +++ b/docs/user/dql/aggregations.rst @@ -135,6 +135,156 @@ Besides regular identifiers, ``COUNT`` aggregate function also accepts arguments 2. ``COUNT(*)`` will count the number of all its input rows. 3. ``COUNT(1)`` is same as ``COUNT(*)`` because any non-null literal will count. +Aggregation Functions +===================== + +COUNT +----- + +Description +>>>>>>>>>>> + +Usage: Returns a count of the number of expr in the rows retrieved by a SELECT statement. + +Example: + + os> SELECT gender, count(*) as countV FROM accounts GROUP BY gender; + fetched rows / total rows = 2/2 + +----------+----------+ + | gender | countV | + |----------+----------| + | F | 1 | + | M | 3 | + +----------+----------+ + +SUM +--- + +Description +>>>>>>>>>>> + +Usage: SUM(expr). Returns the sum of expr. + +Example: + + os> SELECT gender, sum(age) as sumV FROM accounts GROUP BY gender; + fetched rows / total rows = 2/2 + +----------+--------+ + | gender | sumV | + |----------+--------| + | F | 28 | + | M | 101 | + +----------+--------+ + +AVG +--- + +Description +>>>>>>>>>>> + +Usage: AVG(expr). Returns the average value of expr. + +Example: + + os> SELECT gender, avg(age) as avgV FROM accounts GROUP BY gender; + fetched rows / total rows = 2/2 + +----------+--------------------+ + | gender | avgV | + |----------+--------------------| + | F | 28.0 | + | M | 33.666666666666664 | + +----------+--------------------+ + +MAX +--- + +Description +>>>>>>>>>>> + +Usage: MAX(expr). Returns the maximum value of expr. + +Example: + + os> SELECT max(age) as maxV FROM accounts; + fetched rows / total rows = 1/1 + +--------+ + | maxV | + |--------| + | 36 | + +--------+ + +MIN +--- + +Description +>>>>>>>>>>> + +Usage: MIN(expr). Returns the minimum value of expr. + +Example: + + os> SELECT min(age) as minV FROM accounts; + fetched rows / total rows = 1/1 + +--------+ + | minV | + |--------| + | 28 | + +--------+ + +VAR_POP +------- + +Description +>>>>>>>>>>> + +Usage: VAR_POP(expr). Returns the population standard variance of expr. + +Example: + + os> SELECT var_pop(age) as varV FROM accounts; + fetched rows / total rows = 1/1 + +--------+ + | varV | + |--------| + | 8.1875 | + +--------+ + +VAR_SAMP +-------- + +Description +>>>>>>>>>>> + +Usage: VAR_SAMP(expr). Returns the sample variance of expr. + +Example: + + os> SELECT var_samp(age) as varV FROM accounts; + fetched rows / total rows = 1/1 + +--------------------+ + | varV | + |--------------------| + | 10.916666666666666 | + +--------------------+ + +VARIANCE +-------- + +Description +>>>>>>>>>>> + +Usage: VARIANCE(expr). Returns the population standard variance of expr. VARIANCE() is a synonym for the standard SQL function VAR_POP() + +Example: + + os> SELECT variance(age) as varV FROM accounts; + fetched rows / total rows = 1/1 + +--------+ + | varV | + |--------| + | 8.1875 | + +--------+ + HAVING Clause ============= diff --git a/docs/user/ppl/cmd/stats.rst b/docs/user/ppl/cmd/stats.rst index 3aca304fcd..d36d1e3a4b 100644 --- a/docs/user/ppl/cmd/stats.rst +++ b/docs/user/ppl/cmd/stats.rst @@ -38,6 +38,138 @@ stats ... [by-clause]... * aggregation: mandatory. A aggregation function. The argument of aggregation must be field. * by-clause: optional. The one or more fields to group the results by. **Default**: If no is specified, the stats command returns only one row, which is the aggregation over the entire result set. + +Aggregation Functions +===================== + +COUNT +----- + +Description +>>>>>>>>>>> + +Usage: Returns a count of the number of expr in the rows retrieved by a SELECT statement. + +Example: + + os> source=accounts | stats count(); + fetched rows / total rows = 1/1 + +-----------+ + | count() | + |-----------| + | 4 | + +-----------+ + +SUM +--- + +Description +>>>>>>>>>>> + +Usage: SUM(expr). Returns the sum of expr. + +Example: + + os> source=accounts | stats sum(age) by gender; + fetched rows / total rows = 2/2 + +------------+----------+ + | sum(age) | gender | + |------------+----------| + | 28 | F | + | 101 | M | + +------------+----------+ + +AVG +--- + +Description +>>>>>>>>>>> + +Usage: AVG(expr). Returns the average value of expr. + +Example: + + os> source=accounts | stats avg(age) by gender; + fetched rows / total rows = 2/2 + +--------------------+----------+ + | avg(age) | gender | + |--------------------+----------| + | 28.0 | F | + | 33.666666666666664 | M | + +--------------------+----------+ + +MAX +--- + +Description +>>>>>>>>>>> + +Usage: MAX(expr). Returns the maximum value of expr. + +Example: + + os> source=accounts | stats max(age); + fetched rows / total rows = 1/1 + +------------+ + | max(age) | + |------------| + | 36 | + +------------+ + +MIN +--- + +Description +>>>>>>>>>>> + +Usage: MIN(expr). Returns the minimum value of expr. + +Example: + + os> source=accounts | stats min(age); + fetched rows / total rows = 1/1 + +------------+ + | min(age) | + |------------| + | 28 | + +------------+ + +VAR +------- + +Description +>>>>>>>>>>> + +Usage: VAR(expr). Returns the sample variance of expr. + +Example: + + os> source=accounts | stats var(age); + fetched rows / total rows = 1/1 + +--------------------+ + | var(age) | + |--------------------| + | 10.916666666666666 | + +--------------------+ + +VARP +-------- + +Description +>>>>>>>>>>> + +Usage: VARP(expr). Returns the population standard variance of expr. + +Example: + + os> source=accounts | stats varp(age); + fetched rows / total rows = 1/1 + +-------------+ + | varp(age) | + |-------------| + | 8.1875 | + +-------------+ + Example 1: Calculate the count of events ======================================== diff --git a/integ-test/src/test/resources/correctness/queries/aggregation.txt b/integ-test/src/test/resources/correctness/queries/aggregation.txt index 6c6e5b73a1..9318420c04 100644 --- a/integ-test/src/test/resources/correctness/queries/aggregation.txt +++ b/integ-test/src/test/resources/correctness/queries/aggregation.txt @@ -5,4 +5,6 @@ SELECT SUM(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT MAX(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT MAX(timestamp) FROM opensearch_dashboards_sample_data_flights SELECT MIN(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights -SELECT MIN(timestamp) FROM opensearch_dashboards_sample_data_flights \ No newline at end of file +SELECT MIN(timestamp) FROM opensearch_dashboards_sample_data_flights +SELECT VAR_POP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights +SELECT VAR_SAMP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights \ No newline at end of file diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index 0dbfec02c1..0699f103ec 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -37,6 +37,7 @@ import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.AggregatorFactories; import org.opensearch.search.aggregations.bucket.filter.FilterAggregationBuilder; +import org.opensearch.search.aggregations.metrics.ExtendedStats; import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.ExpressionNodeVisitor; @@ -46,6 +47,7 @@ import org.opensearch.sql.opensearch.response.agg.FilterParser; import org.opensearch.sql.opensearch.response.agg.MetricParser; import org.opensearch.sql.opensearch.response.agg.SingleValueParser; +import org.opensearch.sql.opensearch.response.agg.StatsParser; import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder; import org.opensearch.sql.opensearch.storage.serialization.ExpressionSerializer; @@ -124,6 +126,20 @@ public Pair visitNamedAggregator( condition, name, new SingleValueParser(name)); + case "var_samp": + return make( + AggregationBuilders.extendedStats(name), + expression, + condition, + name, + new StatsParser(ExtendedStats::getVarianceSampling,name)); + case "var_pop": + return make( + AggregationBuilders.extendedStats(name), + expression, + condition, + name, + new StatsParser(ExtendedStats::getVariancePopulation,name)); default: throw new IllegalStateException( String.format("unsupported aggregator %s", node.getFunctionName().getFunctionName())); diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 77aecf5a44..b4073840c4 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -139,7 +139,7 @@ statsFunction ; statsFunctionName - : AVG | COUNT | SUM | MIN | MAX + : AVG | COUNT | SUM | MIN | MAX | VAR | VARP ; percentileAggFunction diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java index 07ad97401e..b4763c4bb1 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java @@ -335,6 +335,48 @@ public void testAggFuncCallExpr() { )); } + @Test + public void testVarAggregationShouldPass() { + assertEqual("source=t | stats var(a) by b", + agg( + relation("t"), + exprList( + alias( + "var(a)", + aggregate("var", field("a")) + ) + ), + emptyList(), + exprList( + alias( + "b", + field("b") + )), + defaultStatsArgs() + )); + } + + @Test + public void testVarpAggregationShouldPass() { + assertEqual("source=t | stats varp(a) by b", + agg( + relation("t"), + exprList( + alias( + "varp(a)", + aggregate("varp", field("a")) + ) + ), + emptyList(), + exprList( + alias( + "b", + field("b") + )), + defaultStatsArgs() + )); + } + @Test public void testPercentileAggFuncExpr() { assertEqual("source=t | stats percentile<1>(a)", From 0bb25e88ae49839e62861ee1e23c09c694693011 Mon Sep 17 00:00:00 2001 From: penghuo Date: Mon, 7 Jun 2021 18:04:14 -0700 Subject: [PATCH 04/11] add UT Signed-off-by: penghuo --- .../aggregation/AggregatorFunction.java | 6 ++- .../aggregation/VarianceAggregator.java | 16 ++++++++ .../dsl/MetricAggregationBuilderTest.java | 37 +++++++++++++++++++ 3 files changed, 57 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java index cdbb9855f3..23153e4229 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java @@ -35,6 +35,8 @@ import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.data.type.ExprCoreType.TIME; import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP; +import static org.opensearch.sql.expression.aggregation.VarianceAggregator.variancePopulation; +import static org.opensearch.sql.expression.aggregation.VarianceAggregator.varianceSample; import com.google.common.collect.ImmutableMap; import java.util.Collections; @@ -168,7 +170,7 @@ private static FunctionResolver varSamp() { functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - arguments -> new VarianceAggregator(true, arguments, DOUBLE)) + arguments -> variancePopulation(arguments, DOUBLE)) .build() ); } @@ -179,7 +181,7 @@ private static FunctionResolver varPop() { functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - arguments -> new VarianceAggregator(false, arguments, DOUBLE)) + arguments -> varianceSample(arguments, DOUBLE)) .build() ); } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/VarianceAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/VarianceAggregator.java index 7abfdcb987..bd9f0948f6 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/VarianceAggregator.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/VarianceAggregator.java @@ -33,6 +33,22 @@ public class VarianceAggregator extends Aggregator arguments, + ExprCoreType returnType) { + return new VarianceAggregator(false, arguments, returnType); + } + + /** + * Build Sample Variance {@link VarianceAggregator}. + */ + public static Aggregator varianceSample(List arguments, + ExprCoreType returnType) { + return new VarianceAggregator(true, arguments, returnType); + } + /** * VarianceAggregator constructor. * diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java index 85b3bd5a65..1df8ceaa4c 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java @@ -35,6 +35,8 @@ import static org.opensearch.sql.expression.DSL.literal; import static org.opensearch.sql.expression.DSL.named; import static org.opensearch.sql.expression.DSL.ref; +import static org.opensearch.sql.expression.aggregation.VarianceAggregator.variancePopulation; +import static org.opensearch.sql.expression.aggregation.VarianceAggregator.varianceSample; import com.fasterxml.jackson.databind.ObjectMapper; import java.util.Arrays; @@ -53,6 +55,7 @@ import org.opensearch.sql.expression.aggregation.MinAggregator; import org.opensearch.sql.expression.aggregation.NamedAggregator; import org.opensearch.sql.expression.aggregation.SumAggregator; +import org.opensearch.sql.expression.aggregation.VarianceAggregator; import org.opensearch.sql.expression.function.FunctionName; import org.opensearch.sql.opensearch.storage.serialization.ExpressionSerializer; @@ -185,6 +188,40 @@ void should_build_max_aggregation() { new MaxAggregator(Arrays.asList(ref("age", INTEGER)), INTEGER))))); } + @Test + void should_build_varPop_aggregation() { + assertEquals( + "{\n" + + " \"var_pop(age)\" : {\n" + + " \"extended_stats\" : {\n" + + " \"field\" : \"age\",\n" + + " \"sigma\" : 2.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery( + Arrays.asList( + named("var_pop(age)", + variancePopulation(Arrays.asList(ref("age", INTEGER)), INTEGER))))); + } + + @Test + void should_build_varSamp_aggregation() { + assertEquals( + "{\n" + + " \"var_samp(age)\" : {\n" + + " \"extended_stats\" : {\n" + + " \"field\" : \"age\",\n" + + " \"sigma\" : 2.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery( + Arrays.asList( + named("var_samp(age)", + varianceSample(Arrays.asList(ref("age", INTEGER)), INTEGER))))); + } + @Test void should_throw_exception_for_unsupported_aggregator() { when(aggregator.getFunctionName()).thenReturn(new FunctionName("unsupported_agg")); From 10257fd9864eb38044f982e69f3438440dc712eb Mon Sep 17 00:00:00 2001 From: penghuo Date: Mon, 7 Jun 2021 19:12:45 -0700 Subject: [PATCH 05/11] fix UT Signed-off-by: penghuo --- .../sql/expression/aggregation/AggregatorFunction.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java index 23153e4229..7953d9c2f0 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java @@ -170,7 +170,7 @@ private static FunctionResolver varSamp() { functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - arguments -> variancePopulation(arguments, DOUBLE)) + arguments -> varianceSample(arguments, DOUBLE)) .build() ); } @@ -181,7 +181,7 @@ private static FunctionResolver varPop() { functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - arguments -> varianceSample(arguments, DOUBLE)) + arguments -> variancePopulation(arguments, DOUBLE)) .build() ); } From 62db7a8e2a321f00ec72a6739279993658bcf6e1 Mon Sep 17 00:00:00 2001 From: penghuo Date: Tue, 8 Jun 2021 07:44:28 -0700 Subject: [PATCH 06/11] fix doc format Signed-off-by: penghuo --- docs/user/dql/aggregations.rst | 8 ++++++++ docs/user/ppl/cmd/stats.rst | 7 +++++++ 2 files changed, 15 insertions(+) diff --git a/docs/user/dql/aggregations.rst b/docs/user/dql/aggregations.rst index 5309bc8233..1f894a7a6a 100644 --- a/docs/user/dql/aggregations.rst +++ b/docs/user/dql/aggregations.rst @@ -157,6 +157,7 @@ Example: | M | 3 | +----------+----------+ + SUM --- @@ -176,6 +177,7 @@ Example: | M | 101 | +----------+--------+ + AVG --- @@ -195,6 +197,7 @@ Example: | M | 33.666666666666664 | +----------+--------------------+ + MAX --- @@ -213,6 +216,7 @@ Example: | 36 | +--------+ + MIN --- @@ -231,6 +235,7 @@ Example: | 28 | +--------+ + VAR_POP ------- @@ -249,6 +254,7 @@ Example: | 8.1875 | +--------+ + VAR_SAMP -------- @@ -267,6 +273,7 @@ Example: | 10.916666666666666 | +--------------------+ + VARIANCE -------- @@ -285,6 +292,7 @@ Example: | 8.1875 | +--------+ + HAVING Clause ============= diff --git a/docs/user/ppl/cmd/stats.rst b/docs/user/ppl/cmd/stats.rst index d36d1e3a4b..b4825c257f 100644 --- a/docs/user/ppl/cmd/stats.rst +++ b/docs/user/ppl/cmd/stats.rst @@ -60,6 +60,7 @@ Example: | 4 | +-----------+ + SUM --- @@ -79,6 +80,7 @@ Example: | 101 | M | +------------+----------+ + AVG --- @@ -98,6 +100,7 @@ Example: | 33.666666666666664 | M | +--------------------+----------+ + MAX --- @@ -116,6 +119,7 @@ Example: | 36 | +------------+ + MIN --- @@ -134,6 +138,7 @@ Example: | 28 | +------------+ + VAR ------- @@ -152,6 +157,7 @@ Example: | 10.916666666666666 | +--------------------+ + VARP -------- @@ -170,6 +176,7 @@ Example: | 8.1875 | +-------------+ + Example 1: Calculate the count of events ======================================== From 73f5b34eeb154f8e0cacd1dbb23e4ac0e63b449a Mon Sep 17 00:00:00 2001 From: penghuo Date: Tue, 8 Jun 2021 07:48:12 -0700 Subject: [PATCH 07/11] fix doc format Signed-off-by: penghuo --- docs/user/dql/aggregations.rst | 24 ++++++++---------------- docs/user/ppl/cmd/stats.rst | 25 +++++++++---------------- 2 files changed, 17 insertions(+), 32 deletions(-) diff --git a/docs/user/dql/aggregations.rst b/docs/user/dql/aggregations.rst index 1f894a7a6a..841e053291 100644 --- a/docs/user/dql/aggregations.rst +++ b/docs/user/dql/aggregations.rst @@ -146,7 +146,7 @@ Description Usage: Returns a count of the number of expr in the rows retrieved by a SELECT statement. -Example: +Example:: os> SELECT gender, count(*) as countV FROM accounts GROUP BY gender; fetched rows / total rows = 2/2 @@ -157,7 +157,6 @@ Example: | M | 3 | +----------+----------+ - SUM --- @@ -166,7 +165,7 @@ Description Usage: SUM(expr). Returns the sum of expr. -Example: +Example:: os> SELECT gender, sum(age) as sumV FROM accounts GROUP BY gender; fetched rows / total rows = 2/2 @@ -177,7 +176,6 @@ Example: | M | 101 | +----------+--------+ - AVG --- @@ -186,7 +184,7 @@ Description Usage: AVG(expr). Returns the average value of expr. -Example: +Example:: os> SELECT gender, avg(age) as avgV FROM accounts GROUP BY gender; fetched rows / total rows = 2/2 @@ -197,7 +195,6 @@ Example: | M | 33.666666666666664 | +----------+--------------------+ - MAX --- @@ -206,7 +203,7 @@ Description Usage: MAX(expr). Returns the maximum value of expr. -Example: +Example:: os> SELECT max(age) as maxV FROM accounts; fetched rows / total rows = 1/1 @@ -216,7 +213,6 @@ Example: | 36 | +--------+ - MIN --- @@ -225,7 +221,7 @@ Description Usage: MIN(expr). Returns the minimum value of expr. -Example: +Example:: os> SELECT min(age) as minV FROM accounts; fetched rows / total rows = 1/1 @@ -235,7 +231,6 @@ Example: | 28 | +--------+ - VAR_POP ------- @@ -244,7 +239,7 @@ Description Usage: VAR_POP(expr). Returns the population standard variance of expr. -Example: +Example:: os> SELECT var_pop(age) as varV FROM accounts; fetched rows / total rows = 1/1 @@ -254,7 +249,6 @@ Example: | 8.1875 | +--------+ - VAR_SAMP -------- @@ -263,7 +257,7 @@ Description Usage: VAR_SAMP(expr). Returns the sample variance of expr. -Example: +Example:: os> SELECT var_samp(age) as varV FROM accounts; fetched rows / total rows = 1/1 @@ -273,7 +267,6 @@ Example: | 10.916666666666666 | +--------------------+ - VARIANCE -------- @@ -282,7 +275,7 @@ Description Usage: VARIANCE(expr). Returns the population standard variance of expr. VARIANCE() is a synonym for the standard SQL function VAR_POP() -Example: +Example:: os> SELECT variance(age) as varV FROM accounts; fetched rows / total rows = 1/1 @@ -292,7 +285,6 @@ Example: | 8.1875 | +--------+ - HAVING Clause ============= diff --git a/docs/user/ppl/cmd/stats.rst b/docs/user/ppl/cmd/stats.rst index b4825c257f..b9381f814e 100644 --- a/docs/user/ppl/cmd/stats.rst +++ b/docs/user/ppl/cmd/stats.rst @@ -50,7 +50,7 @@ Description Usage: Returns a count of the number of expr in the rows retrieved by a SELECT statement. -Example: +Example:: os> source=accounts | stats count(); fetched rows / total rows = 1/1 @@ -60,7 +60,6 @@ Example: | 4 | +-----------+ - SUM --- @@ -69,7 +68,7 @@ Description Usage: SUM(expr). Returns the sum of expr. -Example: +Example:: os> source=accounts | stats sum(age) by gender; fetched rows / total rows = 2/2 @@ -80,7 +79,6 @@ Example: | 101 | M | +------------+----------+ - AVG --- @@ -89,7 +87,7 @@ Description Usage: AVG(expr). Returns the average value of expr. -Example: +Example:: os> source=accounts | stats avg(age) by gender; fetched rows / total rows = 2/2 @@ -100,7 +98,6 @@ Example: | 33.666666666666664 | M | +--------------------+----------+ - MAX --- @@ -109,7 +106,7 @@ Description Usage: MAX(expr). Returns the maximum value of expr. -Example: +Example:: os> source=accounts | stats max(age); fetched rows / total rows = 1/1 @@ -119,7 +116,6 @@ Example: | 36 | +------------+ - MIN --- @@ -128,7 +124,7 @@ Description Usage: MIN(expr). Returns the minimum value of expr. -Example: +Example:: os> source=accounts | stats min(age); fetched rows / total rows = 1/1 @@ -138,16 +134,15 @@ Example: | 28 | +------------+ - VAR -------- +--- Description >>>>>>>>>>> Usage: VAR(expr). Returns the sample variance of expr. -Example: +Example:: os> source=accounts | stats var(age); fetched rows / total rows = 1/1 @@ -157,16 +152,15 @@ Example: | 10.916666666666666 | +--------------------+ - VARP --------- +---- Description >>>>>>>>>>> Usage: VARP(expr). Returns the population standard variance of expr. -Example: +Example:: os> source=accounts | stats varp(age); fetched rows / total rows = 1/1 @@ -176,7 +170,6 @@ Example: | 8.1875 | +-------------+ - Example 1: Calculate the count of events ======================================== From 7f24b788b05936311a93d590bc2c41f2fa60783b Mon Sep 17 00:00:00 2001 From: penghuo Date: Tue, 8 Jun 2021 07:50:54 -0700 Subject: [PATCH 08/11] fix the doc Signed-off-by: penghuo --- docs/user/dql/aggregations.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user/dql/aggregations.rst b/docs/user/dql/aggregations.rst index 841e053291..31d0ed19ba 100644 --- a/docs/user/dql/aggregations.rst +++ b/docs/user/dql/aggregations.rst @@ -273,7 +273,7 @@ VARIANCE Description >>>>>>>>>>> -Usage: VARIANCE(expr). Returns the population standard variance of expr. VARIANCE() is a synonym for the standard SQL function VAR_POP() +Usage: VARIANCE(expr). Returns the population standard variance of expr. VARIANCE() is a synonym VAR_POP() function. Example:: From 066c0e0071149331ac00afd1013540bf2fda004b Mon Sep 17 00:00:00 2001 From: penghuo Date: Tue, 8 Jun 2021 14:38:38 -0700 Subject: [PATCH 09/11] add stddev_samp and stddev_pop Signed-off-by: penghuo --- .../org/opensearch/sql/expression/DSL.java | 8 + .../aggregation/AggregatorFunction.java | 26 +++ .../aggregation/StdDevAggregator.java | 110 +++++++++++ .../function/BuiltinFunctionName.java | 10 +- .../aggregation/StdDevAggregatorTest.java | 182 ++++++++++++++++++ docs/user/dql/aggregations.rst | 72 +++++++ docs/user/ppl/cmd/stats.rst | 64 ++++-- .../correctness/queries/aggregation.txt | 4 +- .../dsl/MetricAggregationBuilder.java | 14 ++ ppl/src/main/antlr/OpenSearchPPLLexer.g4 | 6 +- ppl/src/main/antlr/OpenSearchPPLParser.g4 | 2 +- .../ppl/parser/AstExpressionBuilderTest.java | 54 +++++- sql/src/main/antlr/OpenSearchSQLLexer.g4 | 4 + sql/src/main/antlr/OpenSearchSQLParser.g4 | 2 +- 14 files changed, 531 insertions(+), 27 deletions(-) create mode 100644 core/src/main/java/org/opensearch/sql/expression/aggregation/StdDevAggregator.java create mode 100644 core/src/test/java/org/opensearch/sql/expression/aggregation/StdDevAggregatorTest.java diff --git a/core/src/main/java/org/opensearch/sql/expression/DSL.java b/core/src/main/java/org/opensearch/sql/expression/DSL.java index 6af2b19742..560414592c 100644 --- a/core/src/main/java/org/opensearch/sql/expression/DSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/DSL.java @@ -508,6 +508,14 @@ public Aggregator varPop(Expression... expressions) { return aggregate(BuiltinFunctionName.VARPOP, expressions); } + public Aggregator stddevSamp(Expression... expressions) { + return aggregate(BuiltinFunctionName.STDDEV_SAMP, expressions); + } + + public Aggregator stddevPop(Expression... expressions) { + return aggregate(BuiltinFunctionName.STDDEV_POP, expressions); + } + public RankingWindowFunction rowNumber() { return (RankingWindowFunction) repository.compile( BuiltinFunctionName.ROW_NUMBER.getName(), Collections.emptyList()); diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java index 7953d9c2f0..640ae8a934 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java @@ -35,6 +35,8 @@ import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.data.type.ExprCoreType.TIME; import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP; +import static org.opensearch.sql.expression.aggregation.StdDevAggregator.stddevPopulation; +import static org.opensearch.sql.expression.aggregation.StdDevAggregator.stddevSample; import static org.opensearch.sql.expression.aggregation.VarianceAggregator.variancePopulation; import static org.opensearch.sql.expression.aggregation.VarianceAggregator.varianceSample; @@ -72,6 +74,8 @@ public static void register(BuiltinFunctionRepository repository) { repository.register(max()); repository.register(varSamp()); repository.register(varPop()); + repository.register(stddevSamp()); + repository.register(stddevPop()); } private static FunctionResolver avg() { @@ -185,4 +189,26 @@ private static FunctionResolver varPop() { .build() ); } + + private static FunctionResolver stddevSamp() { + FunctionName functionName = BuiltinFunctionName.STDDEV_SAMP.getName(); + return new FunctionResolver( + functionName, + new ImmutableMap.Builder() + .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), + arguments -> stddevSample(arguments, DOUBLE)) + .build() + ); + } + + private static FunctionResolver stddevPop() { + FunctionName functionName = BuiltinFunctionName.STDDEV_POP.getName(); + return new FunctionResolver( + functionName, + new ImmutableMap.Builder() + .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), + arguments -> stddevPopulation(arguments, DOUBLE)) + .build() + ); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/StdDevAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/StdDevAggregator.java new file mode 100644 index 0000000000..0cd8494449 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/StdDevAggregator.java @@ -0,0 +1,110 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.expression.aggregation; + +import static org.opensearch.sql.data.model.ExprValueUtils.doubleValue; +import static org.opensearch.sql.utils.ExpressionUtils.format; + +import java.util.ArrayList; +import java.util.List; +import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation; +import org.opensearch.sql.common.utils.StringUtils; +import org.opensearch.sql.data.model.ExprNullValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.function.BuiltinFunctionName; + +/** + * StandardDeviation Aggregator. + */ +public class StdDevAggregator extends Aggregator { + + private final boolean isSampleStdDev; + + /** + * Build Population Variance {@link VarianceAggregator}. + */ + public static Aggregator stddevPopulation(List arguments, + ExprCoreType returnType) { + return new StdDevAggregator(false, arguments, returnType); + } + + /** + * Build Sample Variance {@link VarianceAggregator}. + */ + public static Aggregator stddevSample(List arguments, + ExprCoreType returnType) { + return new StdDevAggregator(true, arguments, returnType); + } + + /** + * VarianceAggregator constructor. + * + * @param isSampleStdDev true for sample standard deviation aggregator, false for population + * standard deviation aggregator. + * @param arguments aggregator arguments. + * @param returnType aggregator return types. + */ + public StdDevAggregator( + Boolean isSampleStdDev, List arguments, ExprCoreType returnType) { + super( + isSampleStdDev + ? BuiltinFunctionName.STDDEV_SAMP.getName() + : BuiltinFunctionName.STDDEV_POP.getName(), + arguments, + returnType); + this.isSampleStdDev = isSampleStdDev; + } + + @Override + public StdDevAggregator.StdDevState create() { + return new StdDevAggregator.StdDevState(isSampleStdDev); + } + + @Override + protected StdDevAggregator.StdDevState iterate(ExprValue value, + StdDevAggregator.StdDevState state) { + state.evaluate(value); + return state; + } + + @Override + public String toString() { + return StringUtils.format( + "%s(%s)", isSampleStdDev ? "stddev_samp" : "stddev_pop", format(getArguments())); + } + + protected static class StdDevState implements AggregationState { + + private final StandardDeviation standardDeviation; + + private final List values = new ArrayList<>(); + + public StdDevState(boolean isSampleStdDev) { + this.standardDeviation = new StandardDeviation(isSampleStdDev); + } + + public void evaluate(ExprValue value) { + values.add(value.doubleValue()); + } + + @Override + public ExprValue result() { + return values.size() == 0 + ? ExprNullValue.of() + : doubleValue(standardDeviation.evaluate(values.stream().mapToDouble(d -> d).toArray())); + } + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index f531ee4bbd..24e65d4b5d 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -131,6 +131,10 @@ public enum BuiltinFunctionName { VARSAMP(FunctionName.of("var_samp")), // population standard variance VARPOP(FunctionName.of("var_pop")), + // sample standard deviation. + STDDEV_SAMP(FunctionName.of("stddev_samp")), + // population standard deviation. + STDDEV_POP(FunctionName.of("stddev_pop")), /** * Text Functions. @@ -204,8 +208,10 @@ public enum BuiltinFunctionName { .put("var_pop", BuiltinFunctionName.VARPOP) .put("var_samp", BuiltinFunctionName.VARSAMP) .put("variance", BuiltinFunctionName.VARPOP) - .put("var", BuiltinFunctionName.VARSAMP) - .put("varp", BuiltinFunctionName.VARPOP) + .put("std", BuiltinFunctionName.STDDEV_POP) + .put("stddev", BuiltinFunctionName.STDDEV_POP) + .put("stddev_pop", BuiltinFunctionName.STDDEV_POP) + .put("stddev_samp", BuiltinFunctionName.STDDEV_SAMP) .build(); public static Optional of(String str) { diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/StdDevAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/StdDevAggregatorTest.java new file mode 100644 index 0000000000..ef085a81d3 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/StdDevAggregatorTest.java @@ -0,0 +1,182 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.expression.aggregation; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.data.model.ExprValueUtils.doubleValue; +import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; +import static org.opensearch.sql.data.model.ExprValueUtils.missingValue; +import static org.opensearch.sql.data.model.ExprValueUtils.nullValue; +import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.expression.DSL.ref; + +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.storage.bindingtuple.BindingTuple; + +@ExtendWith(MockitoExtension.class) +public class StdDevAggregatorTest extends AggregationTest { + + @Mock + Expression expression; + + @Mock + ExprValue tupleValue; + + @Mock + BindingTuple tuple; + + @Test + public void stddev_sample_field_expression() { + ExprValue result = + stddevSample(integerValue(1), integerValue(2), integerValue(3), integerValue(4)); + assertEquals(1.2909944487358056, result.value()); + } + + @Test + public void stddev_population_field_expression() { + ExprValue result = + stddevPop(integerValue(1), integerValue(2), integerValue(3), integerValue(4)); + assertEquals(1.118033988749895, result.value()); + } + + @Test + public void stddev_sample_arithmetic_expression() { + ExprValue result = + aggregation( + dsl.stddevSamp(dsl.multiply(ref("integer_value", INTEGER), DSL.literal(10))), tuples); + assertEquals(12.909944487358056, result.value()); + } + + @Test + public void stddev_population_arithmetic_expression() { + ExprValue result = + aggregation( + dsl.stddevPop(dsl.multiply(ref("integer_value", INTEGER), DSL.literal(10))), tuples); + assertEquals(11.180339887498949, result.value()); + } + + @Test + public void filtered_stddev_sample() { + ExprValue result = + aggregation( + dsl.stddevSamp(ref("integer_value", INTEGER)) + .condition(dsl.greater(ref("integer_value", INTEGER), DSL.literal(1))), + tuples); + assertEquals(1.0, result.value()); + } + + @Test + public void filtered_stddev_population() { + ExprValue result = + aggregation( + dsl.stddevPop(ref("integer_value", INTEGER)) + .condition(dsl.greater(ref("integer_value", INTEGER), DSL.literal(1))), + tuples); + assertEquals(0.816496580927726, result.value()); + } + + @Test + public void stddev_sample_with_missing() { + ExprValue result = stddevSample(integerValue(2), integerValue(1), missingValue()); + assertEquals(0.7071067811865476, result.value()); + } + + @Test + public void stddev_population_with_missing() { + ExprValue result = stddevPop(integerValue(2), integerValue(1), missingValue()); + assertEquals(0.5, result.value()); + } + + @Test + public void stddev_sample_with_null() { + ExprValue result = stddevSample(doubleValue(3d), doubleValue(4d), nullValue()); + assertEquals(0.7071067811865476, result.value()); + } + + @Test + public void stddev_pop_with_null() { + ExprValue result = stddevPop(doubleValue(3d), doubleValue(4d), nullValue()); + assertEquals(0.5, result.value()); + } + + @Test + public void stddev_sample_with_all_missing_or_null() { + ExprValue result = stddevSample(missingValue(), nullValue()); + assertTrue(result.isNull()); + } + + @Test + public void stddev_pop_with_all_missing_or_null() { + ExprValue result = stddevPop(missingValue(), nullValue()); + assertTrue(result.isNull()); + } + + @Test + public void stddev_sample_to_string() { + Aggregator aggregator = dsl.stddevSamp(ref("integer_value", INTEGER)); + assertEquals("stddev_samp(integer_value)", aggregator.toString()); + } + + @Test + public void stddev_pop_to_string() { + Aggregator aggregator = dsl.stddevPop(ref("integer_value", INTEGER)); + assertEquals("stddev_pop(integer_value)", aggregator.toString()); + } + + @Test + public void stddev_sample_nested_to_string() { + Aggregator avgAggregator = + dsl.stddevSamp( + dsl.multiply( + ref("integer_value", INTEGER), DSL.literal(ExprValueUtils.integerValue(10)))); + assertEquals( + String.format("stddev_samp(*(%s, %d))", ref("integer_value", INTEGER), 10), + avgAggregator.toString()); + } + + private ExprValue stddevSample(ExprValue value, ExprValue... values) { + when(expression.valueOf(any())).thenReturn(value, values); + when(expression.type()).thenReturn(DOUBLE); + return aggregation(dsl.stddevSamp(expression), mockTuples(value, values)); + } + + private ExprValue stddevPop(ExprValue value, ExprValue... values) { + when(expression.valueOf(any())).thenReturn(value, values); + when(expression.type()).thenReturn(DOUBLE); + return aggregation(dsl.stddevPop(expression), mockTuples(value, values)); + } + + private List mockTuples(ExprValue value, ExprValue... values) { + List mockTuples = new ArrayList<>(); + when(tupleValue.bindingTuples()).thenReturn(tuple); + mockTuples.add(tupleValue); + for (ExprValue exprValue : values) { + mockTuples.add(tupleValue); + } + return mockTuples; + } +} diff --git a/docs/user/dql/aggregations.rst b/docs/user/dql/aggregations.rst index 31d0ed19ba..1d6d172981 100644 --- a/docs/user/dql/aggregations.rst +++ b/docs/user/dql/aggregations.rst @@ -285,6 +285,78 @@ Example:: | 8.1875 | +--------+ +STDDEV_POP +---------- + +Description +>>>>>>>>>>> + +Usage: STDDEV_POP(expr). Returns the population standard deviation of expr. + +Example:: + + os> SELECT stddev_pop(age) as stddevV FROM accounts; + fetched rows / total rows = 1/1 + +--------------------+ + | stddevV | + |--------------------| + | 2.8613807855648994 | + +--------------------+ + +STDDEV_SAMP +----------- + +Description +>>>>>>>>>>> + +Usage: STDDEV_SAMP(expr). Returns the sample standard deviation of expr. + +Example:: + + os> SELECT stddev_samp(age) as stddevV FROM accounts; + fetched rows / total rows = 1/1 + +-------------------+ + | stddevV | + |-------------------| + | 3.304037933599835 | + +-------------------+ + +STD +--- + +Description +>>>>>>>>>>> + +Usage: STD(expr). Returns the population standard deviation of expr. STD() is a synonym STDDEV_POP() function. + +Example:: + + os> SELECT stddev_pop(age) as stddevV FROM accounts; + fetched rows / total rows = 1/1 + +--------------------+ + | stddevV | + |--------------------| + | 2.8613807855648994 | + +--------------------+ + +STDDEV +------ + +Description +>>>>>>>>>>> + +Usage: STDDEV(expr). Returns the population standard deviation of expr. STDDEV() is a synonym STDDEV_POP() function. + +Example:: + + os> SELECT stddev(age) as stddevV FROM accounts; + fetched rows / total rows = 1/1 + +--------------------+ + | stddevV | + |--------------------| + | 2.8613807855648994 | + +--------------------+ + HAVING Clause ============= diff --git a/docs/user/ppl/cmd/stats.rst b/docs/user/ppl/cmd/stats.rst index b9381f814e..f6dad255ef 100644 --- a/docs/user/ppl/cmd/stats.rst +++ b/docs/user/ppl/cmd/stats.rst @@ -134,41 +134,77 @@ Example:: | 28 | +------------+ -VAR ---- +VAR_SAMP +-------- Description >>>>>>>>>>> -Usage: VAR(expr). Returns the sample variance of expr. +Usage: VAR_SAMP(expr). Returns the sample variance of expr. Example:: - os> source=accounts | stats var(age); + os> source=accounts | stats var_samp(age); fetched rows / total rows = 1/1 +--------------------+ - | var(age) | + | var_samp(age) | |--------------------| | 10.916666666666666 | +--------------------+ -VARP ----- +VAR_POP +------- Description >>>>>>>>>>> -Usage: VARP(expr). Returns the population standard variance of expr. +Usage: VAR_POP(expr). Returns the population standard variance of expr. Example:: - os> source=accounts | stats varp(age); + os> source=accounts | stats var_pop(age); fetched rows / total rows = 1/1 - +-------------+ - | varp(age) | - |-------------| - | 8.1875 | - +-------------+ + +----------------+ + | var_pop(age) | + |----------------| + | 8.1875 | + +----------------+ + +STDDEV_SAMP +----------- + +Description +>>>>>>>>>>> + +Usage: STDDEV_SAMP(expr). Return the sample standard deviation of expr. + +Example:: + + os> source=accounts | stats stddev_samp(age); + fetched rows / total rows = 1/1 + +--------------------+ + | stddev_samp(age) | + |--------------------| + | 3.304037933599835 | + +--------------------+ + +STDDEV_POP +---------- + +Description +>>>>>>>>>>> + +Usage: STDDEV_POP(expr). Return the population standard deviation of expr. + +Example:: + + os> source=accounts | stats stddev_pop(age); + fetched rows / total rows = 1/1 + +--------------------+ + | stddev_pop(age) | + |--------------------| + | 2.8613807855648994 | + +--------------------+ Example 1: Calculate the count of events ======================================== diff --git a/integ-test/src/test/resources/correctness/queries/aggregation.txt b/integ-test/src/test/resources/correctness/queries/aggregation.txt index 9318420c04..45aa658783 100644 --- a/integ-test/src/test/resources/correctness/queries/aggregation.txt +++ b/integ-test/src/test/resources/correctness/queries/aggregation.txt @@ -7,4 +7,6 @@ SELECT MAX(timestamp) FROM opensearch_dashboards_sample_data_flights SELECT MIN(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights SELECT MIN(timestamp) FROM opensearch_dashboards_sample_data_flights SELECT VAR_POP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights -SELECT VAR_SAMP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights \ No newline at end of file +SELECT VAR_SAMP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights +SELECT STDDEV_POP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights +SELECT STDDEV_SAMP(AvgTicketPrice) FROM opensearch_dashboards_sample_data_flights \ No newline at end of file diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index 0699f103ec..3d40258288 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -140,6 +140,20 @@ public Pair visitNamedAggregator( condition, name, new StatsParser(ExtendedStats::getVariancePopulation,name)); + case "stddev_samp": + return make( + AggregationBuilders.extendedStats(name), + expression, + condition, + name, + new StatsParser(ExtendedStats::getStdDeviationSampling,name)); + case "stddev_pop": + return make( + AggregationBuilders.extendedStats(name), + expression, + condition, + name, + new StatsParser(ExtendedStats::getStdDeviationPopulation,name)); default: throw new IllegalStateException( String.format("unsupported aggregator %s", node.getFunctionName().getFunctionName())); diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index 3874a0a50e..cb665f6c88 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -151,8 +151,10 @@ STDEV: 'STDEV'; STDEVP: 'STDEVP'; SUM: 'SUM'; SUMSQ: 'SUMSQ'; -VAR: 'VAR'; -VARP: 'VARP'; +VAR_SAMP: 'VAR_SAMP'; +VAR_POP: 'VAR_POP'; +STDDEV_SAMP: 'STDDEV_SAMP'; +STDDEV_POP: 'STDDEV_POP'; PERCENTILE: 'PERCENTILE'; FIRST: 'FIRST'; LAST: 'LAST'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index b4073840c4..d552ad0756 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -139,7 +139,7 @@ statsFunction ; statsFunctionName - : AVG | COUNT | SUM | MIN | MAX | VAR | VARP + : AVG | COUNT | SUM | MIN | MAX | VAR_SAMP | VAR_POP | STDDEV_SAMP | STDDEV_POP ; percentileAggFunction diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java index b4763c4bb1..71ef692abf 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java @@ -337,13 +337,13 @@ public void testAggFuncCallExpr() { @Test public void testVarAggregationShouldPass() { - assertEqual("source=t | stats var(a) by b", + assertEqual("source=t | stats var_samp(a) by b", agg( relation("t"), exprList( alias( - "var(a)", - aggregate("var", field("a")) + "var_samp(a)", + aggregate("var_samp", field("a")) ) ), emptyList(), @@ -358,13 +358,55 @@ public void testVarAggregationShouldPass() { @Test public void testVarpAggregationShouldPass() { - assertEqual("source=t | stats varp(a) by b", + assertEqual("source=t | stats var_pop(a) by b", agg( relation("t"), exprList( alias( - "varp(a)", - aggregate("varp", field("a")) + "var_pop(a)", + aggregate("var_pop", field("a")) + ) + ), + emptyList(), + exprList( + alias( + "b", + field("b") + )), + defaultStatsArgs() + )); + } + + @Test + public void testStdDevAggregationShouldPass() { + assertEqual("source=t | stats stddev_samp(a) by b", + agg( + relation("t"), + exprList( + alias( + "stddev_samp(a)", + aggregate("stddev_samp", field("a")) + ) + ), + emptyList(), + exprList( + alias( + "b", + field("b") + )), + defaultStatsArgs() + )); + } + + @Test + public void testStdDevPAggregationShouldPass() { + assertEqual("source=t | stats stddev_pop(a) by b", + agg( + relation("t"), + exprList( + alias( + "stddev_pop(a)", + aggregate("stddev_pop", field("a")) ) ), emptyList(), diff --git a/sql/src/main/antlr/OpenSearchSQLLexer.g4 b/sql/src/main/antlr/OpenSearchSQLLexer.g4 index 828f9709ca..426c77cf06 100644 --- a/sql/src/main/antlr/OpenSearchSQLLexer.g4 +++ b/sql/src/main/antlr/OpenSearchSQLLexer.g4 @@ -129,6 +129,10 @@ SUM: 'SUM'; VAR_POP: 'VAR_POP'; VAR_SAMP: 'VAR_SAMP'; VARIANCE: 'VARIANCE'; +STD: 'STD'; +STDDEV: 'STDDEV'; +STDDEV_POP: 'STDDEV_POP'; +STDDEV_SAMP: 'STDDEV_SAMP'; // Common function Keywords diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index 92144abb54..18c75b94ff 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -345,7 +345,7 @@ filterClause ; aggregationFunctionName - : AVG | COUNT | SUM | MIN | MAX | VAR_POP | VAR_SAMP | VARIANCE + : AVG | COUNT | SUM | MIN | MAX | VAR_POP | VAR_SAMP | VARIANCE | STD | STDDEV | STDDEV_POP | STDDEV_SAMP ; mathematicalFunctionName From b0982a6a3358e87dd295547c1d038e8bd4fa919e Mon Sep 17 00:00:00 2001 From: penghuo Date: Tue, 8 Jun 2021 15:00:59 -0700 Subject: [PATCH 10/11] fix UT coverage --- .../dsl/MetricAggregationBuilderTest.java | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java index 1df8ceaa4c..95a2383475 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java @@ -35,6 +35,8 @@ import static org.opensearch.sql.expression.DSL.literal; import static org.opensearch.sql.expression.DSL.named; import static org.opensearch.sql.expression.DSL.ref; +import static org.opensearch.sql.expression.aggregation.StdDevAggregator.stddevPopulation; +import static org.opensearch.sql.expression.aggregation.StdDevAggregator.stddevSample; import static org.opensearch.sql.expression.aggregation.VarianceAggregator.variancePopulation; import static org.opensearch.sql.expression.aggregation.VarianceAggregator.varianceSample; @@ -222,6 +224,40 @@ void should_build_varSamp_aggregation() { varianceSample(Arrays.asList(ref("age", INTEGER)), INTEGER))))); } + @Test + void should_build_stddevPop_aggregation() { + assertEquals( + "{\n" + + " \"stddev_pop(age)\" : {\n" + + " \"extended_stats\" : {\n" + + " \"field\" : \"age\",\n" + + " \"sigma\" : 2.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery( + Arrays.asList( + named("stddev_pop(age)", + stddevPopulation(Arrays.asList(ref("age", INTEGER)), INTEGER))))); + } + + @Test + void should_build_stddevSamp_aggregation() { + assertEquals( + "{\n" + + " \"stddev_samp(age)\" : {\n" + + " \"extended_stats\" : {\n" + + " \"field\" : \"age\",\n" + + " \"sigma\" : 2.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery( + Arrays.asList( + named("stddev_samp(age)", + stddevSample(Arrays.asList(ref("age", INTEGER)), INTEGER))))); + } + @Test void should_throw_exception_for_unsupported_aggregator() { when(aggregator.getFunctionName()).thenReturn(new FunctionName("unsupported_agg")); From c0952e161f4a13fbbde3d51ffb616c5873d9b447 Mon Sep 17 00:00:00 2001 From: penghuo Date: Wed, 9 Jun 2021 17:11:16 -0700 Subject: [PATCH 11/11] address comments Signed-off-by: penghuo --- docs/user/dql/window.rst | 86 ++++++++++++++++++- .../resources/correctness/queries/window.txt | 12 +++ 2 files changed, 97 insertions(+), 1 deletion(-) diff --git a/docs/user/dql/window.rst b/docs/user/dql/window.rst index 6d71f0637a..feb2aaa44e 100644 --- a/docs/user/dql/window.rst +++ b/docs/user/dql/window.rst @@ -20,7 +20,7 @@ A window function consists of 2 pieces: a function and a window definition. A wi There are three categories of common window functions: -1. **Aggregate Functions**: COUNT(), MIN(), MAX(), AVG() and SUM(). +1. **Aggregate Functions**: COUNT(), MIN(), MAX(), AVG(), SUM(), STDDEV_POP, STDDEV_SAMP, VAR_POP and VAR_SAMP. 2. **Ranking Functions**: ROW_NUMBER(), RANK(), DENSE_RANK(), PERCENT_RANK() and NTILE(). 3. **Analytic Functions**: CUME_DIST(), LAG() and LEAD(). @@ -146,6 +146,90 @@ Here is an example for ``SUM`` function:: | M | 39225 | 49091 | +----------+-----------+-------+ +STDDEV_POP +---------- + +Here is an example for ``STDDEV_POP`` function:: + + os> SELECT + ... gender, balance, + ... STDDEV_POP(balance) OVER( + ... PARTITION BY gender ORDER BY balance + ... ) AS val + ... FROM accounts; + fetched rows / total rows = 4/4 + +----------+-----------+--------------------+ + | gender | balance | val | + |----------+-----------+--------------------| + | F | 32838 | 0.0 | + | M | 4180 | 0.0 | + | M | 5686 | 753.0 | + | M | 39225 | 16177.091422406222 | + +----------+-----------+--------------------+ + +STDDEV_SAMP +----------- + +Here is an example for ``STDDEV_SAMP`` function:: + + os> SELECT + ... gender, balance, + ... STDDEV_SAMP(balance) OVER( + ... PARTITION BY gender ORDER BY balance + ... ) AS val + ... FROM accounts; + fetched rows / total rows = 4/4 + +----------+-----------+--------------------+ + | gender | balance | val | + |----------+-----------+--------------------| + | F | 32838 | 0.0 | + | M | 4180 | 0.0 | + | M | 5686 | 1064.9028124669405 | + | M | 39225 | 19812.809753624886 | + +----------+-----------+--------------------+ + +VAR_POP +------- + +Here is an example for ``SUM`` function:: + + os> SELECT + ... gender, balance, + ... VAR_POP(balance) OVER( + ... PARTITION BY gender ORDER BY balance + ... ) AS val + ... FROM accounts; + fetched rows / total rows = 4/4 + +----------+-----------+--------------------+ + | gender | balance | val | + |----------+-----------+--------------------| + | F | 32838 | 0.0 | + | M | 4180 | 0.0 | + | M | 5686 | 567009.0 | + | M | 39225 | 261698286.88888893 | + +----------+-----------+--------------------+ + +VAR_SAMP +-------- + +Here is an example for ``SUM`` function:: + + os> SELECT + ... gender, balance, + ... VAR_SAMP(balance) OVER( + ... PARTITION BY gender ORDER BY balance + ... ) AS val + ... FROM accounts; + fetched rows / total rows = 4/4 + +----------+-----------+-------------------+ + | gender | balance | val | + |----------+-----------+-------------------| + | F | 32838 | 0.0 | + | M | 4180 | 0.0 | + | M | 5686 | 1134018.0 | + | M | 39225 | 392547430.3333334 | + +----------+-----------+-------------------+ + Ranking Functions ================= diff --git a/integ-test/src/test/resources/correctness/queries/window.txt b/integ-test/src/test/resources/correctness/queries/window.txt index a8d134a254..c3f2715322 100644 --- a/integ-test/src/test/resources/correctness/queries/window.txt +++ b/integ-test/src/test/resources/correctness/queries/window.txt @@ -9,10 +9,18 @@ SELECT DistanceMiles, SUM(DistanceMiles) OVER () AS num FROM opensearch_dashboar SELECT DistanceMiles, AVG(DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, MAX(DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights SELECT DistanceMiles, MIN(DistanceMiles) OVER () AS num FROM opensearch_dashboards_sample_data_flights +SELECT AvgTicketPrice, STDDEV_POP(AvgTicketPrice) OVER () AS num FROM opensearch_dashboards_sample_data_flights +SELECT AvgTicketPrice, STDDEV_SAMP(AvgTicketPrice) OVER () AS num FROM opensearch_dashboards_sample_data_flights +SELECT AvgTicketPrice, VAR_POP(AvgTicketPrice) OVER () AS num FROM opensearch_dashboards_sample_data_flights +SELECT AvgTicketPrice, VAR_SAMP(AvgTicketPrice) OVER () AS num FROM opensearch_dashboards_sample_data_flights SELECT FlightDelayMin, DistanceMiles, SUM(DistanceMiles) OVER (ORDER BY FlightDelayMin) AS num FROM opensearch_dashboards_sample_data_flights SELECT FlightDelayMin, DistanceMiles, AVG(DistanceMiles) OVER (ORDER BY FlightDelayMin) AS num FROM opensearch_dashboards_sample_data_flights SELECT FlightDelayMin, DistanceMiles, MAX(DistanceMiles) OVER (ORDER BY FlightDelayMin) AS num FROM opensearch_dashboards_sample_data_flights SELECT FlightDelayMin, DistanceMiles, MIN(DistanceMiles) OVER (ORDER BY FlightDelayMin) AS num FROM opensearch_dashboards_sample_data_flights +SELECT FlightDelayMin, AvgTicketPrice, STDDEV_POP(AvgTicketPrice) OVER (ORDER BY FlightDelayMin) AS num FROM opensearch_dashboards_sample_data_flights ORDER BY FlightDelayMin +SELECT FlightDelayMin, AvgTicketPrice, STDDEV_SAMP(AvgTicketPrice) OVER (ORDER BY FlightDelayMin) AS num FROM opensearch_dashboards_sample_data_flights ORDER BY FlightDelayMin +SELECT FlightDelayMin, AvgTicketPrice, VAR_POP(AvgTicketPrice) OVER (ORDER BY FlightDelayMin) AS num FROM opensearch_dashboards_sample_data_flights ORDER BY FlightDelayMin +SELECT FlightDelayMin, AvgTicketPrice, VAR_SAMP(AvgTicketPrice) OVER (ORDER BY FlightDelayMin) AS num FROM opensearch_dashboards_sample_data_flights ORDER BY FlightDelayMin SELECT user, RANK() OVER (ORDER BY user) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, DENSE_RANK() OVER (ORDER BY user) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, COUNT(day_of_week_i) OVER (ORDER BY user) AS cnt FROM opensearch_dashboards_sample_data_ecommerce @@ -20,6 +28,8 @@ SELECT user, SUM(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dash SELECT user, AVG(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, MAX(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, MIN(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce +SELECT user, STDDEV_POP(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce ORDER BY user +SELECT user, VAR_POP(day_of_week_i) OVER (ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce ORDER BY user SELECT user, RANK() OVER (ORDER BY user DESC) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, DENSE_RANK() OVER (ORDER BY user DESC) AS rnk FROM opensearch_dashboards_sample_data_ecommerce SELECT user, COUNT(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS cnt FROM opensearch_dashboards_sample_data_ecommerce @@ -27,6 +37,8 @@ SELECT user, SUM(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS nu SELECT user, AVG(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, MAX(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT user, MIN(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce +SELECT user, STDDEV_POP(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce ORDER BY user +SELECT user, VAR_POP(day_of_week_i) OVER (PARTITION BY user ORDER BY order_id) AS num FROM opensearch_dashboards_sample_data_ecommerce ORDER BY user SELECT customer_gender, user, ROW_NUMBER() OVER (PARTITION BY customer_gender ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT customer_gender, user, RANK() OVER (PARTITION BY customer_gender ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce SELECT customer_gender, user, DENSE_RANK() OVER (PARTITION BY customer_gender ORDER BY user) AS num FROM opensearch_dashboards_sample_data_ecommerce