diff --git a/benchmarks/src/jmh/java/org/opensearch/sql/expression/operator/predicate/ComparisonOperatorBenchmark.java b/benchmarks/src/jmh/java/org/opensearch/sql/expression/operator/predicate/ComparisonOperatorBenchmark.java index d2642dd645..01b2068694 100644 --- a/benchmarks/src/jmh/java/org/opensearch/sql/expression/operator/predicate/ComparisonOperatorBenchmark.java +++ b/benchmarks/src/jmh/java/org/opensearch/sql/expression/operator/predicate/ComparisonOperatorBenchmark.java @@ -38,7 +38,7 @@ @Fork(value = 1) public class ComparisonOperatorBenchmark { - @Param(value = { "int", "string", "date" }) + @Param(value = {"int", "string", "date"}) private String testDataType; private final Map params = @@ -65,9 +65,7 @@ public void testGreaterOperator() { private void run(Function dsl) { ExprValue param = params.get(testDataType); - FunctionExpression func = dsl.apply(new Expression[] { - literal(param), literal(param) - }); + FunctionExpression func = dsl.apply(new Expression[] {literal(param), literal(param)}); func.valueOf(); } } diff --git a/legacy/build.gradle b/legacy/build.gradle index d89f7affe7..fce04ae9ba 100644 --- a/legacy/build.gradle +++ b/legacy/build.gradle @@ -53,6 +53,9 @@ compileJava { } } +checkstyleTest.ignoreFailures = true +checkstyleMain.ignoreFailures = true + // TODO: Similarly, need to fix compiling errors in test source code compileTestJava.options.warnings = false compileTestJava { diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/base/BaseType.java b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/base/BaseType.java index 280b7b4c76..37e0c4d4b3 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/base/BaseType.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/base/BaseType.java @@ -3,24 +3,21 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.antlr.semantic.types.base; import java.util.List; import org.opensearch.sql.legacy.antlr.semantic.types.Type; -/** - * Base type interface - */ +/** Base type interface */ public interface BaseType extends Type { - @Override - default Type construct(List others) { - return this; - } + @Override + default Type construct(List others) { + return this; + } - @Override - default String usage() { - return getName(); - } + @Override + default String usage() { + return getName(); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/function/AggregateFunction.java b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/function/AggregateFunction.java index 37e4091b0a..9cebf3dda6 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/function/AggregateFunction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/function/AggregateFunction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.antlr.semantic.types.function; import static org.opensearch.sql.legacy.antlr.semantic.types.base.OpenSearchDataType.DOUBLE; @@ -15,41 +14,38 @@ import org.opensearch.sql.legacy.antlr.semantic.types.Type; import org.opensearch.sql.legacy.antlr.semantic.types.TypeExpression; -/** - * Aggregate function - */ +/** Aggregate function */ public enum AggregateFunction implements TypeExpression { - COUNT( - func().to(INTEGER), // COUNT(*) - func(OPENSEARCH_TYPE).to(INTEGER) - ), - MAX(func(T(NUMBER)).to(T)), - MIN(func(T(NUMBER)).to(T)), - AVG(func(T(NUMBER)).to(DOUBLE)), - SUM(func(T(NUMBER)).to(T)); - - private TypeExpressionSpec[] specifications; - - AggregateFunction(TypeExpressionSpec... specifications) { - this.specifications = specifications; - } - - @Override - public String getName() { - return name(); - } - - @Override - public TypeExpressionSpec[] specifications() { - return specifications; - } - - private static TypeExpressionSpec func(Type... argTypes) { - return new TypeExpressionSpec().map(argTypes); - } - - @Override - public String toString() { - return "Function [" + name() + "]"; - } + COUNT( + func().to(INTEGER), // COUNT(*) + func(OPENSEARCH_TYPE).to(INTEGER)), + MAX(func(T(NUMBER)).to(T)), + MIN(func(T(NUMBER)).to(T)), + AVG(func(T(NUMBER)).to(DOUBLE)), + SUM(func(T(NUMBER)).to(T)); + + private TypeExpressionSpec[] specifications; + + AggregateFunction(TypeExpressionSpec... specifications) { + this.specifications = specifications; + } + + @Override + public String getName() { + return name(); + } + + @Override + public TypeExpressionSpec[] specifications() { + return specifications; + } + + private static TypeExpressionSpec func(Type... argTypes) { + return new TypeExpressionSpec().map(argTypes); + } + + @Override + public String toString() { + return "Function [" + name() + "]"; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/operator/ComparisonOperator.java b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/operator/ComparisonOperator.java index 993d996df3..19e8f85aa3 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/operator/ComparisonOperator.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/semantic/types/operator/ComparisonOperator.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.antlr.semantic.types.operator; import static org.opensearch.sql.legacy.antlr.semantic.types.base.OpenSearchDataType.BOOLEAN; @@ -12,53 +11,50 @@ import java.util.List; import org.opensearch.sql.legacy.antlr.semantic.types.Type; -/** - * Type for comparison operator - */ +/** Type for comparison operator */ public enum ComparisonOperator implements Type { - - EQUAL("="), - NOT_EQUAL("<>"), - NOT_EQUAL2("!="), - GREATER_THAN(">"), - GREATER_THAN_OR_EQUAL_TO(">="), - SMALLER_THAN("<"), - SMALLER_THAN_OR_EQUAL_TO("<="), - IS("IS"); - - /** Actual name representing the operator */ - private final String name; - - ComparisonOperator(String name) { - this.name = name; - } - - @Override - public String getName() { - return name; - } - - @Override - public Type construct(List actualArgs) { - if (actualArgs.size() != 2) { - return TYPE_ERROR; - } - - Type leftType = actualArgs.get(0); - Type rightType = actualArgs.get(1); - if (leftType.isCompatible(rightType) || rightType.isCompatible(leftType)) { - return BOOLEAN; - } - return TYPE_ERROR; - } - - @Override - public String usage() { - return "Please use compatible types from each side."; + EQUAL("="), + NOT_EQUAL("<>"), + NOT_EQUAL2("!="), + GREATER_THAN(">"), + GREATER_THAN_OR_EQUAL_TO(">="), + SMALLER_THAN("<"), + SMALLER_THAN_OR_EQUAL_TO("<="), + IS("IS"); + + /** Actual name representing the operator */ + private final String name; + + ComparisonOperator(String name) { + this.name = name; + } + + @Override + public String getName() { + return name; + } + + @Override + public Type construct(List actualArgs) { + if (actualArgs.size() != 2) { + return TYPE_ERROR; } - @Override - public String toString() { - return "Operator [" + getName() + "]"; + Type leftType = actualArgs.get(0); + Type rightType = actualArgs.get(1); + if (leftType.isCompatible(rightType) || rightType.isCompatible(leftType)) { + return BOOLEAN; } + return TYPE_ERROR; + } + + @Override + public String usage() { + return "Please use compatible types from each side."; + } + + @Override + public String toString() { + return "Operator [" + getName() + "]"; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/syntax/CaseInsensitiveCharStream.java b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/syntax/CaseInsensitiveCharStream.java index de7e60e9f3..c7cb212826 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/syntax/CaseInsensitiveCharStream.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/syntax/CaseInsensitiveCharStream.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.antlr.syntax; import org.antlr.v4.runtime.CharStream; @@ -11,63 +10,64 @@ import org.antlr.v4.runtime.misc.Interval; /** - * Custom stream to convert character to upper case for case insensitive grammar before sending to lexer. + * Custom stream to convert character to upper case for case insensitive grammar before sending to + * lexer. */ public class CaseInsensitiveCharStream implements CharStream { - /** Character stream */ - private final CharStream charStream; + /** Character stream */ + private final CharStream charStream; - public CaseInsensitiveCharStream(String sql) { - this.charStream = CharStreams.fromString(sql); - } + public CaseInsensitiveCharStream(String sql) { + this.charStream = CharStreams.fromString(sql); + } - @Override - public String getText(Interval interval) { - return charStream.getText(interval); - } + @Override + public String getText(Interval interval) { + return charStream.getText(interval); + } - @Override - public void consume() { - charStream.consume(); - } + @Override + public void consume() { + charStream.consume(); + } - @Override - public int LA(int i) { - int c = charStream.LA(i); - if (c <= 0) { - return c; - } - return Character.toUpperCase(c); + @Override + public int LA(int i) { + int c = charStream.LA(i); + if (c <= 0) { + return c; } + return Character.toUpperCase(c); + } - @Override - public int mark() { - return charStream.mark(); - } + @Override + public int mark() { + return charStream.mark(); + } - @Override - public void release(int marker) { - charStream.release(marker); - } + @Override + public void release(int marker) { + charStream.release(marker); + } - @Override - public int index() { - return charStream.index(); - } + @Override + public int index() { + return charStream.index(); + } - @Override - public void seek(int index) { - charStream.seek(index); - } + @Override + public void seek(int index) { + charStream.seek(index); + } - @Override - public int size() { - return charStream.size(); - } + @Override + public int size() { + return charStream.size(); + } - @Override - public String getSourceName() { - return charStream.getSourceName(); - } + @Override + public String getSourceName() { + return charStream.getSourceName(); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/visitor/AntlrSqlParseTreeVisitor.java b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/visitor/AntlrSqlParseTreeVisitor.java index 90a8274568..00db9a6591 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/visitor/AntlrSqlParseTreeVisitor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/visitor/AntlrSqlParseTreeVisitor.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.antlr.visitor; import static java.util.Collections.emptyList; @@ -55,78 +54,74 @@ import org.opensearch.sql.legacy.antlr.parser.OpenSearchLegacySqlParser.TableNamePatternContext; import org.opensearch.sql.legacy.antlr.parser.OpenSearchLegacySqlParserBaseVisitor; -/** - * ANTLR parse tree visitor to drive the analysis process. - */ -public class AntlrSqlParseTreeVisitor extends OpenSearchLegacySqlParserBaseVisitor { - - /** Generic visitor to perform the real action on parse tree */ - private final GenericSqlParseTreeVisitor visitor; - - public AntlrSqlParseTreeVisitor(GenericSqlParseTreeVisitor visitor) { - this.visitor = visitor; - } - - @Override - public T visitRoot(RootContext ctx) { - visitor.visitRoot(); - return super.visitRoot(ctx); - } - - @Override - public T visitUnionSelect(UnionSelectContext ctx) { - T union = visitor.visitOperator("UNION"); - return reduce(union, - asList( - ctx.querySpecification(), - ctx.unionStatement() - ) - ); - } - - @Override - public T visitMinusSelect(MinusSelectContext ctx) { - T minus = visitor.visitOperator("MINUS"); - return reduce(minus, asList(ctx.querySpecification(), ctx.minusStatement())); - } - - @Override - public T visitInPredicate(InPredicateContext ctx) { - T in = visitor.visitOperator("IN"); - PredicateContext field = ctx.predicate(); - ParserRuleContext subquery = (ctx.selectStatement() != null) ? ctx.selectStatement() : ctx.expressions(); - return reduce(in, Arrays.asList(field, subquery)); - } - - @Override - public T visitTableSources(TableSourcesContext ctx) { - if (ctx.tableSource().size() < 2) { - return super.visitTableSources(ctx); - } - T commaJoin = visitor.visitOperator("JOIN"); - return reduce(commaJoin, ctx.tableSource()); - } - - @Override - public T visitTableSourceBase(TableSourceBaseContext ctx) { - if (ctx.joinPart().isEmpty()) { - return super.visitTableSourceBase(ctx); - } - T join = visitor.visitOperator("JOIN"); - return reduce(join, asList(ctx.tableSourceItem(), ctx.joinPart())); - } - - @Override - public T visitInnerJoin(InnerJoinContext ctx) { - return visitJoin(ctx.children, ctx.tableSourceItem()); - } - - @Override - public T visitOuterJoin(OuterJoinContext ctx) { - return visitJoin(ctx.children, ctx.tableSourceItem()); - } +/** ANTLR parse tree visitor to drive the analysis process. */ +public class AntlrSqlParseTreeVisitor + extends OpenSearchLegacySqlParserBaseVisitor { + + /** Generic visitor to perform the real action on parse tree */ + private final GenericSqlParseTreeVisitor visitor; + + public AntlrSqlParseTreeVisitor(GenericSqlParseTreeVisitor visitor) { + this.visitor = visitor; + } + + @Override + public T visitRoot(RootContext ctx) { + visitor.visitRoot(); + return super.visitRoot(ctx); + } + + @Override + public T visitUnionSelect(UnionSelectContext ctx) { + T union = visitor.visitOperator("UNION"); + return reduce(union, asList(ctx.querySpecification(), ctx.unionStatement())); + } + + @Override + public T visitMinusSelect(MinusSelectContext ctx) { + T minus = visitor.visitOperator("MINUS"); + return reduce(minus, asList(ctx.querySpecification(), ctx.minusStatement())); + } + + @Override + public T visitInPredicate(InPredicateContext ctx) { + T in = visitor.visitOperator("IN"); + PredicateContext field = ctx.predicate(); + ParserRuleContext subquery = + (ctx.selectStatement() != null) ? ctx.selectStatement() : ctx.expressions(); + return reduce(in, Arrays.asList(field, subquery)); + } + + @Override + public T visitTableSources(TableSourcesContext ctx) { + if (ctx.tableSource().size() < 2) { + return super.visitTableSources(ctx); + } + T commaJoin = visitor.visitOperator("JOIN"); + return reduce(commaJoin, ctx.tableSource()); + } + + @Override + public T visitTableSourceBase(TableSourceBaseContext ctx) { + if (ctx.joinPart().isEmpty()) { + return super.visitTableSourceBase(ctx); + } + T join = visitor.visitOperator("JOIN"); + return reduce(join, asList(ctx.tableSourceItem(), ctx.joinPart())); + } + + @Override + public T visitInnerJoin(InnerJoinContext ctx) { + return visitJoin(ctx.children, ctx.tableSourceItem()); + } + + @Override + public T visitOuterJoin(OuterJoinContext ctx) { + return visitJoin(ctx.children, ctx.tableSourceItem()); + } /** + *
      * Enforce visit order because ANTLR is generic and unaware.
      *
      * Visiting order is:
@@ -137,275 +132,273 @@ public T visitOuterJoin(OuterJoinContext ctx) {
      *     => HAVING
      *      => ORDER BY
      *       => LIMIT
+     *  
*/ @Override public T visitQuerySpecification(QuerySpecificationContext ctx) { visitor.visitQuery(); - // Always visit FROM clause first to define symbols - FromClauseContext fromClause = ctx.fromClause(); - visit(fromClause.tableSources()); - - if (fromClause.whereExpr != null) { - visit(fromClause.whereExpr); - } - - // Note visit GROUP BY and HAVING later than SELECT for alias definition - T result = visitSelectElements(ctx.selectElements()); - fromClause.groupByItem().forEach(this::visit); - if (fromClause.havingExpr != null) { - visit(fromClause.havingExpr); - } - - if (ctx.orderByClause() != null) { - visitOrderByClause(ctx.orderByClause()); - } - if (ctx.limitClause() != null) { - visitLimitClause(ctx.limitClause()); - } - - visitor.endVisitQuery(); - return result; - } - - @Override - public T visitSubqueryTableItem(SubqueryTableItemContext ctx) { - throw new EarlyExitAnalysisException("Exit when meeting subquery in from"); - } - - /** Visit here instead of tableName because we need alias */ - @Override - public T visitAtomTableItem(AtomTableItemContext ctx) { - String alias = (ctx.alias == null) ? "" : ctx.alias.getText(); - T result = visit(ctx.tableName()); - visitor.visitAs(alias, result); - return result; - } - - @Override - public T visitSimpleTableName(SimpleTableNameContext ctx) { - return visitor.visitIndexName(ctx.getText()); - } - - @Override - public T visitTableNamePattern(TableNamePatternContext ctx) { - return visitor.visitIndexName(ctx.getText()); - } - - @Override - public T visitTableAndTypeName(TableAndTypeNameContext ctx) { - return visitor.visitIndexName(ctx.uid(0).getText()); - } - - @Override - public T visitFullColumnName(FullColumnNameContext ctx) { - return visitor.visitFieldName(ctx.getText()); - } - - @Override - public T visitUdfFunctionCall(UdfFunctionCallContext ctx) { - String funcName = ctx.fullId().getText(); - T func = visitor.visitFunctionName(funcName); - return reduce(func, ctx.functionArgs()); - } - - @Override - public T visitScalarFunctionCall(ScalarFunctionCallContext ctx) { - UnsupportedSemanticVerifier.verify(ctx); - T func = visit(ctx.scalarFunctionName()); - return reduce(func, ctx.functionArgs()); - } - - @Override - public T visitMathOperator(MathOperatorContext ctx) { - UnsupportedSemanticVerifier.verify(ctx); - return super.visitMathOperator(ctx); - } - - @Override - public T visitRegexpPredicate(RegexpPredicateContext ctx) { - UnsupportedSemanticVerifier.verify(ctx); - return super.visitRegexpPredicate(ctx); - } - - @Override - public T visitSelectElements(SelectElementsContext ctx) { - return visitor.visitSelect(ctx.selectElement(). - stream(). - map(this::visit). - collect(Collectors.toList())); - } - - @Override - public T visitSelectStarElement(OpenSearchLegacySqlParser.SelectStarElementContext ctx) { - return visitor.visitSelectAllColumn(); - } - - @Override - public T visitSelectColumnElement(SelectColumnElementContext ctx) { - return visitSelectItem(ctx.fullColumnName(), ctx.uid()); - } - - @Override - public T visitSelectFunctionElement(SelectFunctionElementContext ctx) { - return visitSelectItem(ctx.functionCall(), ctx.uid()); - } - - @Override - public T visitSelectExpressionElement(SelectExpressionElementContext ctx) { - return visitSelectItem(ctx.expression(), ctx.uid()); - } - - @Override - public T visitAggregateWindowedFunction(AggregateWindowedFunctionContext ctx) { - String funcName = ctx.getChild(0).getText(); - T func = visitor.visitFunctionName(funcName); - return reduce(func, ctx.functionArg()); - } - - @Override - public T visitFunctionNameBase(FunctionNameBaseContext ctx) { - return visitor.visitFunctionName(ctx.getText()); - } - - @Override - public T visitBinaryComparisonPredicate(BinaryComparisonPredicateContext ctx) { - if (isNamedArgument(ctx)) { // Essentially named argument is assign instead of comparison - return defaultResult(); - } - - T op = visit(ctx.comparisonOperator()); - return reduce(op, Arrays.asList(ctx.left, ctx.right)); - } - - @Override - public T visitIsExpression(IsExpressionContext ctx) { - T op = visitor.visitOperator("IS"); - return op.reduce(Arrays.asList( - visit(ctx.predicate()), - visitor.visitBoolean(ctx.testValue.getText())) - ); - } - - @Override - public T visitConvertedDataType(OpenSearchLegacySqlParser.ConvertedDataTypeContext ctx) { - if (ctx.getChild(0) != null && !Strings.isNullOrEmpty(ctx.getChild(0).getText())) { - return visitor.visitConvertedType(ctx.getChild(0).getText()); - } else { - return super.visitConvertedDataType(ctx); - } - } - - @Override - public T visitComparisonOperator(ComparisonOperatorContext ctx) { - return visitor.visitOperator(ctx.getText()); - } - - @Override - public T visitConstant(ConstantContext ctx) { - if (ctx.REAL_LITERAL() != null) { - return visitor.visitFloat(ctx.getText()); - } - if (ctx.dateType != null) { - return visitor.visitDate(ctx.getText()); - } - if (ctx.nullLiteral != null) { - return visitor.visitNull(); - } - return super.visitConstant(ctx); - } - - @Override - public T visitStringLiteral(StringLiteralContext ctx) { - return visitor.visitString(ctx.getText()); - } - - @Override - public T visitDecimalLiteral(DecimalLiteralContext ctx) { - return visitor.visitInteger(ctx.getText()); - } - - @Override - public T visitBooleanLiteral(BooleanLiteralContext ctx) { - return visitor.visitBoolean(ctx.getText()); - } - - @Override - protected T defaultResult() { - return visitor.defaultValue(); - } - - @Override - protected T aggregateResult(T aggregate, T nextResult) { - if (nextResult != defaultResult()) { // Simply return non-default value for now - return nextResult; - } - return aggregate; - } - - /** - * Named argument, ex. TOPHITS('size'=3), is under FunctionArgs -> Predicate - * And the function name should be contained in openSearchFunctionNameBase - */ - private boolean isNamedArgument(BinaryComparisonPredicateContext ctx) { - if (ctx.getParent() != null && ctx.getParent().getParent() != null - && ctx.getParent().getParent().getParent() != null - && ctx.getParent().getParent().getParent() instanceof ScalarFunctionCallContext) { - - ScalarFunctionCallContext parent = (ScalarFunctionCallContext) ctx.getParent().getParent().getParent(); - return parent.scalarFunctionName().functionNameBase().openSearchFunctionNameBase() != null; - } - return false; - } - - /** Enforce visiting result of table instead of ON clause as result */ - private T visitJoin(List children, TableSourceItemContext tableCtx) { - T result = defaultResult(); - for (ParseTree child : children) { - if (child == tableCtx) { - result = visit(tableCtx); - } else { - visit(child); - } - } - return result; - } - - /** Visit select items for type check and alias definition */ - private T visitSelectItem(ParserRuleContext item, UidContext uid) { - T result = visit(item); - if (uid != null) { - visitor.visitAs(uid.getText(), result); - } - return result; - } - - private T reduce(T reducer, ParserRuleContext ctx) { - return reduce(reducer, (ctx == null) ? emptyList() : ctx.children); - } - - /** Make constructor apply arguments and return result type */ - private T reduce(T reducer, List nodes) { - List args; - if (nodes == null) { - args = emptyList(); - } else { - args = nodes.stream(). - map(this::visit). - filter(type -> type != defaultResult()). - collect(Collectors.toList()); - } - return reducer.reduce(args); - } - - /** Combine an item and a list of items to a single list */ - private - List asList(Node1 first, List rest) { - - List result = new ArrayList<>(singleton(first)); - result.addAll(rest); - return result; - } - + // Always visit FROM clause first to define symbols + FromClauseContext fromClause = ctx.fromClause(); + visit(fromClause.tableSources()); + + if (fromClause.whereExpr != null) { + visit(fromClause.whereExpr); + } + + // Note visit GROUP BY and HAVING later than SELECT for alias definition + T result = visitSelectElements(ctx.selectElements()); + fromClause.groupByItem().forEach(this::visit); + if (fromClause.havingExpr != null) { + visit(fromClause.havingExpr); + } + + if (ctx.orderByClause() != null) { + visitOrderByClause(ctx.orderByClause()); + } + if (ctx.limitClause() != null) { + visitLimitClause(ctx.limitClause()); + } + + visitor.endVisitQuery(); + return result; + } + + @Override + public T visitSubqueryTableItem(SubqueryTableItemContext ctx) { + throw new EarlyExitAnalysisException("Exit when meeting subquery in from"); + } + + /** Visit here instead of tableName because we need alias */ + @Override + public T visitAtomTableItem(AtomTableItemContext ctx) { + String alias = (ctx.alias == null) ? "" : ctx.alias.getText(); + T result = visit(ctx.tableName()); + visitor.visitAs(alias, result); + return result; + } + + @Override + public T visitSimpleTableName(SimpleTableNameContext ctx) { + return visitor.visitIndexName(ctx.getText()); + } + + @Override + public T visitTableNamePattern(TableNamePatternContext ctx) { + return visitor.visitIndexName(ctx.getText()); + } + + @Override + public T visitTableAndTypeName(TableAndTypeNameContext ctx) { + return visitor.visitIndexName(ctx.uid(0).getText()); + } + + @Override + public T visitFullColumnName(FullColumnNameContext ctx) { + return visitor.visitFieldName(ctx.getText()); + } + + @Override + public T visitUdfFunctionCall(UdfFunctionCallContext ctx) { + String funcName = ctx.fullId().getText(); + T func = visitor.visitFunctionName(funcName); + return reduce(func, ctx.functionArgs()); + } + + @Override + public T visitScalarFunctionCall(ScalarFunctionCallContext ctx) { + UnsupportedSemanticVerifier.verify(ctx); + T func = visit(ctx.scalarFunctionName()); + return reduce(func, ctx.functionArgs()); + } + + @Override + public T visitMathOperator(MathOperatorContext ctx) { + UnsupportedSemanticVerifier.verify(ctx); + return super.visitMathOperator(ctx); + } + + @Override + public T visitRegexpPredicate(RegexpPredicateContext ctx) { + UnsupportedSemanticVerifier.verify(ctx); + return super.visitRegexpPredicate(ctx); + } + + @Override + public T visitSelectElements(SelectElementsContext ctx) { + return visitor.visitSelect( + ctx.selectElement().stream().map(this::visit).collect(Collectors.toList())); + } + + @Override + public T visitSelectStarElement(OpenSearchLegacySqlParser.SelectStarElementContext ctx) { + return visitor.visitSelectAllColumn(); + } + + @Override + public T visitSelectColumnElement(SelectColumnElementContext ctx) { + return visitSelectItem(ctx.fullColumnName(), ctx.uid()); + } + + @Override + public T visitSelectFunctionElement(SelectFunctionElementContext ctx) { + return visitSelectItem(ctx.functionCall(), ctx.uid()); + } + + @Override + public T visitSelectExpressionElement(SelectExpressionElementContext ctx) { + return visitSelectItem(ctx.expression(), ctx.uid()); + } + + @Override + public T visitAggregateWindowedFunction(AggregateWindowedFunctionContext ctx) { + String funcName = ctx.getChild(0).getText(); + T func = visitor.visitFunctionName(funcName); + return reduce(func, ctx.functionArg()); + } + + @Override + public T visitFunctionNameBase(FunctionNameBaseContext ctx) { + return visitor.visitFunctionName(ctx.getText()); + } + + @Override + public T visitBinaryComparisonPredicate(BinaryComparisonPredicateContext ctx) { + if (isNamedArgument(ctx)) { // Essentially named argument is assign instead of comparison + return defaultResult(); + } + + T op = visit(ctx.comparisonOperator()); + return reduce(op, Arrays.asList(ctx.left, ctx.right)); + } + + @Override + public T visitIsExpression(IsExpressionContext ctx) { + T op = visitor.visitOperator("IS"); + return op.reduce( + Arrays.asList(visit(ctx.predicate()), visitor.visitBoolean(ctx.testValue.getText()))); + } + + @Override + public T visitConvertedDataType(OpenSearchLegacySqlParser.ConvertedDataTypeContext ctx) { + if (ctx.getChild(0) != null && !Strings.isNullOrEmpty(ctx.getChild(0).getText())) { + return visitor.visitConvertedType(ctx.getChild(0).getText()); + } else { + return super.visitConvertedDataType(ctx); + } + } + + @Override + public T visitComparisonOperator(ComparisonOperatorContext ctx) { + return visitor.visitOperator(ctx.getText()); + } + + @Override + public T visitConstant(ConstantContext ctx) { + if (ctx.REAL_LITERAL() != null) { + return visitor.visitFloat(ctx.getText()); + } + if (ctx.dateType != null) { + return visitor.visitDate(ctx.getText()); + } + if (ctx.nullLiteral != null) { + return visitor.visitNull(); + } + return super.visitConstant(ctx); + } + + @Override + public T visitStringLiteral(StringLiteralContext ctx) { + return visitor.visitString(ctx.getText()); + } + + @Override + public T visitDecimalLiteral(DecimalLiteralContext ctx) { + return visitor.visitInteger(ctx.getText()); + } + + @Override + public T visitBooleanLiteral(BooleanLiteralContext ctx) { + return visitor.visitBoolean(ctx.getText()); + } + + @Override + protected T defaultResult() { + return visitor.defaultValue(); + } + + @Override + protected T aggregateResult(T aggregate, T nextResult) { + if (nextResult != defaultResult()) { // Simply return non-default value for now + return nextResult; + } + return aggregate; + } + + /** + * Named argument, ex. TOPHITS('size'=3), is under FunctionArgs -> Predicate And the function name + * should be contained in openSearchFunctionNameBase + */ + private boolean isNamedArgument(BinaryComparisonPredicateContext ctx) { + if (ctx.getParent() != null + && ctx.getParent().getParent() != null + && ctx.getParent().getParent().getParent() != null + && ctx.getParent().getParent().getParent() instanceof ScalarFunctionCallContext) { + + ScalarFunctionCallContext parent = + (ScalarFunctionCallContext) ctx.getParent().getParent().getParent(); + return parent.scalarFunctionName().functionNameBase().openSearchFunctionNameBase() != null; + } + return false; + } + + /** Enforce visiting result of table instead of ON clause as result */ + private T visitJoin(List children, TableSourceItemContext tableCtx) { + T result = defaultResult(); + for (ParseTree child : children) { + if (child == tableCtx) { + result = visit(tableCtx); + } else { + visit(child); + } + } + return result; + } + + /** Visit select items for type check and alias definition */ + private T visitSelectItem(ParserRuleContext item, UidContext uid) { + T result = visit(item); + if (uid != null) { + visitor.visitAs(uid.getText(), result); + } + return result; + } + + private T reduce(T reducer, ParserRuleContext ctx) { + return reduce(reducer, (ctx == null) ? emptyList() : ctx.children); + } + + /** Make constructor apply arguments and return result type */ + private T reduce(T reducer, List nodes) { + List args; + if (nodes == null) { + args = emptyList(); + } else { + args = + nodes.stream() + .map(this::visit) + .filter(type -> type != defaultResult()) + .collect(Collectors.toList()); + } + return reducer.reduce(args); + } + + /** Combine an item and a list of items to a single list */ + private List asList( + Node1 first, List rest) { + + List result = new ArrayList<>(singleton(first)); + result.addAll(rest); + return result; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/visitor/EarlyExitAnalysisException.java b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/visitor/EarlyExitAnalysisException.java index b0bd01a093..cf583aab40 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/antlr/visitor/EarlyExitAnalysisException.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/antlr/visitor/EarlyExitAnalysisException.java @@ -3,15 +3,12 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.antlr.visitor; -/** - * Exit visitor early due to some reason. - */ +/** Exit visitor early due to some reason. */ public class EarlyExitAnalysisException extends RuntimeException { - public EarlyExitAnalysisException(String message) { - super(message); - } + public EarlyExitAnalysisException(String message) { + super(message); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/cursor/Cursor.java b/legacy/src/main/java/org/opensearch/sql/legacy/cursor/Cursor.java index d3985259dd..8cc83a5fe2 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/cursor/Cursor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/cursor/Cursor.java @@ -3,19 +3,17 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.cursor; - public interface Cursor { - NullCursor NULL_CURSOR = new NullCursor(); + NullCursor NULL_CURSOR = new NullCursor(); /** - * All cursor's are of the form : + * All cursor's are of the form :
* The serialized form before encoding is upto Cursor implementation */ String generateCursorId(); - CursorType getType(); + CursorType getType(); } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/cursor/CursorType.java b/legacy/src/main/java/org/opensearch/sql/legacy/cursor/CursorType.java index 7c96cb8835..fea47e7e39 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/cursor/CursorType.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/cursor/CursorType.java @@ -3,42 +3,41 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.cursor; import java.util.HashMap; import java.util.Map; /** - * Different types queries for which cursor is supported. - * The result execution, and cursor genreation/parsing will depend on the cursor type. - * NullCursor is the placeholder implementation in case of non-cursor query. + * Different types queries for which cursor is supported. The result execution, and cursor + * generation/parsing will depend on the cursor type. NullCursor is the placeholder implementation + * in case of non-cursor query. */ public enum CursorType { - NULL(null), - DEFAULT("d"), - AGGREGATION("a"), - JOIN("j"); + NULL(null), + DEFAULT("d"), + AGGREGATION("a"), + JOIN("j"); - public String id; + public String id; - CursorType(String id) { - this.id = id; - } + CursorType(String id) { + this.id = id; + } - public String getId() { - return this.id; - } + public String getId() { + return this.id; + } - public static final Map LOOKUP = new HashMap<>(); + public static final Map LOOKUP = new HashMap<>(); - static { - for (CursorType type : CursorType.values()) { - LOOKUP.put(type.getId(), type); - } + static { + for (CursorType type : CursorType.values()) { + LOOKUP.put(type.getId(), type); } + } - public static CursorType getById(String id) { - return LOOKUP.getOrDefault(id, NULL); - } + public static CursorType getById(String id) { + return LOOKUP.getOrDefault(id, NULL); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/cursor/DefaultCursor.java b/legacy/src/main/java/org/opensearch/sql/legacy/cursor/DefaultCursor.java index 856c1e5e2b..c5be0066fc 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/cursor/DefaultCursor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/cursor/DefaultCursor.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.cursor; import com.google.common.base.Strings; @@ -21,9 +20,8 @@ import org.json.JSONObject; import org.opensearch.sql.legacy.executor.format.Schema; - /** - * Minimum metdata that will be serialized for generating cursorId for + * Minimum metdata that will be serialized for generating cursorId for
* SELECT .... FROM .. ORDER BY .... queries */ @Getter @@ -31,130 +29,135 @@ @NoArgsConstructor public class DefaultCursor implements Cursor { - /** Make sure all keys are unique to prevent overriding - * and as small as possible to make cursor compact - */ - private static final String FETCH_SIZE = "f"; - private static final String ROWS_LEFT = "l"; - private static final String INDEX_PATTERN = "i"; - private static final String SCROLL_ID = "s"; - private static final String SCHEMA_COLUMNS = "c"; - private static final String FIELD_ALIAS_MAP = "a"; - - /** To get mappings for index to check if type is date needed for - * @see org.opensearch.sql.legacy.executor.format.DateFieldFormatter */ - @NonNull - private String indexPattern; - - /** List of Schema.Column for maintaining field order and generating null values of missing fields */ - @NonNull - private List columns; - - /** To delegate to correct cursor handler to get next page*/ - private final CursorType type = CursorType.DEFAULT; - + /** + * Make sure all keys are unique to prevent overriding and as small as possible to make cursor + * compact + */ + private static final String FETCH_SIZE = "f"; + + private static final String ROWS_LEFT = "l"; + private static final String INDEX_PATTERN = "i"; + private static final String SCROLL_ID = "s"; + private static final String SCHEMA_COLUMNS = "c"; + private static final String FIELD_ALIAS_MAP = "a"; + + /** + * To get mappings for index to check if type is date needed for + * + * @see org.opensearch.sql.legacy.executor.format.DateFieldFormatter + */ + @NonNull private String indexPattern; + + /** + * List of Schema.Column for maintaining field order and generating null values of missing fields + */ + @NonNull private List columns; + + /** To delegate to correct cursor handler to get next page */ + private final CursorType type = CursorType.DEFAULT; + + /** + * Truncate the @see DataRows to respect LIMIT clause and/or to identify last page to close scroll + * context. docsLeft is decremented by fetch_size for call to get page of result. + */ + private long rowsLeft; + + /** + * @see org.opensearch.sql.legacy.executor.format.SelectResultSet + */ + @NonNull private Map fieldAliasMap; + + /** To get next batch of result */ + private String scrollId; + + /** To reduce the number of rows left by fetchSize */ + @NonNull private Integer fetchSize; + + private Integer limit; + + @Override + public CursorType getType() { + return type; + } + + @Override + public String generateCursorId() { + if (rowsLeft <= 0 || Strings.isNullOrEmpty(scrollId)) { + return null; + } + JSONObject json = new JSONObject(); + json.put(FETCH_SIZE, fetchSize); + json.put(ROWS_LEFT, rowsLeft); + json.put(INDEX_PATTERN, indexPattern); + json.put(SCROLL_ID, scrollId); + json.put(SCHEMA_COLUMNS, getSchemaAsJson()); + json.put(FIELD_ALIAS_MAP, fieldAliasMap); + return String.format("%s:%s", type.getId(), encodeCursor(json)); + } + + public static DefaultCursor from(String cursorId) { /** - * Truncate the @see DataRows to respect LIMIT clause and/or to identify last page to close scroll context. - * docsLeft is decremented by fetch_size for call to get page of result. + * It is assumed that cursorId here is the second part of the original cursor passed by the + * client after removing first part which identifies cursor type */ - private long rowsLeft; - - /** @see org.opensearch.sql.legacy.executor.format.SelectResultSet */ - @NonNull - private Map fieldAliasMap; - - /** To get next batch of result */ - private String scrollId; - - /** To reduce the number of rows left by fetchSize */ - @NonNull - private Integer fetchSize; - - private Integer limit; - - @Override - public CursorType getType() { - return type; - } - - @Override - public String generateCursorId() { - if (rowsLeft <=0 || Strings.isNullOrEmpty(scrollId)) { - return null; - } - JSONObject json = new JSONObject(); - json.put(FETCH_SIZE, fetchSize); - json.put(ROWS_LEFT, rowsLeft); - json.put(INDEX_PATTERN, indexPattern); - json.put(SCROLL_ID, scrollId); - json.put(SCHEMA_COLUMNS, getSchemaAsJson()); - json.put(FIELD_ALIAS_MAP, fieldAliasMap); - return String.format("%s:%s", type.getId(), encodeCursor(json)); - } - - public static DefaultCursor from(String cursorId) { - /** - * It is assumed that cursorId here is the second part of the original cursor passed - * by the client after removing first part which identifies cursor type - */ - JSONObject json = decodeCursor(cursorId); - DefaultCursor cursor = new DefaultCursor(); - cursor.setFetchSize(json.getInt(FETCH_SIZE)); - cursor.setRowsLeft(json.getLong(ROWS_LEFT)); - cursor.setIndexPattern(json.getString(INDEX_PATTERN)); - cursor.setScrollId(json.getString(SCROLL_ID)); - cursor.setColumns(getColumnsFromSchema(json.getJSONArray(SCHEMA_COLUMNS))); - cursor.setFieldAliasMap(fieldAliasMap(json.getJSONObject(FIELD_ALIAS_MAP))); - - return cursor; - } - - private JSONArray getSchemaAsJson() { - JSONArray schemaJson = new JSONArray(); - - for (Schema.Column column : columns) { - schemaJson.put(schemaEntry(column.getName(), column.getAlias(), column.getType())); - } - - return schemaJson; + JSONObject json = decodeCursor(cursorId); + DefaultCursor cursor = new DefaultCursor(); + cursor.setFetchSize(json.getInt(FETCH_SIZE)); + cursor.setRowsLeft(json.getLong(ROWS_LEFT)); + cursor.setIndexPattern(json.getString(INDEX_PATTERN)); + cursor.setScrollId(json.getString(SCROLL_ID)); + cursor.setColumns(getColumnsFromSchema(json.getJSONArray(SCHEMA_COLUMNS))); + cursor.setFieldAliasMap(fieldAliasMap(json.getJSONObject(FIELD_ALIAS_MAP))); + + return cursor; + } + + private JSONArray getSchemaAsJson() { + JSONArray schemaJson = new JSONArray(); + + for (Schema.Column column : columns) { + schemaJson.put(schemaEntry(column.getName(), column.getAlias(), column.getType())); } - private JSONObject schemaEntry(String name, String alias, String type) { - JSONObject entry = new JSONObject(); - entry.put("name", name); - if (alias != null) { - entry.put("alias", alias); - } - entry.put("type", type); - return entry; - } - - private static String encodeCursor(JSONObject cursorJson) { - return Base64.getEncoder().encodeToString(cursorJson.toString().getBytes()); - } - - private static JSONObject decodeCursor(String cursorId) { - return new JSONObject(new String(Base64.getDecoder().decode(cursorId))); - } - - private static Map fieldAliasMap(JSONObject json) { - Map fieldToAliasMap = new HashMap<>(); - json.keySet().forEach(key -> fieldToAliasMap.put(key, json.get(key).toString())); - return fieldToAliasMap; - } + return schemaJson; + } - private static List getColumnsFromSchema(JSONArray schema) { - List columns = IntStream. - range(0, schema.length()). - mapToObj(i -> { - JSONObject jsonColumn = schema.getJSONObject(i); - return new Schema.Column( - jsonColumn.getString("name"), - jsonColumn.optString("alias", null), - Schema.Type.valueOf(jsonColumn.getString("type").toUpperCase()) - ); - } - ).collect(Collectors.toList()); - return columns; + private JSONObject schemaEntry(String name, String alias, String type) { + JSONObject entry = new JSONObject(); + entry.put("name", name); + if (alias != null) { + entry.put("alias", alias); } + entry.put("type", type); + return entry; + } + + private static String encodeCursor(JSONObject cursorJson) { + return Base64.getEncoder().encodeToString(cursorJson.toString().getBytes()); + } + + private static JSONObject decodeCursor(String cursorId) { + return new JSONObject(new String(Base64.getDecoder().decode(cursorId))); + } + + private static Map fieldAliasMap(JSONObject json) { + Map fieldToAliasMap = new HashMap<>(); + json.keySet().forEach(key -> fieldToAliasMap.put(key, json.get(key).toString())); + return fieldToAliasMap; + } + + private static List getColumnsFromSchema(JSONArray schema) { + List columns = + IntStream.range(0, schema.length()) + .mapToObj( + i -> { + JSONObject jsonColumn = schema.getJSONObject(i); + return new Schema.Column( + jsonColumn.getString("name"), + jsonColumn.optString("alias", null), + Schema.Type.valueOf(jsonColumn.getString("type").toUpperCase())); + }) + .collect(Collectors.toList()); + return columns; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/domain/ColumnTypeProvider.java b/legacy/src/main/java/org/opensearch/sql/legacy/domain/ColumnTypeProvider.java index 3b2691186b..b7d90b66da 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/domain/ColumnTypeProvider.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/domain/ColumnTypeProvider.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.domain; import com.google.common.collect.ImmutableList; @@ -17,66 +16,64 @@ import org.opensearch.sql.legacy.antlr.semantic.types.special.Product; import org.opensearch.sql.legacy.executor.format.Schema; -/** - * The definition of column type provider - */ +/** The definition of column type provider */ public class ColumnTypeProvider { - private final List typeList; + private final List typeList; - private static final Map TYPE_MAP = - new ImmutableMap.Builder() - .put(OpenSearchDataType.SHORT, Schema.Type.SHORT) - .put(OpenSearchDataType.LONG, Schema.Type.LONG) - .put(OpenSearchDataType.INTEGER, Schema.Type.INTEGER) - .put(OpenSearchDataType.FLOAT, Schema.Type.FLOAT) - .put(OpenSearchDataType.DOUBLE, Schema.Type.DOUBLE) - .put(OpenSearchDataType.KEYWORD, Schema.Type.KEYWORD) - .put(OpenSearchDataType.TEXT, Schema.Type.TEXT) - .put(OpenSearchDataType.STRING, Schema.Type.TEXT) - .put(OpenSearchDataType.DATE, Schema.Type.DATE) - .put(OpenSearchDataType.BOOLEAN, Schema.Type.BOOLEAN) - .put(OpenSearchDataType.UNKNOWN, Schema.Type.DOUBLE) - .build(); - public static final Schema.Type COLUMN_DEFAULT_TYPE = Schema.Type.DOUBLE; + private static final Map TYPE_MAP = + new ImmutableMap.Builder() + .put(OpenSearchDataType.SHORT, Schema.Type.SHORT) + .put(OpenSearchDataType.LONG, Schema.Type.LONG) + .put(OpenSearchDataType.INTEGER, Schema.Type.INTEGER) + .put(OpenSearchDataType.FLOAT, Schema.Type.FLOAT) + .put(OpenSearchDataType.DOUBLE, Schema.Type.DOUBLE) + .put(OpenSearchDataType.KEYWORD, Schema.Type.KEYWORD) + .put(OpenSearchDataType.TEXT, Schema.Type.TEXT) + .put(OpenSearchDataType.STRING, Schema.Type.TEXT) + .put(OpenSearchDataType.DATE, Schema.Type.DATE) + .put(OpenSearchDataType.BOOLEAN, Schema.Type.BOOLEAN) + .put(OpenSearchDataType.UNKNOWN, Schema.Type.DOUBLE) + .build(); + public static final Schema.Type COLUMN_DEFAULT_TYPE = Schema.Type.DOUBLE; - public ColumnTypeProvider(Type type) { - this.typeList = convertOutputColumnType(type); - } + public ColumnTypeProvider(Type type) { + this.typeList = convertOutputColumnType(type); + } - public ColumnTypeProvider() { - this.typeList = new ArrayList<>(); - } + public ColumnTypeProvider() { + this.typeList = new ArrayList<>(); + } - /** - * Get the type of column by index. - * - * @param index column index. - * @return column type. - */ - public Schema.Type get(int index) { - if (typeList.isEmpty()) { - return COLUMN_DEFAULT_TYPE; - } else { - return typeList.get(index); - } + /** + * Get the type of column by index. + * + * @param index column index. + * @return column type. + */ + public Schema.Type get(int index) { + if (typeList.isEmpty()) { + return COLUMN_DEFAULT_TYPE; + } else { + return typeList.get(index); } + } - private List convertOutputColumnType(Type type) { - if (type instanceof Product) { - List types = ((Product) type).getTypes(); - return types.stream().map(t -> convertType(t)).collect(Collectors.toList()); - } else if (type instanceof OpenSearchDataType) { - return ImmutableList.of(convertType(type)); - } else { - return ImmutableList.of(COLUMN_DEFAULT_TYPE); - } + private List convertOutputColumnType(Type type) { + if (type instanceof Product) { + List types = ((Product) type).getTypes(); + return types.stream().map(t -> convertType(t)).collect(Collectors.toList()); + } else if (type instanceof OpenSearchDataType) { + return ImmutableList.of(convertType(type)); + } else { + return ImmutableList.of(COLUMN_DEFAULT_TYPE); } + } - private Schema.Type convertType(Type type) { - try { - return TYPE_MAP.getOrDefault(type, COLUMN_DEFAULT_TYPE); - } catch (Exception e) { - return COLUMN_DEFAULT_TYPE; - } + private Schema.Type convertType(Type type) { + try { + return TYPE_MAP.getOrDefault(type, COLUMN_DEFAULT_TYPE); + } catch (Exception e) { + return COLUMN_DEFAULT_TYPE; } + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/domain/Condition.java b/legacy/src/main/java/org/opensearch/sql/legacy/domain/Condition.java index ff6b016ddb..8804c543f6 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/domain/Condition.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/domain/Condition.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.domain; import com.alibaba.druid.sql.ast.SQLExpr; @@ -18,363 +17,368 @@ import org.opensearch.sql.legacy.utils.StringUtils; /** - * - * * @author ansj */ public class Condition extends Where { - public enum OPERATOR { - - EQ, - GT, - LT, - GTE, - LTE, - N, - LIKE, - NLIKE, - REGEXP, - IS, - ISN, - IN, - NIN, - BETWEEN, - NBETWEEN, - GEO_INTERSECTS, - GEO_BOUNDING_BOX, - GEO_DISTANCE, - GEO_POLYGON, - IN_TERMS, - TERM, - IDS_QUERY, - NESTED_COMPLEX, - NOT_EXISTS_NESTED_COMPLEX, - CHILDREN_COMPLEX, - SCRIPT, - NIN_TERMS, - NTERM, - NREGEXP; - - public static Map methodNameToOpear; - - public static Map operStringToOpear; - - public static Map simpleOperStringToOpear; - - private static BiMap negatives; - - private static BiMap simpleReverses; - - static { - methodNameToOpear = new HashMap<>(); - methodNameToOpear.put("term", TERM); - methodNameToOpear.put("matchterm", TERM); - methodNameToOpear.put("match_term", TERM); - methodNameToOpear.put("terms", IN_TERMS); - methodNameToOpear.put("in_terms", IN_TERMS); - methodNameToOpear.put("ids", IDS_QUERY); - methodNameToOpear.put("ids_query", IDS_QUERY); - methodNameToOpear.put("regexp", REGEXP); - methodNameToOpear.put("regexp_query", REGEXP); - } - - static { - operStringToOpear = new HashMap<>(); - operStringToOpear.put("=", EQ); - operStringToOpear.put(">", GT); - operStringToOpear.put("<", LT); - operStringToOpear.put(">=", GTE); - operStringToOpear.put("<=", LTE); - operStringToOpear.put("<>", N); - operStringToOpear.put("LIKE", LIKE); - operStringToOpear.put("NOT", N); - operStringToOpear.put("NOT LIKE", NLIKE); - operStringToOpear.put("IS", IS); - operStringToOpear.put("IS NOT", ISN); - operStringToOpear.put("IN", IN); - operStringToOpear.put("NOT IN", NIN); - operStringToOpear.put("BETWEEN", BETWEEN); - operStringToOpear.put("NOT BETWEEN", NBETWEEN); - operStringToOpear.put("GEO_INTERSECTS", GEO_INTERSECTS); - operStringToOpear.put("GEO_BOUNDING_BOX", GEO_BOUNDING_BOX); - operStringToOpear.put("GEO_DISTANCE", GEO_DISTANCE); - operStringToOpear.put("GEO_POLYGON", GEO_POLYGON); - operStringToOpear.put("NESTED", NESTED_COMPLEX); - operStringToOpear.put("CHILDREN", CHILDREN_COMPLEX); - operStringToOpear.put("SCRIPT", SCRIPT); - } - - static { - simpleOperStringToOpear = new HashMap<>(); - simpleOperStringToOpear.put("=", EQ); - simpleOperStringToOpear.put(">", GT); - simpleOperStringToOpear.put("<", LT); - simpleOperStringToOpear.put(">=", GTE); - simpleOperStringToOpear.put("<=", LTE); - simpleOperStringToOpear.put("<>", N); - } - - static { - negatives = HashBiMap.create(7); - negatives.put(EQ, N); - negatives.put(IN_TERMS, NIN_TERMS); - negatives.put(TERM, NTERM); - negatives.put(GT, LTE); - negatives.put(LT, GTE); - negatives.put(LIKE, NLIKE); - negatives.put(IS, ISN); - negatives.put(IN, NIN); - negatives.put(BETWEEN, NBETWEEN); - negatives.put(NESTED_COMPLEX, NOT_EXISTS_NESTED_COMPLEX); - negatives.put(REGEXP, NREGEXP); - } - - static { - simpleReverses = HashBiMap.create(4); - simpleReverses.put(EQ, EQ); - simpleReverses.put(GT, LT); - simpleReverses.put(GTE, LTE); - simpleReverses.put(N, N); - } - - public OPERATOR negative() throws SqlParseException { - OPERATOR negative = negatives.get(this); - negative = negative != null ? negative : negatives.inverse().get(this); - if (negative == null) { - throw new SqlParseException(StringUtils.format("Negative operator [%s] is not supported.", - this.name())); - } - return negative; - } - - public OPERATOR simpleReverse() throws SqlParseException { - OPERATOR reverse = simpleReverses.get(this); - reverse = reverse != null ? reverse : simpleReverses.inverse().get(this); - if (reverse == null) { - throw new SqlParseException(StringUtils.format("Simple reverse operator [%s] is not supported.", - this.name())); - } - return reverse; - } - - public Boolean isSimpleOperator() { - return simpleOperStringToOpear.containsValue(this); - } + public enum OPERATOR { + EQ, + GT, + LT, + GTE, + LTE, + N, + LIKE, + NLIKE, + REGEXP, + IS, + ISN, + IN, + NIN, + BETWEEN, + NBETWEEN, + GEO_INTERSECTS, + GEO_BOUNDING_BOX, + GEO_DISTANCE, + GEO_POLYGON, + IN_TERMS, + TERM, + IDS_QUERY, + NESTED_COMPLEX, + NOT_EXISTS_NESTED_COMPLEX, + CHILDREN_COMPLEX, + SCRIPT, + NIN_TERMS, + NTERM, + NREGEXP; + + public static Map methodNameToOpear; + + public static Map operStringToOpear; + + public static Map simpleOperStringToOpear; + + private static BiMap negatives; + + private static BiMap simpleReverses; + + static { + methodNameToOpear = new HashMap<>(); + methodNameToOpear.put("term", TERM); + methodNameToOpear.put("matchterm", TERM); + methodNameToOpear.put("match_term", TERM); + methodNameToOpear.put("terms", IN_TERMS); + methodNameToOpear.put("in_terms", IN_TERMS); + methodNameToOpear.put("ids", IDS_QUERY); + methodNameToOpear.put("ids_query", IDS_QUERY); + methodNameToOpear.put("regexp", REGEXP); + methodNameToOpear.put("regexp_query", REGEXP); } - private String name; - - private SQLExpr nameExpr; - - private Object value; - - public SQLExpr getNameExpr() { - return nameExpr; + static { + operStringToOpear = new HashMap<>(); + operStringToOpear.put("=", EQ); + operStringToOpear.put(">", GT); + operStringToOpear.put("<", LT); + operStringToOpear.put(">=", GTE); + operStringToOpear.put("<=", LTE); + operStringToOpear.put("<>", N); + operStringToOpear.put("LIKE", LIKE); + operStringToOpear.put("NOT", N); + operStringToOpear.put("NOT LIKE", NLIKE); + operStringToOpear.put("IS", IS); + operStringToOpear.put("IS NOT", ISN); + operStringToOpear.put("IN", IN); + operStringToOpear.put("NOT IN", NIN); + operStringToOpear.put("BETWEEN", BETWEEN); + operStringToOpear.put("NOT BETWEEN", NBETWEEN); + operStringToOpear.put("GEO_INTERSECTS", GEO_INTERSECTS); + operStringToOpear.put("GEO_BOUNDING_BOX", GEO_BOUNDING_BOX); + operStringToOpear.put("GEO_DISTANCE", GEO_DISTANCE); + operStringToOpear.put("GEO_POLYGON", GEO_POLYGON); + operStringToOpear.put("NESTED", NESTED_COMPLEX); + operStringToOpear.put("CHILDREN", CHILDREN_COMPLEX); + operStringToOpear.put("SCRIPT", SCRIPT); } - public SQLExpr getValueExpr() { - return valueExpr; + static { + simpleOperStringToOpear = new HashMap<>(); + simpleOperStringToOpear.put("=", EQ); + simpleOperStringToOpear.put(">", GT); + simpleOperStringToOpear.put("<", LT); + simpleOperStringToOpear.put(">=", GTE); + simpleOperStringToOpear.put("<=", LTE); + simpleOperStringToOpear.put("<>", N); } - private SQLExpr valueExpr; - - private OPERATOR OPERATOR; - - private Object relationshipType; - - private boolean isNested; - private String nestedPath; - - private boolean isChildren; - private String childType; - - public Condition(CONN conn, String field, SQLExpr nameExpr, String condition, Object obj, SQLExpr valueExpr) - throws SqlParseException { - this(conn, field, nameExpr, condition, obj, valueExpr, null); + static { + negatives = HashBiMap.create(7); + negatives.put(EQ, N); + negatives.put(IN_TERMS, NIN_TERMS); + negatives.put(TERM, NTERM); + negatives.put(GT, LTE); + negatives.put(LT, GTE); + negatives.put(LIKE, NLIKE); + negatives.put(IS, ISN); + negatives.put(IN, NIN); + negatives.put(BETWEEN, NBETWEEN); + negatives.put(NESTED_COMPLEX, NOT_EXISTS_NESTED_COMPLEX); + negatives.put(REGEXP, NREGEXP); } - public Condition(CONN conn, String field, SQLExpr nameExpr, OPERATOR condition, Object obj, SQLExpr valueExpr) - throws SqlParseException { - this(conn, field, nameExpr, condition, obj, valueExpr, null); + static { + simpleReverses = HashBiMap.create(4); + simpleReverses.put(EQ, EQ); + simpleReverses.put(GT, LT); + simpleReverses.put(GTE, LTE); + simpleReverses.put(N, N); } - public Condition(CONN conn, String name, SQLExpr nameExpr, String oper, - Object value, SQLExpr valueExpr, Object relationshipType) throws SqlParseException { - super(conn); - - this.OPERATOR = null; - this.name = name; - this.value = value; - this.nameExpr = nameExpr; - this.valueExpr = valueExpr; - - this.relationshipType = relationshipType; - - if (this.relationshipType != null) { - if (this.relationshipType instanceof NestedType) { - NestedType nestedType = (NestedType) relationshipType; - - this.isNested = true; - this.nestedPath = nestedType.path; - this.isChildren = false; - this.childType = ""; - } else if (relationshipType instanceof ChildrenType) { - ChildrenType childrenType = (ChildrenType) relationshipType; - - this.isNested = false; - this.nestedPath = ""; - this.isChildren = true; - this.childType = childrenType.childType; - } - } else { - this.isNested = false; - this.nestedPath = ""; - this.isChildren = false; - this.childType = ""; - } - - if (OPERATOR.operStringToOpear.containsKey(oper)) { - this.OPERATOR = OPERATOR.operStringToOpear.get(oper); - } else { - throw new SqlParseException("Unsupported operation: " + oper); - } + public OPERATOR negative() throws SqlParseException { + OPERATOR negative = negatives.get(this); + negative = negative != null ? negative : negatives.inverse().get(this); + if (negative == null) { + throw new SqlParseException( + StringUtils.format("Negative operator [%s] is not supported.", this.name())); + } + return negative; } - - public Condition(CONN conn, - String name, - SQLExpr nameExpr, - OPERATOR oper, - Object value, - SQLExpr valueExpr, - Object relationshipType - ) throws SqlParseException { - super(conn); - - this.OPERATOR = null; - this.nameExpr = nameExpr; - this.valueExpr = valueExpr; - this.name = name; - this.value = value; - this.OPERATOR = oper; - this.relationshipType = relationshipType; - - if (this.relationshipType != null) { - if (this.relationshipType instanceof NestedType) { - NestedType nestedType = (NestedType) relationshipType; - - this.isNested = true; - this.nestedPath = nestedType.path; - this.isChildren = false; - this.childType = ""; - } else if (relationshipType instanceof ChildrenType) { - ChildrenType childrenType = (ChildrenType) relationshipType; - - this.isNested = false; - this.nestedPath = ""; - this.isChildren = true; - this.childType = childrenType.childType; - } - } else { - this.isNested = false; - this.nestedPath = ""; - this.isChildren = false; - this.childType = ""; - } + public OPERATOR simpleReverse() throws SqlParseException { + OPERATOR reverse = simpleReverses.get(this); + reverse = reverse != null ? reverse : simpleReverses.inverse().get(this); + if (reverse == null) { + throw new SqlParseException( + StringUtils.format("Simple reverse operator [%s] is not supported.", this.name())); + } + return reverse; } - public String getOpertatorSymbol() throws SqlParseException { - switch (OPERATOR) { - case EQ: - return "=="; - case GT: - return ">"; - case LT: - return "<"; - case GTE: - return ">="; - case LTE: - return "<="; - case N: - return "<>"; - case IS: - return "=="; - - case ISN: - return "!="; - default: - throw new SqlParseException(StringUtils.format("Failed to parse operator [%s]", OPERATOR)); - } + public Boolean isSimpleOperator() { + return simpleOperStringToOpear.containsValue(this); } - - - public String getName() { - return name; + } + + private String name; + + private SQLExpr nameExpr; + + private Object value; + + public SQLExpr getNameExpr() { + return nameExpr; + } + + public SQLExpr getValueExpr() { + return valueExpr; + } + + private SQLExpr valueExpr; + + private OPERATOR OPERATOR; + + private Object relationshipType; + + private boolean isNested; + private String nestedPath; + + private boolean isChildren; + private String childType; + + public Condition( + CONN conn, String field, SQLExpr nameExpr, String condition, Object obj, SQLExpr valueExpr) + throws SqlParseException { + this(conn, field, nameExpr, condition, obj, valueExpr, null); + } + + public Condition( + CONN conn, String field, SQLExpr nameExpr, OPERATOR condition, Object obj, SQLExpr valueExpr) + throws SqlParseException { + this(conn, field, nameExpr, condition, obj, valueExpr, null); + } + + public Condition( + CONN conn, + String name, + SQLExpr nameExpr, + String oper, + Object value, + SQLExpr valueExpr, + Object relationshipType) + throws SqlParseException { + super(conn); + + this.OPERATOR = null; + this.name = name; + this.value = value; + this.nameExpr = nameExpr; + this.valueExpr = valueExpr; + + this.relationshipType = relationshipType; + + if (this.relationshipType != null) { + if (this.relationshipType instanceof NestedType) { + NestedType nestedType = (NestedType) relationshipType; + + this.isNested = true; + this.nestedPath = nestedType.path; + this.isChildren = false; + this.childType = ""; + } else if (relationshipType instanceof ChildrenType) { + ChildrenType childrenType = (ChildrenType) relationshipType; + + this.isNested = false; + this.nestedPath = ""; + this.isChildren = true; + this.childType = childrenType.childType; + } + } else { + this.isNested = false; + this.nestedPath = ""; + this.isChildren = false; + this.childType = ""; } - public void setName(String name) { - this.name = name; + if (OPERATOR.operStringToOpear.containsKey(oper)) { + this.OPERATOR = OPERATOR.operStringToOpear.get(oper); + } else { + throw new SqlParseException("Unsupported operation: " + oper); } - - public Object getValue() { - return value; + } + + public Condition( + CONN conn, + String name, + SQLExpr nameExpr, + OPERATOR oper, + Object value, + SQLExpr valueExpr, + Object relationshipType) + throws SqlParseException { + super(conn); + + this.OPERATOR = null; + this.nameExpr = nameExpr; + this.valueExpr = valueExpr; + this.name = name; + this.value = value; + this.OPERATOR = oper; + this.relationshipType = relationshipType; + + if (this.relationshipType != null) { + if (this.relationshipType instanceof NestedType) { + NestedType nestedType = (NestedType) relationshipType; + + this.isNested = true; + this.nestedPath = nestedType.path; + this.isChildren = false; + this.childType = ""; + } else if (relationshipType instanceof ChildrenType) { + ChildrenType childrenType = (ChildrenType) relationshipType; + + this.isNested = false; + this.nestedPath = ""; + this.isChildren = true; + this.childType = childrenType.childType; + } + } else { + this.isNested = false; + this.nestedPath = ""; + this.isChildren = false; + this.childType = ""; } - - public void setValue(Object value) { - this.value = value; + } + + public String getOpertatorSymbol() throws SqlParseException { + switch (OPERATOR) { + case EQ: + return "=="; + case GT: + return ">"; + case LT: + return "<"; + case GTE: + return ">="; + case LTE: + return "<="; + case N: + return "<>"; + case IS: + return "=="; + + case ISN: + return "!="; + default: + throw new SqlParseException(StringUtils.format("Failed to parse operator [%s]", OPERATOR)); } + } - public OPERATOR getOPERATOR() { - return OPERATOR; - } + public String getName() { + return name; + } - public void setOPERATOR(OPERATOR OPERATOR) { - this.OPERATOR = OPERATOR; - } + public void setName(String name) { + this.name = name; + } - public Object getRelationshipType() { - return relationshipType; - } + public Object getValue() { + return value; + } - public void setRelationshipType(Object relationshipType) { - this.relationshipType = relationshipType; - } + public void setValue(Object value) { + this.value = value; + } - public boolean isNested() { - return isNested; - } + public OPERATOR getOPERATOR() { + return OPERATOR; + } - public void setNested(boolean isNested) { - this.isNested = isNested; - } + public void setOPERATOR(OPERATOR OPERATOR) { + this.OPERATOR = OPERATOR; + } - public String getNestedPath() { - return nestedPath; - } + public Object getRelationshipType() { + return relationshipType; + } - public void setNestedPath(String nestedPath) { - this.nestedPath = nestedPath; - } + public void setRelationshipType(Object relationshipType) { + this.relationshipType = relationshipType; + } - public boolean isChildren() { - return isChildren; - } + public boolean isNested() { + return isNested; + } - public void setChildren(boolean isChildren) { - this.isChildren = isChildren; - } + public void setNested(boolean isNested) { + this.isNested = isNested; + } - public String getChildType() { - return childType; - } + public String getNestedPath() { + return nestedPath; + } - public void setChildType(String childType) { - this.childType = childType; - } + public void setNestedPath(String nestedPath) { + this.nestedPath = nestedPath; + } + + public boolean isChildren() { + return isChildren; + } + + public void setChildren(boolean isChildren) { + this.isChildren = isChildren; + } + + public String getChildType() { + return childType; + } + + public void setChildType(String childType) { + this.childType = childType; + } /** - * Return true if the opear is {@link OPERATOR#NESTED_COMPLEX} + * Return true if the opear is {@link OPERATOR#NESTED_COMPLEX}
* For example, the opear is {@link OPERATOR#NESTED_COMPLEX} when condition is * nested('projects', projects.started_year > 2000 OR projects.name LIKE '%security%') */ @@ -382,40 +386,53 @@ public boolean isNestedComplex() { return OPERATOR.NESTED_COMPLEX == OPERATOR; } - @Override - public String toString() { - String result = ""; - - if (this.isNested()) { - result = "nested condition "; - if (this.getNestedPath() != null) { - result += "on path:" + this.getNestedPath() + " "; - } - } else if (this.isChildren()) { - result = "children condition "; - - if (this.getChildType() != null) { - result += "on child: " + this.getChildType() + " "; - } - } - - if (value instanceof Object[]) { - result += this.conn + " " + this.name + " " + this.OPERATOR + " " + Arrays.toString((Object[]) value); - } else { - result += this.conn + " " + this.name + " " + this.OPERATOR + " " + this.value; - } - - return result; + @Override + public String toString() { + String result = ""; + + if (this.isNested()) { + result = "nested condition "; + if (this.getNestedPath() != null) { + result += "on path:" + this.getNestedPath() + " "; + } + } else if (this.isChildren()) { + result = "children condition "; + + if (this.getChildType() != null) { + result += "on child: " + this.getChildType() + " "; + } + } + + if (value instanceof Object[]) { + result += + this.conn + + " " + + this.name + + " " + + this.OPERATOR + + " " + + Arrays.toString((Object[]) value); + } else { + result += this.conn + " " + this.name + " " + this.OPERATOR + " " + this.value; } - @Override - public Object clone() throws CloneNotSupportedException { - try { - return new Condition(this.getConn(), this.getName(), this.getNameExpr(), - this.getOPERATOR(), this.getValue(), this.getValueExpr(), this.getRelationshipType()); - } catch (SqlParseException e) { + return result; + } + + @Override + public Object clone() throws CloneNotSupportedException { + try { + return new Condition( + this.getConn(), + this.getName(), + this.getNameExpr(), + this.getOPERATOR(), + this.getValue(), + this.getValueExpr(), + this.getRelationshipType()); + } catch (SqlParseException e) { - } - return null; } + return null; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/domain/Delete.java b/legacy/src/main/java/org/opensearch/sql/legacy/domain/Delete.java index 587a8b3ef9..efa77da0a5 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/domain/Delete.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/domain/Delete.java @@ -3,12 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.domain; -/** - * SQL Delete statement. - */ -public class Delete extends Query { - -} +/** SQL Delete statement. */ +public class Delete extends Query {} diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/domain/bucketpath/BucketPath.java b/legacy/src/main/java/org/opensearch/sql/legacy/domain/bucketpath/BucketPath.java index 996caae5e2..635d0062a5 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/domain/bucketpath/BucketPath.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/domain/bucketpath/BucketPath.java @@ -3,39 +3,35 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.domain.bucketpath; import java.util.ArrayDeque; import java.util.Deque; /** - * The bucket path syntax + * The bucket path syntax
* [ , ]* [ , ] * - * https://www.elastic.co/guide/en/elasticsearch/reference/current/search-aggregations-pipeline.html#buckets-path-syntax + *

https://www.elastic.co/guide/en/elasticsearch/reference/current/search-aggregations-pipeline.html#buckets-path-syntax */ public class BucketPath { - private Deque pathStack = new ArrayDeque<>(); + private Deque pathStack = new ArrayDeque<>(); - public BucketPath add(Path path) { - if (pathStack.isEmpty()) { - assert path.isMetricPath() : "The last path in the bucket path must be Metric"; - } else { - assert path.isAggPath() : "All the other path in the bucket path must be Agg"; - } - pathStack.push(path); - return this; + public BucketPath add(Path path) { + if (pathStack.isEmpty()) { + assert path.isMetricPath() : "The last path in the bucket path must be Metric"; + } else { + assert path.isAggPath() : "All the other path in the bucket path must be Agg"; } + pathStack.push(path); + return this; + } - /** - * Return the bucket path. - * Return "", if there is no agg or metric available - */ - public String getBucketPath() { - String bucketPath = pathStack.isEmpty() ? "" : pathStack.pop().getPath(); - return pathStack.stream() - .map(path -> path.getSeparator() + path.getPath()) - .reduce(bucketPath, String::concat); - } + /** Return the bucket path. Return "", if there is no agg or metric available */ + public String getBucketPath() { + String bucketPath = pathStack.isEmpty() ? "" : pathStack.pop().getPath(); + return pathStack.stream() + .map(path -> path.getSeparator() + path.getPath()) + .reduce(bucketPath, String::concat); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/AsyncRestExecutor.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/AsyncRestExecutor.java index 1df0036bab..b3cec4648c 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/AsyncRestExecutor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/AsyncRestExecutor.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor; import java.io.IOException; @@ -29,135 +28,141 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.Transports; -/** - * A RestExecutor wrapper to execute request asynchronously to avoid blocking transport thread. - */ +/** A RestExecutor wrapper to execute request asynchronously to avoid blocking transport thread. */ public class AsyncRestExecutor implements RestExecutor { - /** - * Custom thread pool name managed by OpenSearch - */ - public static final String SQL_WORKER_THREAD_POOL_NAME = "sql-worker"; - - private static final Logger LOG = LogManager.getLogger(AsyncRestExecutor.class); - - /** - * Treat all actions as blocking which means async all actions, - * ex. execute() in csv executor or pretty format executor - */ - private static final Predicate ALL_ACTION_IS_BLOCKING = anyAction -> true; - - /** - * Delegated rest executor to async - */ - private final RestExecutor executor; - - /** - * Request type that expect to async to avoid blocking - */ - private final Predicate isBlocking; - - - AsyncRestExecutor(RestExecutor executor) { - this(executor, ALL_ACTION_IS_BLOCKING); + /** Custom thread pool name managed by OpenSearch */ + public static final String SQL_WORKER_THREAD_POOL_NAME = "sql-worker"; + + private static final Logger LOG = LogManager.getLogger(AsyncRestExecutor.class); + + /** + * Treat all actions as blocking which means async all actions, ex. execute() in csv executor or + * pretty format executor + */ + private static final Predicate ALL_ACTION_IS_BLOCKING = anyAction -> true; + + /** Delegated rest executor to async */ + private final RestExecutor executor; + + /** Request type that expect to async to avoid blocking */ + private final Predicate isBlocking; + + AsyncRestExecutor(RestExecutor executor) { + this(executor, ALL_ACTION_IS_BLOCKING); + } + + AsyncRestExecutor(RestExecutor executor, Predicate isBlocking) { + this.executor = executor; + this.isBlocking = isBlocking; + } + + @Override + public void execute( + Client client, Map params, QueryAction queryAction, RestChannel channel) + throws Exception { + if (isBlockingAction(queryAction) && isRunningInTransportThread()) { + if (LOG.isDebugEnabled()) { + LOG.debug( + "[{}] Async blocking query action [{}] for executor [{}] in current thread [{}]", + QueryContext.getRequestId(), + name(executor), + name(queryAction), + Thread.currentThread().getName()); + } + async(client, params, queryAction, channel); + } else { + if (LOG.isDebugEnabled()) { + LOG.debug( + "[{}] Continue running query action [{}] for executor [{}] in current thread [{}]", + QueryContext.getRequestId(), + name(executor), + name(queryAction), + Thread.currentThread().getName()); + } + doExecuteWithTimeMeasured(client, params, queryAction, channel); } - - AsyncRestExecutor(RestExecutor executor, Predicate isBlocking) { - this.executor = executor; - this.isBlocking = isBlocking; - } - - @Override - public void execute(Client client, Map params, QueryAction queryAction, RestChannel channel) - throws Exception { - if (isBlockingAction(queryAction) && isRunningInTransportThread()) { - if (LOG.isDebugEnabled()) { - LOG.debug("[{}] Async blocking query action [{}] for executor [{}] in current thread [{}]", - QueryContext.getRequestId(), name(executor), name(queryAction), Thread.currentThread().getName()); - } - async(client, params, queryAction, channel); - } else { - if (LOG.isDebugEnabled()) { - LOG.debug("[{}] Continue running query action [{}] for executor [{}] in current thread [{}]", - QueryContext.getRequestId(), name(executor), name(queryAction), Thread.currentThread().getName()); - } + } + + @Override + public String execute(Client client, Map params, QueryAction queryAction) + throws Exception { + // Result is always required and no easy way to async it here. + return executor.execute(client, params, queryAction); + } + + private boolean isBlockingAction(QueryAction queryAction) { + return isBlocking.test(queryAction); + } + + private boolean isRunningInTransportThread() { + return Transports.isTransportThread(Thread.currentThread()); + } + + /** Run given task in thread pool asynchronously */ + private void async( + Client client, Map params, QueryAction queryAction, RestChannel channel) { + + ThreadPool threadPool = client.threadPool(); + Runnable runnable = + () -> { + try { doExecuteWithTimeMeasured(client, params, queryAction, channel); - } - } - - @Override - public String execute(Client client, Map params, QueryAction queryAction) throws Exception { - // Result is always required and no easy way to async it here. - return executor.execute(client, params, queryAction); - } - - private boolean isBlockingAction(QueryAction queryAction) { - return isBlocking.test(queryAction); - } - - private boolean isRunningInTransportThread() { - return Transports.isTransportThread(Thread.currentThread()); - } - - /** - * Run given task in thread pool asynchronously - */ - private void async(Client client, Map params, QueryAction queryAction, RestChannel channel) { - - ThreadPool threadPool = client.threadPool(); - Runnable runnable = () -> { - try { - doExecuteWithTimeMeasured(client, params, queryAction, channel); - } catch (IOException | SqlParseException | OpenSearchException e) { - Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); - LOG.warn("[{}] [MCB] async task got an IO/SQL exception: {}", QueryContext.getRequestId(), - e.getMessage()); - channel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, e.getMessage())); - } catch (IllegalStateException e) { - Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); - LOG.warn("[{}] [MCB] async task got a runtime exception: {}", QueryContext.getRequestId(), - e.getMessage()); - channel.sendResponse(new BytesRestResponse(RestStatus.INSUFFICIENT_STORAGE, - "Memory circuit is broken.")); - } catch (Throwable t) { - Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); - LOG.warn("[{}] [MCB] async task got an unknown throwable: {}", QueryContext.getRequestId(), - t.getMessage()); - channel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, - String.valueOf(t.getMessage()))); - } finally { - BackOffRetryStrategy.releaseMem(executor); - } + } catch (IOException | SqlParseException | OpenSearchException e) { + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); + LOG.warn( + "[{}] [MCB] async task got an IO/SQL exception: {}", + QueryContext.getRequestId(), + e.getMessage()); + channel.sendResponse( + new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, e.getMessage())); + } catch (IllegalStateException e) { + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); + LOG.warn( + "[{}] [MCB] async task got a runtime exception: {}", + QueryContext.getRequestId(), + e.getMessage()); + channel.sendResponse( + new BytesRestResponse( + RestStatus.INSUFFICIENT_STORAGE, "Memory circuit is broken.")); + } catch (Throwable t) { + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); + LOG.warn( + "[{}] [MCB] async task got an unknown throwable: {}", + QueryContext.getRequestId(), + t.getMessage()); + channel.sendResponse( + new BytesRestResponse( + RestStatus.INTERNAL_SERVER_ERROR, String.valueOf(t.getMessage()))); + } finally { + BackOffRetryStrategy.releaseMem(executor); + } }; - // Preserve context of calling thread to ensure headers of requests are forwarded when running blocking actions - threadPool.schedule( - QueryContext.withCurrentContext(runnable), - new TimeValue(0L), - SQL_WORKER_THREAD_POOL_NAME - ); + // Preserve context of calling thread to ensure headers of requests are forwarded when running + // blocking actions + threadPool.schedule( + QueryContext.withCurrentContext(runnable), new TimeValue(0L), SQL_WORKER_THREAD_POOL_NAME); + } + + /** Time the real execution of Executor and log slow query for troubleshooting */ + private void doExecuteWithTimeMeasured( + Client client, Map params, QueryAction action, RestChannel channel) + throws Exception { + long startTime = System.nanoTime(); + try { + executor.execute(client, params, action, channel); + } finally { + Duration elapsed = Duration.ofNanos(System.nanoTime() - startTime); + int slowLogThreshold = LocalClusterState.state().getSettingValue(Settings.Key.SQL_SLOWLOG); + if (elapsed.getSeconds() >= slowLogThreshold) { + LOG.warn( + "[{}] Slow query: elapsed={} (ms)", QueryContext.getRequestId(), elapsed.toMillis()); + } } + } - /** - * Time the real execution of Executor and log slow query for troubleshooting - */ - private void doExecuteWithTimeMeasured(Client client, - Map params, - QueryAction action, - RestChannel channel) throws Exception { - long startTime = System.nanoTime(); - try { - executor.execute(client, params, action, channel); - } finally { - Duration elapsed = Duration.ofNanos(System.nanoTime() - startTime); - int slowLogThreshold = LocalClusterState.state().getSettingValue(Settings.Key.SQL_SLOWLOG); - if (elapsed.getSeconds() >= slowLogThreshold) { - LOG.warn("[{}] Slow query: elapsed={} (ms)", QueryContext.getRequestId(), elapsed.toMillis()); - } - } - } - - private String name(Object object) { - return object.getClass().getSimpleName(); - } + private String name(Object object) { + return object.getClass().getSimpleName(); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/ElasticDefaultRestExecutor.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/ElasticDefaultRestExecutor.java index 7ba5f384c0..ce777c4468 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/ElasticDefaultRestExecutor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/ElasticDefaultRestExecutor.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor; import com.google.common.collect.Maps; @@ -36,90 +35,94 @@ import org.opensearch.sql.legacy.query.join.JoinRequestBuilder; import org.opensearch.sql.legacy.query.multi.MultiQueryRequestBuilder; - public class ElasticDefaultRestExecutor implements RestExecutor { - /** - * Request builder to generate OpenSearch DSL - */ - private final SqlElasticRequestBuilder requestBuilder; + /** Request builder to generate OpenSearch DSL */ + private final SqlElasticRequestBuilder requestBuilder; - private static final Logger LOG = LogManager.getLogger(ElasticDefaultRestExecutor.class); + private static final Logger LOG = LogManager.getLogger(ElasticDefaultRestExecutor.class); - public ElasticDefaultRestExecutor(QueryAction queryAction) { - // Put explain() here to make it run in NIO thread - try { - this.requestBuilder = queryAction.explain(); - } catch (SqlParseException e) { - throw new IllegalStateException("Failed to explain query action", e); - } + public ElasticDefaultRestExecutor(QueryAction queryAction) { + // Put explain() here to make it run in NIO thread + try { + this.requestBuilder = queryAction.explain(); + } catch (SqlParseException e) { + throw new IllegalStateException("Failed to explain query action", e); } + } - /** - * Execute the ActionRequest and returns the REST response using the channel. - */ - @Override - public void execute(Client client, Map params, QueryAction queryAction, RestChannel channel) - throws Exception { - ActionRequest request = requestBuilder.request(); + /** Execute the ActionRequest and returns the REST response using the channel. */ + @Override + public void execute( + Client client, Map params, QueryAction queryAction, RestChannel channel) + throws Exception { + ActionRequest request = requestBuilder.request(); - if (requestBuilder instanceof JoinRequestBuilder) { - ElasticJoinExecutor executor = ElasticJoinExecutor.createJoinExecutor(client, requestBuilder); - executor.run(); - executor.sendResponse(channel); - } else if (requestBuilder instanceof MultiQueryRequestBuilder) { - ElasticHitsExecutor executor = MultiRequestExecutorFactory.createExecutor(client, - (MultiQueryRequestBuilder) requestBuilder); - executor.run(); - sendDefaultResponse(executor.getHits(), channel); - } else if (request instanceof SearchRequest) { - client.search((SearchRequest) request, new RestStatusToXContentListener<>(channel)); - } else if (request instanceof DeleteByQueryRequest) { - requestBuilder.getBuilder().execute( - new BulkIndexByScrollResponseContentListener(channel, Maps.newHashMap())); - } else if (request instanceof GetIndexRequest) { - requestBuilder.getBuilder().execute(new GetIndexRequestRestListener(channel, (GetIndexRequest) request)); - } else if (request instanceof SearchScrollRequest) { - client.searchScroll((SearchScrollRequest) request, new RestStatusToXContentListener<>(channel)); - } else { - throw new Exception(String.format("Unsupported ActionRequest provided: %s", request.getClass().getName())); - } + if (requestBuilder instanceof JoinRequestBuilder) { + ElasticJoinExecutor executor = ElasticJoinExecutor.createJoinExecutor(client, requestBuilder); + executor.run(); + executor.sendResponse(channel); + } else if (requestBuilder instanceof MultiQueryRequestBuilder) { + ElasticHitsExecutor executor = + MultiRequestExecutorFactory.createExecutor( + client, (MultiQueryRequestBuilder) requestBuilder); + executor.run(); + sendDefaultResponse(executor.getHits(), channel); + } else if (request instanceof SearchRequest) { + client.search((SearchRequest) request, new RestStatusToXContentListener<>(channel)); + } else if (request instanceof DeleteByQueryRequest) { + requestBuilder + .getBuilder() + .execute(new BulkIndexByScrollResponseContentListener(channel, Maps.newHashMap())); + } else if (request instanceof GetIndexRequest) { + requestBuilder + .getBuilder() + .execute(new GetIndexRequestRestListener(channel, (GetIndexRequest) request)); + } else if (request instanceof SearchScrollRequest) { + client.searchScroll( + (SearchScrollRequest) request, new RestStatusToXContentListener<>(channel)); + } else { + throw new Exception( + String.format("Unsupported ActionRequest provided: %s", request.getClass().getName())); } + } - @Override - public String execute(Client client, Map params, QueryAction queryAction) throws Exception { - ActionRequest request = requestBuilder.request(); - - if (requestBuilder instanceof JoinRequestBuilder) { - ElasticJoinExecutor executor = ElasticJoinExecutor.createJoinExecutor(client, requestBuilder); - executor.run(); - return ElasticUtils.hitsAsStringResult(executor.getHits(), new MetaSearchResult()); - } else if (requestBuilder instanceof MultiQueryRequestBuilder) { - ElasticHitsExecutor executor = MultiRequestExecutorFactory.createExecutor(client, - (MultiQueryRequestBuilder) requestBuilder); - executor.run(); - return ElasticUtils.hitsAsStringResult(executor.getHits(), new MetaSearchResult()); - } else if (request instanceof SearchRequest) { - ActionFuture future = client.search((SearchRequest) request); - SearchResponse response = future.actionGet(); - return response.toString(); - } else if (request instanceof DeleteByQueryRequest) { - return requestBuilder.get().toString(); - } else if (request instanceof GetIndexRequest) { - return requestBuilder.getBuilder().execute().actionGet().toString(); - } else { - throw new Exception(String.format("Unsupported ActionRequest provided: %s", request.getClass().getName())); - } + @Override + public String execute(Client client, Map params, QueryAction queryAction) + throws Exception { + ActionRequest request = requestBuilder.request(); + if (requestBuilder instanceof JoinRequestBuilder) { + ElasticJoinExecutor executor = ElasticJoinExecutor.createJoinExecutor(client, requestBuilder); + executor.run(); + return ElasticUtils.hitsAsStringResult(executor.getHits(), new MetaSearchResult()); + } else if (requestBuilder instanceof MultiQueryRequestBuilder) { + ElasticHitsExecutor executor = + MultiRequestExecutorFactory.createExecutor( + client, (MultiQueryRequestBuilder) requestBuilder); + executor.run(); + return ElasticUtils.hitsAsStringResult(executor.getHits(), new MetaSearchResult()); + } else if (request instanceof SearchRequest) { + ActionFuture future = client.search((SearchRequest) request); + SearchResponse response = future.actionGet(); + return response.toString(); + } else if (request instanceof DeleteByQueryRequest) { + return requestBuilder.get().toString(); + } else if (request instanceof GetIndexRequest) { + return requestBuilder.getBuilder().execute().actionGet().toString(); + } else { + throw new Exception( + String.format("Unsupported ActionRequest provided: %s", request.getClass().getName())); } + } - private void sendDefaultResponse(SearchHits hits, RestChannel channel) { - try { - String json = ElasticUtils.hitsAsStringResult(hits, new MetaSearchResult()); - BytesRestResponse bytesRestResponse = new BytesRestResponse(RestStatus.OK, json); - channel.sendResponse(bytesRestResponse); - } catch (IOException e) { - e.printStackTrace(); - } + private void sendDefaultResponse(SearchHits hits, RestChannel channel) { + try { + String json = ElasticUtils.hitsAsStringResult(hits, new MetaSearchResult()); + BytesRestResponse bytesRestResponse = new BytesRestResponse(RestStatus.OK, json); + channel.sendResponse(bytesRestResponse); + } catch (IOException e) { + e.printStackTrace(); } + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/ElasticHitsExecutor.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/ElasticHitsExecutor.java index c48eb673bd..62a6d63ef7 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/ElasticHitsExecutor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/ElasticHitsExecutor.java @@ -3,18 +3,15 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor; import java.io.IOException; import org.opensearch.search.SearchHits; import org.opensearch.sql.legacy.exception.SqlParseException; -/** - * Created by Eliran on 21/8/2016. - */ +/** Created by Eliran on 21/8/2016. */ public interface ElasticHitsExecutor { - void run() throws IOException, SqlParseException; + void run() throws IOException, SqlParseException; - SearchHits getHits(); + SearchHits getHits(); } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/ElasticResultHandler.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/ElasticResultHandler.java index ff241fce77..6f753a5e7c 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/ElasticResultHandler.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/ElasticResultHandler.java @@ -3,38 +3,34 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor; import java.util.Map; import org.opensearch.search.SearchHit; -/** - * Created by Eliran on 3/10/2015. - */ +/** Created by Eliran on 3/10/2015. */ public class ElasticResultHandler { - public static Object getFieldValue(SearchHit hit, String field) { - return deepSearchInMap(hit.getSourceAsMap(), field); - } + public static Object getFieldValue(SearchHit hit, String field) { + return deepSearchInMap(hit.getSourceAsMap(), field); + } - private static Object deepSearchInMap(Map fieldsMap, String name) { - if (name.contains(".")) { - String[] path = name.split("\\."); - Map currentObject = fieldsMap; - for (int i = 0; i < path.length - 1; i++) { - Object valueFromCurrentMap = currentObject.get(path[i]); - if (valueFromCurrentMap == null) { - return null; - } - if (!Map.class.isAssignableFrom(valueFromCurrentMap.getClass())) { - return null; - } - currentObject = (Map) valueFromCurrentMap; - } - return currentObject.get(path[path.length - 1]); + private static Object deepSearchInMap(Map fieldsMap, String name) { + if (name.contains(".")) { + String[] path = name.split("\\."); + Map currentObject = fieldsMap; + for (int i = 0; i < path.length - 1; i++) { + Object valueFromCurrentMap = currentObject.get(path[i]); + if (valueFromCurrentMap == null) { + return null; } - - return fieldsMap.get(name); + if (!Map.class.isAssignableFrom(valueFromCurrentMap.getClass())) { + return null; + } + currentObject = (Map) valueFromCurrentMap; + } + return currentObject.get(path[path.length - 1]); } + return fieldsMap.get(name); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/csv/CSVResult.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/csv/CSVResult.java index 680c0c8e85..28bc559a01 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/csv/CSVResult.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/csv/CSVResult.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.csv; import com.google.common.collect.ImmutableSet; @@ -12,86 +11,86 @@ import java.util.Set; import java.util.stream.Collectors; -/** - * Created by Eliran on 27/12/2015. - */ +/** Created by Eliran on 27/12/2015. */ public class CSVResult { - private static final Set SENSITIVE_CHAR = ImmutableSet.of("=", "+", "-", "@"); + private static final Set SENSITIVE_CHAR = ImmutableSet.of("=", "+", "-", "@"); - private final List headers; - private final List lines; + private final List headers; + private final List lines; - /** - * Skip sanitizing if string line provided. This constructor is basically used by - * assertion in test code. - */ - public CSVResult(List headers, List lines) { - this.headers = headers; - this.lines = lines; - } + /** + * Skip sanitizing if string line provided. This constructor is basically used by assertion in + * test code. + */ + public CSVResult(List headers, List lines) { + this.headers = headers; + this.lines = lines; + } /** * Sanitize both headers and data lines by: - * 1) First prepend single quote if first char is sensitive (= - + @) - * 2) Second double quote entire cell if any comma found + *

    + *
  1. First prepend single quote if first char is sensitive (= - + @) + *
  2. Second double quote entire cell if any comma found + *
*/ public CSVResult(String separator, List headers, List> lines) { this.headers = sanitizeHeaders(separator, headers); this.lines = sanitizeLines(separator, lines); } - /** - * Return CSV header names which are sanitized because OpenSearch allows - * special character present in field name too. - * @return CSV header name list after sanitized - */ - public List getHeaders() { - return headers; - } - - /** - * Return CSV lines in which each cell is sanitized to avoid CSV injection. - * @return CSV lines after sanitized - */ - public List getLines() { - return lines; + /** + * Return CSV header names which are sanitized because OpenSearch allows special character present + * in field name too. + * + * @return CSV header name list after sanitized + */ + public List getHeaders() { + return headers; + } + + /** + * Return CSV lines in which each cell is sanitized to avoid CSV injection. + * + * @return CSV lines after sanitized + */ + public List getLines() { + return lines; + } + + private List sanitizeHeaders(String separator, List headers) { + return headers.stream() + .map(this::sanitizeCell) + .map(cell -> quoteIfRequired(separator, cell)) + .collect(Collectors.toList()); + } + + private List sanitizeLines(String separator, List> lines) { + List result = new ArrayList<>(); + for (List line : lines) { + result.add( + line.stream() + .map(this::sanitizeCell) + .map(cell -> quoteIfRequired(separator, cell)) + .collect(Collectors.joining(separator))); } + return result; + } - private List sanitizeHeaders(String separator, List headers) { - return headers.stream(). - map(this::sanitizeCell). - map(cell -> quoteIfRequired(separator, cell)). - collect(Collectors.toList()); + private String sanitizeCell(String cell) { + if (isStartWithSensitiveChar(cell)) { + return "'" + cell; } + return cell; + } - private List sanitizeLines(String separator, List> lines) { - List result = new ArrayList<>(); - for (List line : lines) { - result.add(line.stream(). - map(this::sanitizeCell). - map(cell -> quoteIfRequired(separator, cell)). - collect(Collectors.joining(separator))); - } - return result; - } - - private String sanitizeCell(String cell) { - if (isStartWithSensitiveChar(cell)) { - return "'" + cell; - } - return cell; - } - - private String quoteIfRequired(String separator, String cell) { - final String quote = "\""; - return cell.contains(separator) - ? quote + cell.replaceAll("\"", "\"\"") + quote : cell; - } - - private boolean isStartWithSensitiveChar(String cell) { - return SENSITIVE_CHAR.stream(). - anyMatch(cell::startsWith); - } + private String quoteIfRequired(String separator, String cell) { + final String quote = "\""; + return cell.contains(separator) ? quote + cell.replaceAll("\"", "\"\"") + quote : cell; + } + private boolean isStartWithSensitiveChar(String cell) { + return SENSITIVE_CHAR.stream().anyMatch(cell::startsWith); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/csv/CSVResultRestExecutor.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/csv/CSVResultRestExecutor.java index ae7623e3a2..81bc35a40d 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/csv/CSVResultRestExecutor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/csv/CSVResultRestExecutor.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.csv; import com.google.common.base.Joiner; @@ -18,60 +17,64 @@ import org.opensearch.sql.legacy.query.QueryAction; import org.opensearch.sql.legacy.query.join.BackOffRetryStrategy; -/** - * Created by Eliran on 26/12/2015. - */ +/** Created by Eliran on 26/12/2015. */ public class CSVResultRestExecutor implements RestExecutor { - @Override - public void execute(final Client client, final Map params, final QueryAction queryAction, - final RestChannel channel) throws Exception { - - final String csvString = execute(client, params, queryAction); - final BytesRestResponse bytesRestResponse = new BytesRestResponse(RestStatus.OK, csvString); + @Override + public void execute( + final Client client, + final Map params, + final QueryAction queryAction, + final RestChannel channel) + throws Exception { - if (!BackOffRetryStrategy.isHealthy(2 * bytesRestResponse.content().length(), this)) { - throw new IllegalStateException( - "[CSVResultRestExecutor] Memory could be insufficient when sendResponse()."); - } + final String csvString = execute(client, params, queryAction); + final BytesRestResponse bytesRestResponse = new BytesRestResponse(RestStatus.OK, csvString); - channel.sendResponse(bytesRestResponse); + if (!BackOffRetryStrategy.isHealthy(2 * bytesRestResponse.content().length(), this)) { + throw new IllegalStateException( + "[CSVResultRestExecutor] Memory could be insufficient when sendResponse()."); } - @Override - public String execute(final Client client, final Map params, final QueryAction queryAction) - throws Exception { + channel.sendResponse(bytesRestResponse); + } - final Object queryResult = QueryActionElasticExecutor.executeAnyAction(client, queryAction); + @Override + public String execute( + final Client client, final Map params, final QueryAction queryAction) + throws Exception { - final String separator = params.getOrDefault("separator", ","); - final String newLine = params.getOrDefault("newLine", "\n"); + final Object queryResult = QueryActionElasticExecutor.executeAnyAction(client, queryAction); - final boolean flat = getBooleanOrDefault(params, "flat", false); - final boolean includeScore = getBooleanOrDefault(params, "_score", false); - final boolean includeId = getBooleanOrDefault(params, "_id", false); + final String separator = params.getOrDefault("separator", ","); + final String newLine = params.getOrDefault("newLine", "\n"); - final List fieldNames = queryAction.getFieldNames().orElse(null); - final CSVResult result = new CSVResultsExtractor(includeScore, includeId) - .extractResults(queryResult, flat, separator, fieldNames); + final boolean flat = getBooleanOrDefault(params, "flat", false); + final boolean includeScore = getBooleanOrDefault(params, "_score", false); + final boolean includeId = getBooleanOrDefault(params, "_id", false); - return buildString(separator, result, newLine); - } + final List fieldNames = queryAction.getFieldNames().orElse(null); + final CSVResult result = + new CSVResultsExtractor(includeScore, includeId) + .extractResults(queryResult, flat, separator, fieldNames); - private boolean getBooleanOrDefault(Map params, String param, boolean defaultValue) { - boolean flat = defaultValue; - if (params.containsKey(param)) { - flat = Boolean.parseBoolean(params.get(param)); - } - return flat; - } + return buildString(separator, result, newLine); + } - private String buildString(String separator, CSVResult result, String newLine) { - StringBuilder csv = new StringBuilder(); - csv.append(Joiner.on(separator).join(result.getHeaders())); - csv.append(newLine); - csv.append(Joiner.on(newLine).join(result.getLines())); - return csv.toString(); + private boolean getBooleanOrDefault( + Map params, String param, boolean defaultValue) { + boolean flat = defaultValue; + if (params.containsKey(param)) { + flat = Boolean.parseBoolean(params.get(param)); } - + return flat; + } + + private String buildString(String separator, CSVResult result, String newLine) { + StringBuilder csv = new StringBuilder(); + csv.append(Joiner.on(separator).join(result.getHeaders())); + csv.append(newLine); + csv.append(Joiner.on(newLine).join(result.getLines())); + return csv.toString(); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/csv/CSVResultsExtractor.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/csv/CSVResultsExtractor.java index 70cdd91452..0c299ac7e0 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/csv/CSVResultsExtractor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/csv/CSVResultsExtractor.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.csv; import java.util.ArrayList; @@ -31,320 +30,332 @@ import org.opensearch.sql.legacy.expression.model.ExprValue; import org.opensearch.sql.legacy.utils.Util; -/** - * Created by Eliran on 27/12/2015. - */ +/** Created by Eliran on 27/12/2015. */ public class CSVResultsExtractor { - private final boolean includeScore; - private final boolean includeId; - private int currentLineIndex; - - public CSVResultsExtractor(boolean includeScore, boolean includeId) { - this.includeScore = includeScore; - this.includeId = includeId; - this.currentLineIndex = 0; + private final boolean includeScore; + private final boolean includeId; + private int currentLineIndex; + + public CSVResultsExtractor(boolean includeScore, boolean includeId) { + this.includeScore = includeScore; + this.includeId = includeId; + this.currentLineIndex = 0; + } + + public CSVResult extractResults( + Object queryResult, boolean flat, String separator, final List fieldNames) + throws CsvExtractorException { + + if (queryResult instanceof SearchHits) { + SearchHit[] hits = ((SearchHits) queryResult).getHits(); + List> docsAsMap = new ArrayList<>(); + List headers = createHeadersAndFillDocsMap(flat, hits, docsAsMap, fieldNames); + List> csvLines = createCSVLinesFromDocs(flat, separator, docsAsMap, headers); + return new CSVResult(separator, headers, csvLines); } - - public CSVResult extractResults(Object queryResult, boolean flat, String separator, - final List fieldNames) throws CsvExtractorException { - - if (queryResult instanceof SearchHits) { - SearchHit[] hits = ((SearchHits) queryResult).getHits(); - List> docsAsMap = new ArrayList<>(); - List headers = createHeadersAndFillDocsMap(flat, hits, docsAsMap, fieldNames); - List> csvLines = createCSVLinesFromDocs(flat, separator, docsAsMap, headers); - return new CSVResult(separator, headers, csvLines); - } - if (queryResult instanceof Aggregations) { - List headers = new ArrayList<>(); - List> lines = new ArrayList<>(); - lines.add(new ArrayList()); - handleAggregations((Aggregations) queryResult, headers, lines); - return new CSVResult(separator, headers, lines); - } - // Handle List result. - if (queryResult instanceof List) { - List bindingTuples = (List) queryResult; - List> csvLines = bindingTuples.stream().map(tuple -> { - Map bindingMap = tuple.getBindingMap(); - List rowValues = new ArrayList<>(); - for (String fieldName : fieldNames) { - if (bindingMap.containsKey(fieldName)) { + if (queryResult instanceof Aggregations) { + List headers = new ArrayList<>(); + List> lines = new ArrayList<>(); + lines.add(new ArrayList()); + handleAggregations((Aggregations) queryResult, headers, lines); + return new CSVResult(separator, headers, lines); + } + // Handle List result. + if (queryResult instanceof List) { + List bindingTuples = (List) queryResult; + List> csvLines = + bindingTuples.stream() + .map( + tuple -> { + Map bindingMap = tuple.getBindingMap(); + List rowValues = new ArrayList<>(); + for (String fieldName : fieldNames) { + if (bindingMap.containsKey(fieldName)) { rowValues.add(String.valueOf(bindingMap.get(fieldName).value())); - } else { + } else { rowValues.add(""); + } } - } - return rowValues; - }).collect(Collectors.toList()); + return rowValues; + }) + .collect(Collectors.toList()); - return new CSVResult(separator, fieldNames, csvLines); - } - return null; + return new CSVResult(separator, fieldNames, csvLines); } - - private void handleAggregations(Aggregations aggregations, List headers, List> lines) - throws CsvExtractorException { - if (allNumericAggregations(aggregations)) { - lines.get(this.currentLineIndex) - .addAll(fillHeaderAndCreateLineForNumericAggregations(aggregations, headers)); - return; - } - //aggregations with size one only supported when not metrics. - List aggregationList = aggregations.asList(); - if (aggregationList.size() > 1) { - throw new CsvExtractorException( - "currently support only one aggregation at same level (Except for numeric metrics)"); - } - Aggregation aggregation = aggregationList.get(0); - //we want to skip singleBucketAggregations (nested,reverse_nested,filters) - if (aggregation instanceof SingleBucketAggregation) { - Aggregations singleBucketAggs = ((SingleBucketAggregation) aggregation).getAggregations(); - handleAggregations(singleBucketAggs, headers, lines); - return; - } - if (aggregation instanceof NumericMetricsAggregation) { - handleNumericMetricAggregation(headers, lines.get(currentLineIndex), aggregation); - return; - } - if (aggregation instanceof GeoBounds) { - handleGeoBoundsAggregation(headers, lines, (GeoBounds) aggregation); - return; - } - if (aggregation instanceof TopHits) { - //todo: handle this . it returns hits... maby back to normal? - //todo: read about this usages - // TopHits topHitsAggregation = (TopHits) aggregation; - } - if (aggregation instanceof MultiBucketsAggregation) { - MultiBucketsAggregation bucketsAggregation = (MultiBucketsAggregation) aggregation; - String name = bucketsAggregation.getName(); - //checking because it can comes from sub aggregation again - if (!headers.contains(name)) { - headers.add(name); - } - Collection buckets = bucketsAggregation.getBuckets(); - - //clone current line. - List currentLine = lines.get(this.currentLineIndex); - List clonedLine = new ArrayList<>(currentLine); - - //call handle_Agg with current_line++ - boolean firstLine = true; - for (MultiBucketsAggregation.Bucket bucket : buckets) { - //each bucket need to add new line with current line copied => except for first line - String key = bucket.getKeyAsString(); - if (firstLine) { - firstLine = false; - } else { - currentLineIndex++; - currentLine = new ArrayList(clonedLine); - lines.add(currentLine); - } - currentLine.add(key); - handleAggregations(bucket.getAggregations(), headers, lines); - - } - } + return null; + } + + private void handleAggregations( + Aggregations aggregations, List headers, List> lines) + throws CsvExtractorException { + if (allNumericAggregations(aggregations)) { + lines + .get(this.currentLineIndex) + .addAll(fillHeaderAndCreateLineForNumericAggregations(aggregations, headers)); + return; } - - private void handleGeoBoundsAggregation(List headers, List> lines, - GeoBounds geoBoundsAggregation) { - String geoBoundAggName = geoBoundsAggregation.getName(); - headers.add(geoBoundAggName + ".topLeft.lon"); - headers.add(geoBoundAggName + ".topLeft.lat"); - headers.add(geoBoundAggName + ".bottomRight.lon"); - headers.add(geoBoundAggName + ".bottomRight.lat"); - List line = lines.get(this.currentLineIndex); - line.add(String.valueOf(geoBoundsAggregation.topLeft().getLon())); - line.add(String.valueOf(geoBoundsAggregation.topLeft().getLat())); - line.add(String.valueOf(geoBoundsAggregation.bottomRight().getLon())); - line.add(String.valueOf(geoBoundsAggregation.bottomRight().getLat())); - lines.add(line); + // aggregations with size one only supported when not metrics. + List aggregationList = aggregations.asList(); + if (aggregationList.size() > 1) { + throw new CsvExtractorException( + "currently support only one aggregation at same level (Except for numeric metrics)"); } - - private List fillHeaderAndCreateLineForNumericAggregations(Aggregations aggregations, List header) - throws CsvExtractorException { - List line = new ArrayList<>(); - List aggregationList = aggregations.asList(); - for (Aggregation aggregation : aggregationList) { - handleNumericMetricAggregation(header, line, aggregation); - } - return line; + Aggregation aggregation = aggregationList.get(0); + // we want to skip singleBucketAggregations (nested,reverse_nested,filters) + if (aggregation instanceof SingleBucketAggregation) { + Aggregations singleBucketAggs = ((SingleBucketAggregation) aggregation).getAggregations(); + handleAggregations(singleBucketAggs, headers, lines); + return; } - - private void handleNumericMetricAggregation(List header, List line, Aggregation aggregation) - throws CsvExtractorException { - final String name = aggregation.getName(); - - if (aggregation instanceof NumericMetricsAggregation.SingleValue) { - if (!header.contains(name)) { - header.add(name); - } - NumericMetricsAggregation.SingleValue agg = (NumericMetricsAggregation.SingleValue) aggregation; - line.add(!Double.isInfinite(agg.value()) ? agg.getValueAsString() : "null"); - } else if (aggregation instanceof NumericMetricsAggregation.MultiValue) { - //todo:Numeric MultiValue - Stats,ExtendedStats,Percentile... - if (aggregation instanceof Stats) { - String[] statsHeaders = new String[]{"count", "sum", "avg", "min", "max"}; - boolean isExtendedStats = aggregation instanceof ExtendedStats; - if (isExtendedStats) { - String[] extendedHeaders = new String[]{"sumOfSquares", "variance", "stdDeviation"}; - statsHeaders = Util.concatStringsArrays(statsHeaders, extendedHeaders); - } - mergeHeadersWithPrefix(header, name, statsHeaders); - Stats stats = (Stats) aggregation; - line.add(String.valueOf(stats.getCount())); - line.add(stats.getSumAsString()); - line.add(stats.getAvgAsString()); - line.add(stats.getMinAsString()); - line.add(stats.getMaxAsString()); - if (isExtendedStats) { - ExtendedStats extendedStats = (ExtendedStats) aggregation; - line.add(extendedStats.getSumOfSquaresAsString()); - line.add(extendedStats.getVarianceAsString()); - line.add(extendedStats.getStdDeviationAsString()); - } - } else if (aggregation instanceof Percentiles) { - - final List percentileHeaders = new ArrayList<>(7); - final Percentiles percentiles = (Percentiles) aggregation; - - for (final Percentile p : percentiles) { - percentileHeaders.add(String.valueOf(p.getPercent())); - line.add(percentiles.percentileAsString(p.getPercent())); - } - mergeHeadersWithPrefix(header, name, percentileHeaders.toArray(new String[0])); - } else { - throw new CsvExtractorException( - "unknown NumericMetricsAggregation.MultiValue:" + aggregation.getClass()); - } - + if (aggregation instanceof NumericMetricsAggregation) { + handleNumericMetricAggregation(headers, lines.get(currentLineIndex), aggregation); + return; + } + if (aggregation instanceof GeoBounds) { + handleGeoBoundsAggregation(headers, lines, (GeoBounds) aggregation); + return; + } + if (aggregation instanceof TopHits) { + // todo: handle this . it returns hits... maby back to normal? + // todo: read about this usages + // TopHits topHitsAggregation = (TopHits) aggregation; + } + if (aggregation instanceof MultiBucketsAggregation) { + MultiBucketsAggregation bucketsAggregation = (MultiBucketsAggregation) aggregation; + String name = bucketsAggregation.getName(); + // checking because it can comes from sub aggregation again + if (!headers.contains(name)) { + headers.add(name); + } + Collection buckets = + bucketsAggregation.getBuckets(); + + // clone current line. + List currentLine = lines.get(this.currentLineIndex); + List clonedLine = new ArrayList<>(currentLine); + + // call handle_Agg with current_line++ + boolean firstLine = true; + for (MultiBucketsAggregation.Bucket bucket : buckets) { + // each bucket need to add new line with current line copied => except for first line + String key = bucket.getKeyAsString(); + if (firstLine) { + firstLine = false; } else { - throw new CsvExtractorException("unknown NumericMetricsAggregation" + aggregation.getClass()); + currentLineIndex++; + currentLine = new ArrayList(clonedLine); + lines.add(currentLine); } + currentLine.add(key); + handleAggregations(bucket.getAggregations(), headers, lines); + } } - - private void mergeHeadersWithPrefix(List header, String prefix, String[] newHeaders) { - for (int i = 0; i < newHeaders.length; i++) { - String newHeader = newHeaders[i]; - if (prefix != null && !prefix.equals("")) { - newHeader = prefix + "." + newHeader; - } - if (!header.contains(newHeader)) { - header.add(newHeader); - } - } + } + + private void handleGeoBoundsAggregation( + List headers, List> lines, GeoBounds geoBoundsAggregation) { + String geoBoundAggName = geoBoundsAggregation.getName(); + headers.add(geoBoundAggName + ".topLeft.lon"); + headers.add(geoBoundAggName + ".topLeft.lat"); + headers.add(geoBoundAggName + ".bottomRight.lon"); + headers.add(geoBoundAggName + ".bottomRight.lat"); + List line = lines.get(this.currentLineIndex); + line.add(String.valueOf(geoBoundsAggregation.topLeft().getLon())); + line.add(String.valueOf(geoBoundsAggregation.topLeft().getLat())); + line.add(String.valueOf(geoBoundsAggregation.bottomRight().getLon())); + line.add(String.valueOf(geoBoundsAggregation.bottomRight().getLat())); + lines.add(line); + } + + private List fillHeaderAndCreateLineForNumericAggregations( + Aggregations aggregations, List header) throws CsvExtractorException { + List line = new ArrayList<>(); + List aggregationList = aggregations.asList(); + for (Aggregation aggregation : aggregationList) { + handleNumericMetricAggregation(header, line, aggregation); } - - private boolean allNumericAggregations(Aggregations aggregations) { - List aggregationList = aggregations.asList(); - for (Aggregation aggregation : aggregationList) { - if (!(aggregation instanceof NumericMetricsAggregation)) { - return false; - } + return line; + } + + private void handleNumericMetricAggregation( + List header, List line, Aggregation aggregation) + throws CsvExtractorException { + final String name = aggregation.getName(); + + if (aggregation instanceof NumericMetricsAggregation.SingleValue) { + if (!header.contains(name)) { + header.add(name); + } + NumericMetricsAggregation.SingleValue agg = + (NumericMetricsAggregation.SingleValue) aggregation; + line.add(!Double.isInfinite(agg.value()) ? agg.getValueAsString() : "null"); + } else if (aggregation instanceof NumericMetricsAggregation.MultiValue) { + // todo:Numeric MultiValue - Stats,ExtendedStats,Percentile... + if (aggregation instanceof Stats) { + String[] statsHeaders = new String[] {"count", "sum", "avg", "min", "max"}; + boolean isExtendedStats = aggregation instanceof ExtendedStats; + if (isExtendedStats) { + String[] extendedHeaders = new String[] {"sumOfSquares", "variance", "stdDeviation"}; + statsHeaders = Util.concatStringsArrays(statsHeaders, extendedHeaders); } - return true; - } + mergeHeadersWithPrefix(header, name, statsHeaders); + Stats stats = (Stats) aggregation; + line.add(String.valueOf(stats.getCount())); + line.add(stats.getSumAsString()); + line.add(stats.getAvgAsString()); + line.add(stats.getMinAsString()); + line.add(stats.getMaxAsString()); + if (isExtendedStats) { + ExtendedStats extendedStats = (ExtendedStats) aggregation; + line.add(extendedStats.getSumOfSquaresAsString()); + line.add(extendedStats.getVarianceAsString()); + line.add(extendedStats.getStdDeviationAsString()); + } + } else if (aggregation instanceof Percentiles) { + + final List percentileHeaders = new ArrayList<>(7); + final Percentiles percentiles = (Percentiles) aggregation; - private Aggregation skipAggregations(Aggregation firstAggregation) { - while (firstAggregation instanceof SingleBucketAggregation) { - firstAggregation = getFirstAggregation(((SingleBucketAggregation) firstAggregation).getAggregations()); + for (final Percentile p : percentiles) { + percentileHeaders.add(String.valueOf(p.getPercent())); + line.add(percentiles.percentileAsString(p.getPercent())); } - return firstAggregation; + mergeHeadersWithPrefix(header, name, percentileHeaders.toArray(new String[0])); + } else { + throw new CsvExtractorException( + "unknown NumericMetricsAggregation.MultiValue:" + aggregation.getClass()); + } + + } else { + throw new CsvExtractorException("unknown NumericMetricsAggregation" + aggregation.getClass()); } - - private Aggregation getFirstAggregation(Aggregations aggregations) { - return aggregations.asList().get(0); + } + + private void mergeHeadersWithPrefix(List header, String prefix, String[] newHeaders) { + for (int i = 0; i < newHeaders.length; i++) { + String newHeader = newHeaders[i]; + if (prefix != null && !prefix.equals("")) { + newHeader = prefix + "." + newHeader; + } + if (!header.contains(newHeader)) { + header.add(newHeader); + } } + } + + private boolean allNumericAggregations(Aggregations aggregations) { + List aggregationList = aggregations.asList(); + for (Aggregation aggregation : aggregationList) { + if (!(aggregation instanceof NumericMetricsAggregation)) { + return false; + } + } + return true; + } - private List> createCSVLinesFromDocs(boolean flat, String separator, - List> docsAsMap, - List headers) { - List> csvLines = new ArrayList<>(); - for (Map doc : docsAsMap) { - List line = new ArrayList<>(); - for (String header : headers) { - line.add(findFieldValue(header, doc, flat, separator)); - } - csvLines.add(line); - } - return csvLines; + private Aggregation skipAggregations(Aggregation firstAggregation) { + while (firstAggregation instanceof SingleBucketAggregation) { + firstAggregation = + getFirstAggregation(((SingleBucketAggregation) firstAggregation).getAggregations()); + } + return firstAggregation; + } + + private Aggregation getFirstAggregation(Aggregations aggregations) { + return aggregations.asList().get(0); + } + + private List> createCSVLinesFromDocs( + boolean flat, String separator, List> docsAsMap, List headers) { + List> csvLines = new ArrayList<>(); + for (Map doc : docsAsMap) { + List line = new ArrayList<>(); + for (String header : headers) { + line.add(findFieldValue(header, doc, flat, separator)); + } + csvLines.add(line); + } + return csvLines; + } + + private List createHeadersAndFillDocsMap( + final boolean flat, + final SearchHit[] hits, + final List> docsAsMap, + final List fieldNames) { + final Set csvHeaders = new LinkedHashSet<>(); + if (fieldNames != null) { + csvHeaders.addAll(fieldNames); } - private List createHeadersAndFillDocsMap(final boolean flat, final SearchHit[] hits, - final List> docsAsMap, - final List fieldNames) { - final Set csvHeaders = new LinkedHashSet<>(); - if (fieldNames != null) { - csvHeaders.addAll(fieldNames); - } + for (final SearchHit hit : hits) { + final Map doc = hit.getSourceAsMap(); + final Map fields = hit.getFields(); + for (final DocumentField searchHitField : fields.values()) { + doc.put(searchHitField.getName(), searchHitField.getValue()); + } + + if (this.includeId) { + doc.put("_id", hit.getId()); + } + if (this.includeScore) { + doc.put("_score", hit.getScore()); + } + + // select function as field is a special case where each hit has non-null field (function) + // and sourceAsMap is all columns in index (the same as 'SELECT *') + if (fields.isEmpty()) { + mergeHeaders(csvHeaders, doc, flat); + } + docsAsMap.add(doc); + } - for (final SearchHit hit : hits) { - final Map doc = hit.getSourceAsMap(); - final Map fields = hit.getFields(); - for (final DocumentField searchHitField : fields.values()) { - doc.put(searchHitField.getName(), searchHitField.getValue()); - } - - if (this.includeId) { - doc.put("_id", hit.getId()); - } - if (this.includeScore) { - doc.put("_score", hit.getScore()); - } - - // select function as field is a special case where each hit has non-null field (function) - // and sourceAsMap is all columns in index (the same as 'SELECT *') - if (fields.isEmpty()) { - mergeHeaders(csvHeaders, doc, flat); - } - docsAsMap.add(doc); - } + return new ArrayList<>(csvHeaders); + } - return new ArrayList<>(csvHeaders); - } + private String findFieldValue( + String header, Map doc, boolean flat, String separator) { + if (flat && header.contains(".")) { + String[] split = header.split("\\."); + Object innerDoc = doc; - private String findFieldValue(String header, Map doc, boolean flat, String separator) { - if (flat && header.contains(".")) { - String[] split = header.split("\\."); - Object innerDoc = doc; - - for (String innerField : split) { - if (!(innerDoc instanceof Map)) { - return ""; - } - innerDoc = ((Map) innerDoc).get(innerField); - if (innerDoc == null) { - return ""; - } - } - return innerDoc.toString(); - } else { - if (doc.containsKey(header)) { - return String.valueOf(doc.get(header)); - } + for (String innerField : split) { + if (!(innerDoc instanceof Map)) { + return ""; } - return ""; - } - - private void mergeHeaders(Set headers, Map doc, boolean flat) { - if (!flat) { - headers.addAll(doc.keySet()); - return; + innerDoc = ((Map) innerDoc).get(innerField); + if (innerDoc == null) { + return ""; } - mergeFieldNamesRecursive(headers, doc, ""); + } + return innerDoc.toString(); + } else { + if (doc.containsKey(header)) { + return String.valueOf(doc.get(header)); + } } + return ""; + } - private void mergeFieldNamesRecursive(Set headers, Map doc, String prefix) { - for (Map.Entry field : doc.entrySet()) { - Object value = field.getValue(); - if (value instanceof Map) { - mergeFieldNamesRecursive(headers, (Map) value, prefix + field.getKey() + "."); - } else { - headers.add(prefix + field.getKey()); - } - } + private void mergeHeaders(Set headers, Map doc, boolean flat) { + if (!flat) { + headers.addAll(doc.keySet()); + return; + } + mergeFieldNamesRecursive(headers, doc, ""); + } + + private void mergeFieldNamesRecursive( + Set headers, Map doc, String prefix) { + for (Map.Entry field : doc.entrySet()) { + Object value = field.getValue(); + if (value instanceof Map) { + mergeFieldNamesRecursive( + headers, (Map) value, prefix + field.getKey() + "."); + } else { + headers.add(prefix + field.getKey()); + } } + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/csv/CsvExtractorException.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/csv/CsvExtractorException.java index 7e0f8e8ff9..cb289e4625 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/csv/CsvExtractorException.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/csv/CsvExtractorException.java @@ -3,14 +3,11 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.csv; -/** - * Created by Eliran on 29/12/2015. - */ +/** Created by Eliran on 29/12/2015. */ public class CsvExtractorException extends Exception { - public CsvExtractorException(String message) { - super(message); - } + public CsvExtractorException(String message) { + super(message); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorActionRequestRestExecutorFactory.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorActionRequestRestExecutorFactory.java index 7c8ed62a07..b4add64f9c 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorActionRequestRestExecutorFactory.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorActionRequestRestExecutorFactory.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.cursor; import org.opensearch.rest.RestRequest; @@ -11,16 +10,17 @@ public class CursorActionRequestRestExecutorFactory { - public static CursorAsyncRestExecutor createExecutor(RestRequest request, String cursorId, Format format) { + public static CursorAsyncRestExecutor createExecutor( + RestRequest request, String cursorId, Format format) { - if (isCursorCloseRequest(request)) { - return new CursorAsyncRestExecutor(new CursorCloseExecutor(cursorId)); - } else { - return new CursorAsyncRestExecutor(new CursorResultExecutor(cursorId, format)); - } + if (isCursorCloseRequest(request)) { + return new CursorAsyncRestExecutor(new CursorCloseExecutor(cursorId)); + } else { + return new CursorAsyncRestExecutor(new CursorResultExecutor(cursorId, format)); } + } - private static boolean isCursorCloseRequest(final RestRequest request) { - return request.path().endsWith("/_sql/close"); - } + private static boolean isCursorCloseRequest(final RestRequest request) { + return request.path().endsWith("/_sql/close"); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorAsyncRestExecutor.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorAsyncRestExecutor.java index 92703dde2a..958bf68703 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorAsyncRestExecutor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorAsyncRestExecutor.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.cursor; import java.io.IOException; @@ -25,84 +24,83 @@ import org.opensearch.threadpool.ThreadPool; public class CursorAsyncRestExecutor { - /** - * Custom thread pool name managed by OpenSearch - */ - public static final String SQL_WORKER_THREAD_POOL_NAME = "sql-worker"; + /** Custom thread pool name managed by OpenSearch */ + public static final String SQL_WORKER_THREAD_POOL_NAME = "sql-worker"; - private static final Logger LOG = LogManager.getLogger(CursorAsyncRestExecutor.class); + private static final Logger LOG = LogManager.getLogger(CursorAsyncRestExecutor.class); - /** - * Delegated rest executor to async - */ - private final CursorRestExecutor executor; + /** Delegated rest executor to async */ + private final CursorRestExecutor executor; + CursorAsyncRestExecutor(CursorRestExecutor executor) { + this.executor = executor; + } - CursorAsyncRestExecutor(CursorRestExecutor executor) { - this.executor = executor; - } + public void execute(Client client, Map params, RestChannel channel) { + async(client, params, channel); + } - public void execute(Client client, Map params, RestChannel channel) { - async(client, params, channel); - } + /** Run given task in thread pool asynchronously */ + private void async(Client client, Map params, RestChannel channel) { - /** - * Run given task in thread pool asynchronously - */ - private void async(Client client, Map params, RestChannel channel) { - - ThreadPool threadPool = client.threadPool(); - Runnable runnable = () -> { - try { - doExecuteWithTimeMeasured(client, params, channel); - } catch (IOException e) { - Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); - LOG.warn("[{}] [MCB] async task got an IO/SQL exception: {}", QueryContext.getRequestId(), - e.getMessage()); - e.printStackTrace(); - channel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, e.getMessage())); - } catch (IllegalStateException e) { - Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); - LOG.warn("[{}] [MCB] async task got a runtime exception: {}", QueryContext.getRequestId(), - e.getMessage()); - e.printStackTrace(); - channel.sendResponse(new BytesRestResponse(RestStatus.INSUFFICIENT_STORAGE, - "Memory circuit is broken.")); - } catch (Throwable t) { - Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); - LOG.warn("[{}] [MCB] async task got an unknown throwable: {}", QueryContext.getRequestId(), - t.getMessage()); - t.printStackTrace(); - channel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, - String.valueOf(t.getMessage()))); - } finally { - BackOffRetryStrategy.releaseMem(executor); - } + ThreadPool threadPool = client.threadPool(); + Runnable runnable = + () -> { + try { + doExecuteWithTimeMeasured(client, params, channel); + } catch (IOException e) { + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); + LOG.warn( + "[{}] [MCB] async task got an IO/SQL exception: {}", + QueryContext.getRequestId(), + e.getMessage()); + e.printStackTrace(); + channel.sendResponse( + new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, e.getMessage())); + } catch (IllegalStateException e) { + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); + LOG.warn( + "[{}] [MCB] async task got a runtime exception: {}", + QueryContext.getRequestId(), + e.getMessage()); + e.printStackTrace(); + channel.sendResponse( + new BytesRestResponse( + RestStatus.INSUFFICIENT_STORAGE, "Memory circuit is broken.")); + } catch (Throwable t) { + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); + LOG.warn( + "[{}] [MCB] async task got an unknown throwable: {}", + QueryContext.getRequestId(), + t.getMessage()); + t.printStackTrace(); + channel.sendResponse( + new BytesRestResponse( + RestStatus.INTERNAL_SERVER_ERROR, String.valueOf(t.getMessage()))); + } finally { + BackOffRetryStrategy.releaseMem(executor); + } }; - // Preserve context of calling thread to ensure headers of requests are forwarded when running blocking actions - threadPool.schedule( - QueryContext.withCurrentContext(runnable), - new TimeValue(0L), - SQL_WORKER_THREAD_POOL_NAME - ); - } + // Preserve context of calling thread to ensure headers of requests are forwarded when running + // blocking actions + threadPool.schedule( + QueryContext.withCurrentContext(runnable), new TimeValue(0L), SQL_WORKER_THREAD_POOL_NAME); + } - /** - * Time the real execution of Executor and log slow query for troubleshooting - */ - private void doExecuteWithTimeMeasured(Client client, - Map params, - RestChannel channel) throws Exception { - long startTime = System.nanoTime(); - try { - executor.execute(client, params, channel); - } finally { - Duration elapsed = Duration.ofNanos(System.nanoTime() - startTime); - int slowLogThreshold = LocalClusterState.state().getSettingValue(Settings.Key.SQL_SLOWLOG); - if (elapsed.getSeconds() >= slowLogThreshold) { - LOG.warn("[{}] Slow query: elapsed={} (ms)", QueryContext.getRequestId(), elapsed.toMillis()); - } - } + /** Time the real execution of Executor and log slow query for troubleshooting */ + private void doExecuteWithTimeMeasured( + Client client, Map params, RestChannel channel) throws Exception { + long startTime = System.nanoTime(); + try { + executor.execute(client, params, channel); + } finally { + Duration elapsed = Duration.ofNanos(System.nanoTime() - startTime); + int slowLogThreshold = LocalClusterState.state().getSettingValue(Settings.Key.SQL_SLOWLOG); + if (elapsed.getSeconds() >= slowLogThreshold) { + LOG.warn( + "[{}] Slow query: elapsed={} (ms)", QueryContext.getRequestId(), elapsed.toMillis()); + } } + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorCloseExecutor.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorCloseExecutor.java index 98e89c12e4..7282eaed4c 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorCloseExecutor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorCloseExecutor.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.cursor; import static org.opensearch.core.rest.RestStatus.OK; @@ -25,66 +24,69 @@ public class CursorCloseExecutor implements CursorRestExecutor { - private static final Logger LOG = LogManager.getLogger(CursorCloseExecutor.class); - - private static final String SUCCEEDED_TRUE = "{\"succeeded\":true}"; - private static final String SUCCEEDED_FALSE = "{\"succeeded\":false}"; - - private String cursorId; - - public CursorCloseExecutor(String cursorId) { - this.cursorId = cursorId; - } - - public void execute(Client client, Map params, RestChannel channel) throws Exception { - try { - String formattedResponse = execute(client, params); - channel.sendResponse(new BytesRestResponse(OK, "application/json; charset=UTF-8", formattedResponse)); - } catch (IllegalArgumentException | JSONException e) { - Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_CUS).increment(); - LOG.error("Error parsing the cursor", e); - channel.sendResponse(new BytesRestResponse(channel, e)); - } catch (OpenSearchException e) { - int status = (e.status().getStatus()); - if (status > 399 && status < 500) { - Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_CUS).increment(); - } else if (status > 499) { - Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); - } - LOG.error("Error completing cursor request", e); - channel.sendResponse(new BytesRestResponse(channel, e)); - } + private static final Logger LOG = LogManager.getLogger(CursorCloseExecutor.class); + + private static final String SUCCEEDED_TRUE = "{\"succeeded\":true}"; + private static final String SUCCEEDED_FALSE = "{\"succeeded\":false}"; + + private String cursorId; + + public CursorCloseExecutor(String cursorId) { + this.cursorId = cursorId; + } + + public void execute(Client client, Map params, RestChannel channel) + throws Exception { + try { + String formattedResponse = execute(client, params); + channel.sendResponse( + new BytesRestResponse(OK, "application/json; charset=UTF-8", formattedResponse)); + } catch (IllegalArgumentException | JSONException e) { + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_CUS).increment(); + LOG.error("Error parsing the cursor", e); + channel.sendResponse(new BytesRestResponse(channel, e)); + } catch (OpenSearchException e) { + int status = (e.status().getStatus()); + if (status > 399 && status < 500) { + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_CUS).increment(); + } else if (status > 499) { + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); + } + LOG.error("Error completing cursor request", e); + channel.sendResponse(new BytesRestResponse(channel, e)); } + } - public String execute(Client client, Map params) throws Exception { - String[] splittedCursor = cursorId.split(":"); - - if (splittedCursor.length!=2) { - throw new VerificationException("Not able to parse invalid cursor"); - } - - String type = splittedCursor[0]; - CursorType cursorType = CursorType.getById(type); - - switch(cursorType) { - case DEFAULT: - DefaultCursor defaultCursor = DefaultCursor.from(splittedCursor[1]); - return handleDefaultCursorCloseRequest(client, defaultCursor); - case AGGREGATION: - case JOIN: - default: throw new VerificationException("Unsupported cursor type [" + type + "]"); - } + public String execute(Client client, Map params) throws Exception { + String[] splittedCursor = cursorId.split(":"); + if (splittedCursor.length != 2) { + throw new VerificationException("Not able to parse invalid cursor"); } - private String handleDefaultCursorCloseRequest(Client client, DefaultCursor cursor) { - String scrollId = cursor.getScrollId(); - ClearScrollResponse clearScrollResponse = client.prepareClearScroll().addScrollId(scrollId).get(); - if (clearScrollResponse.isSucceeded()) { - return SUCCEEDED_TRUE; - } else { - Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); - return SUCCEEDED_FALSE; - } + String type = splittedCursor[0]; + CursorType cursorType = CursorType.getById(type); + + switch (cursorType) { + case DEFAULT: + DefaultCursor defaultCursor = DefaultCursor.from(splittedCursor[1]); + return handleDefaultCursorCloseRequest(client, defaultCursor); + case AGGREGATION: + case JOIN: + default: + throw new VerificationException("Unsupported cursor type [" + type + "]"); + } + } + + private String handleDefaultCursorCloseRequest(Client client, DefaultCursor cursor) { + String scrollId = cursor.getScrollId(); + ClearScrollResponse clearScrollResponse = + client.prepareClearScroll().addScrollId(scrollId).get(); + if (clearScrollResponse.isSucceeded()) { + return SUCCEEDED_TRUE; + } else { + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); + return SUCCEEDED_FALSE; } + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorRestExecutor.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorRestExecutor.java index 5f294f8e32..4c4b854379 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorRestExecutor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorRestExecutor.java @@ -3,21 +3,16 @@ * SPDX-License-Identifier: Apache-2.0 */ - - package org.opensearch.sql.legacy.executor.cursor; import java.util.Map; import org.opensearch.client.Client; import org.opensearch.rest.RestChannel; -/** - * Interface to execute cursor request. - */ +/** Interface to execute cursor request. */ public interface CursorRestExecutor { - void execute(Client client, Map params, RestChannel channel) - throws Exception; + void execute(Client client, Map params, RestChannel channel) throws Exception; - String execute(Client client, Map params) throws Exception; + String execute(Client client, Map params) throws Exception; } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorResultExecutor.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorResultExecutor.java index 9753f8049c..620b8e7b86 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorResultExecutor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorResultExecutor.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.cursor; import static org.opensearch.core.rest.RestStatus.OK; @@ -34,99 +33,105 @@ public class CursorResultExecutor implements CursorRestExecutor { - private String cursorId; - private Format format; - - private static final Logger LOG = LogManager.getLogger(CursorResultExecutor.class); - - public CursorResultExecutor(String cursorId, Format format) { - this.cursorId = cursorId; - this.format = format; + private String cursorId; + private Format format; + + private static final Logger LOG = LogManager.getLogger(CursorResultExecutor.class); + + public CursorResultExecutor(String cursorId, Format format) { + this.cursorId = cursorId; + this.format = format; + } + + public void execute(Client client, Map params, RestChannel channel) + throws Exception { + try { + String formattedResponse = execute(client, params); + channel.sendResponse( + new BytesRestResponse(OK, "application/json; charset=UTF-8", formattedResponse)); + } catch (IllegalArgumentException | JSONException e) { + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_CUS).increment(); + LOG.error("Error parsing the cursor", e); + channel.sendResponse(new BytesRestResponse(channel, e)); + } catch (OpenSearchException e) { + int status = (e.status().getStatus()); + if (status > 399 && status < 500) { + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_CUS).increment(); + } else if (status > 499) { + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); + } + LOG.error("Error completing cursor request", e); + channel.sendResponse(new BytesRestResponse(channel, e)); } + } - public void execute(Client client, Map params, RestChannel channel) throws Exception { - try { - String formattedResponse = execute(client, params); - channel.sendResponse(new BytesRestResponse(OK, "application/json; charset=UTF-8", formattedResponse)); - } catch (IllegalArgumentException | JSONException e) { - Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_CUS).increment(); - LOG.error("Error parsing the cursor", e); - channel.sendResponse(new BytesRestResponse(channel, e)); - } catch (OpenSearchException e) { - int status = (e.status().getStatus()); - if (status > 399 && status < 500) { - Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_CUS).increment(); - } else if (status > 499) { - Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); - } - LOG.error("Error completing cursor request", e); - channel.sendResponse(new BytesRestResponse(channel, e)); - } - } - - public String execute(Client client, Map params) throws Exception { - /** - * All cursor's are of the form : - * The serialized form before encoding is upto Cursor implementation - */ - String[] splittedCursor = cursorId.split(":", 2); + public String execute(Client client, Map params) throws Exception { + /** + * All cursor's are of the form : The serialized form before + * encoding is upto Cursor implementation + */ + String[] splittedCursor = cursorId.split(":", 2); - if (splittedCursor.length!=2) { - throw new VerificationException("Not able to parse invalid cursor"); - } - - String type = splittedCursor[0]; - CursorType cursorType = CursorType.getById(type); + if (splittedCursor.length != 2) { + throw new VerificationException("Not able to parse invalid cursor"); + } - switch(cursorType) { - case DEFAULT: - DefaultCursor defaultCursor = DefaultCursor.from(splittedCursor[1]); - return handleDefaultCursorRequest(client, defaultCursor); - case AGGREGATION: - case JOIN: - default: throw new VerificationException("Unsupported cursor type [" + type + "]"); - } + String type = splittedCursor[0]; + CursorType cursorType = CursorType.getById(type); + + switch (cursorType) { + case DEFAULT: + DefaultCursor defaultCursor = DefaultCursor.from(splittedCursor[1]); + return handleDefaultCursorRequest(client, defaultCursor); + case AGGREGATION: + case JOIN: + default: + throw new VerificationException("Unsupported cursor type [" + type + "]"); } + } - private String handleDefaultCursorRequest(Client client, DefaultCursor cursor) { - String previousScrollId = cursor.getScrollId(); - LocalClusterState clusterState = LocalClusterState.state(); - TimeValue scrollTimeout = clusterState.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE); - SearchResponse scrollResponse = client.prepareSearchScroll(previousScrollId).setScroll(scrollTimeout).get(); - SearchHits searchHits = scrollResponse.getHits(); - SearchHit[] searchHitArray = searchHits.getHits(); - String newScrollId = scrollResponse.getScrollId(); + private String handleDefaultCursorRequest(Client client, DefaultCursor cursor) { + String previousScrollId = cursor.getScrollId(); + LocalClusterState clusterState = LocalClusterState.state(); + TimeValue scrollTimeout = clusterState.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE); + SearchResponse scrollResponse = + client.prepareSearchScroll(previousScrollId).setScroll(scrollTimeout).get(); + SearchHits searchHits = scrollResponse.getHits(); + SearchHit[] searchHitArray = searchHits.getHits(); + String newScrollId = scrollResponse.getScrollId(); - int rowsLeft = (int) cursor.getRowsLeft(); - int fetch = cursor.getFetchSize(); + int rowsLeft = (int) cursor.getRowsLeft(); + int fetch = cursor.getFetchSize(); if (rowsLeft < fetch && rowsLeft < searchHitArray.length) { /** * This condition implies we are on the last page, and we might need to truncate the result from SearchHit[] * Avoid truncating in following two scenarios - * 1. number of rows to be sent equals fetchSize - * 2. size of SearchHit[] is already less that rows that needs to be sent - * + *
    + *
  1. number of rows to be sent equals fetchSize + *
  2. size of SearchHit[] is already less that rows that needs to be sent + *
* Else truncate to desired number of rows */ SearchHit[] newSearchHits = Arrays.copyOf(searchHitArray, rowsLeft); searchHits = new SearchHits(newSearchHits, searchHits.getTotalHits(), searchHits.getMaxScore()); } - rowsLeft = rowsLeft - fetch; + rowsLeft = rowsLeft - fetch; - if (rowsLeft <=0) { - /** Clear the scroll context on last page */ - ClearScrollResponse clearScrollResponse = client.prepareClearScroll().addScrollId(newScrollId).get(); - if (!clearScrollResponse.isSucceeded()) { - Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); - LOG.info("Error closing the cursor context {} ", newScrollId); - } - } - - cursor.setRowsLeft(rowsLeft); - cursor.setScrollId(newScrollId); - Protocol protocol = new Protocol(client, searchHits, format.name().toLowerCase(), cursor); - return protocol.cursorFormat(); + if (rowsLeft <= 0) { + /** Clear the scroll context on last page */ + ClearScrollResponse clearScrollResponse = + client.prepareClearScroll().addScrollId(newScrollId).get(); + if (!clearScrollResponse.isSucceeded()) { + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); + LOG.info("Error closing the cursor context {} ", newScrollId); + } } + + cursor.setRowsLeft(rowsLeft); + cursor.setScrollId(newScrollId); + Protocol protocol = new Protocol(client, searchHits, format.name().toLowerCase(), cursor); + return protocol.cursorFormat(); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/BindingTupleResultSet.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/BindingTupleResultSet.java index d9eb463572..872442f04f 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/BindingTupleResultSet.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/BindingTupleResultSet.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.format; import static org.opensearch.sql.legacy.executor.format.DateFieldFormatter.FORMAT_JDBC; @@ -18,43 +17,44 @@ import org.opensearch.sql.legacy.expression.model.ExprValue; import org.opensearch.sql.legacy.query.planner.core.ColumnNode; -/** - * The definition of BindingTuple ResultSet. - */ +/** The definition of BindingTuple ResultSet. */ public class BindingTupleResultSet extends ResultSet { - public BindingTupleResultSet(List columnNodes, List bindingTuples) { - this.schema = buildSchema(columnNodes); - this.dataRows = buildDataRows(columnNodes, bindingTuples); - } - - @VisibleForTesting - public static Schema buildSchema(List columnNodes) { - List columnList = columnNodes.stream() - .map(node -> new Schema.Column( - node.getName(), - node.getAlias(), - node.getType())) - .collect(Collectors.toList()); - return new Schema(columnList); - } - - @VisibleForTesting - public static DataRows buildDataRows(List columnNodes, List bindingTuples) { - List rowList = bindingTuples.stream().map(tuple -> { - Map bindingMap = tuple.getBindingMap(); - Map rowMap = new HashMap<>(); - for (ColumnNode column : columnNodes) { - String columnName = column.columnName(); - Object value = bindingMap.get(columnName).value(); - if (column.getType() == Schema.Type.DATE) { - value = DateFormat.getFormattedDate(new Date((Long) value), FORMAT_JDBC); - } - rowMap.put(columnName, value); - } - return new DataRows.Row(rowMap); - }).collect(Collectors.toList()); - - return new DataRows(bindingTuples.size(), bindingTuples.size(), rowList); - } + public BindingTupleResultSet(List columnNodes, List bindingTuples) { + this.schema = buildSchema(columnNodes); + this.dataRows = buildDataRows(columnNodes, bindingTuples); + } + + @VisibleForTesting + public static Schema buildSchema(List columnNodes) { + List columnList = + columnNodes.stream() + .map(node -> new Schema.Column(node.getName(), node.getAlias(), node.getType())) + .collect(Collectors.toList()); + return new Schema(columnList); + } + + @VisibleForTesting + public static DataRows buildDataRows( + List columnNodes, List bindingTuples) { + List rowList = + bindingTuples.stream() + .map( + tuple -> { + Map bindingMap = tuple.getBindingMap(); + Map rowMap = new HashMap<>(); + for (ColumnNode column : columnNodes) { + String columnName = column.columnName(); + Object value = bindingMap.get(columnName).value(); + if (column.getType() == Schema.Type.DATE) { + value = DateFormat.getFormattedDate(new Date((Long) value), FORMAT_JDBC); + } + rowMap.put(columnName, value); + } + return new DataRows.Row(rowMap); + }) + .collect(Collectors.toList()); + + return new DataRows(bindingTuples.size(), bindingTuples.size(), rowList); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/DataRows.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/DataRows.java index 541d3200a5..fc153afae8 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/DataRows.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/DataRows.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.format; import java.util.Iterator; @@ -12,76 +11,76 @@ public class DataRows implements Iterable { - private long size; - private long totalHits; - private List rows; - - public DataRows(long size, long totalHits, List rows) { - this.size = size; - this.totalHits = totalHits; - this.rows = rows; + private long size; + private long totalHits; + private List rows; + + public DataRows(long size, long totalHits, List rows) { + this.size = size; + this.totalHits = totalHits; + this.rows = rows; + } + + public DataRows(List rows) { + this.size = rows.size(); + this.totalHits = rows.size(); + this.rows = rows; + } + + public long getSize() { + return size; + } + + public long getTotalHits() { + return totalHits; + } + + // Iterator method for DataRows + @Override + public Iterator iterator() { + return new Iterator() { + private final Iterator iter = rows.iterator(); + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public Row next() { + return iter.next(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException("No changes allowed to DataRows rows"); + } + }; + } + + // Inner class for Row object + public static class Row { + + private Map data; + + public Row(Map data) { + this.data = data; } - public DataRows(List rows) { - this.size = rows.size(); - this.totalHits = rows.size(); - this.rows = rows; + public Map getContents() { + return data; } - public long getSize() { - return size; + public boolean hasField(String field) { + return data.containsKey(field); } - public long getTotalHits() { - return totalHits; + public Object getData(String field) { + return data.get(field); } - // Iterator method for DataRows - @Override - public Iterator iterator() { - return new Iterator() { - private final Iterator iter = rows.iterator(); - - @Override - public boolean hasNext() { - return iter.hasNext(); - } - - @Override - public Row next() { - return iter.next(); - } - - @Override - public void remove() { - throw new UnsupportedOperationException("No changes allowed to DataRows rows"); - } - }; - } - - // Inner class for Row object - public static class Row { - - private Map data; - - public Row(Map data) { - this.data = data; - } - - public Map getContents() { - return data; - } - - public boolean hasField(String field) { - return data.containsKey(field); - } - - public Object getData(String field) { - return data.get(field); - } - - public Object getDataOrDefault(String field, Object defaultValue) { - return data.getOrDefault(field, defaultValue); - } + public Object getDataOrDefault(String field, Object defaultValue) { + return data.getOrDefault(field, defaultValue); } + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/DateFieldFormatter.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/DateFieldFormatter.java index aa803975df..dc239abd84 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/DateFieldFormatter.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/DateFieldFormatter.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.format; import com.google.common.annotations.VisibleForTesting; @@ -23,163 +22,169 @@ import org.opensearch.sql.legacy.esdomain.LocalClusterState; import org.opensearch.sql.legacy.esdomain.mapping.FieldMappings; -/** - * Formatter to transform date fields into a consistent format for consumption by clients. - */ +/** Formatter to transform date fields into a consistent format for consumption by clients. */ public class DateFieldFormatter { - private static final Logger LOG = LogManager.getLogger(DateFieldFormatter.class); - public static final String FORMAT_JDBC = "yyyy-MM-dd HH:mm:ss.SSS"; - private static final String FORMAT_DELIMITER = "\\|\\|"; - - private static final String FORMAT_DOT_DATE_AND_TIME = "yyyy-MM-dd'T'HH:mm:ss.SSSZ"; - private static final String FORMAT_DOT_OPENSEARCH_DASHBOARDS_SAMPLE_DATA_LOGS_EXCEPTION = "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"; - private static final String FORMAT_DOT_OPENSEARCH_DASHBOARDS_SAMPLE_DATA_FLIGHTS_EXCEPTION = "yyyy-MM-dd'T'HH:mm:ss"; - private static final String - FORMAT_DOT_OPENSEARCH_DASHBOARDS_SAMPLE_DATA_FLIGHTS_EXCEPTION_NO_TIME = "yyyy-MM-dd'T'"; - private static final String FORMAT_DOT_OPENSEARCH_DASHBOARDS_SAMPLE_DATA_ECOMMERCE_EXCEPTION = "yyyy-MM-dd'T'HH:mm:ssXXX"; - private static final String FORMAT_DOT_DATE = DateFormat.getFormatString("date"); - - private final Map> dateFieldFormatMap; - private final Map fieldAliasMap; - private Set dateColumns; - - public DateFieldFormatter(String indexName, List columns, Map fieldAliasMap) { - this.dateFieldFormatMap = getDateFieldFormatMap(indexName); - this.dateColumns = getDateColumns(columns); - this.fieldAliasMap = fieldAliasMap; + private static final Logger LOG = LogManager.getLogger(DateFieldFormatter.class); + public static final String FORMAT_JDBC = "yyyy-MM-dd HH:mm:ss.SSS"; + private static final String FORMAT_DELIMITER = "\\|\\|"; + + private static final String FORMAT_DOT_DATE_AND_TIME = "yyyy-MM-dd'T'HH:mm:ss.SSSZ"; + private static final String FORMAT_DOT_OPENSEARCH_DASHBOARDS_SAMPLE_DATA_LOGS_EXCEPTION = + "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"; + private static final String FORMAT_DOT_OPENSEARCH_DASHBOARDS_SAMPLE_DATA_FLIGHTS_EXCEPTION = + "yyyy-MM-dd'T'HH:mm:ss"; + private static final String + FORMAT_DOT_OPENSEARCH_DASHBOARDS_SAMPLE_DATA_FLIGHTS_EXCEPTION_NO_TIME = "yyyy-MM-dd'T'"; + private static final String FORMAT_DOT_OPENSEARCH_DASHBOARDS_SAMPLE_DATA_ECOMMERCE_EXCEPTION = + "yyyy-MM-dd'T'HH:mm:ssXXX"; + private static final String FORMAT_DOT_DATE = DateFormat.getFormatString("date"); + + private final Map> dateFieldFormatMap; + private final Map fieldAliasMap; + private Set dateColumns; + + public DateFieldFormatter( + String indexName, List columns, Map fieldAliasMap) { + this.dateFieldFormatMap = getDateFieldFormatMap(indexName); + this.dateColumns = getDateColumns(columns); + this.fieldAliasMap = fieldAliasMap; + } + + @VisibleForTesting + protected DateFieldFormatter( + Map> dateFieldFormatMap, + List columns, + Map fieldAliasMap) { + this.dateFieldFormatMap = dateFieldFormatMap; + this.dateColumns = getDateColumns(columns); + this.fieldAliasMap = fieldAliasMap; + } + + /** + * Apply the JDBC date format ({@code yyyy-MM-dd HH:mm:ss.SSS}) to date values in the current row. + * + * @param rowSource The row in which to format the date values. + */ + public void applyJDBCDateFormat(Map rowSource) { + for (String columnName : dateColumns) { + Object columnOriginalDate = rowSource.get(columnName); + if (columnOriginalDate == null) { + // Don't try to parse null date values + continue; + } + + List formats = getFormatsForColumn(columnName); + if (formats == null) { + LOG.warn( + "Could not determine date formats for column {}; returning original value", columnName); + continue; + } + + Date date = parseDateString(formats, columnOriginalDate.toString()); + if (date != null) { + rowSource.put(columnName, DateFormat.getFormattedDate(date, FORMAT_JDBC)); + break; + } else { + LOG.warn("Could not parse date value; returning original value"); + } } - - @VisibleForTesting - protected DateFieldFormatter(Map> dateFieldFormatMap, - List columns, - Map fieldAliasMap) { - this.dateFieldFormatMap = dateFieldFormatMap; - this.dateColumns = getDateColumns(columns); - this.fieldAliasMap = fieldAliasMap; + } + + private List getFormatsForColumn(String columnName) { + // Handle special cases for column names + if (fieldAliasMap.get(columnName) != null) { + // Column was aliased, and we need to find the base name for the column + columnName = fieldAliasMap.get(columnName); + } else if (columnName.split("\\.").length == 2) { + // Column is part of a join, and is qualified by the table alias + columnName = columnName.split("\\.")[1]; } - - /** - * Apply the JDBC date format ({@code yyyy-MM-dd HH:mm:ss.SSS}) to date values in the current row. - * - * @param rowSource The row in which to format the date values. - */ - public void applyJDBCDateFormat(Map rowSource) { - for (String columnName : dateColumns) { - Object columnOriginalDate = rowSource.get(columnName); - if (columnOriginalDate == null) { - // Don't try to parse null date values - continue; - } - - List formats = getFormatsForColumn(columnName); - if (formats == null) { - LOG.warn("Could not determine date formats for column {}; returning original value", columnName); - continue; - } - - Date date = parseDateString(formats, columnOriginalDate.toString()); - if (date != null) { - rowSource.put(columnName, DateFormat.getFormattedDate(date, FORMAT_JDBC)); - break; - } else { - LOG.warn("Could not parse date value; returning original value"); - } + return dateFieldFormatMap.get(columnName); + } + + private Set getDateColumns(List columns) { + return columns.stream() + .filter(column -> column.getType().equals(Schema.Type.DATE.nameLowerCase())) + .map(Schema.Column::getName) + .collect(Collectors.toSet()); + } + + private Map> getDateFieldFormatMap(String indexName) { + LocalClusterState state = LocalClusterState.state(); + Map> formatMap = new HashMap<>(); + + String[] indices = indexName.split("\\|"); + Collection typeProperties = state.getFieldMappings(indices).allMappings(); + + for (FieldMappings fieldMappings : typeProperties) { + for (Map.Entry> field : fieldMappings.data().entrySet()) { + String fieldName = field.getKey(); + Map properties = field.getValue(); + + if (properties.containsKey("format")) { + formatMap.put(fieldName, getFormatsFromProperties(properties.get("format").toString())); + } else { + // Give all field types a format, since operations such as casts + // can change the output type for a field to `date`. + formatMap.put(fieldName, getFormatsFromProperties("date_optional_time")); } + } } - private List getFormatsForColumn(String columnName) { - // Handle special cases for column names - if (fieldAliasMap.get(columnName) != null) { - // Column was aliased, and we need to find the base name for the column - columnName = fieldAliasMap.get(columnName); - } else if (columnName.split("\\.").length == 2) { - // Column is part of a join, and is qualified by the table alias - columnName = columnName.split("\\.")[1]; - } - return dateFieldFormatMap.get(columnName); - } - - private Set getDateColumns(List columns) { - return columns.stream() - .filter(column -> column.getType().equals(Schema.Type.DATE.nameLowerCase())) - .map(Schema.Column::getName) - .collect(Collectors.toSet()); - } - - private Map> getDateFieldFormatMap(String indexName) { - LocalClusterState state = LocalClusterState.state(); - Map> formatMap = new HashMap<>(); - - String[] indices = indexName.split("\\|"); - Collection typeProperties = state.getFieldMappings(indices) - .allMappings(); - - for (FieldMappings fieldMappings: typeProperties) { - for (Map.Entry> field : fieldMappings.data().entrySet()) { - String fieldName = field.getKey(); - Map properties = field.getValue(); - - if (properties.containsKey("format")) { - formatMap.put(fieldName, getFormatsFromProperties(properties.get("format").toString())); - } else { - // Give all field types a format, since operations such as casts - // can change the output type for a field to `date`. - formatMap.put(fieldName, getFormatsFromProperties("date_optional_time")); - } + return formatMap; + } + + private List getFormatsFromProperties(String formatProperty) { + String[] formats = formatProperty.split(FORMAT_DELIMITER); + return Arrays.asList(formats); + } + + private Date parseDateString(List formats, String columnOriginalDate) { + TimeZone originalDefaultTimeZone = TimeZone.getDefault(); + Date parsedDate = null; + + // Apache Commons DateUtils uses the default TimeZone for the JVM when parsing. + // However, since all dates on OpenSearch are stored as UTC, we need to + // parse these values using the UTC timezone. + TimeZone.setDefault(TimeZone.getTimeZone("UTC")); + for (String columnFormat : formats) { + try { + switch (columnFormat) { + case "date_optional_time": + case "strict_date_optional_time": + parsedDate = + DateUtils.parseDate( + columnOriginalDate, + FORMAT_DOT_OPENSEARCH_DASHBOARDS_SAMPLE_DATA_LOGS_EXCEPTION, + FORMAT_DOT_OPENSEARCH_DASHBOARDS_SAMPLE_DATA_FLIGHTS_EXCEPTION, + FORMAT_DOT_OPENSEARCH_DASHBOARDS_SAMPLE_DATA_FLIGHTS_EXCEPTION_NO_TIME, + FORMAT_DOT_OPENSEARCH_DASHBOARDS_SAMPLE_DATA_ECOMMERCE_EXCEPTION, + FORMAT_DOT_DATE_AND_TIME, + FORMAT_DOT_DATE); + break; + case "epoch_millis": + parsedDate = new Date(Long.parseLong(columnOriginalDate)); + break; + case "epoch_second": + parsedDate = new Date(Long.parseLong(columnOriginalDate) * 1000); + break; + default: + String formatString = DateFormat.getFormatString(columnFormat); + if (formatString == null) { + // Custom format; take as-is + formatString = columnFormat; } + parsedDate = DateUtils.parseDate(columnOriginalDate, formatString); } - - return formatMap; - } - - private List getFormatsFromProperties(String formatProperty) { - String[] formats = formatProperty.split(FORMAT_DELIMITER); - return Arrays.asList(formats); + } catch (ParseException | NumberFormatException e) { + LOG.warn( + String.format( + "Could not parse date string %s as %s", columnOriginalDate, columnFormat)); + } } + // Reset default timezone after parsing + TimeZone.setDefault(originalDefaultTimeZone); - private Date parseDateString(List formats, String columnOriginalDate) { - TimeZone originalDefaultTimeZone = TimeZone.getDefault(); - Date parsedDate = null; - - // Apache Commons DateUtils uses the default TimeZone for the JVM when parsing. - // However, since all dates on OpenSearch are stored as UTC, we need to - // parse these values using the UTC timezone. - TimeZone.setDefault(TimeZone.getTimeZone("UTC")); - for (String columnFormat : formats) { - try { - switch (columnFormat) { - case "date_optional_time": - case "strict_date_optional_time": - parsedDate = DateUtils.parseDate( - columnOriginalDate, - FORMAT_DOT_OPENSEARCH_DASHBOARDS_SAMPLE_DATA_LOGS_EXCEPTION, - FORMAT_DOT_OPENSEARCH_DASHBOARDS_SAMPLE_DATA_FLIGHTS_EXCEPTION, - FORMAT_DOT_OPENSEARCH_DASHBOARDS_SAMPLE_DATA_FLIGHTS_EXCEPTION_NO_TIME, - FORMAT_DOT_OPENSEARCH_DASHBOARDS_SAMPLE_DATA_ECOMMERCE_EXCEPTION, - FORMAT_DOT_DATE_AND_TIME, - FORMAT_DOT_DATE); - break; - case "epoch_millis": - parsedDate = new Date(Long.parseLong(columnOriginalDate)); - break; - case "epoch_second": - parsedDate = new Date(Long.parseLong(columnOriginalDate) * 1000); - break; - default: - String formatString = DateFormat.getFormatString(columnFormat); - if (formatString == null) { - // Custom format; take as-is - formatString = columnFormat; - } - parsedDate = DateUtils.parseDate(columnOriginalDate, formatString); - } - } catch (ParseException | NumberFormatException e) { - LOG.warn(String.format("Could not parse date string %s as %s", columnOriginalDate, columnFormat)); - } - } - // Reset default timezone after parsing - TimeZone.setDefault(originalDefaultTimeZone); - - return parsedDate; - } + return parsedDate; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/DateFormat.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/DateFormat.java index 40151c9413..fc9237918c 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/DateFormat.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/DateFormat.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.format; import java.time.Instant; @@ -15,112 +14,121 @@ public class DateFormat { - private static Map formatMap = new HashMap<>(); - - static { - // Special cases that are parsed separately - formatMap.put("date_optional_time", ""); - formatMap.put("strict_date_optional_time", ""); - formatMap.put("epoch_millis", ""); - formatMap.put("epoch_second", ""); - - formatMap.put("basic_date", Date.BASIC_DATE); - formatMap.put("basic_date_time", Date.BASIC_DATE + Time.T + Time.BASIC_TIME + Time.MILLIS + Time.TZ); - formatMap.put("basic_date_time_no_millis", Date.BASIC_DATE + Time.T + Time.BASIC_TIME + Time.TZ); - - formatMap.put("basic_ordinal_date", Date.BASIC_ORDINAL_DATE); - formatMap.put("basic_ordinal_date_time", - Date.BASIC_ORDINAL_DATE + Time.T + Time.BASIC_TIME + Time.MILLIS + Time.TZ); - formatMap.put("basic_ordinal_date_time_no_millis", Date.BASIC_ORDINAL_DATE+ Time.T + Time.BASIC_TIME + Time.TZ); - - formatMap.put("basic_time", Time.BASIC_TIME + Time.MILLIS + Time.TZ); - formatMap.put("basic_time_no_millis", Time.BASIC_TIME + Time.TZ); - - formatMap.put("basic_t_time", Time.T + Time.BASIC_TIME + Time.MILLIS + Time.TZ); - formatMap.put("basic_t_time_no_millis", Time.T + Time.BASIC_TIME + Time.TZ); - - formatMap.put("basic_week_date", Date.BASIC_WEEK_DATE); - formatMap.put("basic_week_date_time", Date.BASIC_WEEK_DATE + Time.T + Time.BASIC_TIME + Time.MILLIS + Time.TZ); - formatMap.put("basic_week_date_time_no_millis", Date.BASIC_WEEK_DATE + Time.T + Time.BASIC_TIME + Time.TZ); - - formatMap.put("date", Date.DATE); - formatMap.put("date_hour", Date.DATE + Time.T + Time.HOUR); - formatMap.put("date_hour_minute", Date.DATE + Time.T + Time.HOUR_MINUTE); - formatMap.put("date_hour_minute_second", Date.DATE + Time.T + Time.TIME); - formatMap.put("date_hour_minute_second_fraction", Date.DATE + Time.T + Time.TIME + Time.MILLIS); - formatMap.put("date_hour_minute_second_millis", Date.DATE + Time.T + Time.TIME + Time.MILLIS); - formatMap.put("date_time", Date.DATE + Time.T + Time.TIME + Time.MILLIS + Time.TZZ); - formatMap.put("date_time_no_millis", Date.DATE + Time.T + Time.TIME + Time.TZZ); - - formatMap.put("hour", Time.HOUR); - formatMap.put("hour_minute", Time.HOUR_MINUTE); - formatMap.put("hour_minute_second", Time.TIME); - formatMap.put("hour_minute_second_fraction", Time.TIME + Time.MILLIS); - formatMap.put("hour_minute_second_millis", Time.TIME + Time.MILLIS); - - formatMap.put("ordinal_date", Date.ORDINAL_DATE); - formatMap.put("ordinal_date_time", Date.ORDINAL_DATE + Time.T + Time.TIME + Time.MILLIS + Time.TZZ); - formatMap.put("ordinal_date_time_no_millis", Date.ORDINAL_DATE + Time.T + Time.TIME + Time.TZZ); - - formatMap.put("time", Time.TIME + Time.MILLIS + Time.TZZ); - formatMap.put("time_no_millis", Time.TIME + Time.TZZ); - - formatMap.put("t_time", Time.T + Time.TIME + Time.MILLIS + Time.TZZ); - formatMap.put("t_time_no_millis", Time.T + Time.TIME + Time.TZZ); - - formatMap.put("week_date", Date.WEEK_DATE); - formatMap.put("week_date_time", Date.WEEK_DATE + Time.T + Time.TIME + Time.MILLIS + Time.TZZ); - formatMap.put("week_date_time_no_millis", Date.WEEK_DATE + Time.T + Time.TIME + Time.TZZ); - - // Note: input mapping is "weekyear", but output value is "week_year" - formatMap.put("week_year", Date.WEEKYEAR); - formatMap.put("weekyear_week", Date.WEEKYEAR_WEEK); - formatMap.put("weekyear_week_day", Date.WEEK_DATE); - - formatMap.put("year", Date.YEAR); - formatMap.put("year_month", Date.YEAR_MONTH); - formatMap.put("year_month_day", Date.DATE); - } - - private DateFormat() { - } - - public static String getFormatString(String formatName) { - return formatMap.get(formatName); - } - - public static String getFormattedDate(java.util.Date date, String dateFormat) { - Instant instant = date.toInstant(); - ZonedDateTime zdt = ZonedDateTime.ofInstant(instant, ZoneId.of("Etc/UTC")); - return zdt.format(DateTimeFormatter.ofPattern(dateFormat)); - } - - private static class Date { - static String BASIC_DATE = "yyyyMMdd"; - static String BASIC_ORDINAL_DATE = "yyyyDDD"; - static String BASIC_WEEK_DATE = "YYYY'W'wwu"; - - static String DATE = "yyyy-MM-dd"; - static String ORDINAL_DATE = "yyyy-DDD"; - - static String YEAR = "yyyy"; - static String YEAR_MONTH = "yyyy-MM"; - - static String WEEK_DATE = "YYYY-'W'ww-u"; - static String WEEKYEAR = "YYYY"; - static String WEEKYEAR_WEEK = "YYYY-'W'ww"; - } - - private static class Time { - static String T = "'T'"; - static String BASIC_TIME = "HHmmss"; - static String TIME = "HH:mm:ss"; - - static String HOUR = "HH"; - static String HOUR_MINUTE = "HH:mm"; - - static String MILLIS = ".SSS"; - static String TZ = "Z"; - static String TZZ = "XX"; - } + private static Map formatMap = new HashMap<>(); + + static { + // Special cases that are parsed separately + formatMap.put("date_optional_time", ""); + formatMap.put("strict_date_optional_time", ""); + formatMap.put("epoch_millis", ""); + formatMap.put("epoch_second", ""); + + formatMap.put("basic_date", Date.BASIC_DATE); + formatMap.put( + "basic_date_time", Date.BASIC_DATE + Time.T + Time.BASIC_TIME + Time.MILLIS + Time.TZ); + formatMap.put( + "basic_date_time_no_millis", Date.BASIC_DATE + Time.T + Time.BASIC_TIME + Time.TZ); + + formatMap.put("basic_ordinal_date", Date.BASIC_ORDINAL_DATE); + formatMap.put( + "basic_ordinal_date_time", + Date.BASIC_ORDINAL_DATE + Time.T + Time.BASIC_TIME + Time.MILLIS + Time.TZ); + formatMap.put( + "basic_ordinal_date_time_no_millis", + Date.BASIC_ORDINAL_DATE + Time.T + Time.BASIC_TIME + Time.TZ); + + formatMap.put("basic_time", Time.BASIC_TIME + Time.MILLIS + Time.TZ); + formatMap.put("basic_time_no_millis", Time.BASIC_TIME + Time.TZ); + + formatMap.put("basic_t_time", Time.T + Time.BASIC_TIME + Time.MILLIS + Time.TZ); + formatMap.put("basic_t_time_no_millis", Time.T + Time.BASIC_TIME + Time.TZ); + + formatMap.put("basic_week_date", Date.BASIC_WEEK_DATE); + formatMap.put( + "basic_week_date_time", + Date.BASIC_WEEK_DATE + Time.T + Time.BASIC_TIME + Time.MILLIS + Time.TZ); + formatMap.put( + "basic_week_date_time_no_millis", + Date.BASIC_WEEK_DATE + Time.T + Time.BASIC_TIME + Time.TZ); + + formatMap.put("date", Date.DATE); + formatMap.put("date_hour", Date.DATE + Time.T + Time.HOUR); + formatMap.put("date_hour_minute", Date.DATE + Time.T + Time.HOUR_MINUTE); + formatMap.put("date_hour_minute_second", Date.DATE + Time.T + Time.TIME); + formatMap.put("date_hour_minute_second_fraction", Date.DATE + Time.T + Time.TIME + Time.MILLIS); + formatMap.put("date_hour_minute_second_millis", Date.DATE + Time.T + Time.TIME + Time.MILLIS); + formatMap.put("date_time", Date.DATE + Time.T + Time.TIME + Time.MILLIS + Time.TZZ); + formatMap.put("date_time_no_millis", Date.DATE + Time.T + Time.TIME + Time.TZZ); + + formatMap.put("hour", Time.HOUR); + formatMap.put("hour_minute", Time.HOUR_MINUTE); + formatMap.put("hour_minute_second", Time.TIME); + formatMap.put("hour_minute_second_fraction", Time.TIME + Time.MILLIS); + formatMap.put("hour_minute_second_millis", Time.TIME + Time.MILLIS); + + formatMap.put("ordinal_date", Date.ORDINAL_DATE); + formatMap.put( + "ordinal_date_time", Date.ORDINAL_DATE + Time.T + Time.TIME + Time.MILLIS + Time.TZZ); + formatMap.put("ordinal_date_time_no_millis", Date.ORDINAL_DATE + Time.T + Time.TIME + Time.TZZ); + + formatMap.put("time", Time.TIME + Time.MILLIS + Time.TZZ); + formatMap.put("time_no_millis", Time.TIME + Time.TZZ); + + formatMap.put("t_time", Time.T + Time.TIME + Time.MILLIS + Time.TZZ); + formatMap.put("t_time_no_millis", Time.T + Time.TIME + Time.TZZ); + + formatMap.put("week_date", Date.WEEK_DATE); + formatMap.put("week_date_time", Date.WEEK_DATE + Time.T + Time.TIME + Time.MILLIS + Time.TZZ); + formatMap.put("week_date_time_no_millis", Date.WEEK_DATE + Time.T + Time.TIME + Time.TZZ); + + // Note: input mapping is "weekyear", but output value is "week_year" + formatMap.put("week_year", Date.WEEKYEAR); + formatMap.put("weekyear_week", Date.WEEKYEAR_WEEK); + formatMap.put("weekyear_week_day", Date.WEEK_DATE); + + formatMap.put("year", Date.YEAR); + formatMap.put("year_month", Date.YEAR_MONTH); + formatMap.put("year_month_day", Date.DATE); + } + + private DateFormat() {} + + public static String getFormatString(String formatName) { + return formatMap.get(formatName); + } + + public static String getFormattedDate(java.util.Date date, String dateFormat) { + Instant instant = date.toInstant(); + ZonedDateTime zdt = ZonedDateTime.ofInstant(instant, ZoneId.of("Etc/UTC")); + return zdt.format(DateTimeFormatter.ofPattern(dateFormat)); + } + + private static class Date { + static String BASIC_DATE = "yyyyMMdd"; + static String BASIC_ORDINAL_DATE = "yyyyDDD"; + static String BASIC_WEEK_DATE = "YYYY'W'wwu"; + + static String DATE = "yyyy-MM-dd"; + static String ORDINAL_DATE = "yyyy-DDD"; + + static String YEAR = "yyyy"; + static String YEAR_MONTH = "yyyy-MM"; + + static String WEEK_DATE = "YYYY-'W'ww-u"; + static String WEEKYEAR = "YYYY"; + static String WEEKYEAR_WEEK = "YYYY-'W'ww"; + } + + private static class Time { + static String T = "'T'"; + static String BASIC_TIME = "HHmmss"; + static String TIME = "HH:mm:ss"; + + static String HOUR = "HH"; + static String HOUR_MINUTE = "HH:mm"; + + static String MILLIS = ".SSS"; + static String TZ = "Z"; + static String TZZ = "XX"; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/DeleteResultSet.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/DeleteResultSet.java index ccecacc432..24afb0a7af 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/DeleteResultSet.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/DeleteResultSet.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.format; import java.util.Collections; @@ -14,28 +13,28 @@ import org.opensearch.sql.legacy.domain.Delete; public class DeleteResultSet extends ResultSet { - private Delete query; - private Object queryResult; - - public static final String DELETED = "deleted_rows"; - - public DeleteResultSet(Client client, Delete query, Object queryResult) { - this.client = client; - this.query = query; - this.queryResult = queryResult; - this.schema = new Schema(loadColumns()); - this.dataRows = new DataRows(loadRows()); - } - - private List loadColumns() { - return Collections.singletonList(new Schema.Column(DELETED, null, Schema.Type.LONG)); - } - - private List loadRows() { - return Collections.singletonList(new DataRows.Row(loadDeletedData())); - } - - private Map loadDeletedData(){ - return Collections.singletonMap(DELETED, ((BulkByScrollResponse) queryResult).getDeleted()); - } + private Delete query; + private Object queryResult; + + public static final String DELETED = "deleted_rows"; + + public DeleteResultSet(Client client, Delete query, Object queryResult) { + this.client = client; + this.query = query; + this.queryResult = queryResult; + this.schema = new Schema(loadColumns()); + this.dataRows = new DataRows(loadRows()); + } + + private List loadColumns() { + return Collections.singletonList(new Schema.Column(DELETED, null, Schema.Type.LONG)); + } + + private List loadRows() { + return Collections.singletonList(new DataRows.Row(loadDeletedData())); + } + + private Map loadDeletedData() { + return Collections.singletonMap(DELETED, ((BulkByScrollResponse) queryResult).getDeleted()); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/DescribeResultSet.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/DescribeResultSet.java index 0cccf73268..eba6db2453 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/DescribeResultSet.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/DescribeResultSet.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.format; import java.util.ArrayList; @@ -21,145 +20,142 @@ public class DescribeResultSet extends ResultSet { - private static final int DEFAULT_NUM_PREC_RADIX = 10; - private static final String IS_AUTOINCREMENT = "NO"; - - /** - * You are not required to set the field type to object explicitly, as this is the default value. - * https://www.elastic.co/guide/en/elasticsearch/reference/current/object.html - */ - public static final String DEFAULT_OBJECT_DATATYPE = "object"; - - private IndexStatement statement; - private Object queryResult; - - public DescribeResultSet(Client client, IndexStatement statement, Object queryResult) { - this.client = client; - this.clusterName = getClusterName(); - this.statement = statement; - this.queryResult = queryResult; - - this.schema = new Schema(statement, loadColumns()); - this.dataRows = new DataRows(loadRows()); + private static final int DEFAULT_NUM_PREC_RADIX = 10; + private static final String IS_AUTOINCREMENT = "NO"; + + /** + * You are not required to set the field type to object explicitly, as this is the default value. + * https://www.elastic.co/guide/en/elasticsearch/reference/current/object.html + */ + public static final String DEFAULT_OBJECT_DATATYPE = "object"; + + private IndexStatement statement; + private Object queryResult; + + public DescribeResultSet(Client client, IndexStatement statement, Object queryResult) { + this.client = client; + this.clusterName = getClusterName(); + this.statement = statement; + this.queryResult = queryResult; + + this.schema = new Schema(statement, loadColumns()); + this.dataRows = new DataRows(loadRows()); + } + + private List loadColumns() { + List columns = new ArrayList<>(); + // Unused Columns are still included in Schema to match JDBC/ODBC standard + columns.add(new Column("TABLE_CAT", null, Type.KEYWORD)); + columns.add(new Column("TABLE_SCHEM", null, Type.KEYWORD)); + columns.add(new Column("TABLE_NAME", null, Type.KEYWORD)); + columns.add(new Column("COLUMN_NAME", null, Type.KEYWORD)); + columns.add(new Column("DATA_TYPE", null, Type.INTEGER)); + columns.add(new Column("TYPE_NAME", null, Type.KEYWORD)); + columns.add(new Column("COLUMN_SIZE", null, Type.INTEGER)); + columns.add(new Column("BUFFER_LENGTH", null, Type.INTEGER)); // Not used + columns.add(new Column("DECIMAL_DIGITS", null, Type.INTEGER)); + columns.add(new Column("NUM_PREC_RADIX", null, Type.INTEGER)); + columns.add(new Column("NULLABLE", null, Type.INTEGER)); + columns.add(new Column("REMARKS", null, Type.KEYWORD)); + columns.add(new Column("COLUMN_DEF", null, Type.KEYWORD)); + columns.add(new Column("SQL_DATA_TYPE", null, Type.INTEGER)); // Not used + columns.add(new Column("SQL_DATETIME_SUB", null, Type.INTEGER)); // Not used + columns.add(new Column("CHAR_OCTET_LENGTH", null, Type.INTEGER)); + columns.add(new Column("ORDINAL_POSITION", null, Type.INTEGER)); + columns.add(new Column("IS_NULLABLE", null, Type.KEYWORD)); + columns.add(new Column("SCOPE_CATALOG", null, Type.KEYWORD)); // Not used + columns.add(new Column("SCOPE_SCHEMA", null, Type.KEYWORD)); // Not used + columns.add(new Column("SCOPE_TABLE", null, Type.KEYWORD)); // Not used + columns.add(new Column("SOURCE_DATA_TYPE", null, Type.SHORT)); // Not used + columns.add(new Column("IS_AUTOINCREMENT", null, Type.KEYWORD)); + columns.add(new Column("IS_GENERATEDCOLUMN", null, Type.KEYWORD)); + + return columns; + } + + private List loadRows() { + List rows = new ArrayList<>(); + GetIndexResponse indexResponse = (GetIndexResponse) queryResult; + Map indexMappings = indexResponse.getMappings(); + + // Iterate through indices in indexMappings + for (Entry indexCursor : indexMappings.entrySet()) { + String index = indexCursor.getKey(); + + if (matchesPatternIfRegex(index, statement.getIndexPattern())) { + rows.addAll(loadIndexData(index, indexCursor.getValue().getSourceAsMap())); + } } - - private List loadColumns() { - List columns = new ArrayList<>(); - // Unused Columns are still included in Schema to match JDBC/ODBC standard - columns.add(new Column("TABLE_CAT", null, Type.KEYWORD)); - columns.add(new Column("TABLE_SCHEM", null, Type.KEYWORD)); - columns.add(new Column("TABLE_NAME", null, Type.KEYWORD)); - columns.add(new Column("COLUMN_NAME", null, Type.KEYWORD)); - columns.add(new Column("DATA_TYPE", null, Type.INTEGER)); - columns.add(new Column("TYPE_NAME", null, Type.KEYWORD)); - columns.add(new Column("COLUMN_SIZE", null, Type.INTEGER)); - columns.add(new Column("BUFFER_LENGTH", null, Type.INTEGER)); // Not used - columns.add(new Column("DECIMAL_DIGITS", null, Type.INTEGER)); - columns.add(new Column("NUM_PREC_RADIX", null, Type.INTEGER)); - columns.add(new Column("NULLABLE", null, Type.INTEGER)); - columns.add(new Column("REMARKS", null, Type.KEYWORD)); - columns.add(new Column("COLUMN_DEF", null, Type.KEYWORD)); - columns.add(new Column("SQL_DATA_TYPE", null, Type.INTEGER)); // Not used - columns.add(new Column("SQL_DATETIME_SUB", null, Type.INTEGER)); // Not used - columns.add(new Column("CHAR_OCTET_LENGTH", null, Type.INTEGER)); - columns.add(new Column("ORDINAL_POSITION", null, Type.INTEGER)); - columns.add(new Column("IS_NULLABLE", null, Type.KEYWORD)); - columns.add(new Column("SCOPE_CATALOG", null, Type.KEYWORD)); // Not used - columns.add(new Column("SCOPE_SCHEMA", null, Type.KEYWORD)); // Not used - columns.add(new Column("SCOPE_TABLE", null, Type.KEYWORD)); // Not used - columns.add(new Column("SOURCE_DATA_TYPE", null, Type.SHORT)); // Not used - columns.add(new Column("IS_AUTOINCREMENT", null, Type.KEYWORD)); - columns.add(new Column("IS_GENERATEDCOLUMN", null, Type.KEYWORD)); - - return columns; + return rows; + } + + @SuppressWarnings("unchecked") + private List loadIndexData(String index, Map mappingMetadata) { + List rows = new ArrayList<>(); + + Map flattenedMetaData = + flattenMappingMetaData(mappingMetadata, "", new HashMap<>()); + int position = 1; // Used as an arbitrary ORDINAL_POSITION value for the time being + for (Entry entry : flattenedMetaData.entrySet()) { + String columnPattern = statement.getColumnPattern(); + + // Check to see if column name matches pattern, if given + if (columnPattern == null || matchesPattern(entry.getKey(), columnPattern)) { + rows.add(new Row(loadRowData(index, entry.getKey(), entry.getValue(), position))); + position++; + } } - private List loadRows() { - List rows = new ArrayList<>(); - GetIndexResponse indexResponse = (GetIndexResponse) queryResult; - Map indexMappings = indexResponse.getMappings(); - - // Iterate through indices in indexMappings - for (Entry indexCursor : indexMappings.entrySet()) { - String index = indexCursor.getKey(); - - if (matchesPatternIfRegex(index, statement.getIndexPattern())) { - rows.addAll(loadIndexData(index, indexCursor.getValue().getSourceAsMap())); - } - } - return rows; - } - - @SuppressWarnings("unchecked") - private List loadIndexData(String index, Map mappingMetadata) { - List rows = new ArrayList<>(); - - Map flattenedMetaData = flattenMappingMetaData(mappingMetadata, "", new HashMap<>()); - int position = 1; // Used as an arbitrary ORDINAL_POSITION value for the time being - for (Entry entry : flattenedMetaData.entrySet()) { - String columnPattern = statement.getColumnPattern(); - - // Check to see if column name matches pattern, if given - if (columnPattern == null || matchesPattern(entry.getKey(), columnPattern)) { - rows.add( - new Row( - loadRowData(index, entry.getKey(), entry.getValue(), position) - ) - ); - position++; - } - } - - return rows; + return rows; + } + + private Map loadRowData(String index, String column, String type, int position) { + Map data = new HashMap<>(); + data.put("TABLE_CAT", clusterName); + data.put("TABLE_NAME", index); + data.put("COLUMN_NAME", column); + data.put("TYPE_NAME", type); + data.put("NUM_PREC_RADIX", DEFAULT_NUM_PREC_RADIX); + data.put("NULLABLE", 2); // TODO Defaulting to 2, need to find a way to check this + data.put("ORDINAL_POSITION", position); // There is no deterministic position of column in table + data.put("IS_NULLABLE", ""); // TODO Defaulting to unknown, need to check this + data.put("IS_AUTOINCREMENT", IS_AUTOINCREMENT); // Defaulting to "NO" + data.put("IS_GENERATEDCOLUMN", ""); // TODO Defaulting to unknown, need to check + + return data; + } + + /** + * To not disrupt old logic, for the time being, ShowQueryAction and DescribeQueryAction are using + * the same 'GetIndexRequestBuilder' that was used in the old ShowQueryAction. Since the format of + * the resulting meta data is different, this method is being used to flatten and retrieve types. + * + *

In the future, should look for a way to generalize this since Schema is currently using + * FieldMappingMetaData whereas here we are using MappingMetaData. + */ + @SuppressWarnings("unchecked") + private Map flattenMappingMetaData( + Map mappingMetaData, String currPath, Map flattenedMapping) { + Map properties = (Map) mappingMetaData.get("properties"); + for (Entry entry : properties.entrySet()) { + Map metaData = (Map) entry.getValue(); + + String fullPath = addToPath(currPath, entry.getKey()); + flattenedMapping.put( + fullPath, (String) metaData.getOrDefault("type", DEFAULT_OBJECT_DATATYPE)); + if (metaData.containsKey("properties")) { + flattenedMapping = flattenMappingMetaData(metaData, fullPath, flattenedMapping); + } } - private Map loadRowData(String index, String column, String type, int position) { - Map data = new HashMap<>(); - data.put("TABLE_CAT", clusterName); - data.put("TABLE_NAME", index); - data.put("COLUMN_NAME", column); - data.put("TYPE_NAME", type); - data.put("NUM_PREC_RADIX", DEFAULT_NUM_PREC_RADIX); - data.put("NULLABLE", 2); // TODO Defaulting to 2, need to find a way to check this - data.put("ORDINAL_POSITION", position); // There is no deterministic position of column in table - data.put("IS_NULLABLE", ""); // TODO Defaulting to unknown, need to check this - data.put("IS_AUTOINCREMENT", IS_AUTOINCREMENT); // Defaulting to "NO" - data.put("IS_GENERATEDCOLUMN", ""); // TODO Defaulting to unknown, need to check - - return data; - } + return flattenedMapping; + } - /** - * To not disrupt old logic, for the time being, ShowQueryAction and DescribeQueryAction are using the same - * 'GetIndexRequestBuilder' that was used in the old ShowQueryAction. Since the format of the resulting meta data - * is different, this method is being used to flatten and retrieve types. - *

- * In the future, should look for a way to generalize this since Schema is currently using FieldMappingMetaData - * whereas here we are using MappingMetaData. - */ - @SuppressWarnings("unchecked") - private Map flattenMappingMetaData(Map mappingMetaData, - String currPath, - Map flattenedMapping) { - Map properties = (Map) mappingMetaData.get("properties"); - for (Entry entry : properties.entrySet()) { - Map metaData = (Map) entry.getValue(); - - String fullPath = addToPath(currPath, entry.getKey()); - flattenedMapping.put(fullPath, (String) metaData.getOrDefault("type", DEFAULT_OBJECT_DATATYPE)); - if (metaData.containsKey("properties")) { - flattenedMapping = flattenMappingMetaData(metaData, fullPath, flattenedMapping); - } - } - - return flattenedMapping; + private String addToPath(String currPath, String field) { + if (currPath.isEmpty()) { + return field; } - private String addToPath(String currPath, String field) { - if (currPath.isEmpty()) { - return field; - } - - return currPath + "." + field; - } + return currPath + "." + field; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/join/ElasticJoinExecutor.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/join/ElasticJoinExecutor.java index f7d1fbf641..3087d6f041 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/join/ElasticJoinExecutor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/join/ElasticJoinExecutor.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.join; import java.io.IOException; @@ -41,219 +40,239 @@ import org.opensearch.sql.legacy.query.join.TableInJoinRequestBuilder; import org.opensearch.sql.legacy.query.planner.HashJoinQueryPlanRequestBuilder; -/** - * Created by Eliran on 15/9/2015. - */ +/** Created by Eliran on 15/9/2015. */ public abstract class ElasticJoinExecutor implements ElasticHitsExecutor { - private static final Logger LOG = LogManager.getLogger(); - - protected List results; // Keep list to avoid copy to new array in SearchHits - protected MetaSearchResult metaResults; - protected final int MAX_RESULTS_ON_ONE_FETCH = 10000; - private Set aliasesOnReturn; - private boolean allFieldsReturn; - - protected ElasticJoinExecutor(JoinRequestBuilder requestBuilder) { - metaResults = new MetaSearchResult(); - aliasesOnReturn = new HashSet<>(); - List firstTableReturnedField = requestBuilder.getFirstTable().getReturnedFields(); - List secondTableReturnedField = requestBuilder.getSecondTable().getReturnedFields(); - allFieldsReturn = (firstTableReturnedField == null || firstTableReturnedField.size() == 0) - && (secondTableReturnedField == null || secondTableReturnedField.size() == 0); - } - - public void sendResponse(RestChannel channel) throws IOException { - XContentBuilder builder = null; - long len; - try { - builder = ElasticUtils.hitsAsStringResultZeroCopy(results, metaResults, this); - BytesRestResponse bytesRestResponse = new BytesRestResponse(RestStatus.OK, builder); - len = bytesRestResponse.content().length(); - channel.sendResponse(bytesRestResponse); - } catch (IOException e) { - try { - if (builder != null) { - builder.close(); - } - } catch (Exception ex) { - // Ignore. Already logged in channel - } - throw e; + private static final Logger LOG = LogManager.getLogger(); + + protected List results; // Keep list to avoid copy to new array in SearchHits + protected MetaSearchResult metaResults; + protected final int MAX_RESULTS_ON_ONE_FETCH = 10000; + private Set aliasesOnReturn; + private boolean allFieldsReturn; + + protected ElasticJoinExecutor(JoinRequestBuilder requestBuilder) { + metaResults = new MetaSearchResult(); + aliasesOnReturn = new HashSet<>(); + List firstTableReturnedField = requestBuilder.getFirstTable().getReturnedFields(); + List secondTableReturnedField = requestBuilder.getSecondTable().getReturnedFields(); + allFieldsReturn = + (firstTableReturnedField == null || firstTableReturnedField.size() == 0) + && (secondTableReturnedField == null || secondTableReturnedField.size() == 0); + } + + public void sendResponse(RestChannel channel) throws IOException { + XContentBuilder builder = null; + long len; + try { + builder = ElasticUtils.hitsAsStringResultZeroCopy(results, metaResults, this); + BytesRestResponse bytesRestResponse = new BytesRestResponse(RestStatus.OK, builder); + len = bytesRestResponse.content().length(); + channel.sendResponse(bytesRestResponse); + } catch (IOException e) { + try { + if (builder != null) { + builder.close(); } - LOG.debug("[MCB] Successfully send response with size of {}. Thread id = {}", len, - Thread.currentThread().getId()); - } - - public void run() throws IOException, SqlParseException { - long timeBefore = System.currentTimeMillis(); - results = innerRun(); - long joinTimeInMilli = System.currentTimeMillis() - timeBefore; - this.metaResults.setTookImMilli(joinTimeInMilli); + } catch (Exception ex) { + // Ignore. Already logged in channel + } + throw e; } - - - protected abstract List innerRun() throws IOException, SqlParseException; - - public SearchHits getHits() { - return new SearchHits(results.toArray(new SearchHit[results.size()]), new TotalHits(results.size(), - Relation.EQUAL_TO), 1.0f); + LOG.debug( + "[MCB] Successfully send response with size of {}. Thread id = {}", + len, + Thread.currentThread().getId()); + } + + public void run() throws IOException, SqlParseException { + long timeBefore = System.currentTimeMillis(); + results = innerRun(); + long joinTimeInMilli = System.currentTimeMillis() - timeBefore; + this.metaResults.setTookImMilli(joinTimeInMilli); + } + + protected abstract List innerRun() throws IOException, SqlParseException; + + public SearchHits getHits() { + return new SearchHits( + results.toArray(new SearchHit[results.size()]), + new TotalHits(results.size(), Relation.EQUAL_TO), + 1.0f); + } + + public static ElasticJoinExecutor createJoinExecutor( + Client client, SqlElasticRequestBuilder requestBuilder) { + if (requestBuilder instanceof HashJoinQueryPlanRequestBuilder) { + return new QueryPlanElasticExecutor((HashJoinQueryPlanRequestBuilder) requestBuilder); + } else if (requestBuilder instanceof HashJoinElasticRequestBuilder) { + HashJoinElasticRequestBuilder hashJoin = (HashJoinElasticRequestBuilder) requestBuilder; + return new HashJoinElasticExecutor(client, hashJoin); + } else if (requestBuilder instanceof NestedLoopsElasticRequestBuilder) { + NestedLoopsElasticRequestBuilder nestedLoops = + (NestedLoopsElasticRequestBuilder) requestBuilder; + return new NestedLoopsElasticExecutor(client, nestedLoops); + } else { + throw new RuntimeException("Unsuported requestBuilder of type: " + requestBuilder.getClass()); } - - public static ElasticJoinExecutor createJoinExecutor(Client client, SqlElasticRequestBuilder requestBuilder) { - if (requestBuilder instanceof HashJoinQueryPlanRequestBuilder) { - return new QueryPlanElasticExecutor((HashJoinQueryPlanRequestBuilder) requestBuilder); - } else if (requestBuilder instanceof HashJoinElasticRequestBuilder) { - HashJoinElasticRequestBuilder hashJoin = (HashJoinElasticRequestBuilder) requestBuilder; - return new HashJoinElasticExecutor(client, hashJoin); - } else if (requestBuilder instanceof NestedLoopsElasticRequestBuilder) { - NestedLoopsElasticRequestBuilder nestedLoops = (NestedLoopsElasticRequestBuilder) requestBuilder; - return new NestedLoopsElasticExecutor(client, nestedLoops); - } else { - throw new RuntimeException("Unsuported requestBuilder of type: " + requestBuilder.getClass()); - } + } + + protected void mergeSourceAndAddAliases( + Map secondTableHitSource, + SearchHit searchHit, + String t1Alias, + String t2Alias) { + Map results = mapWithAliases(searchHit.getSourceAsMap(), t1Alias); + results.putAll(mapWithAliases(secondTableHitSource, t2Alias)); + searchHit.getSourceAsMap().clear(); + searchHit.getSourceAsMap().putAll(results); + } + + protected Map mapWithAliases(Map source, String alias) { + Map mapWithAliases = new HashMap<>(); + for (Map.Entry fieldNameToValue : source.entrySet()) { + if (!aliasesOnReturn.contains(fieldNameToValue.getKey())) { + mapWithAliases.put(alias + "." + fieldNameToValue.getKey(), fieldNameToValue.getValue()); + } else { + mapWithAliases.put(fieldNameToValue.getKey(), fieldNameToValue.getValue()); + } } - - protected void mergeSourceAndAddAliases(Map secondTableHitSource, SearchHit searchHit, - String t1Alias, String t2Alias) { - Map results = mapWithAliases(searchHit.getSourceAsMap(), t1Alias); - results.putAll(mapWithAliases(secondTableHitSource, t2Alias)); - searchHit.getSourceAsMap().clear(); - searchHit.getSourceAsMap().putAll(results); + return mapWithAliases; + } + + protected void onlyReturnedFields( + Map fieldsMap, List required, boolean allRequired) { + HashMap filteredMap = new HashMap<>(); + if (allFieldsReturn || allRequired) { + filteredMap.putAll(fieldsMap); + return; } - - protected Map mapWithAliases(Map source, String alias) { - Map mapWithAliases = new HashMap<>(); - for (Map.Entry fieldNameToValue : source.entrySet()) { - if (!aliasesOnReturn.contains(fieldNameToValue.getKey())) { - mapWithAliases.put(alias + "." + fieldNameToValue.getKey(), fieldNameToValue.getValue()); - } else { - mapWithAliases.put(fieldNameToValue.getKey(), fieldNameToValue.getValue()); - } - } - return mapWithAliases; + for (Field field : required) { + String name = field.getName(); + String returnName = name; + String alias = field.getAlias(); + if (alias != null && alias != "") { + returnName = alias; + aliasesOnReturn.add(alias); + } + filteredMap.put(returnName, deepSearchInMap(fieldsMap, name)); } - - protected void onlyReturnedFields(Map fieldsMap, List required, boolean allRequired) { - HashMap filteredMap = new HashMap<>(); - if (allFieldsReturn || allRequired) { - filteredMap.putAll(fieldsMap); - return; + fieldsMap.clear(); + fieldsMap.putAll(filteredMap); + } + + protected Object deepSearchInMap(Map fieldsMap, String name) { + if (name.contains(".")) { + String[] path = name.split("\\."); + Map currentObject = fieldsMap; + for (int i = 0; i < path.length - 1; i++) { + Object valueFromCurrentMap = currentObject.get(path[i]); + if (valueFromCurrentMap == null) { + return null; } - for (Field field : required) { - String name = field.getName(); - String returnName = name; - String alias = field.getAlias(); - if (alias != null && alias != "") { - returnName = alias; - aliasesOnReturn.add(alias); - } - filteredMap.put(returnName, deepSearchInMap(fieldsMap, name)); + if (!Map.class.isAssignableFrom(valueFromCurrentMap.getClass())) { + return null; } - fieldsMap.clear(); - fieldsMap.putAll(filteredMap); - + currentObject = (Map) valueFromCurrentMap; + } + return currentObject.get(path[path.length - 1]); } - protected Object deepSearchInMap(Map fieldsMap, String name) { - if (name.contains(".")) { - String[] path = name.split("\\."); - Map currentObject = fieldsMap; - for (int i = 0; i < path.length - 1; i++) { - Object valueFromCurrentMap = currentObject.get(path[i]); - if (valueFromCurrentMap == null) { - return null; - } - if (!Map.class.isAssignableFrom(valueFromCurrentMap.getClass())) { - return null; - } - currentObject = (Map) valueFromCurrentMap; - } - return currentObject.get(path[path.length - 1]); + return fieldsMap.get(name); + } + + protected void addUnmatchedResults( + List combinedResults, + Collection firstTableSearchHits, + List secondTableReturnedFields, + int currentNumOfIds, + int totalLimit, + String t1Alias, + String t2Alias) { + boolean limitReached = false; + for (SearchHitsResult hitsResult : firstTableSearchHits) { + if (!hitsResult.isMatchedWithOtherTable()) { + for (SearchHit hit : hitsResult.getSearchHits()) { + + // todo: decide which id to put or type. or maby its ok this way. just need to doc. + SearchHit unmachedResult = + createUnmachedResult(secondTableReturnedFields, hit.docId(), t1Alias, t2Alias, hit); + combinedResults.add(unmachedResult); + currentNumOfIds++; + if (currentNumOfIds >= totalLimit) { + limitReached = true; + break; + } } - - return fieldsMap.get(name); + } + if (limitReached) { + break; + } } - - - protected void addUnmatchedResults(List combinedResults, - Collection firstTableSearchHits, - List secondTableReturnedFields, int currentNumOfIds, int totalLimit, - String t1Alias, String t2Alias) { - boolean limitReached = false; - for (SearchHitsResult hitsResult : firstTableSearchHits) { - if (!hitsResult.isMatchedWithOtherTable()) { - for (SearchHit hit : hitsResult.getSearchHits()) { - - //todo: decide which id to put or type. or maby its ok this way. just need to doc. - SearchHit unmachedResult = createUnmachedResult(secondTableReturnedFields, hit.docId(), - t1Alias, t2Alias, hit); - combinedResults.add(unmachedResult); - currentNumOfIds++; - if (currentNumOfIds >= totalLimit) { - limitReached = true; - break; - } - - } - } - if (limitReached) { - break; - } - } + } + + protected SearchHit createUnmachedResult( + List secondTableReturnedFields, + int docId, + String t1Alias, + String t2Alias, + SearchHit hit) { + String unmatchedId = hit.getId() + "|0"; + + Map documentFields = new HashMap<>(); + Map metaFields = new HashMap<>(); + hit.getFields() + .forEach( + (fieldName, docField) -> + (MapperService.META_FIELDS_BEFORE_7DOT8.contains(fieldName) + ? metaFields + : documentFields) + .put(fieldName, docField)); + SearchHit searchHit = new SearchHit(docId, unmatchedId, documentFields, metaFields); + + searchHit.sourceRef(hit.getSourceRef()); + searchHit.getSourceAsMap().clear(); + searchHit.getSourceAsMap().putAll(hit.getSourceAsMap()); + Map emptySecondTableHitSource = createNullsSource(secondTableReturnedFields); + + mergeSourceAndAddAliases(emptySecondTableHitSource, searchHit, t1Alias, t2Alias); + + return searchHit; + } + + protected Map createNullsSource(List secondTableReturnedFields) { + Map nulledSource = new HashMap<>(); + for (Field field : secondTableReturnedFields) { + if (!field.getName().equals("*")) { + nulledSource.put(field.getName(), null); + } } - - protected SearchHit createUnmachedResult(List secondTableReturnedFields, int docId, String t1Alias, - String t2Alias, SearchHit hit) { - String unmatchedId = hit.getId() + "|0"; - - Map documentFields = new HashMap<>(); - Map metaFields = new HashMap<>(); - hit.getFields().forEach((fieldName, docField) -> - (MapperService.META_FIELDS_BEFORE_7DOT8.contains(fieldName) ? metaFields : documentFields).put(fieldName, docField)); - SearchHit searchHit = new SearchHit(docId, unmatchedId, documentFields, metaFields); - - searchHit.sourceRef(hit.getSourceRef()); - searchHit.getSourceAsMap().clear(); - searchHit.getSourceAsMap().putAll(hit.getSourceAsMap()); - Map emptySecondTableHitSource = createNullsSource(secondTableReturnedFields); - - mergeSourceAndAddAliases(emptySecondTableHitSource, searchHit, t1Alias, t2Alias); - - return searchHit; + return nulledSource; + } + + protected void updateMetaSearchResults(SearchResponse searchResponse) { + this.metaResults.addSuccessfulShards(searchResponse.getSuccessfulShards()); + this.metaResults.addFailedShards(searchResponse.getFailedShards()); + this.metaResults.addTotalNumOfShards(searchResponse.getTotalShards()); + this.metaResults.updateTimeOut(searchResponse.isTimedOut()); + } + + protected SearchResponse scrollOneTimeWithMax( + Client client, TableInJoinRequestBuilder tableRequest) { + SearchRequestBuilder scrollRequest = + tableRequest + .getRequestBuilder() + .setScroll(new TimeValue(60000)) + .setSize(MAX_RESULTS_ON_ONE_FETCH); + boolean ordered = tableRequest.getOriginalSelect().isOrderdSelect(); + if (!ordered) { + scrollRequest.addSort(FieldSortBuilder.DOC_FIELD_NAME, SortOrder.ASC); } - - protected Map createNullsSource(List secondTableReturnedFields) { - Map nulledSource = new HashMap<>(); - for (Field field : secondTableReturnedFields) { - if (!field.getName().equals("*")) { - nulledSource.put(field.getName(), null); - } - } - return nulledSource; - } - - protected void updateMetaSearchResults(SearchResponse searchResponse) { - this.metaResults.addSuccessfulShards(searchResponse.getSuccessfulShards()); - this.metaResults.addFailedShards(searchResponse.getFailedShards()); - this.metaResults.addTotalNumOfShards(searchResponse.getTotalShards()); - this.metaResults.updateTimeOut(searchResponse.isTimedOut()); - } - - protected SearchResponse scrollOneTimeWithMax(Client client, TableInJoinRequestBuilder tableRequest) { - SearchRequestBuilder scrollRequest = tableRequest.getRequestBuilder() - .setScroll(new TimeValue(60000)).setSize(MAX_RESULTS_ON_ONE_FETCH); - boolean ordered = tableRequest.getOriginalSelect().isOrderdSelect(); - if (!ordered) { - scrollRequest.addSort(FieldSortBuilder.DOC_FIELD_NAME, SortOrder.ASC); - } - SearchResponse responseWithHits = scrollRequest.get(); - //on ordered select - not using SCAN , elastic returns hits on first scroll - //es5.0 elastic always return docs on scan - // if(!ordered) - // responseWithHits = client.prepareSearchScroll(responseWithHits.getScrollId()) - // .setScroll(new TimeValue(600000)).get(); - return responseWithHits; - } - - + SearchResponse responseWithHits = scrollRequest.get(); + // on ordered select - not using SCAN , elastic returns hits on first scroll + // es5.0 elastic always return docs on scan + // if(!ordered) + // responseWithHits = client.prepareSearchScroll(responseWithHits.getScrollId()) + // .setScroll(new TimeValue(600000)).get(); + return responseWithHits; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/multi/ComperableHitResult.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/multi/ComperableHitResult.java index 766ecd3692..fa3514600b 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/multi/ComperableHitResult.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/multi/ComperableHitResult.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.multi; import com.google.common.base.Joiner; @@ -14,72 +13,70 @@ import org.opensearch.search.SearchHit; import org.opensearch.sql.legacy.utils.Util; -/** - * Created by Eliran on 9/9/2016. - */ +/** Created by Eliran on 9/9/2016. */ public class ComperableHitResult { - private SearchHit hit; - private String comperator; - private boolean isAllNull; - private Map flattenMap; + private SearchHit hit; + private String comperator; + private boolean isAllNull; + private Map flattenMap; - public ComperableHitResult(SearchHit hit, String[] fieldsOrder, String seperator) { - this.hit = hit; - Map hitAsMap = hit.getSourceAsMap(); - this.flattenMap = new HashMap<>(); - List results = new ArrayList<>(); - this.isAllNull = true; + public ComperableHitResult(SearchHit hit, String[] fieldsOrder, String seperator) { + this.hit = hit; + Map hitAsMap = hit.getSourceAsMap(); + this.flattenMap = new HashMap<>(); + List results = new ArrayList<>(); + this.isAllNull = true; - for (int i = 0; i < fieldsOrder.length; i++) { - String field = fieldsOrder[i]; - Object result = Util.deepSearchInMap(hitAsMap, field); - if (result == null) { - results.add(""); - } else { - this.isAllNull = false; - results.add(result.toString()); - this.flattenMap.put(field, result); - } - } - this.comperator = Joiner.on(seperator).join(results); + for (int i = 0; i < fieldsOrder.length; i++) { + String field = fieldsOrder[i]; + Object result = Util.deepSearchInMap(hitAsMap, field); + if (result == null) { + results.add(""); + } else { + this.isAllNull = false; + results.add(result.toString()); + this.flattenMap.put(field, result); + } } + this.comperator = Joiner.on(seperator).join(results); + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - - ComperableHitResult that = (ComperableHitResult) o; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } - if (!comperator.equals(that.comperator)) { - return false; - } + ComperableHitResult that = (ComperableHitResult) o; - return true; + if (!comperator.equals(that.comperator)) { + return false; } - public boolean isAllNull() { - return isAllNull; - } + return true; + } - @Override - public int hashCode() { - return comperator.hashCode(); - } + public boolean isAllNull() { + return isAllNull; + } - public String getComperator() { - return comperator; - } + @Override + public int hashCode() { + return comperator.hashCode(); + } - public Map getFlattenMap() { - return flattenMap; - } + public String getComperator() { + return comperator; + } - public SearchHit getOriginalHit() { - return hit; - } + public Map getFlattenMap() { + return flattenMap; + } + + public SearchHit getOriginalHit() { + return hit; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/builder/ArithmeticFunctionFactory.java b/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/builder/ArithmeticFunctionFactory.java index afa6f6c439..c1de63fe88 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/builder/ArithmeticFunctionFactory.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/builder/ArithmeticFunctionFactory.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.expression.core.builder; import org.opensearch.sql.legacy.expression.core.operator.BinaryScalarOperator; @@ -12,205 +11,130 @@ import org.opensearch.sql.legacy.expression.core.operator.ScalarOperation; import org.opensearch.sql.legacy.expression.core.operator.UnaryScalarOperator; -/** - * The definition of arithmetic function builder factory. - */ +/** The definition of arithmetic function builder factory. */ public class ArithmeticFunctionFactory { - public static ExpressionBuilder add() { - return new BinaryExpressionBuilder( - new BinaryScalarOperator( - ScalarOperation.ADD, - Math::addExact, - Math::addExact, - Double::sum, - Float::sum)); - } - - public static ExpressionBuilder subtract() { - return new BinaryExpressionBuilder( - new BinaryScalarOperator( - ScalarOperation.ADD, - Math::subtractExact, - Math::subtractExact, - (v1, v2) -> v1 - v2, - (v1, v2) -> v1 - v2)); - } - - public static ExpressionBuilder multiply() { - return new BinaryExpressionBuilder( - new BinaryScalarOperator( - ScalarOperation.MULTIPLY, - Math::multiplyExact, - Math::multiplyExact, - (v1, v2) -> v1 * v2, - (v1, v2) -> v1 * v2 - )); - } - - public static ExpressionBuilder divide() { - return new BinaryExpressionBuilder( - new BinaryScalarOperator( - ScalarOperation.DIVIDE, - (v1, v2) -> v1 / v2, - (v1, v2) -> v1 / v2, - (v1, v2) -> v1 / v2, - (v1, v2) -> v1 / v2 - )); - } - - public static ExpressionBuilder modules() { - return new BinaryExpressionBuilder( - new BinaryScalarOperator( - ScalarOperation.MODULES, - (v1, v2) -> v1 % v2, - (v1, v2) -> v1 % v2, - (v1, v2) -> v1 % v2, - (v1, v2) -> v1 % v2 - )); - } - - public static ExpressionBuilder abs() { - return new UnaryExpressionBuilder( - new UnaryScalarOperator( - ScalarOperation.ABS, - Math::abs, - Math::abs, - Math::abs, - Math::abs - )); - } - - public static ExpressionBuilder acos() { - return new UnaryExpressionBuilder( - new DoubleUnaryScalarOperator( - ScalarOperation.ACOS, - Math::acos - )); - } - - public static ExpressionBuilder asin() { - return new UnaryExpressionBuilder( - new DoubleUnaryScalarOperator( - ScalarOperation.ASIN, - Math::asin - ) - ); - } - - public static ExpressionBuilder atan() { - return new UnaryExpressionBuilder( - new DoubleUnaryScalarOperator( - ScalarOperation.ATAN, - Math::atan - ) - ); - } - - public static ExpressionBuilder atan2() { - return new BinaryExpressionBuilder( - new DoubleBinaryScalarOperator( - ScalarOperation.ATAN2, - Math::atan2 - ) - ); - } - - public static ExpressionBuilder tan() { - return new UnaryExpressionBuilder( - new DoubleUnaryScalarOperator( - ScalarOperation.TAN, - Math::tan - ) - ); - } - - public static ExpressionBuilder cbrt() { - return new UnaryExpressionBuilder( - new DoubleUnaryScalarOperator( - ScalarOperation.CBRT, - Math::cbrt - ) - ); - } - - public static ExpressionBuilder ceil() { - return new UnaryExpressionBuilder( - new DoubleUnaryScalarOperator( - ScalarOperation.CEIL, - Math::ceil - ) - ); - } - - public static ExpressionBuilder cos() { - return new UnaryExpressionBuilder( - new DoubleUnaryScalarOperator( - ScalarOperation.COS, - Math::cos - ) - ); - } - - public static ExpressionBuilder cosh() { - return new UnaryExpressionBuilder( - new DoubleUnaryScalarOperator( - ScalarOperation.COSH, - Math::cosh - ) - ); - } - - public static ExpressionBuilder exp() { - return new UnaryExpressionBuilder( - new DoubleUnaryScalarOperator( - ScalarOperation.EXP, - Math::exp - ) - ); - } - - public static ExpressionBuilder floor() { - return new UnaryExpressionBuilder( - new DoubleUnaryScalarOperator( - ScalarOperation.FLOOR, - Math::floor - ) - ); - } - - public static ExpressionBuilder ln() { - return new UnaryExpressionBuilder( - new DoubleUnaryScalarOperator( - ScalarOperation.LN, - Math::log - ) - ); - } - - public static ExpressionBuilder log() { - return new UnaryExpressionBuilder( - new DoubleUnaryScalarOperator( - ScalarOperation.LOG, - Math::log - ) - ); - } - - public static ExpressionBuilder log2() { - return new UnaryExpressionBuilder( - new DoubleUnaryScalarOperator( - ScalarOperation.LOG2, - (x) -> Math.log(x) / Math.log(2d) - ) - ); - } - - public static ExpressionBuilder log10() { - return new UnaryExpressionBuilder( - new DoubleUnaryScalarOperator( - ScalarOperation.LOG10, - Math::log10 - ) - ); - } + public static ExpressionBuilder add() { + return new BinaryExpressionBuilder( + new BinaryScalarOperator( + ScalarOperation.ADD, Math::addExact, Math::addExact, Double::sum, Float::sum)); + } + + public static ExpressionBuilder subtract() { + return new BinaryExpressionBuilder( + new BinaryScalarOperator( + ScalarOperation.ADD, + Math::subtractExact, + Math::subtractExact, + (v1, v2) -> v1 - v2, + (v1, v2) -> v1 - v2)); + } + + public static ExpressionBuilder multiply() { + return new BinaryExpressionBuilder( + new BinaryScalarOperator( + ScalarOperation.MULTIPLY, + Math::multiplyExact, + Math::multiplyExact, + (v1, v2) -> v1 * v2, + (v1, v2) -> v1 * v2)); + } + + public static ExpressionBuilder divide() { + return new BinaryExpressionBuilder( + new BinaryScalarOperator( + ScalarOperation.DIVIDE, + (v1, v2) -> v1 / v2, + (v1, v2) -> v1 / v2, + (v1, v2) -> v1 / v2, + (v1, v2) -> v1 / v2)); + } + + public static ExpressionBuilder modules() { + return new BinaryExpressionBuilder( + new BinaryScalarOperator( + ScalarOperation.MODULES, + (v1, v2) -> v1 % v2, + (v1, v2) -> v1 % v2, + (v1, v2) -> v1 % v2, + (v1, v2) -> v1 % v2)); + } + + public static ExpressionBuilder abs() { + return new UnaryExpressionBuilder( + new UnaryScalarOperator(ScalarOperation.ABS, Math::abs, Math::abs, Math::abs, Math::abs)); + } + + public static ExpressionBuilder acos() { + return new UnaryExpressionBuilder( + new DoubleUnaryScalarOperator(ScalarOperation.ACOS, Math::acos)); + } + + public static ExpressionBuilder asin() { + return new UnaryExpressionBuilder( + new DoubleUnaryScalarOperator(ScalarOperation.ASIN, Math::asin)); + } + + public static ExpressionBuilder atan() { + return new UnaryExpressionBuilder( + new DoubleUnaryScalarOperator(ScalarOperation.ATAN, Math::atan)); + } + + public static ExpressionBuilder atan2() { + return new BinaryExpressionBuilder( + new DoubleBinaryScalarOperator(ScalarOperation.ATAN2, Math::atan2)); + } + + public static ExpressionBuilder tan() { + return new UnaryExpressionBuilder( + new DoubleUnaryScalarOperator(ScalarOperation.TAN, Math::tan)); + } + + public static ExpressionBuilder cbrt() { + return new UnaryExpressionBuilder( + new DoubleUnaryScalarOperator(ScalarOperation.CBRT, Math::cbrt)); + } + + public static ExpressionBuilder ceil() { + return new UnaryExpressionBuilder( + new DoubleUnaryScalarOperator(ScalarOperation.CEIL, Math::ceil)); + } + + public static ExpressionBuilder cos() { + return new UnaryExpressionBuilder( + new DoubleUnaryScalarOperator(ScalarOperation.COS, Math::cos)); + } + + public static ExpressionBuilder cosh() { + return new UnaryExpressionBuilder( + new DoubleUnaryScalarOperator(ScalarOperation.COSH, Math::cosh)); + } + + public static ExpressionBuilder exp() { + return new UnaryExpressionBuilder( + new DoubleUnaryScalarOperator(ScalarOperation.EXP, Math::exp)); + } + + public static ExpressionBuilder floor() { + return new UnaryExpressionBuilder( + new DoubleUnaryScalarOperator(ScalarOperation.FLOOR, Math::floor)); + } + + public static ExpressionBuilder ln() { + return new UnaryExpressionBuilder(new DoubleUnaryScalarOperator(ScalarOperation.LN, Math::log)); + } + + public static ExpressionBuilder log() { + return new UnaryExpressionBuilder( + new DoubleUnaryScalarOperator(ScalarOperation.LOG, Math::log)); + } + + public static ExpressionBuilder log2() { + return new UnaryExpressionBuilder( + new DoubleUnaryScalarOperator(ScalarOperation.LOG2, (x) -> Math.log(x) / Math.log(2d))); + } + + public static ExpressionBuilder log10() { + return new UnaryExpressionBuilder( + new DoubleUnaryScalarOperator(ScalarOperation.LOG10, Math::log10)); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/builder/BinaryExpressionBuilder.java b/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/builder/BinaryExpressionBuilder.java index 99ddd50248..fcf08180a5 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/builder/BinaryExpressionBuilder.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/builder/BinaryExpressionBuilder.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.expression.core.builder; import java.util.Arrays; @@ -14,33 +13,32 @@ import org.opensearch.sql.legacy.expression.domain.BindingTuple; import org.opensearch.sql.legacy.expression.model.ExprValue; -/** - * The definition of the Expression Builder which has two arguments. - */ +/** The definition of the Expression Builder which has two arguments. */ @RequiredArgsConstructor public class BinaryExpressionBuilder implements ExpressionBuilder { - private final ScalarOperator op; + private final ScalarOperator op; - /** - * Build the expression with two {@link Expression} as arguments. - * @param expressionList expression list. - * @return expression. - */ - @Override - public Expression build(List expressionList) { - Expression e1 = expressionList.get(0); - Expression e2 = expressionList.get(1); + /** + * Build the expression with two {@link Expression} as arguments. + * + * @param expressionList expression list. + * @return expression. + */ + @Override + public Expression build(List expressionList) { + Expression e1 = expressionList.get(0); + Expression e2 = expressionList.get(1); - return new Expression() { - @Override - public ExprValue valueOf(BindingTuple tuple) { - return op.apply(Arrays.asList(e1.valueOf(tuple), e2.valueOf(tuple))); - } + return new Expression() { + @Override + public ExprValue valueOf(BindingTuple tuple) { + return op.apply(Arrays.asList(e1.valueOf(tuple), e2.valueOf(tuple))); + } - @Override - public String toString() { - return String.format("%s(%s,%s)", op.name(), e1, e2); - } - }; - } + @Override + public String toString() { + return String.format("%s(%s,%s)", op.name(), e1, e2); + } + }; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/operator/BinaryScalarOperator.java b/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/operator/BinaryScalarOperator.java index 70d47a3e83..02d29e1ed9 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/operator/BinaryScalarOperator.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/operator/BinaryScalarOperator.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.expression.core.operator; import static org.opensearch.sql.legacy.expression.model.ExprValue.ExprValueKind.DOUBLE_VALUE; @@ -24,54 +23,53 @@ import org.opensearch.sql.legacy.expression.model.ExprValueFactory; /** - * Binary Scalar Operator take two {@link ExprValue} as arguments ans return one {@link ExprValue} as result. + * Binary Scalar Operator take two {@link ExprValue} as arguments ans return one {@link ExprValue} + * as result. */ @RequiredArgsConstructor public class BinaryScalarOperator implements ScalarOperator { - private static final Map numberTypeOrder = - new ImmutableMap.Builder() - .put(INTEGER_VALUE, 0) - .put(LONG_VALUE, 1) - .put(DOUBLE_VALUE, 2) - .put(FLOAT_VALUE, 3) - .build(); + private static final Map numberTypeOrder = + new ImmutableMap.Builder() + .put(INTEGER_VALUE, 0) + .put(LONG_VALUE, 1) + .put(DOUBLE_VALUE, 2) + .put(FLOAT_VALUE, 3) + .build(); - private final ScalarOperation op; - private final BiFunction integerFunc; - private final BiFunction longFunc; - private final BiFunction doubleFunc; - private final BiFunction floatFunc; + private final ScalarOperation op; + private final BiFunction integerFunc; + private final BiFunction longFunc; + private final BiFunction doubleFunc; + private final BiFunction floatFunc; - @Override - public ExprValue apply(List valueList) { - ExprValue v1 = valueList.get(0); - ExprValue v2 = valueList.get(1); - if (!numberTypeOrder.containsKey(v1.kind()) || !numberTypeOrder.containsKey(v2.kind())) { - throw new RuntimeException( - String.format("unexpected operation type: %s(%s, %s) ", op.name(), v1.kind(), v2.kind())); - } - ExprValue.ExprValueKind expectedType = numberTypeOrder.get(v1.kind()) > numberTypeOrder.get(v2.kind()) - ? v1.kind() : v2.kind(); - switch (expectedType) { - case DOUBLE_VALUE: - return ExprValueFactory.from(doubleFunc.apply(getDoubleValue(v1), getDoubleValue(v2))); - case INTEGER_VALUE: - return ExprValueFactory - .from(integerFunc.apply(getIntegerValue(v1), getIntegerValue(v2))); - case LONG_VALUE: - return ExprValueFactory - .from(longFunc.apply(getLongValue(v1), getLongValue(v2))); - case FLOAT_VALUE: - return ExprValueFactory - .from(floatFunc.apply(getFloatValue(v1), getFloatValue(v2))); - default: - throw new RuntimeException(String.format("unexpected operation type: %s(%s, %s)", op.name(), v1.kind(), - v2.kind())); - } + @Override + public ExprValue apply(List valueList) { + ExprValue v1 = valueList.get(0); + ExprValue v2 = valueList.get(1); + if (!numberTypeOrder.containsKey(v1.kind()) || !numberTypeOrder.containsKey(v2.kind())) { + throw new RuntimeException( + String.format("unexpected operation type: %s(%s, %s) ", op.name(), v1.kind(), v2.kind())); } - - @Override - public String name() { - return op.name(); + ExprValue.ExprValueKind expectedType = + numberTypeOrder.get(v1.kind()) > numberTypeOrder.get(v2.kind()) ? v1.kind() : v2.kind(); + switch (expectedType) { + case DOUBLE_VALUE: + return ExprValueFactory.from(doubleFunc.apply(getDoubleValue(v1), getDoubleValue(v2))); + case INTEGER_VALUE: + return ExprValueFactory.from(integerFunc.apply(getIntegerValue(v1), getIntegerValue(v2))); + case LONG_VALUE: + return ExprValueFactory.from(longFunc.apply(getLongValue(v1), getLongValue(v2))); + case FLOAT_VALUE: + return ExprValueFactory.from(floatFunc.apply(getFloatValue(v1), getFloatValue(v2))); + default: + throw new RuntimeException( + String.format( + "unexpected operation type: %s(%s, %s)", op.name(), v1.kind(), v2.kind())); } + } + + @Override + public String name() { + return op.name(); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/operator/DoubleBinaryScalarOperator.java b/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/operator/DoubleBinaryScalarOperator.java index 2555b2a53c..12e7aacbaa 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/operator/DoubleBinaryScalarOperator.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/operator/DoubleBinaryScalarOperator.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.expression.core.operator; import static org.opensearch.sql.legacy.expression.model.ExprValueUtils.getDoubleValue; @@ -16,37 +15,41 @@ import org.opensearch.sql.legacy.expression.model.ExprValueFactory; /** - * Double Binary Scalar Operator take two {@link ExprValue} which have double value as arguments ans return one - * {@link ExprDoubleValue} as result. + * Double Binary Scalar Operator take two {@link ExprValue} which have double value as arguments ans + * return one {@link ExprDoubleValue} as result. */ @RequiredArgsConstructor public class DoubleBinaryScalarOperator implements ScalarOperator { - private final ScalarOperation op; - private final BiFunction doubleFunc; + private final ScalarOperation op; + private final BiFunction doubleFunc; - @Override - public ExprValue apply(List exprValues) { - ExprValue exprValue1 = exprValues.get(0); - ExprValue exprValue2 = exprValues.get(1); - if (exprValue1.kind() != exprValue2.kind()) { - throw new RuntimeException(String.format("unexpected operation type: %s(%s,%s)", op.name(), - exprValue1.kind(), exprValue2.kind())); - } - switch (exprValue1.kind()) { - case DOUBLE_VALUE: - case INTEGER_VALUE: - case LONG_VALUE: - case FLOAT_VALUE: - return ExprValueFactory.from(doubleFunc.apply(getDoubleValue(exprValue1), - getDoubleValue(exprValue2))); - default: - throw new RuntimeException(String.format("unexpected operation type: %s(%s,%s)", op.name(), - exprValue1.kind(), exprValue2.kind())); - } + @Override + public ExprValue apply(List exprValues) { + ExprValue exprValue1 = exprValues.get(0); + ExprValue exprValue2 = exprValues.get(1); + if (exprValue1.kind() != exprValue2.kind()) { + throw new RuntimeException( + String.format( + "unexpected operation type: %s(%s,%s)", + op.name(), exprValue1.kind(), exprValue2.kind())); } - - @Override - public String name() { - return op.name(); + switch (exprValue1.kind()) { + case DOUBLE_VALUE: + case INTEGER_VALUE: + case LONG_VALUE: + case FLOAT_VALUE: + return ExprValueFactory.from( + doubleFunc.apply(getDoubleValue(exprValue1), getDoubleValue(exprValue2))); + default: + throw new RuntimeException( + String.format( + "unexpected operation type: %s(%s,%s)", + op.name(), exprValue1.kind(), exprValue2.kind())); } + } + + @Override + public String name() { + return op.name(); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/operator/DoubleUnaryScalarOperator.java b/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/operator/DoubleUnaryScalarOperator.java index 736216472f..8242eee8a6 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/operator/DoubleUnaryScalarOperator.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/expression/core/operator/DoubleUnaryScalarOperator.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.expression.core.operator; import static org.opensearch.sql.legacy.expression.model.ExprValueUtils.getDoubleValue; @@ -16,31 +15,31 @@ import org.opensearch.sql.legacy.expression.model.ExprValueFactory; /** - * Unary Binary Scalar Operator take one {@link ExprValue} which have double value as arguments ans return one - * {@link ExprDoubleValue} as result. + * Unary Binary Scalar Operator take one {@link ExprValue} which have double value as arguments ans + * return one {@link ExprDoubleValue} as result. */ @RequiredArgsConstructor public class DoubleUnaryScalarOperator implements ScalarOperator { - private final ScalarOperation op; - private final Function doubleFunc; + private final ScalarOperation op; + private final Function doubleFunc; - @Override - public ExprValue apply(List exprValues) { - ExprValue exprValue = exprValues.get(0); - switch (exprValue.kind()) { - case DOUBLE_VALUE: - case INTEGER_VALUE: - case LONG_VALUE: - case FLOAT_VALUE: - return ExprValueFactory.from(doubleFunc.apply(getDoubleValue(exprValue))); - default: - throw new RuntimeException(String.format("unexpected operation type: %s(%s)", - op.name(), exprValue.kind())); - } + @Override + public ExprValue apply(List exprValues) { + ExprValue exprValue = exprValues.get(0); + switch (exprValue.kind()) { + case DOUBLE_VALUE: + case INTEGER_VALUE: + case LONG_VALUE: + case FLOAT_VALUE: + return ExprValueFactory.from(doubleFunc.apply(getDoubleValue(exprValue))); + default: + throw new RuntimeException( + String.format("unexpected operation type: %s(%s)", op.name(), exprValue.kind())); } + } - @Override - public String name() { - return op.name(); - } + @Override + public String name() { + return op.name(); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/expression/domain/BindingTuple.java b/legacy/src/main/java/org/opensearch/sql/legacy/expression/domain/BindingTuple.java index badc7c8355..328f63b7ca 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/expression/domain/BindingTuple.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/expression/domain/BindingTuple.java @@ -3,10 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.expression.domain; - import java.util.Map; import java.util.stream.Collectors; import lombok.Builder; @@ -19,42 +17,41 @@ import org.opensearch.sql.legacy.expression.model.ExprValueFactory; /** - * BindingTuple represents the a relationship between bindingName and ExprValue. - * e.g. The operation output column name is bindingName, the value is the ExprValue. + * BindingTuple represents the a relationship between bindingName and ExprValue. e.g. The operation + * output column name is bindingName, the value is the ExprValue. */ @Builder @Getter @EqualsAndHashCode public class BindingTuple { - @Singular("binding") - private final Map bindingMap; - - /** - * Resolve the Binding Name in BindingTuple context. - * - * @param bindingName binding name. - * @return binding value. - */ - public ExprValue resolve(String bindingName) { - return bindingMap.getOrDefault(bindingName, new ExprMissingValue()); - } - - @Override - public String toString() { - return bindingMap.entrySet() - .stream() - .map(entry -> String.format("%s:%s", entry.getKey(), entry.getValue())) - .collect(Collectors.joining(",", "<", ">")); - } - - public static BindingTuple from(Map map) { - return from(new JSONObject(map)); - } - - public static BindingTuple from(JSONObject json) { - Map map = json.toMap(); - BindingTupleBuilder bindingTupleBuilder = BindingTuple.builder(); - map.forEach((key, value) -> bindingTupleBuilder.binding(key, ExprValueFactory.from(value))); - return bindingTupleBuilder.build(); - } + @Singular("binding") + private final Map bindingMap; + + /** + * Resolve the Binding Name in BindingTuple context. + * + * @param bindingName binding name. + * @return binding value. + */ + public ExprValue resolve(String bindingName) { + return bindingMap.getOrDefault(bindingName, new ExprMissingValue()); + } + + @Override + public String toString() { + return bindingMap.entrySet().stream() + .map(entry -> String.format("%s:%s", entry.getKey(), entry.getValue())) + .collect(Collectors.joining(",", "<", ">")); + } + + public static BindingTuple from(Map map) { + return from(new JSONObject(map)); + } + + public static BindingTuple from(JSONObject json) { + Map map = json.toMap(); + BindingTupleBuilder bindingTupleBuilder = BindingTuple.builder(); + map.forEach((key, value) -> bindingTupleBuilder.binding(key, ExprValueFactory.from(value))); + return bindingTupleBuilder.build(); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/metrics/BasicCounter.java b/legacy/src/main/java/org/opensearch/sql/legacy/metrics/BasicCounter.java index 8bb15eeb74..88d5f817e8 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/metrics/BasicCounter.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/metrics/BasicCounter.java @@ -3,32 +3,31 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.metrics; import java.util.concurrent.atomic.LongAdder; public class BasicCounter implements Counter { - private LongAdder count = new LongAdder(); + private LongAdder count = new LongAdder(); - @Override - public void increment() { - count.increment(); - } + @Override + public void increment() { + count.increment(); + } - @Override - public void add(long n) { - count.add(n); - } + @Override + public void add(long n) { + count.add(n); + } - @Override - public Long getValue() { - return count.longValue(); - } + @Override + public Long getValue() { + return count.longValue(); + } - @Override - public void reset() { - count.reset(); - } + @Override + public void reset() { + count.reset(); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/metrics/Counter.java b/legacy/src/main/java/org/opensearch/sql/legacy/metrics/Counter.java index 7d490704e8..f91731ab0e 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/metrics/Counter.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/metrics/Counter.java @@ -3,16 +3,15 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.metrics; public interface Counter { - void increment(); + void increment(); - void add(long n); + void add(long n); - T getValue(); + T getValue(); - void reset(); + void reset(); } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/parser/CaseWhenParser.java b/legacy/src/main/java/org/opensearch/sql/legacy/parser/CaseWhenParser.java index c711ee2929..d55ee64601 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/parser/CaseWhenParser.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/parser/CaseWhenParser.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.parser; import com.alibaba.druid.sql.ast.SQLExpr; @@ -19,101 +18,119 @@ import org.opensearch.sql.legacy.exception.SqlParseException; import org.opensearch.sql.legacy.utils.Util; -/** - * Created by allwefantasy on 9/3/16. - */ +/** Created by allwefantasy on 9/3/16. */ public class CaseWhenParser { - private SQLCaseExpr caseExpr; - private String alias; - private String tableAlias; - - public CaseWhenParser(SQLCaseExpr caseExpr, String alias, String tableAlias) { - this.alias = alias; - this.tableAlias = tableAlias; - this.caseExpr = caseExpr; + private SQLCaseExpr caseExpr; + private String alias; + private String tableAlias; + + public CaseWhenParser(SQLCaseExpr caseExpr, String alias, String tableAlias) { + this.alias = alias; + this.tableAlias = tableAlias; + this.caseExpr = caseExpr; + } + + public String parse() throws SqlParseException { + List result = new ArrayList<>(); + + if (caseExpr.getValueExpr() != null) { + for (SQLCaseExpr.Item item : caseExpr.getItems()) { + SQLExpr left = caseExpr.getValueExpr(); + SQLExpr right = item.getConditionExpr(); + SQLBinaryOpExpr conditionExpr = + new SQLBinaryOpExpr(left, SQLBinaryOperator.Equality, right); + item.setConditionExpr(conditionExpr); + } + caseExpr.setValueExpr(null); } - public String parse() throws SqlParseException { - List result = new ArrayList<>(); - - if (caseExpr.getValueExpr() != null) { - for (SQLCaseExpr.Item item : caseExpr.getItems()) { - SQLExpr left = caseExpr.getValueExpr(); - SQLExpr right = item.getConditionExpr(); - SQLBinaryOpExpr conditionExpr = new SQLBinaryOpExpr(left, SQLBinaryOperator.Equality, right); - item.setConditionExpr(conditionExpr); - } - caseExpr.setValueExpr(null); - } - - for (SQLCaseExpr.Item item : caseExpr.getItems()) { - SQLExpr conditionExpr = item.getConditionExpr(); - - WhereParser parser = new WhereParser(new SqlParser(), conditionExpr); - String scriptCode = explain(parser.findWhere()); - if (scriptCode.startsWith(" &&")) { - scriptCode = scriptCode.substring(3); - } - if (result.size() == 0) { - result.add("if(" + scriptCode + ")" + "{" + Util.getScriptValueWithQuote(item.getValueExpr(), - "'") + "}"); - } else { - result.add("else if(" + scriptCode + ")" + "{" + Util.getScriptValueWithQuote(item.getValueExpr(), - "'") + "}"); - } - - } - SQLExpr elseExpr = caseExpr.getElseExpr(); - if (elseExpr == null) { - result.add("else { null }"); - } else { - result.add("else {" + Util.getScriptValueWithQuote(elseExpr, "'") + "}"); - } - - - return Joiner.on(" ").join(result); + for (SQLCaseExpr.Item item : caseExpr.getItems()) { + SQLExpr conditionExpr = item.getConditionExpr(); + + WhereParser parser = new WhereParser(new SqlParser(), conditionExpr); + String scriptCode = explain(parser.findWhere()); + if (scriptCode.startsWith(" &&")) { + scriptCode = scriptCode.substring(3); + } + if (result.size() == 0) { + result.add( + "if(" + + scriptCode + + ")" + + "{" + + Util.getScriptValueWithQuote(item.getValueExpr(), "'") + + "}"); + } else { + result.add( + "else if(" + + scriptCode + + ")" + + "{" + + Util.getScriptValueWithQuote(item.getValueExpr(), "'") + + "}"); + } } - - public String explain(Where where) throws SqlParseException { - List codes = new ArrayList<>(); - while (where.getWheres().size() == 1) { - where = where.getWheres().getFirst(); - } - explainWhere(codes, where); - String relation = where.getConn().name().equals("AND") ? " && " : " || "; - return Joiner.on(relation).join(codes); + SQLExpr elseExpr = caseExpr.getElseExpr(); + if (elseExpr == null) { + result.add("else { null }"); + } else { + result.add("else {" + Util.getScriptValueWithQuote(elseExpr, "'") + "}"); } + return Joiner.on(" ").join(result); + } - private void explainWhere(List codes, Where where) throws SqlParseException { - if (where instanceof Condition) { - Condition condition = (Condition) where; - - if (condition.getValue() instanceof ScriptFilter) { - codes.add("(" + ((ScriptFilter) condition.getValue()).getScript() + ")"); - } else if (condition.getOPERATOR() == Condition.OPERATOR.BETWEEN) { - Object[] objs = (Object[]) condition.getValue(); - codes.add("(" + "doc['" + condition.getName() + "'].value >= " + objs[0] + " && doc['" - + condition.getName() + "'].value <=" + objs[1] + ")"); - } else { - SQLExpr nameExpr = condition.getNameExpr(); - SQLExpr valueExpr = condition.getValueExpr(); - if (valueExpr instanceof SQLNullExpr) { - codes.add("(" + "doc['" + nameExpr.toString() + "']" + ".empty)"); - } else { - codes.add("(" + Util.getScriptValueWithQuote(nameExpr, "'") + condition.getOpertatorSymbol() - + Util.getScriptValueWithQuote(valueExpr, "'") + ")"); - } - } + public String explain(Where where) throws SqlParseException { + List codes = new ArrayList<>(); + while (where.getWheres().size() == 1) { + where = where.getWheres().getFirst(); + } + explainWhere(codes, where); + String relation = where.getConn().name().equals("AND") ? " && " : " || "; + return Joiner.on(relation).join(codes); + } + + private void explainWhere(List codes, Where where) throws SqlParseException { + if (where instanceof Condition) { + Condition condition = (Condition) where; + + if (condition.getValue() instanceof ScriptFilter) { + codes.add("(" + ((ScriptFilter) condition.getValue()).getScript() + ")"); + } else if (condition.getOPERATOR() == Condition.OPERATOR.BETWEEN) { + Object[] objs = (Object[]) condition.getValue(); + codes.add( + "(" + + "doc['" + + condition.getName() + + "'].value >= " + + objs[0] + + " && doc['" + + condition.getName() + + "'].value <=" + + objs[1] + + ")"); + } else { + SQLExpr nameExpr = condition.getNameExpr(); + SQLExpr valueExpr = condition.getValueExpr(); + if (valueExpr instanceof SQLNullExpr) { + codes.add("(" + "doc['" + nameExpr.toString() + "']" + ".empty)"); } else { - for (Where subWhere : where.getWheres()) { - List subCodes = new ArrayList<>(); - explainWhere(subCodes, subWhere); - String relation = subWhere.getConn().name().equals("AND") ? "&&" : "||"; - codes.add(Joiner.on(relation).join(subCodes)); - } + codes.add( + "(" + + Util.getScriptValueWithQuote(nameExpr, "'") + + condition.getOpertatorSymbol() + + Util.getScriptValueWithQuote(valueExpr, "'") + + ")"); } + } + } else { + for (Where subWhere : where.getWheres()) { + List subCodes = new ArrayList<>(); + explainWhere(subCodes, subWhere); + String relation = subWhere.getConn().name().equals("AND") ? "&&" : "||"; + codes.add(Joiner.on(relation).join(subCodes)); + } } - + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/parser/ChildrenType.java b/legacy/src/main/java/org/opensearch/sql/legacy/parser/ChildrenType.java index 74945cb94f..27374849df 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/parser/ChildrenType.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/parser/ChildrenType.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.parser; import com.alibaba.druid.sql.ast.SQLExpr; @@ -16,56 +15,55 @@ import org.opensearch.sql.legacy.exception.SqlParseException; import org.opensearch.sql.legacy.utils.Util; -/** - * Created by Razma Tazz on 14/04/2016. - */ +/** Created by Razma Tazz on 14/04/2016. */ public class ChildrenType { - public String field; - public String childType; - public Where where; - private boolean simple; + public String field; + public String childType; + public Where where; + private boolean simple; - public boolean tryFillFromExpr(SQLExpr expr) throws SqlParseException { - if (!(expr instanceof SQLMethodInvokeExpr)) { - return false; - } - SQLMethodInvokeExpr method = (SQLMethodInvokeExpr) expr; - - String methodName = method.getMethodName(); + public boolean tryFillFromExpr(SQLExpr expr) throws SqlParseException { + if (!(expr instanceof SQLMethodInvokeExpr)) { + return false; + } + SQLMethodInvokeExpr method = (SQLMethodInvokeExpr) expr; - if (!methodName.toLowerCase().equals("children")) { - return false; - } + String methodName = method.getMethodName(); - List parameters = method.getParameters(); + if (!methodName.toLowerCase().equals("children")) { + return false; + } - if (parameters.size() != 2) { - throw new SqlParseException( - "on children object only allowed 2 parameters (type, field)/(type, conditions...) "); - } + List parameters = method.getParameters(); - String type = Util.extendedToString(parameters.get(0)); - this.childType = type; + if (parameters.size() != 2) { + throw new SqlParseException( + "on children object only allowed 2 parameters (type, field)/(type, conditions...) "); + } - SQLExpr secondParameter = parameters.get(1); - if (secondParameter instanceof SQLTextLiteralExpr || secondParameter instanceof SQLIdentifierExpr - || secondParameter instanceof SQLPropertyExpr) { - this.field = Util.extendedToString(secondParameter); - this.simple = true; - } else { - Where where = Where.newInstance(); - new WhereParser(new SqlParser()).parseWhere(secondParameter, where); - if (where.getWheres().size() == 0) { - throw new SqlParseException("Failed to parse filter condition"); - } - this.where = where; - simple = false; - } + String type = Util.extendedToString(parameters.get(0)); + this.childType = type; - return true; + SQLExpr secondParameter = parameters.get(1); + if (secondParameter instanceof SQLTextLiteralExpr + || secondParameter instanceof SQLIdentifierExpr + || secondParameter instanceof SQLPropertyExpr) { + this.field = Util.extendedToString(secondParameter); + this.simple = true; + } else { + Where where = Where.newInstance(); + new WhereParser(new SqlParser()).parseWhere(secondParameter, where); + if (where.getWheres().size() == 0) { + throw new SqlParseException("Failed to parse filter condition"); + } + this.where = where; + simple = false; } - public boolean isSimple() { - return simple; - } + return true; + } + + public boolean isSimple() { + return simple; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/parser/ElasticLexer.java b/legacy/src/main/java/org/opensearch/sql/legacy/parser/ElasticLexer.java index 8720c3ba85..67b49fb4ad 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/parser/ElasticLexer.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/parser/ElasticLexer.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.parser; import static com.alibaba.druid.sql.parser.CharTypes.isFirstIdentifierChar; @@ -14,86 +13,82 @@ import com.alibaba.druid.sql.parser.ParserException; import com.alibaba.druid.sql.parser.Token; -/** - * Created by Eliran on 18/8/2015. - */ +/** Created by Eliran on 18/8/2015. */ public class ElasticLexer extends MySqlLexer { - public ElasticLexer(String input) { - super(input); - } + public ElasticLexer(String input) { + super(input); + } + public ElasticLexer(char[] input, int inputLength, boolean skipComment) { + super(input, inputLength, skipComment); + } - public ElasticLexer(char[] input, int inputLength, boolean skipComment) { - super(input, inputLength, skipComment); - } + public void scanIdentifier() { + final char first = ch; + + if (ch == '`') { - public void scanIdentifier() { - final char first = ch; + mark = pos; + bufPos = 1; + char ch; + for (; ; ) { + ch = charAt(++pos); if (ch == '`') { + bufPos++; + ch = charAt(++pos); + break; + } else if (ch == EOI) { + throw new ParserException("illegal identifier"); + } - mark = pos; - bufPos = 1; - char ch; - for (; ; ) { - ch = charAt(++pos); - - if (ch == '`') { - bufPos++; - ch = charAt(++pos); - break; - } else if (ch == EOI) { - throw new ParserException("illegal identifier"); - } - - bufPos++; - continue; - } - - this.ch = charAt(pos); - - stringVal = subString(mark, bufPos); - Token tok = keywods.getKeyword(stringVal); - if (tok != null) { - token = tok; - } else { - token = Token.IDENTIFIER; - } - } else { - - final boolean firstFlag = isFirstIdentifierChar(first); - if (!firstFlag) { - throw new ParserException("illegal identifier"); - } - - mark = pos; - bufPos = 1; - char ch; - for (; ; ) { - ch = charAt(++pos); - - if (!isElasticIdentifierChar(ch)) { - break; - } - - bufPos++; - continue; - } - - this.ch = charAt(pos); - - stringVal = addSymbol(); - Token tok = keywods.getKeyword(stringVal); - if (tok != null) { - token = tok; - } else { - token = Token.IDENTIFIER; - } + bufPos++; + continue; + } + + this.ch = charAt(pos); + + stringVal = subString(mark, bufPos); + Token tok = keywods.getKeyword(stringVal); + if (tok != null) { + token = tok; + } else { + token = Token.IDENTIFIER; + } + } else { + + final boolean firstFlag = isFirstIdentifierChar(first); + if (!firstFlag) { + throw new ParserException("illegal identifier"); + } + + mark = pos; + bufPos = 1; + char ch; + for (; ; ) { + ch = charAt(++pos); + + if (!isElasticIdentifierChar(ch)) { + break; } - } + bufPos++; + continue; + } - private boolean isElasticIdentifierChar(char ch) { - return ch == '*' || ch == ':' || ch == '-' || ch == '.' || ch == ';' || isIdentifierChar(ch); + this.ch = charAt(pos); + + stringVal = addSymbol(); + Token tok = keywods.getKeyword(stringVal); + if (tok != null) { + token = tok; + } else { + token = Token.IDENTIFIER; + } } + } + + private boolean isElasticIdentifierChar(char ch) { + return ch == '*' || ch == ':' || ch == '-' || ch == '.' || ch == ';' || isIdentifierChar(ch); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/AggregationQueryAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/AggregationQueryAction.java index 24194e8de5..57af269001 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/AggregationQueryAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/AggregationQueryAction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query; import com.alibaba.druid.sql.ast.SQLExpr; @@ -38,457 +37,462 @@ import org.opensearch.sql.legacy.query.maker.AggMaker; import org.opensearch.sql.legacy.query.maker.QueryMaker; -/** - * Transform SQL query to OpenSearch aggregations query - */ +/** Transform SQL query to OpenSearch aggregations query */ public class AggregationQueryAction extends QueryAction { - private final Select select; - private AggMaker aggMaker = new AggMaker(); - private SearchRequestBuilder request; - - public AggregationQueryAction(Client client, Select select) { - super(client, select); - this.select = select; - } - - @Override - public SqlOpenSearchRequestBuilder explain() throws SqlParseException { - this.request = new SearchRequestBuilder(client, SearchAction.INSTANCE); - - if (select.getRowCount() == null) { - select.setRowCount(Select.DEFAULT_LIMIT); - } - - setIndicesAndTypes(); - - setWhere(select.getWhere()); - AggregationBuilder lastAgg = null; - - for (List groupBy : select.getGroupBys()) { - if (!groupBy.isEmpty()) { - Field field = groupBy.get(0); - - //make groupby can reference to field alias - lastAgg = getGroupAgg(field, select); - - if (lastAgg instanceof TermsAggregationBuilder) { - - // TODO: Consider removing that condition - // in theory we should be able to apply this for all types of fields, but - // this change requires too much of related integration tests (e.g. there are comparisons against - // raw javascript dsl, so I'd like to scope the changes as of now to one particular fix for - // scripted functions - - // the condition `field.getName().equals("script")` is to include the CAST cases, since the cast - // method is instance of MethodField with script. => corrects the shard size of CASTs - if (!(field instanceof MethodField) || field instanceof ScriptMethodField - || field.getName().equals("script")) { - //if limit size is too small, increasing shard size is required - if (select.getRowCount() < 200) { - ((TermsAggregationBuilder) lastAgg).shardSize(2000); - for (Hint hint : select.getHints()) { - if (hint.getType() == HintType.SHARD_SIZE) { - if (hint.getParams() != null && hint.getParams().length != 0 - && hint.getParams()[0] != null) { - ((TermsAggregationBuilder) lastAgg).shardSize((Integer) hint.getParams()[0]); - } - } - } - } - - if (select.getRowCount() > 0) { - ((TermsAggregationBuilder) lastAgg).size(select.getRowCount()); - } - } - } - - if (field.isNested()) { - AggregationBuilder nestedBuilder = createNestedAggregation(field); - - if (insertFilterIfExistsAfter(lastAgg, groupBy, nestedBuilder, 1)) { - groupBy.remove(1); - } else { - nestedBuilder.subAggregation(lastAgg); - } + private final Select select; + private AggMaker aggMaker = new AggMaker(); + private SearchRequestBuilder request; - request.addAggregation(wrapNestedIfNeeded(nestedBuilder, field.isReverseNested())); - } else if (field.isChildren()) { - AggregationBuilder childrenBuilder = createChildrenAggregation(field); + public AggregationQueryAction(Client client, Select select) { + super(client, select); + this.select = select; + } - if (insertFilterIfExistsAfter(lastAgg, groupBy, childrenBuilder, 1)) { - groupBy.remove(1); - } else { - childrenBuilder.subAggregation(lastAgg); - } + @Override + public SqlOpenSearchRequestBuilder explain() throws SqlParseException { + this.request = new SearchRequestBuilder(client, SearchAction.INSTANCE); - request.addAggregation(childrenBuilder); - } else { - request.addAggregation(lastAgg); - } + if (select.getRowCount() == null) { + select.setRowCount(Select.DEFAULT_LIMIT); + } - for (int i = 1; i < groupBy.size(); i++) { - field = groupBy.get(i); - AggregationBuilder subAgg = getGroupAgg(field, select); - //ES5.0 termsaggregation with size = 0 not supported anymore -// if (subAgg instanceof TermsAggregationBuilder && !(field instanceof MethodField)) { - -// //((TermsAggregationBuilder) subAgg).size(0); -// } - - if (field.isNested()) { - AggregationBuilder nestedBuilder = createNestedAggregation(field); - - if (insertFilterIfExistsAfter(subAgg, groupBy, nestedBuilder, i + 1)) { - groupBy.remove(i + 1); - i++; - } else { - nestedBuilder.subAggregation(subAgg); - } - - lastAgg.subAggregation(wrapNestedIfNeeded(nestedBuilder, field.isReverseNested())); - } else if (field.isChildren()) { - AggregationBuilder childrenBuilder = createChildrenAggregation(field); - - if (insertFilterIfExistsAfter(subAgg, groupBy, childrenBuilder, i + 1)) { - groupBy.remove(i + 1); - i++; - } else { - childrenBuilder.subAggregation(subAgg); - } - - lastAgg.subAggregation(childrenBuilder); - } else { - lastAgg.subAggregation(subAgg); - } - - lastAgg = subAgg; + setIndicesAndTypes(); + + setWhere(select.getWhere()); + AggregationBuilder lastAgg = null; + + for (List groupBy : select.getGroupBys()) { + if (!groupBy.isEmpty()) { + Field field = groupBy.get(0); + + // make groupby can reference to field alias + lastAgg = getGroupAgg(field, select); + + if (lastAgg instanceof TermsAggregationBuilder) { + + // TODO: Consider removing that condition + // in theory we should be able to apply this for all types of fields, but + // this change requires too much of related integration tests (e.g. there are comparisons + // against + // raw javascript dsl, so I'd like to scope the changes as of now to one particular fix + // for + // scripted functions + + // the condition `field.getName().equals("script")` is to include the CAST cases, since + // the cast + // method is instance of MethodField with script. => corrects the shard size of CASTs + if (!(field instanceof MethodField) + || field instanceof ScriptMethodField + || field.getName().equals("script")) { + // if limit size is too small, increasing shard size is required + if (select.getRowCount() < 200) { + ((TermsAggregationBuilder) lastAgg).shardSize(2000); + for (Hint hint : select.getHints()) { + if (hint.getType() == HintType.SHARD_SIZE) { + if (hint.getParams() != null + && hint.getParams().length != 0 + && hint.getParams()[0] != null) { + ((TermsAggregationBuilder) lastAgg).shardSize((Integer) hint.getParams()[0]); + } } + } } - // explain the field from SELECT and HAVING clause - List combinedList = new ArrayList<>(); - combinedList.addAll(select.getFields()); - if (select.getHaving() != null) { - combinedList.addAll(select.getHaving().getHavingFields()); + if (select.getRowCount() > 0) { + ((TermsAggregationBuilder) lastAgg).size(select.getRowCount()); } - // add aggregation function to each groupBy - explanFields(request, combinedList, lastAgg); - - explainHaving(lastAgg); + } } - if (select.getGroupBys().size() < 1) { - //add aggregation when having no groupBy script - explanFields(request, select.getFields(), lastAgg); + if (field.isNested()) { + AggregationBuilder nestedBuilder = createNestedAggregation(field); - } + if (insertFilterIfExistsAfter(lastAgg, groupBy, nestedBuilder, 1)) { + groupBy.remove(1); + } else { + nestedBuilder.subAggregation(lastAgg); + } - Map groupMap = aggMaker.getGroupMap(); - // add field - if (select.getFields().size() > 0) { - setFields(select.getFields()); -// explanFields(request, select.getFields(), lastAgg); - } + request.addAggregation(wrapNestedIfNeeded(nestedBuilder, field.isReverseNested())); + } else if (field.isChildren()) { + AggregationBuilder childrenBuilder = createChildrenAggregation(field); - // add order - if (lastAgg != null && select.getOrderBys().size() > 0) { - for (Order order : select.getOrderBys()) { - - // check "standard" fields - KVValue temp = groupMap.get(order.getName()); - if (temp != null) { - TermsAggregationBuilder termsBuilder = (TermsAggregationBuilder) temp.value; - switch (temp.key) { - case "COUNT": - termsBuilder.order(BucketOrder.count(isASC(order))); - break; - case "KEY": - termsBuilder.order(BucketOrder.key(isASC(order))); - break; - case "FIELD": - termsBuilder.order(BucketOrder.aggregation(order.getName(), isASC(order))); - break; - default: - throw new SqlParseException(order.getName() + " can not to order"); - } - } else if (order.isScript()) { - // Do not add scripted fields into sort, they must be sorted inside of aggregation - } else { - // TODO: Is there a legit case when we want to add field into sort for aggregation queries? - request.addSort(order.getName(), SortOrder.valueOf(order.getType())); - } - } + if (insertFilterIfExistsAfter(lastAgg, groupBy, childrenBuilder, 1)) { + groupBy.remove(1); + } else { + childrenBuilder.subAggregation(lastAgg); + } + + request.addAggregation(childrenBuilder); + } else { + request.addAggregation(lastAgg); } - setLimitFromHint(this.select.getHints()); + for (int i = 1; i < groupBy.size(); i++) { + field = groupBy.get(i); + AggregationBuilder subAgg = getGroupAgg(field, select); + // ES5.0 termsaggregation with size = 0 not supported anymore + // if (subAgg instanceof TermsAggregationBuilder && !(field instanceof + // MethodField)) { - request.setSearchType(SearchType.DEFAULT); - updateRequestWithIndexAndRoutingOptions(select, request); - updateRequestWithHighlight(select, request); - updateRequestWithCollapse(select, request); - updateRequestWithPostFilter(select, request); - return new SqlOpenSearchRequestBuilder(request); - } + // //((TermsAggregationBuilder) subAgg).size(0); + // } - private AggregationBuilder getGroupAgg(Field groupByField, Select select) throws SqlParseException { - AggregationBuilder lastAgg = null; - Field shadowField = null; - - for (Field selectField : select.getFields()) { - if (selectField instanceof MethodField && selectField.getName().equals("script")) { - MethodField scriptField = (MethodField) selectField; - for (KVValue kv : scriptField.getParams()) { - if (kv.value.equals(groupByField.getName())) { - shadowField = scriptField; - break; - } - } - } - } + if (field.isNested()) { + AggregationBuilder nestedBuilder = createNestedAggregation(field); - if (shadowField == null) { - for (Field selectField: select.getFields()) { - if (selectField.getAlias() != null - && (groupByField.getName().equals(selectField.getAlias()) - || groupByField.getExpression().equals(selectField.getExpression()))) { - shadowField = selectField; - } + if (insertFilterIfExistsAfter(subAgg, groupBy, nestedBuilder, i + 1)) { + groupBy.remove(i + 1); + i++; + } else { + nestedBuilder.subAggregation(subAgg); } - } + lastAgg.subAggregation(wrapNestedIfNeeded(nestedBuilder, field.isReverseNested())); + } else if (field.isChildren()) { + AggregationBuilder childrenBuilder = createChildrenAggregation(field); - if (null != shadowField) { - groupByField.setAlias(shadowField.getAlias()); - groupByField = shadowField; - } + if (insertFilterIfExistsAfter(subAgg, groupBy, childrenBuilder, i + 1)) { + groupBy.remove(i + 1); + i++; + } else { + childrenBuilder.subAggregation(subAgg); + } - lastAgg = aggMaker.makeGroupAgg(groupByField); + lastAgg.subAggregation(childrenBuilder); + } else { + lastAgg.subAggregation(subAgg); + } - // find if we have order for that aggregation. As of now only special case for script fields - if (groupByField.isScriptField()) { - addOrderByScriptFieldIfPresent(select, (TermsAggregationBuilder) lastAgg, groupByField.getExpression()); + lastAgg = subAgg; } + } + + // explain the field from SELECT and HAVING clause + List combinedList = new ArrayList<>(); + combinedList.addAll(select.getFields()); + if (select.getHaving() != null) { + combinedList.addAll(select.getHaving().getHavingFields()); + } + // add aggregation function to each groupBy + explanFields(request, combinedList, lastAgg); + + explainHaving(lastAgg); + } - return lastAgg; + if (select.getGroupBys().size() < 1) { + // add aggregation when having no groupBy script + explanFields(request, select.getFields(), lastAgg); } - private void addOrderByScriptFieldIfPresent(Select select, TermsAggregationBuilder groupByAggregation, - SQLExpr groupByExpression) { - // TODO: Explore other ways to correlate different fields/functions in the query (params?) - // This feels like a hacky way, but it's the best that could be done now. - select - .getOrderBys() - .stream() - .filter(order -> groupByExpression.equals(order.getSortField().getExpression())) - .findFirst() - .ifPresent(orderForGroupBy -> groupByAggregation.order(BucketOrder.key(isASC(orderForGroupBy)))); + Map groupMap = aggMaker.getGroupMap(); + // add field + if (select.getFields().size() > 0) { + setFields(select.getFields()); + // explanFields(request, select.getFields(), lastAgg); } - private AggregationBuilder wrapNestedIfNeeded(AggregationBuilder nestedBuilder, boolean reverseNested) { - if (!reverseNested) { - return nestedBuilder; + // add order + if (lastAgg != null && select.getOrderBys().size() > 0) { + for (Order order : select.getOrderBys()) { + + // check "standard" fields + KVValue temp = groupMap.get(order.getName()); + if (temp != null) { + TermsAggregationBuilder termsBuilder = (TermsAggregationBuilder) temp.value; + switch (temp.key) { + case "COUNT": + termsBuilder.order(BucketOrder.count(isASC(order))); + break; + case "KEY": + termsBuilder.order(BucketOrder.key(isASC(order))); + break; + case "FIELD": + termsBuilder.order(BucketOrder.aggregation(order.getName(), isASC(order))); + break; + default: + throw new SqlParseException(order.getName() + " can not to order"); + } + } else if (order.isScript()) { + // Do not add scripted fields into sort, they must be sorted inside of aggregation + } else { + // TODO: Is there a legit case when we want to add field into sort for aggregation + // queries? + request.addSort(order.getName(), SortOrder.valueOf(order.getType())); } - if (reverseNested && !(nestedBuilder instanceof NestedAggregationBuilder)) { - return nestedBuilder; + } + } + + setLimitFromHint(this.select.getHints()); + + request.setSearchType(SearchType.DEFAULT); + updateRequestWithIndexAndRoutingOptions(select, request); + updateRequestWithHighlight(select, request); + updateRequestWithCollapse(select, request); + updateRequestWithPostFilter(select, request); + return new SqlOpenSearchRequestBuilder(request); + } + + private AggregationBuilder getGroupAgg(Field groupByField, Select select) + throws SqlParseException { + AggregationBuilder lastAgg = null; + Field shadowField = null; + + for (Field selectField : select.getFields()) { + if (selectField instanceof MethodField && selectField.getName().equals("script")) { + MethodField scriptField = (MethodField) selectField; + for (KVValue kv : scriptField.getParams()) { + if (kv.value.equals(groupByField.getName())) { + shadowField = scriptField; + break; + } } - //we need to jump back to root - return AggregationBuilders.reverseNested(nestedBuilder.getName() + "_REVERSED").subAggregation(nestedBuilder); + } } - private AggregationBuilder createNestedAggregation(Field field) { - AggregationBuilder nestedBuilder; + if (shadowField == null) { + for (Field selectField : select.getFields()) { + if (selectField.getAlias() != null + && (groupByField.getName().equals(selectField.getAlias()) + || groupByField.getExpression().equals(selectField.getExpression()))) { + shadowField = selectField; + } + } + } - String nestedPath = field.getNestedPath(); + if (null != shadowField) { + groupByField.setAlias(shadowField.getAlias()); + groupByField = shadowField; + } - if (field.isReverseNested()) { - if (nestedPath == null || !nestedPath.startsWith("~")) { - ReverseNestedAggregationBuilder reverseNestedAggregationBuilder = - AggregationBuilders.reverseNested(getNestedAggName(field)); - if (nestedPath != null) { - reverseNestedAggregationBuilder.path(nestedPath); - } - return reverseNestedAggregationBuilder; - } - nestedPath = nestedPath.substring(1); - } + lastAgg = aggMaker.makeGroupAgg(groupByField); - nestedBuilder = AggregationBuilders.nested(getNestedAggName(field), nestedPath); + // find if we have order for that aggregation. As of now only special case for script fields + if (groupByField.isScriptField()) { + addOrderByScriptFieldIfPresent( + select, (TermsAggregationBuilder) lastAgg, groupByField.getExpression()); + } - return nestedBuilder; + return lastAgg; + } + + private void addOrderByScriptFieldIfPresent( + Select select, TermsAggregationBuilder groupByAggregation, SQLExpr groupByExpression) { + // TODO: Explore other ways to correlate different fields/functions in the query (params?) + // This feels like a hacky way, but it's the best that could be done now. + select.getOrderBys().stream() + .filter(order -> groupByExpression.equals(order.getSortField().getExpression())) + .findFirst() + .ifPresent( + orderForGroupBy -> groupByAggregation.order(BucketOrder.key(isASC(orderForGroupBy)))); + } + + private AggregationBuilder wrapNestedIfNeeded( + AggregationBuilder nestedBuilder, boolean reverseNested) { + if (!reverseNested) { + return nestedBuilder; + } + if (reverseNested && !(nestedBuilder instanceof NestedAggregationBuilder)) { + return nestedBuilder; + } + // we need to jump back to root + return AggregationBuilders.reverseNested(nestedBuilder.getName() + "_REVERSED") + .subAggregation(nestedBuilder); + } + + private AggregationBuilder createNestedAggregation(Field field) { + AggregationBuilder nestedBuilder; + + String nestedPath = field.getNestedPath(); + + if (field.isReverseNested()) { + if (nestedPath == null || !nestedPath.startsWith("~")) { + ReverseNestedAggregationBuilder reverseNestedAggregationBuilder = + AggregationBuilders.reverseNested(getNestedAggName(field)); + if (nestedPath != null) { + reverseNestedAggregationBuilder.path(nestedPath); + } + return reverseNestedAggregationBuilder; + } + nestedPath = nestedPath.substring(1); } - private AggregationBuilder createChildrenAggregation(Field field) { - AggregationBuilder childrenBuilder; + nestedBuilder = AggregationBuilders.nested(getNestedAggName(field), nestedPath); - String childType = field.getChildType(); + return nestedBuilder; + } - childrenBuilder = JoinAggregationBuilders.children(getChildrenAggName(field), childType); + private AggregationBuilder createChildrenAggregation(Field field) { + AggregationBuilder childrenBuilder; - return childrenBuilder; - } + String childType = field.getChildType(); - private String getNestedAggName(Field field) { - String prefix; + childrenBuilder = JoinAggregationBuilders.children(getChildrenAggName(field), childType); - if (field instanceof MethodField) { - String nestedPath = field.getNestedPath(); - if (nestedPath != null) { - prefix = nestedPath; - } else { - prefix = field.getAlias(); - } - } else { - prefix = field.getName(); - } - return prefix + "@NESTED"; - } + return childrenBuilder; + } - private String getChildrenAggName(Field field) { - String prefix; + private String getNestedAggName(Field field) { + String prefix; - if (field instanceof MethodField) { - String childType = field.getChildType(); + if (field instanceof MethodField) { + String nestedPath = field.getNestedPath(); + if (nestedPath != null) { + prefix = nestedPath; + } else { + prefix = field.getAlias(); + } + } else { + prefix = field.getName(); + } + return prefix + "@NESTED"; + } + + private String getChildrenAggName(Field field) { + String prefix; + + if (field instanceof MethodField) { + String childType = field.getChildType(); + + if (childType != null) { + prefix = childType; + } else { + prefix = field.getAlias(); + } + } else { + prefix = field.getName(); + } - if (childType != null) { - prefix = childType; - } else { - prefix = field.getAlias(); - } - } else { - prefix = field.getName(); - } + return prefix + "@CHILDREN"; + } - return prefix + "@CHILDREN"; + private boolean insertFilterIfExistsAfter( + AggregationBuilder agg, List groupBy, AggregationBuilder builder, int nextPosition) + throws SqlParseException { + if (groupBy.size() <= nextPosition) { + return false; } - - private boolean insertFilterIfExistsAfter(AggregationBuilder agg, List groupBy, AggregationBuilder builder, - int nextPosition) throws SqlParseException { - if (groupBy.size() <= nextPosition) { - return false; - } - Field filterFieldCandidate = groupBy.get(nextPosition); - if (!(filterFieldCandidate instanceof MethodField)) { - return false; - } - MethodField methodField = (MethodField) filterFieldCandidate; - if (!methodField.getName().toLowerCase().equals("filter")) { - return false; - } - builder.subAggregation(aggMaker.makeGroupAgg(filterFieldCandidate).subAggregation(agg)); - return true; + Field filterFieldCandidate = groupBy.get(nextPosition); + if (!(filterFieldCandidate instanceof MethodField)) { + return false; } - - private AggregationBuilder updateAggIfNested(AggregationBuilder lastAgg, Field field) { - if (field.isNested()) { - lastAgg = AggregationBuilders.nested(field.getName() + "Nested", field.getNestedPath()) - .subAggregation(lastAgg); - } - return lastAgg; + MethodField methodField = (MethodField) filterFieldCandidate; + if (!methodField.getName().toLowerCase().equals("filter")) { + return false; } - - private boolean isASC(Order order) { - return "ASC".equals(order.getType()); + builder.subAggregation(aggMaker.makeGroupAgg(filterFieldCandidate).subAggregation(agg)); + return true; + } + + private AggregationBuilder updateAggIfNested(AggregationBuilder lastAgg, Field field) { + if (field.isNested()) { + lastAgg = + AggregationBuilders.nested(field.getName() + "Nested", field.getNestedPath()) + .subAggregation(lastAgg); } + return lastAgg; + } - private void setFields(List fields) { - if (select.getFields().size() > 0) { - ArrayList includeFields = new ArrayList<>(); + private boolean isASC(Order order) { + return "ASC".equals(order.getType()); + } - for (Field field : fields) { - if (field != null) { - includeFields.add(field.getName()); - } - } + private void setFields(List fields) { + if (select.getFields().size() > 0) { + ArrayList includeFields = new ArrayList<>(); - request.setFetchSource(includeFields.toArray(new String[0]), null); + for (Field field : fields) { + if (field != null) { + includeFields.add(field.getName()); } - } + } - private void explanFields(SearchRequestBuilder request, List fields, AggregationBuilder groupByAgg) - throws SqlParseException { - for (Field field : fields) { - if (field instanceof MethodField) { - - if (field.getName().equals("script")) { - request.addStoredField(field.getAlias()); - DefaultQueryAction defaultQueryAction = new DefaultQueryAction(client, select); - defaultQueryAction.initialize(request); - List tempFields = Lists.newArrayList(field); - defaultQueryAction.setFields(tempFields); - continue; - } - - AggregationBuilder makeAgg = aggMaker - .withWhere(select.getWhere()) - .makeFieldAgg((MethodField) field, groupByAgg); - if (groupByAgg != null) { - groupByAgg.subAggregation(makeAgg); - } else { - request.addAggregation(makeAgg); - } - } else if (field != null) { - request.addStoredField(field.getName()); - } else { - throw new SqlParseException("it did not support this field method " + field); - } - } + request.setFetchSource(includeFields.toArray(new String[0]), null); } - - private void explainHaving(AggregationBuilder lastAgg) throws SqlParseException { - Having having = select.getHaving(); - if (having != null) { - having.explain(lastAgg, select.getFields()); + } + + private void explanFields( + SearchRequestBuilder request, List fields, AggregationBuilder groupByAgg) + throws SqlParseException { + for (Field field : fields) { + if (field instanceof MethodField) { + + if (field.getName().equals("script")) { + request.addStoredField(field.getAlias()); + DefaultQueryAction defaultQueryAction = new DefaultQueryAction(client, select); + defaultQueryAction.initialize(request); + List tempFields = Lists.newArrayList(field); + defaultQueryAction.setFields(tempFields); + continue; } - } - /** - * Create filters based on - * the Where clause. - * - * @param where the 'WHERE' part of the SQL query. - * @throws SqlParseException - */ - private void setWhere(Where where) throws SqlParseException { - BoolQueryBuilder boolQuery = null; - if (where != null) { - boolQuery = QueryMaker.explain(where, this.select.isQuery); - } - // Used to prevent NullPointerException in old tests as they do not set sqlRequest in QueryAction - if (sqlRequest != null) { - boolQuery = sqlRequest.checkAndAddFilter(boolQuery); + AggregationBuilder makeAgg = + aggMaker.withWhere(select.getWhere()).makeFieldAgg((MethodField) field, groupByAgg); + if (groupByAgg != null) { + groupByAgg.subAggregation(makeAgg); + } else { + request.addAggregation(makeAgg); } - request.setQuery(boolQuery); + } else if (field != null) { + request.addStoredField(field.getName()); + } else { + throw new SqlParseException("it did not support this field method " + field); + } } + } - - /** - * Set indices and types to the search request. - */ - private void setIndicesAndTypes() { - request.setIndices(query.getIndexArr()); + private void explainHaving(AggregationBuilder lastAgg) throws SqlParseException { + Having having = select.getHaving(); + if (having != null) { + having.explain(lastAgg, select.getFields()); } - - private void setLimitFromHint(List hints) { - int from = 0; - int size = 0; - for (Hint hint : hints) { - if (hint.getType() == HintType.DOCS_WITH_AGGREGATION) { - Integer[] params = (Integer[]) hint.getParams(); - if (params.length > 1) { - // if 2 or more are given, use the first as the from and the second as the size - // so it is the same as LIMIT from,size - // except written as /*! DOCS_WITH_AGGREGATION(from,size) */ - from = params[0]; - size = params[1]; - } else if (params.length == 1) { - // if only 1 parameter is given, use it as the size with a from of 0 - size = params[0]; - } - break; - } + } + + /** + * Create filters based on the Where clause. + * + * @param where the 'WHERE' part of the SQL query. + * @throws SqlParseException + */ + private void setWhere(Where where) throws SqlParseException { + BoolQueryBuilder boolQuery = null; + if (where != null) { + boolQuery = QueryMaker.explain(where, this.select.isQuery); + } + // Used to prevent NullPointerException in old tests as they do not set sqlRequest in + // QueryAction + if (sqlRequest != null) { + boolQuery = sqlRequest.checkAndAddFilter(boolQuery); + } + request.setQuery(boolQuery); + } + + /** Set indices and types to the search request. */ + private void setIndicesAndTypes() { + request.setIndices(query.getIndexArr()); + } + + private void setLimitFromHint(List hints) { + int from = 0; + int size = 0; + for (Hint hint : hints) { + if (hint.getType() == HintType.DOCS_WITH_AGGREGATION) { + Integer[] params = (Integer[]) hint.getParams(); + if (params.length > 1) { + // if 2 or more are given, use the first as the from and the second as the size + // so it is the same as LIMIT from,size + // except written as /*! DOCS_WITH_AGGREGATION(from,size) */ + from = params[0]; + size = params[1]; + } else if (params.length == 1) { + // if only 1 parameter is given, use it as the size with a from of 0 + size = params[0]; } - request.setFrom(from); - request.setSize(size); + break; + } } + request.setFrom(from); + request.setSize(size); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/DefaultQueryAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/DefaultQueryAction.java index 0ed5043ac8..18c9708df8 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/DefaultQueryAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/DefaultQueryAction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query; import com.alibaba.druid.sql.ast.SQLExpr; @@ -50,264 +49,268 @@ import org.opensearch.sql.legacy.rewriter.nestedfield.NestedFieldProjection; import org.opensearch.sql.legacy.utils.SQLFunctions; -/** - * Transform SQL query to standard OpenSearch search query - */ +/** Transform SQL query to standard OpenSearch search query */ public class DefaultQueryAction extends QueryAction { - private final Select select; - private SearchRequestBuilder request; - - private final List fieldNames = new LinkedList<>(); - - public DefaultQueryAction(Client client, Select select) { - super(client, select); - this.select = select; - } - - public void initialize(SearchRequestBuilder request) { - this.request = request; + private final Select select; + private SearchRequestBuilder request; + + private final List fieldNames = new LinkedList<>(); + + public DefaultQueryAction(Client client, Select select) { + super(client, select); + this.select = select; + } + + public void initialize(SearchRequestBuilder request) { + this.request = request; + } + + @Override + public SqlOpenSearchRequestBuilder explain() throws SqlParseException { + Objects.requireNonNull(this.sqlRequest, "SqlRequest is required for OpenSearch request build"); + buildRequest(); + checkAndSetScroll(); + return new SqlOpenSearchRequestBuilder(request); + } + + private void buildRequest() throws SqlParseException { + this.request = new SearchRequestBuilder(client, SearchAction.INSTANCE); + setIndicesAndTypes(); + setFields(select.getFields()); + setWhere(select.getWhere()); + setSorts(select.getOrderBys()); + updateRequestWithIndexAndRoutingOptions(select, request); + updateRequestWithHighlight(select, request); + updateRequestWithCollapse(select, request); + updateRequestWithPostFilter(select, request); + updateRequestWithInnerHits(select, request); + } + + @VisibleForTesting + public void checkAndSetScroll() { + LocalClusterState clusterState = LocalClusterState.state(); + + Integer fetchSize = sqlRequest.fetchSize(); + TimeValue timeValue = clusterState.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE); + Integer rowCount = select.getRowCount(); + + if (checkIfScrollNeeded(fetchSize, rowCount)) { + Metrics.getInstance() + .getNumericalMetric(MetricName.DEFAULT_CURSOR_REQUEST_COUNT_TOTAL) + .increment(); + Metrics.getInstance().getNumericalMetric(MetricName.DEFAULT_CURSOR_REQUEST_TOTAL).increment(); + request.setSize(fetchSize).setScroll(timeValue); + } else { + request.setSearchType(SearchType.DFS_QUERY_THEN_FETCH); + setLimit(select.getOffset(), rowCount != null ? rowCount : Select.DEFAULT_LIMIT); } - - @Override - public SqlOpenSearchRequestBuilder explain() throws SqlParseException { - Objects.requireNonNull(this.sqlRequest, "SqlRequest is required for OpenSearch request build"); - buildRequest(); - checkAndSetScroll(); - return new SqlOpenSearchRequestBuilder(request); - } - - private void buildRequest() throws SqlParseException { - this.request = new SearchRequestBuilder(client, SearchAction.INSTANCE); - setIndicesAndTypes(); - setFields(select.getFields()); - setWhere(select.getWhere()); - setSorts(select.getOrderBys()); - updateRequestWithIndexAndRoutingOptions(select, request); - updateRequestWithHighlight(select, request); - updateRequestWithCollapse(select, request); - updateRequestWithPostFilter(select, request); - updateRequestWithInnerHits(select, request); - } - - @VisibleForTesting - public void checkAndSetScroll() { - LocalClusterState clusterState = LocalClusterState.state(); - - Integer fetchSize = sqlRequest.fetchSize(); - TimeValue timeValue = clusterState.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE); - Integer rowCount = select.getRowCount(); - - if (checkIfScrollNeeded(fetchSize, rowCount)) { - Metrics.getInstance().getNumericalMetric(MetricName.DEFAULT_CURSOR_REQUEST_COUNT_TOTAL).increment(); - Metrics.getInstance().getNumericalMetric(MetricName.DEFAULT_CURSOR_REQUEST_TOTAL).increment(); - request.setSize(fetchSize).setScroll(timeValue); - } else { - request.setSearchType(SearchType.DFS_QUERY_THEN_FETCH); - setLimit(select.getOffset(), rowCount != null ? rowCount : Select.DEFAULT_LIMIT); + } + + private boolean checkIfScrollNeeded(Integer fetchSize, Integer rowCount) { + return (format != null && format.equals(Format.JDBC)) + && fetchSize > 0 + && (rowCount == null || (rowCount > fetchSize)); + } + + @Override + public Optional> getFieldNames() { + return Optional.of(fieldNames); + } + + public Select getSelect() { + return select; + } + + /** Set indices and types to the search request. */ + private void setIndicesAndTypes() { + request.setIndices(query.getIndexArr()); + } + + /** + * Set source filtering on a search request. + * + * @param fields list of fields to source filter. + */ + public void setFields(List fields) throws SqlParseException { + + if (!select.getFields().isEmpty() && !select.isSelectAll()) { + ArrayList includeFields = new ArrayList<>(); + ArrayList excludeFields = new ArrayList<>(); + + for (Field field : fields) { + if (field instanceof MethodField) { + MethodField method = (MethodField) field; + if (method.getName().toLowerCase().equals("script")) { + handleScriptField(method); + if (method.getExpression() instanceof SQLCastExpr) { + includeFields.add(method.getParams().get(0).toString()); + } + } else if (method.getName().equalsIgnoreCase("include")) { + for (KVValue kvValue : method.getParams()) { + includeFields.add(kvValue.value.toString()); + } + } else if (method.getName().equalsIgnoreCase("exclude")) { + for (KVValue kvValue : method.getParams()) { + excludeFields.add(kvValue.value.toString()); + } + } + } else if (field != null) { + if (isNotNested(field)) { + includeFields.add(field.getName()); + } } - } - - - private boolean checkIfScrollNeeded(Integer fetchSize, Integer rowCount) { - return (format != null && format.equals(Format.JDBC)) - && fetchSize > 0 - && (rowCount == null || (rowCount > fetchSize)); - } + } - @Override - public Optional> getFieldNames() { - return Optional.of(fieldNames); + fieldNames.addAll(includeFields); + request.setFetchSource( + includeFields.toArray(new String[0]), excludeFields.toArray(new String[0])); } + } + private void handleScriptField(final MethodField method) throws SqlParseException { - public Select getSelect() { - return select; - } - - /** - * Set indices and types to the search request. - */ - private void setIndicesAndTypes() { - request.setIndices(query.getIndexArr()); - } - - /** - * Set source filtering on a search request. - * - * @param fields list of fields to source filter. - */ - public void setFields(List fields) throws SqlParseException { - - if (!select.getFields().isEmpty() && !select.isSelectAll()) { - ArrayList includeFields = new ArrayList<>(); - ArrayList excludeFields = new ArrayList<>(); - - for (Field field : fields) { - if (field instanceof MethodField) { - MethodField method = (MethodField) field; - if (method.getName().toLowerCase().equals("script")) { - handleScriptField(method); - if (method.getExpression() instanceof SQLCastExpr) { - includeFields.add(method.getParams().get(0).toString()); - } - } else if (method.getName().equalsIgnoreCase("include")) { - for (KVValue kvValue : method.getParams()) { - includeFields.add(kvValue.value.toString()); - } - } else if (method.getName().equalsIgnoreCase("exclude")) { - for (KVValue kvValue : method.getParams()) { - excludeFields.add(kvValue.value.toString()); - } - } - } else if (field != null) { - if (isNotNested(field)) { - includeFields.add(field.getName()); - } - } - } + final List params = method.getParams(); + final int numOfParams = params.size(); - fieldNames.addAll(includeFields); - request.setFetchSource(includeFields.toArray(new String[0]), excludeFields.toArray(new String[0])); - } + if (2 != numOfParams && 3 != numOfParams) { + throw new SqlParseException( + "scripted_field only allows 'script(name,script)' " + "or 'script(name,lang,script)'"); } - private void handleScriptField(final MethodField method) throws SqlParseException { - - final List params = method.getParams(); - final int numOfParams = params.size(); - - if (2 != numOfParams && 3 != numOfParams) { - throw new SqlParseException("scripted_field only allows 'script(name,script)' " - + "or 'script(name,lang,script)'"); - } - - final String fieldName = params.get(0).value.toString(); - fieldNames.add(fieldName); - - final String secondParam = params.get(1).value.toString(); - final Script script = (2 == numOfParams) ? new Script(secondParam) : - new Script(ScriptType.INLINE, secondParam, params.get(2).value.toString(), Collections.emptyMap()); - request.addScriptField(fieldName, script); + final String fieldName = params.get(0).value.toString(); + fieldNames.add(fieldName); + + final String secondParam = params.get(1).value.toString(); + final Script script = + (2 == numOfParams) + ? new Script(secondParam) + : new Script( + ScriptType.INLINE, + secondParam, + params.get(2).value.toString(), + Collections.emptyMap()); + request.addScriptField(fieldName, script); + } + + /** + * Create filters or queries based on the Where clause. + * + * @param where the 'WHERE' part of the SQL query. + * @throws SqlParseException if the where clause does not represent valid sql + */ + private void setWhere(Where where) throws SqlParseException { + BoolQueryBuilder boolQuery = null; + if (where != null) { + boolQuery = QueryMaker.explain(where, this.select.isQuery); } - - /** - * Create filters or queries based on the Where clause. - * - * @param where the 'WHERE' part of the SQL query. - * @throws SqlParseException if the where clause does not represent valid sql - */ - private void setWhere(Where where) throws SqlParseException { - BoolQueryBuilder boolQuery = null; - if (where != null) { - boolQuery = QueryMaker.explain(where, this.select.isQuery); - } - // Used to prevent NullPointerException in old tests as they do not set sqlRequest in QueryAction - if (sqlRequest != null) { - boolQuery = sqlRequest.checkAndAddFilter(boolQuery); - } - request.setQuery(boolQuery); + // Used to prevent NullPointerException in old tests as they do not set sqlRequest in + // QueryAction + if (sqlRequest != null) { + boolQuery = sqlRequest.checkAndAddFilter(boolQuery); } - - /** - * Add sorts to the OpenSearch query based on the 'ORDER BY' clause. - * - * @param orderBys list of Order object - */ - private void setSorts(List orderBys) { - Map sortBuilderMap = new HashMap<>(); - - for (Order order : orderBys) { - String orderByName = order.getName(); - SortOrder sortOrder = SortOrder.valueOf(order.getType()); - - if (order.getNestedPath() != null) { - request.addSort( - SortBuilders.fieldSort(orderByName) - .order(sortOrder) - .setNestedSort(new NestedSortBuilder(order.getNestedPath()))); - } else if (order.isScript()) { - // TODO: Investigate how to find the type of expression (string or number) - // As of now this shouldn't be a problem, because the support is for date_format function - request.addSort( - SortBuilders - .scriptSort(new Script(orderByName), getScriptSortType(order)) - .order(sortOrder)); - } else if (orderByName.equals(ScoreSortBuilder.NAME)) { - request.addSort(orderByName, sortOrder); - } else { - FieldSortBuilder fieldSortBuilder = sortBuilderMap.computeIfAbsent(orderByName, key -> { - FieldSortBuilder fs = SortBuilders.fieldSort(key); - request.addSort(fs); - return fs; + request.setQuery(boolQuery); + } + + /** + * Add sorts to the OpenSearch query based on the 'ORDER BY' clause. + * + * @param orderBys list of Order object + */ + private void setSorts(List orderBys) { + Map sortBuilderMap = new HashMap<>(); + + for (Order order : orderBys) { + String orderByName = order.getName(); + SortOrder sortOrder = SortOrder.valueOf(order.getType()); + + if (order.getNestedPath() != null) { + request.addSort( + SortBuilders.fieldSort(orderByName) + .order(sortOrder) + .setNestedSort(new NestedSortBuilder(order.getNestedPath()))); + } else if (order.isScript()) { + // TODO: Investigate how to find the type of expression (string or number) + // As of now this shouldn't be a problem, because the support is for date_format function + request.addSort( + SortBuilders.scriptSort(new Script(orderByName), getScriptSortType(order)) + .order(sortOrder)); + } else if (orderByName.equals(ScoreSortBuilder.NAME)) { + request.addSort(orderByName, sortOrder); + } else { + FieldSortBuilder fieldSortBuilder = + sortBuilderMap.computeIfAbsent( + orderByName, + key -> { + FieldSortBuilder fs = SortBuilders.fieldSort(key); + request.addSort(fs); + return fs; }); - setSortParams(fieldSortBuilder, order); - } - } + setSortParams(fieldSortBuilder, order); + } } + } + private void setSortParams(FieldSortBuilder fieldSortBuilder, Order order) { + fieldSortBuilder.order(SortOrder.valueOf(order.getType())); - private void setSortParams(FieldSortBuilder fieldSortBuilder, Order order) { - fieldSortBuilder.order(SortOrder.valueOf(order.getType())); - - SQLExpr expr = order.getSortField().getExpression(); - if (expr instanceof SQLBinaryOpExpr) { - // we set SQLBinaryOpExpr in Field.setExpression() to support ORDER by IS NULL/IS NOT NULL - fieldSortBuilder.missing(getNullOrderString((SQLBinaryOpExpr) expr)); - } + SQLExpr expr = order.getSortField().getExpression(); + if (expr instanceof SQLBinaryOpExpr) { + // we set SQLBinaryOpExpr in Field.setExpression() to support ORDER by IS NULL/IS NOT NULL + fieldSortBuilder.missing(getNullOrderString((SQLBinaryOpExpr) expr)); } - - private String getNullOrderString(SQLBinaryOpExpr expr) { - SQLBinaryOperator operator = expr.getOperator(); - return operator == SQLBinaryOperator.IsNot ? "_first" : "_last"; + } + + private String getNullOrderString(SQLBinaryOpExpr expr) { + SQLBinaryOperator operator = expr.getOperator(); + return operator == SQLBinaryOperator.IsNot ? "_first" : "_last"; + } + + private ScriptSortType getScriptSortType(Order order) { + ScriptSortType scriptSortType; + Schema.Type scriptFunctionReturnType = SQLFunctions.getOrderByFieldType(order.getSortField()); + + // as of now script function return type returns only text and double + switch (scriptFunctionReturnType) { + case TEXT: + scriptSortType = ScriptSortType.STRING; + break; + + case DOUBLE: + case FLOAT: + case INTEGER: + case LONG: + scriptSortType = ScriptSortType.NUMBER; + break; + default: + throw new IllegalStateException("Unknown type: " + scriptFunctionReturnType); } - - private ScriptSortType getScriptSortType(Order order) { - ScriptSortType scriptSortType; - Schema.Type scriptFunctionReturnType = SQLFunctions.getOrderByFieldType(order.getSortField()); - - - // as of now script function return type returns only text and double - switch (scriptFunctionReturnType) { - case TEXT: - scriptSortType = ScriptSortType.STRING; - break; - - case DOUBLE: - case FLOAT: - case INTEGER: - case LONG: - scriptSortType = ScriptSortType.NUMBER; - break; - default: - throw new IllegalStateException("Unknown type: " + scriptFunctionReturnType); - } - return scriptSortType; + return scriptSortType; + } + + /** + * Add from and size to the OpenSearch query based on the 'LIMIT' clause + * + * @param from starts from document at position from + * @param size number of documents to return. + */ + private void setLimit(int from, int size) { + request.setFrom(from); + + if (size > -1) { + request.setSize(size); } + } - /** - * Add from and size to the OpenSearch query based on the 'LIMIT' clause - * - * @param from starts from document at position from - * @param size number of documents to return. - */ - private void setLimit(int from, int size) { - request.setFrom(from); - - if (size > -1) { - request.setSize(size); - } - } + public SearchRequestBuilder getRequestBuilder() { + return request; + } - public SearchRequestBuilder getRequestBuilder() { - return request; - } + private boolean isNotNested(Field field) { + return !field.isNested() || field.isReverseNested(); + } - private boolean isNotNested(Field field) { - return !field.isNested() || field.isReverseNested(); - } - - private void updateRequestWithInnerHits(Select select, SearchRequestBuilder request) { - new NestedFieldProjection(request).project(select.getFields(), select.getNestedJoinType()); - } + private void updateRequestWithInnerHits(Select select, SearchRequestBuilder request) { + new NestedFieldProjection(request).project(select.getFields(), select.getNestedJoinType()); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/DeleteQueryAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/DeleteQueryAction.java index 892c5aeb2d..331921345f 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/DeleteQueryAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/DeleteQueryAction.java @@ -3,10 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query; - import org.opensearch.client.Client; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; @@ -20,50 +18,44 @@ public class DeleteQueryAction extends QueryAction { - private final Delete delete; - private DeleteByQueryRequestBuilder request; - - public DeleteQueryAction(Client client, Delete delete) { - super(client, delete); - this.delete = delete; - } - - @Override - public SqlElasticDeleteByQueryRequestBuilder explain() throws SqlParseException { - this.request = new DeleteByQueryRequestBuilder(client, DeleteByQueryAction.INSTANCE); - - setIndicesAndTypes(); - setWhere(delete.getWhere()); - SqlElasticDeleteByQueryRequestBuilder deleteByQueryRequestBuilder = - new SqlElasticDeleteByQueryRequestBuilder(request); - return deleteByQueryRequestBuilder; - } - - - /** - * Set indices and types to the delete by query request. - */ - private void setIndicesAndTypes() { - - DeleteByQueryRequest innerRequest = request.request(); - innerRequest.indices(query.getIndexArr()); + private final Delete delete; + private DeleteByQueryRequestBuilder request; + + public DeleteQueryAction(Client client, Delete delete) { + super(client, delete); + this.delete = delete; + } + + @Override + public SqlElasticDeleteByQueryRequestBuilder explain() throws SqlParseException { + this.request = new DeleteByQueryRequestBuilder(client, DeleteByQueryAction.INSTANCE); + + setIndicesAndTypes(); + setWhere(delete.getWhere()); + SqlElasticDeleteByQueryRequestBuilder deleteByQueryRequestBuilder = + new SqlElasticDeleteByQueryRequestBuilder(request); + return deleteByQueryRequestBuilder; + } + + /** Set indices and types to the delete by query request. */ + private void setIndicesAndTypes() { + + DeleteByQueryRequest innerRequest = request.request(); + innerRequest.indices(query.getIndexArr()); + } + + /** + * Create filters based on the Where clause. + * + * @param where the 'WHERE' part of the SQL query. + * @throws SqlParseException + */ + private void setWhere(Where where) throws SqlParseException { + if (where != null) { + QueryBuilder whereQuery = QueryMaker.explain(where); + request.filter(whereQuery); + } else { + request.filter(QueryBuilders.matchAllQuery()); } - - - /** - * Create filters based on - * the Where clause. - * - * @param where the 'WHERE' part of the SQL query. - * @throws SqlParseException - */ - private void setWhere(Where where) throws SqlParseException { - if (where != null) { - QueryBuilder whereQuery = QueryMaker.explain(where); - request.filter(whereQuery); - } else { - request.filter(QueryBuilders.matchAllQuery()); - } - } - + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/DescribeQueryAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/DescribeQueryAction.java index 077d9c28b8..ffc9695d81 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/DescribeQueryAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/DescribeQueryAction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query; import org.opensearch.action.admin.indices.get.GetIndexRequestBuilder; @@ -14,22 +13,23 @@ public class DescribeQueryAction extends QueryAction { - private final IndexStatement statement; + private final IndexStatement statement; - public DescribeQueryAction(Client client, IndexStatement statement) { - super(client, null); - this.statement = statement; - } + public DescribeQueryAction(Client client, IndexStatement statement) { + super(client, null); + this.statement = statement; + } - @Override - public QueryStatement getQueryStatement() { - return statement; - } + @Override + public QueryStatement getQueryStatement() { + return statement; + } - @Override - public SqlOpenSearchRequestBuilder explain() { - final GetIndexRequestBuilder indexRequestBuilder = Util.prepareIndexRequestBuilder(client, statement); + @Override + public SqlOpenSearchRequestBuilder explain() { + final GetIndexRequestBuilder indexRequestBuilder = + Util.prepareIndexRequestBuilder(client, statement); - return new SqlOpenSearchRequestBuilder(indexRequestBuilder); - } + return new SqlOpenSearchRequestBuilder(indexRequestBuilder); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/join/BackOffRetryStrategy.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/join/BackOffRetryStrategy.java index 06ec21247a..d767268cb1 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/join/BackOffRetryStrategy.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/join/BackOffRetryStrategy.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.join; import java.util.ArrayList; @@ -22,198 +21,198 @@ public class BackOffRetryStrategy { - private static final Logger LOG = LogManager.getLogger(); - - /** - * Interval (ms) between each retry - */ - private static final long[] intervals = milliseconds(new double[]{4, 8 + 4, 16 + 4}); + private static final Logger LOG = LogManager.getLogger(); - /** - * Delta to randomize interval (ms) - */ - private static final long delta = 4 * 1000; + /** Interval (ms) between each retry */ + private static final long[] intervals = milliseconds(new double[] {4, 8 + 4, 16 + 4}); - private static final int threshold = 85; + /** Delta to randomize interval (ms) */ + private static final long delta = 4 * 1000; - private static IdentityHashMap> memUse = new IdentityHashMap<>(); + private static final int threshold = 85; - private static AtomicLong mem = new AtomicLong(0L); + private static IdentityHashMap> memUse = new IdentityHashMap<>(); - private static long lastTimeoutCleanTime = System.currentTimeMillis(); + private static AtomicLong mem = new AtomicLong(0L); - private static final long RELTIMEOUT = 1000 * 60 * 30; + private static long lastTimeoutCleanTime = System.currentTimeMillis(); - private static final int MAXRETRIES = 999; + private static final long RELTIMEOUT = 1000 * 60 * 30; - private static final Object obj = new Object(); + private static final int MAXRETRIES = 999; - public static final Supplier GET_CB_STATE = () -> isMemoryHealthy() ? 0 : 1; + private static final Object obj = new Object(); - private BackOffRetryStrategy() { + public static final Supplier GET_CB_STATE = () -> isMemoryHealthy() ? 0 : 1; - } + private BackOffRetryStrategy() {} - private static boolean isMemoryHealthy() { - final long freeMemory = Runtime.getRuntime().freeMemory(); - final long totalMemory = Runtime.getRuntime().totalMemory(); - final int memoryUsage = (int) Math.round((double) (totalMemory - freeMemory + mem.get()) - / (double) totalMemory * 100); + private static boolean isMemoryHealthy() { + final long freeMemory = Runtime.getRuntime().freeMemory(); + final long totalMemory = Runtime.getRuntime().totalMemory(); + final int memoryUsage = + (int) + Math.round( + (double) (totalMemory - freeMemory + mem.get()) / (double) totalMemory * 100); - LOG.debug("[MCB1] Memory total, free, allocate: {}, {}, {}", totalMemory, freeMemory, mem.get()); - LOG.debug("[MCB1] Memory usage and limit: {}%, {}%", memoryUsage, threshold); + LOG.debug( + "[MCB1] Memory total, free, allocate: {}, {}, {}", totalMemory, freeMemory, mem.get()); + LOG.debug("[MCB1] Memory usage and limit: {}%, {}%", memoryUsage, threshold); - return memoryUsage < threshold; - } + return memoryUsage < threshold; + } - public static boolean isHealthy() { - for (int i = 0; i < intervals.length; i++) { - if (isMemoryHealthy()) { - return true; - } - - LOG.warn("[MCB1] Memory monitor is unhealthy now, back off retrying: {} attempt, thread id = {}", - i, Thread.currentThread().getId()); - if (ThreadLocalRandom.current().nextBoolean()) { - Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_CB).increment(); - LOG.warn("[MCB1] Directly abort on idx {}.", i); - return false; - } - backOffSleep(intervals[i]); - } - - boolean isHealthy = isMemoryHealthy(); - if (!isHealthy) { - Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_CB).increment(); - } + public static boolean isHealthy() { + for (int i = 0; i < intervals.length; i++) { + if (isMemoryHealthy()) { + return true; + } - return isHealthy; + LOG.warn( + "[MCB1] Memory monitor is unhealthy now, back off retrying: {} attempt, thread id = {}", + i, + Thread.currentThread().getId()); + if (ThreadLocalRandom.current().nextBoolean()) { + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_CB).increment(); + LOG.warn("[MCB1] Directly abort on idx {}.", i); + return false; + } + backOffSleep(intervals[i]); } - private static boolean isMemoryHealthy(long allocateMemory, int idx, Object key) { - long logMem = mem.get(); - - releaseTimeoutMemory(); - if (idx == 0 && allocateMemory > 0) { - logMem = mem.addAndGet(allocateMemory); - synchronized (BackOffRetryStrategy.class) { - if (memUse.containsKey(key)) { - memUse.put(key, Tuple.tuple(memUse.get(key).v1(), memUse.get(key).v2() + allocateMemory)); - } else { - memUse.put(key, Tuple.tuple(System.currentTimeMillis(), allocateMemory)); - } - } - } - - final long freeMemory = Runtime.getRuntime().freeMemory(); - final long totalMemory = Runtime.getRuntime().totalMemory(); - final int memoryUsage = (int) Math.round((double) (totalMemory - freeMemory + logMem) - / (double) totalMemory * 100); + boolean isHealthy = isMemoryHealthy(); + if (!isHealthy) { + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_CB).increment(); + } - LOG.debug("[MCB] Idx is {}", idx); - LOG.debug("[MCB] Memory total, free, allocate: {}, {}, {}, {}", totalMemory, freeMemory, - allocateMemory, logMem); - LOG.debug("[MCB] Memory usage and limit: {}%, {}%", memoryUsage, threshold); + return isHealthy; + } - return memoryUsage < threshold; + private static boolean isMemoryHealthy(long allocateMemory, int idx, Object key) { + long logMem = mem.get(); + releaseTimeoutMemory(); + if (idx == 0 && allocateMemory > 0) { + logMem = mem.addAndGet(allocateMemory); + synchronized (BackOffRetryStrategy.class) { + if (memUse.containsKey(key)) { + memUse.put(key, Tuple.tuple(memUse.get(key).v1(), memUse.get(key).v2() + allocateMemory)); + } else { + memUse.put(key, Tuple.tuple(System.currentTimeMillis(), allocateMemory)); + } + } } - public static boolean isHealthy(long allocateMemory, Object key) { - if (key == null) { - key = obj; - } + final long freeMemory = Runtime.getRuntime().freeMemory(); + final long totalMemory = Runtime.getRuntime().totalMemory(); + final int memoryUsage = + (int) Math.round((double) (totalMemory - freeMemory + logMem) / (double) totalMemory * 100); + + LOG.debug("[MCB] Idx is {}", idx); + LOG.debug( + "[MCB] Memory total, free, allocate: {}, {}, {}, {}", + totalMemory, + freeMemory, + allocateMemory, + logMem); + LOG.debug("[MCB] Memory usage and limit: {}%, {}%", memoryUsage, threshold); + + return memoryUsage < threshold; + } + + public static boolean isHealthy(long allocateMemory, Object key) { + if (key == null) { + key = obj; + } - for (int i = 0; i < intervals.length; i++) { - if (isMemoryHealthy(allocateMemory, i, key)) { - return true; - } - - LOG.warn("[MCB] Memory monitor is unhealthy now, back off retrying: {} attempt, " - + "executor = {}, thread id = {}", i, key, Thread.currentThread().getId()); - if (ThreadLocalRandom.current().nextBoolean()) { - LOG.warn("[MCB] Directly abort on idx {}, executor is {}.", i, key); - return false; - } - backOffSleep(intervals[i]); - } - return isMemoryHealthy(allocateMemory, MAXRETRIES, key); + for (int i = 0; i < intervals.length; i++) { + if (isMemoryHealthy(allocateMemory, i, key)) { + return true; + } + + LOG.warn( + "[MCB] Memory monitor is unhealthy now, back off retrying: {} attempt, " + + "executor = {}, thread id = {}", + i, + key, + Thread.currentThread().getId()); + if (ThreadLocalRandom.current().nextBoolean()) { + LOG.warn("[MCB] Directly abort on idx {}, executor is {}.", i, key); + return false; + } + backOffSleep(intervals[i]); } + return isMemoryHealthy(allocateMemory, MAXRETRIES, key); + } - public static void backOffSleep(long interval) { - try { - long millis = randomize(interval); + public static void backOffSleep(long interval) { + try { + long millis = randomize(interval); - LOG.info("[MCB] Back off sleeping: {} ms", millis); - Thread.sleep(millis); - } catch (InterruptedException e) { - LOG.error("[MCB] Sleep interrupted", e); - } + LOG.info("[MCB] Back off sleeping: {} ms", millis); + Thread.sleep(millis); + } catch (InterruptedException e) { + LOG.error("[MCB] Sleep interrupted", e); } - - /** - * Generate random interval in [interval-delta, interval+delta) - */ - private static long randomize(long interval) { - // Random number within range generator for JDK 7+ - return ThreadLocalRandom.current().nextLong( - lowerBound(interval), upperBound(interval) - ); + } + + /** Generate random interval in [interval-delta, interval+delta) */ + private static long randomize(long interval) { + // Random number within range generator for JDK 7+ + return ThreadLocalRandom.current().nextLong(lowerBound(interval), upperBound(interval)); + } + + private static long lowerBound(long interval) { + return Math.max(0, interval - delta); + } + + private static long upperBound(long interval) { + return interval + delta; + } + + private static long[] milliseconds(double[] seconds) { + return Arrays.stream(seconds).mapToLong((second) -> (long) (1000 * second)).toArray(); + } + + public static void releaseMem(Object key) { + LOG.debug("[MCB] mem is {} before release", mem); + long v = 0L; + synchronized (BackOffRetryStrategy.class) { + if (memUse.containsKey(key)) { + v = memUse.get(key).v2(); + memUse.remove(key); + } } - - private static long lowerBound(long interval) { - return Math.max(0, interval - delta); + if (v > 0) { + atomicMinusLowBoundZero(mem, v); } + LOG.debug("[MCB] mem is {} after release", mem); + } - private static long upperBound(long interval) { - return interval + delta; + private static void releaseTimeoutMemory() { + long cur = System.currentTimeMillis(); + if (cur - lastTimeoutCleanTime < RELTIMEOUT) { + return; } - private static long[] milliseconds(double[] seconds) { - return Arrays.stream(seconds). - mapToLong((second) -> (long) (1000 * second)). - toArray(); + List bulks = new ArrayList<>(); + Predicate> isTimeout = t -> cur - t.v1() > RELTIMEOUT; + synchronized (BackOffRetryStrategy.class) { + memUse.values().stream().filter(isTimeout).forEach(v -> bulks.add(v.v2())); + memUse.values().removeIf(isTimeout); } - public static void releaseMem(Object key) { - LOG.debug("[MCB] mem is {} before release", mem); - long v = 0L; - synchronized (BackOffRetryStrategy.class) { - if (memUse.containsKey(key)) { - v = memUse.get(key).v2(); - memUse.remove(key); - } - } - if (v > 0) { - atomicMinusLowBoundZero(mem, v); - } - LOG.debug("[MCB] mem is {} after release", mem); + for (long v : bulks) { + atomicMinusLowBoundZero(mem, v); } - private static void releaseTimeoutMemory() { - long cur = System.currentTimeMillis(); - if (cur - lastTimeoutCleanTime < RELTIMEOUT) { - return; - } + lastTimeoutCleanTime = cur; + } - List bulks = new ArrayList<>(); - Predicate> isTimeout = t -> cur - t.v1() > RELTIMEOUT; - synchronized (BackOffRetryStrategy.class) { - memUse.values().stream().filter(isTimeout).forEach(v -> bulks.add(v.v2())); - memUse.values().removeIf(isTimeout); - } - - for (long v : bulks) { - atomicMinusLowBoundZero(mem, v); - } - - lastTimeoutCleanTime = cur; - } - - private static void atomicMinusLowBoundZero(AtomicLong x, Long y) { - long memRes = x.addAndGet(-y); - if (memRes < 0) { - x.compareAndSet(memRes, 0L); - } + private static void atomicMinusLowBoundZero(AtomicLong x, Long y) { + long memRes = x.addAndGet(-y); + if (memRes < 0) { + x.compareAndSet(memRes, 0L); } + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/maker/AggMaker.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/maker/AggMaker.java index 0c9caab03d..19ccbff5a0 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/maker/AggMaker.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/maker/AggMaker.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.maker; import com.alibaba.druid.sql.ast.expr.SQLAggregateOption; @@ -65,758 +64,788 @@ public class AggMaker { - /** - * The mapping bettwen group fieldName or Alias to the KVValue. - */ - private Map groupMap = new HashMap<>(); - private Where where; - - /** - * - * - * @param field - * @return - * @throws SqlParseException - */ - public AggregationBuilder makeGroupAgg(Field field) throws SqlParseException { - - if (field instanceof MethodField && field.getName().equals("script")) { - MethodField methodField = (MethodField) field; - TermsAggregationBuilder termsBuilder = AggregationBuilders.terms(methodField.getAlias()) - .script(new Script(methodField.getParams().get(1).value.toString())); - extendGroupMap(methodField, new KVValue("KEY", termsBuilder)); - return termsBuilder; - } - - - if (field instanceof MethodField) { - - MethodField methodField = (MethodField) field; - if (methodField.getName().equals("filter")) { - Map paramsAsMap = methodField.getParamsAsMap(); - Where where = (Where) paramsAsMap.get("where"); - return AggregationBuilders.filter(paramsAsMap.get("alias").toString(), - QueryMaker.explain(where)); - } - return makeRangeGroup(methodField); - } else { - String termName = (Strings.isNullOrEmpty(field.getAlias())) ? field.getName() : field.getAlias(); - TermsAggregationBuilder termsBuilder = AggregationBuilders.terms(termName).field(field.getName()); - final KVValue kvValue = new KVValue("KEY", termsBuilder); - groupMap.put(termName, kvValue); - // map the field name with KVValue if it is not yet. The use case is when alias exist, - // the termName is different with fieldName, both of them should be included in the map. - groupMap.putIfAbsent(field.getName(), kvValue); - return termsBuilder; - } + /** The mapping bettwen group fieldName or Alias to the KVValue. */ + private Map groupMap = new HashMap<>(); + + private Where where; + + /** + * @param field + * @return + * @throws SqlParseException + */ + public AggregationBuilder makeGroupAgg(Field field) throws SqlParseException { + + if (field instanceof MethodField && field.getName().equals("script")) { + MethodField methodField = (MethodField) field; + TermsAggregationBuilder termsBuilder = + AggregationBuilders.terms(methodField.getAlias()) + .script(new Script(methodField.getParams().get(1).value.toString())); + extendGroupMap(methodField, new KVValue("KEY", termsBuilder)); + return termsBuilder; } - - /** - * Create aggregation according to the SQL function. - * - * @param field SQL function - * @param parent parentAggregation - * @return AggregationBuilder represents the SQL function - * @throws SqlParseException in case of unrecognized function - */ - public AggregationBuilder makeFieldAgg(MethodField field, AggregationBuilder parent) throws SqlParseException { - extendGroupMap(field, new KVValue("FIELD", parent)); - ValuesSourceAggregationBuilder builder; - field.setAlias(fixAlias(field.getAlias())); - switch (field.getName().toUpperCase()) { - case "SUM": - builder = AggregationBuilders.sum(field.getAlias()); - return addFieldToAgg(field, builder); - case "MAX": - builder = AggregationBuilders.max(field.getAlias()); - return addFieldToAgg(field, builder); - case "MIN": - builder = AggregationBuilders.min(field.getAlias()); - return addFieldToAgg(field, builder); - case "AVG": - builder = AggregationBuilders.avg(field.getAlias()); - return addFieldToAgg(field, builder); - case "STATS": - builder = AggregationBuilders.stats(field.getAlias()); - return addFieldToAgg(field, builder); - case "EXTENDED_STATS": - builder = AggregationBuilders.extendedStats(field.getAlias()); - return addFieldToAgg(field, builder); - case "PERCENTILES": - builder = AggregationBuilders.percentiles(field.getAlias()); - addSpecificPercentiles((PercentilesAggregationBuilder) builder, field.getParams()); - return addFieldToAgg(field, builder); - case "TOPHITS": - return makeTopHitsAgg(field); - case "SCRIPTED_METRIC": - return scriptedMetric(field); - case "COUNT": - extendGroupMap(field, new KVValue("COUNT", parent)); - return addFieldToAgg(field, makeCountAgg(field)); - default: - throw new SqlParseException("the agg function not to define !"); - } + if (field instanceof MethodField) { + + MethodField methodField = (MethodField) field; + if (methodField.getName().equals("filter")) { + Map paramsAsMap = methodField.getParamsAsMap(); + Where where = (Where) paramsAsMap.get("where"); + return AggregationBuilders.filter( + paramsAsMap.get("alias").toString(), QueryMaker.explain(where)); + } + return makeRangeGroup(methodField); + } else { + String termName = + (Strings.isNullOrEmpty(field.getAlias())) ? field.getName() : field.getAlias(); + TermsAggregationBuilder termsBuilder = + AggregationBuilders.terms(termName).field(field.getName()); + final KVValue kvValue = new KVValue("KEY", termsBuilder); + groupMap.put(termName, kvValue); + // map the field name with KVValue if it is not yet. The use case is when alias exist, + // the termName is different with fieldName, both of them should be included in the map. + groupMap.putIfAbsent(field.getName(), kvValue); + return termsBuilder; } - - /** - * With {@link Where} Condition. - */ - public AggMaker withWhere(Where where) { - this.where = where; - return this; + } + + /** + * Create aggregation according to the SQL function. + * + * @param field SQL function + * @param parent parentAggregation + * @return AggregationBuilder represents the SQL function + * @throws SqlParseException in case of unrecognized function + */ + public AggregationBuilder makeFieldAgg(MethodField field, AggregationBuilder parent) + throws SqlParseException { + extendGroupMap(field, new KVValue("FIELD", parent)); + ValuesSourceAggregationBuilder builder; + field.setAlias(fixAlias(field.getAlias())); + switch (field.getName().toUpperCase()) { + case "SUM": + builder = AggregationBuilders.sum(field.getAlias()); + return addFieldToAgg(field, builder); + case "MAX": + builder = AggregationBuilders.max(field.getAlias()); + return addFieldToAgg(field, builder); + case "MIN": + builder = AggregationBuilders.min(field.getAlias()); + return addFieldToAgg(field, builder); + case "AVG": + builder = AggregationBuilders.avg(field.getAlias()); + return addFieldToAgg(field, builder); + case "STATS": + builder = AggregationBuilders.stats(field.getAlias()); + return addFieldToAgg(field, builder); + case "EXTENDED_STATS": + builder = AggregationBuilders.extendedStats(field.getAlias()); + return addFieldToAgg(field, builder); + case "PERCENTILES": + builder = AggregationBuilders.percentiles(field.getAlias()); + addSpecificPercentiles((PercentilesAggregationBuilder) builder, field.getParams()); + return addFieldToAgg(field, builder); + case "TOPHITS": + return makeTopHitsAgg(field); + case "SCRIPTED_METRIC": + return scriptedMetric(field); + case "COUNT": + extendGroupMap(field, new KVValue("COUNT", parent)); + return addFieldToAgg(field, makeCountAgg(field)); + default: + throw new SqlParseException("the agg function not to define !"); } - - private void addSpecificPercentiles(PercentilesAggregationBuilder percentilesBuilder, List params) { - List percentiles = new ArrayList<>(); - for (KVValue kValue : params) { - if (kValue.value.getClass().equals(BigDecimal.class)) { - BigDecimal percentile = (BigDecimal) kValue.value; - percentiles.add(percentile.doubleValue()); - - } else if (kValue.value instanceof Integer) { - percentiles.add(((Integer) kValue.value).doubleValue()); - } - } - if (percentiles.size() > 0) { - double[] percentilesArr = new double[percentiles.size()]; - int i = 0; - for (Double percentile : percentiles) { - percentilesArr[i] = percentile; - i++; - } - percentilesBuilder.percentiles(percentilesArr); - } + } + + /** With {@link Where} Condition. */ + public AggMaker withWhere(Where where) { + this.where = where; + return this; + } + + private void addSpecificPercentiles( + PercentilesAggregationBuilder percentilesBuilder, List params) { + List percentiles = new ArrayList<>(); + for (KVValue kValue : params) { + if (kValue.value.getClass().equals(BigDecimal.class)) { + BigDecimal percentile = (BigDecimal) kValue.value; + percentiles.add(percentile.doubleValue()); + + } else if (kValue.value instanceof Integer) { + percentiles.add(((Integer) kValue.value).doubleValue()); + } } - - private String fixAlias(String alias) { - //because [ is not legal as alias - return alias.replaceAll("\\[", "(").replaceAll("\\]", ")"); + if (percentiles.size() > 0) { + double[] percentilesArr = new double[percentiles.size()]; + int i = 0; + for (Double percentile : percentiles) { + percentilesArr[i] = percentile; + i++; + } + percentilesBuilder.percentiles(percentilesArr); } + } + + private String fixAlias(String alias) { + // because [ is not legal as alias + return alias.replaceAll("\\[", "(").replaceAll("\\]", ")"); + } + + private AggregationBuilder addFieldToAgg( + MethodField field, ValuesSourceAggregationBuilder builder) throws SqlParseException { + KVValue kvValue = field.getParams().get(0); + if (kvValue.key != null && kvValue.key.equals("script")) { + if (kvValue.value instanceof MethodField) { + return builder.script( + new Script(((MethodField) kvValue.value).getParams().get(1).toString())); + } else { + return builder.script(new Script(kvValue.value.toString())); + } + + } else if (kvValue.key != null && kvValue.value.toString().trim().startsWith("def")) { + return builder.script(new Script(kvValue.value.toString())); + } else if (kvValue.key != null + && (kvValue.key.equals("nested") || kvValue.key.equals("reverse_nested"))) { + NestedType nestedType = (NestedType) kvValue.value; + nestedType.addBucketPath(Path.getMetricPath(builder.getName())); + + if (nestedType.isNestedField()) { + builder.field("_index"); + } else { + builder.field(nestedType.field); + } + + AggregationBuilder nestedBuilder; + + String nestedAggName = nestedType.getNestedAggName(); + + if (nestedType.isReverse()) { + if (nestedType.path != null && nestedType.path.startsWith("~")) { + String realPath = nestedType.path.substring(1); + nestedBuilder = AggregationBuilders.nested(nestedAggName, realPath); + nestedBuilder = nestedBuilder.subAggregation(builder); + return AggregationBuilders.reverseNested(nestedAggName + "_REVERSED") + .subAggregation(nestedBuilder); + } else { + ReverseNestedAggregationBuilder reverseNestedAggregationBuilder = + AggregationBuilders.reverseNested(nestedAggName); + if (nestedType.path != null) { + reverseNestedAggregationBuilder.path(nestedType.path); + } + nestedBuilder = reverseNestedAggregationBuilder; + } + } else { + nestedBuilder = AggregationBuilders.nested(nestedAggName, nestedType.path); + } - private AggregationBuilder addFieldToAgg(MethodField field, ValuesSourceAggregationBuilder builder) - throws SqlParseException { - KVValue kvValue = field.getParams().get(0); - if (kvValue.key != null && kvValue.key.equals("script")) { - if (kvValue.value instanceof MethodField) { - return builder.script(new Script(((MethodField) kvValue.value).getParams().get(1).toString())); - } else { - return builder.script(new Script(kvValue.value.toString())); - } - - } else if (kvValue.key != null && kvValue.value.toString().trim().startsWith("def")) { - return builder.script(new Script(kvValue.value.toString())); - } else if (kvValue.key != null && (kvValue.key.equals("nested") || kvValue.key.equals("reverse_nested"))) { - NestedType nestedType = (NestedType) kvValue.value; - nestedType.addBucketPath(Path.getMetricPath(builder.getName())); - - if (nestedType.isNestedField()) { - builder.field("_index"); - } else { - builder.field(nestedType.field); - } - - AggregationBuilder nestedBuilder; - - String nestedAggName = nestedType.getNestedAggName(); - - if (nestedType.isReverse()) { - if (nestedType.path != null && nestedType.path.startsWith("~")) { - String realPath = nestedType.path.substring(1); - nestedBuilder = AggregationBuilders.nested(nestedAggName, realPath); - nestedBuilder = nestedBuilder.subAggregation(builder); - return AggregationBuilders.reverseNested(nestedAggName + "_REVERSED") - .subAggregation(nestedBuilder); - } else { - ReverseNestedAggregationBuilder reverseNestedAggregationBuilder = - AggregationBuilders.reverseNested(nestedAggName); - if (nestedType.path != null) { - reverseNestedAggregationBuilder.path(nestedType.path); - } - nestedBuilder = reverseNestedAggregationBuilder; - } - } else { - nestedBuilder = AggregationBuilders.nested(nestedAggName, nestedType.path); - } - - AggregationBuilder aggregation = nestedBuilder.subAggregation(wrapWithFilterAgg( - nestedType, - builder)); - nestedType.addBucketPath(Path.getAggPath(nestedBuilder.getName())); - return aggregation; - } else if (kvValue.key != null && (kvValue.key.equals("children"))) { - ChildrenType childrenType = (ChildrenType) kvValue.value; - - builder.field(childrenType.field); + AggregationBuilder aggregation = + nestedBuilder.subAggregation(wrapWithFilterAgg(nestedType, builder)); + nestedType.addBucketPath(Path.getAggPath(nestedBuilder.getName())); + return aggregation; + } else if (kvValue.key != null && (kvValue.key.equals("children"))) { + ChildrenType childrenType = (ChildrenType) kvValue.value; - AggregationBuilder childrenBuilder; + builder.field(childrenType.field); - String childrenAggName = childrenType.field + "@CHILDREN"; + AggregationBuilder childrenBuilder; - childrenBuilder = JoinAggregationBuilders.children(childrenAggName, childrenType.childType); + String childrenAggName = childrenType.field + "@CHILDREN"; - return childrenBuilder; - } + childrenBuilder = JoinAggregationBuilders.children(childrenAggName, childrenType.childType); - return builder.field(kvValue.toString()); + return childrenBuilder; } - private AggregationBuilder makeRangeGroup(MethodField field) throws SqlParseException { - switch (field.getName().toLowerCase()) { - case "range": - return rangeBuilder(field); - case "date_histogram": - return dateHistogram(field); - case "date_range": - case "month": - return dateRange(field); - case "histogram": - return histogram(field); - case "geohash_grid": - return geohashGrid(field); - case "geo_bounds": - return geoBounds(field); - case "terms": - return termsAgg(field); - default: - throw new SqlParseException("can define this method " + field); - } - + return builder.field(kvValue.toString()); + } + + private AggregationBuilder makeRangeGroup(MethodField field) throws SqlParseException { + switch (field.getName().toLowerCase()) { + case "range": + return rangeBuilder(field); + case "date_histogram": + return dateHistogram(field); + case "date_range": + case "month": + return dateRange(field); + case "histogram": + return histogram(field); + case "geohash_grid": + return geohashGrid(field); + case "geo_bounds": + return geoBounds(field); + case "terms": + return termsAgg(field); + default: + throw new SqlParseException("can define this method " + field); } - - private AggregationBuilder geoBounds(MethodField field) throws SqlParseException { - String aggName = gettAggNameFromParamsOrAlias(field); - GeoBoundsAggregationBuilder boundsBuilder = new GeoBoundsAggregationBuilder(aggName); - String value; - for (KVValue kv : field.getParams()) { - value = kv.value.toString(); - switch (kv.key.toLowerCase()) { - case "field": - boundsBuilder.field(value); - break; - case "wrap_longitude": - boundsBuilder.wrapLongitude(Boolean.getBoolean(value)); - break; - case "alias": - case "nested": - case "reverse_nested": - case "children": - break; - default: - throw new SqlParseException("geo_bounds err or not define field " + kv.toString()); - } - } - return boundsBuilder; + } + + private AggregationBuilder geoBounds(MethodField field) throws SqlParseException { + String aggName = gettAggNameFromParamsOrAlias(field); + GeoBoundsAggregationBuilder boundsBuilder = new GeoBoundsAggregationBuilder(aggName); + String value; + for (KVValue kv : field.getParams()) { + value = kv.value.toString(); + switch (kv.key.toLowerCase()) { + case "field": + boundsBuilder.field(value); + break; + case "wrap_longitude": + boundsBuilder.wrapLongitude(Boolean.getBoolean(value)); + break; + case "alias": + case "nested": + case "reverse_nested": + case "children": + break; + default: + throw new SqlParseException("geo_bounds err or not define field " + kv.toString()); + } } - - private AggregationBuilder termsAgg(MethodField field) throws SqlParseException { - String aggName = gettAggNameFromParamsOrAlias(field); - TermsAggregationBuilder terms = AggregationBuilders.terms(aggName); - String value; - IncludeExclude include = null, exclude = null; - for (KVValue kv : field.getParams()) { - if (kv.value.toString().contains("doc[")) { - String script = kv.value + "; return " + kv.key; - terms.script(new Script(script)); + return boundsBuilder; + } + + private AggregationBuilder termsAgg(MethodField field) throws SqlParseException { + String aggName = gettAggNameFromParamsOrAlias(field); + TermsAggregationBuilder terms = AggregationBuilders.terms(aggName); + String value; + IncludeExclude include = null, exclude = null; + for (KVValue kv : field.getParams()) { + if (kv.value.toString().contains("doc[")) { + String script = kv.value + "; return " + kv.key; + terms.script(new Script(script)); + } else { + value = kv.value.toString(); + switch (kv.key.toLowerCase()) { + case "field": + terms.field(value); + break; + case "size": + terms.size(Integer.parseInt(value)); + break; + case "shard_size": + terms.shardSize(Integer.parseInt(value)); + break; + case "min_doc_count": + terms.minDocCount(Integer.parseInt(value)); + break; + case "missing": + terms.missing(value); + break; + case "order": + if ("asc".equalsIgnoreCase(value)) { + terms.order(BucketOrder.key(true)); + } else if ("desc".equalsIgnoreCase(value)) { + terms.order(BucketOrder.key(false)); } else { - value = kv.value.toString(); - switch (kv.key.toLowerCase()) { - case "field": - terms.field(value); - break; - case "size": - terms.size(Integer.parseInt(value)); - break; - case "shard_size": - terms.shardSize(Integer.parseInt(value)); - break; - case "min_doc_count": - terms.minDocCount(Integer.parseInt(value)); - break; - case "missing": - terms.missing(value); - break; - case "order": - if ("asc".equalsIgnoreCase(value)) { - terms.order(BucketOrder.key(true)); - } else if ("desc".equalsIgnoreCase(value)) { - terms.order(BucketOrder.key(false)); - } else { - List orderElements = new ArrayList<>(); - try (JsonXContentParser parser = new JsonXContentParser(NamedXContentRegistry.EMPTY, - LoggingDeprecationHandler.INSTANCE, new JsonFactory().createParser(value))) { - XContentParser.Token currentToken = parser.nextToken(); - if (currentToken == XContentParser.Token.START_OBJECT) { - orderElements.add(InternalOrder.Parser.parseOrderParam(parser)); - } else if (currentToken == XContentParser.Token.START_ARRAY) { - for (currentToken = parser.nextToken(); - currentToken != XContentParser.Token.END_ARRAY; - currentToken = parser.nextToken()) { - if (currentToken == XContentParser.Token.START_OBJECT) { - orderElements.add(InternalOrder.Parser.parseOrderParam(parser)); - } else { - throw new ParsingException(parser.getTokenLocation(), - "Invalid token in order array"); - } - } - } - } catch (IOException e) { - throw new SqlParseException("couldn't parse order: " + e.getMessage()); - } - terms.order(orderElements); - } - break; - case "alias": - case "nested": - case "reverse_nested": - case "children": - break; - case "execution_hint": - terms.executionHint(value); - break; - case "include": - try (XContentParser parser = JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY, - LoggingDeprecationHandler.INSTANCE, value)) { - parser.nextToken(); - include = IncludeExclude.parseInclude(parser); - } catch (IOException e) { - throw new SqlParseException("parse include[" + value + "] error: " + e.getMessage()); - } - break; - case "exclude": - try (XContentParser parser = JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY, - LoggingDeprecationHandler.INSTANCE, value)) { - parser.nextToken(); - exclude = IncludeExclude.parseExclude(parser); - } catch (IOException e) { - throw new SqlParseException("parse exclude[" + value + "] error: " + e.getMessage()); - } - break; - default: - throw new SqlParseException("terms aggregation err or not define field " + kv.toString()); + List orderElements = new ArrayList<>(); + try (JsonXContentParser parser = + new JsonXContentParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + new JsonFactory().createParser(value))) { + XContentParser.Token currentToken = parser.nextToken(); + if (currentToken == XContentParser.Token.START_OBJECT) { + orderElements.add(InternalOrder.Parser.parseOrderParam(parser)); + } else if (currentToken == XContentParser.Token.START_ARRAY) { + for (currentToken = parser.nextToken(); + currentToken != XContentParser.Token.END_ARRAY; + currentToken = parser.nextToken()) { + if (currentToken == XContentParser.Token.START_OBJECT) { + orderElements.add(InternalOrder.Parser.parseOrderParam(parser)); + } else { + throw new ParsingException( + parser.getTokenLocation(), "Invalid token in order array"); + } + } } + } catch (IOException e) { + throw new SqlParseException("couldn't parse order: " + e.getMessage()); + } + terms.order(orderElements); } - } - terms.includeExclude(IncludeExclude.merge(include, exclude)); - return terms; - } - - private AbstractAggregationBuilder scriptedMetric(MethodField field) throws SqlParseException { - String aggName = gettAggNameFromParamsOrAlias(field); - ScriptedMetricAggregationBuilder scriptedMetricBuilder = AggregationBuilders.scriptedMetric(aggName); - Map scriptedMetricParams = field.getParamsAsMap(); - if (!scriptedMetricParams.containsKey("map_script") && !scriptedMetricParams.containsKey("map_script_id") - && !scriptedMetricParams.containsKey("map_script_file")) { - throw new SqlParseException( - "scripted metric parameters must contain map_script/map_script_id/map_script_file parameter"); - } - HashMap scriptAdditionalParams = new HashMap<>(); - HashMap reduceScriptAdditionalParams = new HashMap<>(); - for (Map.Entry param : scriptedMetricParams.entrySet()) { - String paramValue = param.getValue().toString(); - if (param.getKey().startsWith("@")) { - if (param.getKey().startsWith("@reduce_")) { - reduceScriptAdditionalParams.put(param.getKey().replace("@reduce_", ""), - param.getValue()); - } else { - scriptAdditionalParams.put(param.getKey().replace("@", ""), param.getValue()); - } - continue; + break; + case "alias": + case "nested": + case "reverse_nested": + case "children": + break; + case "execution_hint": + terms.executionHint(value); + break; + case "include": + try (XContentParser parser = + JsonXContent.jsonXContent.createParser( + NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, value)) { + parser.nextToken(); + include = IncludeExclude.parseInclude(parser); + } catch (IOException e) { + throw new SqlParseException("parse include[" + value + "] error: " + e.getMessage()); } - - switch (param.getKey().toLowerCase()) { - case "map_script": - scriptedMetricBuilder.mapScript(new Script(paramValue)); - break; - case "map_script_id": - scriptedMetricBuilder.mapScript(new Script(ScriptType.STORED, Script.DEFAULT_SCRIPT_LANG, - paramValue, new HashMap<>())); - break; - case "init_script": - scriptedMetricBuilder.initScript(new Script(paramValue)); - break; - case "init_script_id": - scriptedMetricBuilder.initScript(new Script(ScriptType.STORED, Script.DEFAULT_SCRIPT_LANG, - paramValue, new HashMap<>())); - break; - case "combine_script": - scriptedMetricBuilder.combineScript(new Script(paramValue)); - break; - case "combine_script_id": - scriptedMetricBuilder.combineScript(new Script(ScriptType.STORED, Script.DEFAULT_SCRIPT_LANG, - paramValue, new HashMap<>())); - break; - case "reduce_script": - scriptedMetricBuilder.reduceScript(new Script(ScriptType.INLINE, Script.DEFAULT_SCRIPT_LANG, - paramValue, reduceScriptAdditionalParams)); - break; - case "reduce_script_id": - scriptedMetricBuilder.reduceScript(new Script(ScriptType.STORED, Script.DEFAULT_SCRIPT_LANG, - paramValue, reduceScriptAdditionalParams)); - break; - case "alias": - case "nested": - case "reverse_nested": - case "children": - break; - default: - throw new SqlParseException("scripted_metric err or not define field " + param.getKey()); + break; + case "exclude": + try (XContentParser parser = + JsonXContent.jsonXContent.createParser( + NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, value)) { + parser.nextToken(); + exclude = IncludeExclude.parseExclude(parser); + } catch (IOException e) { + throw new SqlParseException("parse exclude[" + value + "] error: " + e.getMessage()); } + break; + default: + throw new SqlParseException( + "terms aggregation err or not define field " + kv.toString()); } - if (scriptAdditionalParams.size() > 0) { - scriptAdditionalParams.put("_agg", new HashMap<>()); - scriptedMetricBuilder.params(scriptAdditionalParams); - } - - return scriptedMetricBuilder; + } } - - private AggregationBuilder geohashGrid(MethodField field) throws SqlParseException { - String aggName = gettAggNameFromParamsOrAlias(field); - GeoGridAggregationBuilder geoHashGrid = new GeoHashGridAggregationBuilder(aggName); - String value; - for (KVValue kv : field.getParams()) { - value = kv.value.toString(); - switch (kv.key.toLowerCase()) { - case "precision": - geoHashGrid.precision(Integer.parseInt(value)); - break; - case "field": - geoHashGrid.field(value); - break; - case "size": - geoHashGrid.size(Integer.parseInt(value)); - break; - case "shard_size": - geoHashGrid.shardSize(Integer.parseInt(value)); - break; - case "alias": - case "nested": - case "reverse_nested": - case "children": - break; - default: - throw new SqlParseException("geohash grid err or not define field " + kv.toString()); - } - } - return geoHashGrid; + terms.includeExclude(IncludeExclude.merge(include, exclude)); + return terms; + } + + private AbstractAggregationBuilder scriptedMetric(MethodField field) throws SqlParseException { + String aggName = gettAggNameFromParamsOrAlias(field); + ScriptedMetricAggregationBuilder scriptedMetricBuilder = + AggregationBuilders.scriptedMetric(aggName); + Map scriptedMetricParams = field.getParamsAsMap(); + if (!scriptedMetricParams.containsKey("map_script") + && !scriptedMetricParams.containsKey("map_script_id") + && !scriptedMetricParams.containsKey("map_script_file")) { + throw new SqlParseException( + "scripted metric parameters must contain map_script/map_script_id/map_script_file" + + " parameter"); } - - private static final String TIME_FARMAT = "yyyy-MM-dd HH:mm:ss"; - - private ValuesSourceAggregationBuilder dateRange(MethodField field) { - String alias = gettAggNameFromParamsOrAlias(field); - DateRangeAggregationBuilder dateRange = AggregationBuilders.dateRange(alias).format(TIME_FARMAT); - - String value; - List ranges = new ArrayList<>(); - for (KVValue kv : field.getParams()) { - value = kv.value.toString(); - if ("field".equals(kv.key)) { - dateRange.field(value); - } else if ("format".equals(kv.key)) { - dateRange.format(value); - } else if ("time_zone".equals(kv.key)) { - dateRange.timeZone(ZoneOffset.of(value)); - } else if ("from".equals(kv.key)) { - dateRange.addUnboundedFrom(kv.value.toString()); - } else if ("to".equals(kv.key)) { - dateRange.addUnboundedTo(kv.value.toString()); - } else if (!"alias".equals(kv.key) && !"nested".equals(kv.key) && !"children".equals(kv.key)) { - ranges.add(value); - } - } - - for (int i = 1; i < ranges.size(); i++) { - dateRange.addRange(ranges.get(i - 1), ranges.get(i)); + HashMap scriptAdditionalParams = new HashMap<>(); + HashMap reduceScriptAdditionalParams = new HashMap<>(); + for (Map.Entry param : scriptedMetricParams.entrySet()) { + String paramValue = param.getValue().toString(); + if (param.getKey().startsWith("@")) { + if (param.getKey().startsWith("@reduce_")) { + reduceScriptAdditionalParams.put( + param.getKey().replace("@reduce_", ""), param.getValue()); + } else { + scriptAdditionalParams.put(param.getKey().replace("@", ""), param.getValue()); } + continue; + } + + switch (param.getKey().toLowerCase()) { + case "map_script": + scriptedMetricBuilder.mapScript(new Script(paramValue)); + break; + case "map_script_id": + scriptedMetricBuilder.mapScript( + new Script( + ScriptType.STORED, Script.DEFAULT_SCRIPT_LANG, paramValue, new HashMap<>())); + break; + case "init_script": + scriptedMetricBuilder.initScript(new Script(paramValue)); + break; + case "init_script_id": + scriptedMetricBuilder.initScript( + new Script( + ScriptType.STORED, Script.DEFAULT_SCRIPT_LANG, paramValue, new HashMap<>())); + break; + case "combine_script": + scriptedMetricBuilder.combineScript(new Script(paramValue)); + break; + case "combine_script_id": + scriptedMetricBuilder.combineScript( + new Script( + ScriptType.STORED, Script.DEFAULT_SCRIPT_LANG, paramValue, new HashMap<>())); + break; + case "reduce_script": + scriptedMetricBuilder.reduceScript( + new Script( + ScriptType.INLINE, + Script.DEFAULT_SCRIPT_LANG, + paramValue, + reduceScriptAdditionalParams)); + break; + case "reduce_script_id": + scriptedMetricBuilder.reduceScript( + new Script( + ScriptType.STORED, + Script.DEFAULT_SCRIPT_LANG, + paramValue, + reduceScriptAdditionalParams)); + break; + case "alias": + case "nested": + case "reverse_nested": + case "children": + break; + default: + throw new SqlParseException("scripted_metric err or not define field " + param.getKey()); + } + } + if (scriptAdditionalParams.size() > 0) { + scriptAdditionalParams.put("_agg", new HashMap<>()); + scriptedMetricBuilder.params(scriptAdditionalParams); + } - return dateRange; + return scriptedMetricBuilder; + } + + private AggregationBuilder geohashGrid(MethodField field) throws SqlParseException { + String aggName = gettAggNameFromParamsOrAlias(field); + GeoGridAggregationBuilder geoHashGrid = new GeoHashGridAggregationBuilder(aggName); + String value; + for (KVValue kv : field.getParams()) { + value = kv.value.toString(); + switch (kv.key.toLowerCase()) { + case "precision": + geoHashGrid.precision(Integer.parseInt(value)); + break; + case "field": + geoHashGrid.field(value); + break; + case "size": + geoHashGrid.size(Integer.parseInt(value)); + break; + case "shard_size": + geoHashGrid.shardSize(Integer.parseInt(value)); + break; + case "alias": + case "nested": + case "reverse_nested": + case "children": + break; + default: + throw new SqlParseException("geohash grid err or not define field " + kv.toString()); + } + } + return geoHashGrid; + } + + private static final String TIME_FARMAT = "yyyy-MM-dd HH:mm:ss"; + + private ValuesSourceAggregationBuilder dateRange(MethodField field) { + String alias = gettAggNameFromParamsOrAlias(field); + DateRangeAggregationBuilder dateRange = + AggregationBuilders.dateRange(alias).format(TIME_FARMAT); + + String value; + List ranges = new ArrayList<>(); + for (KVValue kv : field.getParams()) { + value = kv.value.toString(); + if ("field".equals(kv.key)) { + dateRange.field(value); + } else if ("format".equals(kv.key)) { + dateRange.format(value); + } else if ("time_zone".equals(kv.key)) { + dateRange.timeZone(ZoneOffset.of(value)); + } else if ("from".equals(kv.key)) { + dateRange.addUnboundedFrom(kv.value.toString()); + } else if ("to".equals(kv.key)) { + dateRange.addUnboundedTo(kv.value.toString()); + } else if (!"alias".equals(kv.key) + && !"nested".equals(kv.key) + && !"children".equals(kv.key)) { + ranges.add(value); + } } - /** - * - * - * @param field - * @return - * @throws SqlParseException - */ - private DateHistogramAggregationBuilder dateHistogram(MethodField field) throws SqlParseException { - String alias = gettAggNameFromParamsOrAlias(field); - DateHistogramAggregationBuilder dateHistogram = AggregationBuilders.dateHistogram(alias).format(TIME_FARMAT); - String value; - for (KVValue kv : field.getParams()) { - if (kv.value.toString().contains("doc[")) { - String script = kv.value + "; return " + kv.key; - dateHistogram.script(new Script(script)); - } else { - value = kv.value.toString(); - switch (kv.key.toLowerCase()) { - case "interval": - dateHistogram.dateHistogramInterval(new DateHistogramInterval(kv.value.toString())); - break; - case "fixed_interval": - dateHistogram.fixedInterval(new DateHistogramInterval(kv.value.toString())); - break; - case "field": - dateHistogram.field(value); - break; - case "format": - dateHistogram.format(value); - break; - case "time_zone": - dateHistogram.timeZone(ZoneOffset.of(value)); - break; - case "min_doc_count": - dateHistogram.minDocCount(Long.parseLong(value)); - break; - case "order": - dateHistogram.order("desc".equalsIgnoreCase(value) ? BucketOrder.key(false) : - BucketOrder.key(true)); - break; - case "extended_bounds": - String[] bounds = value.split(":"); - if (bounds.length == 2) { - dateHistogram.extendedBounds(new LongBounds(bounds[0], bounds[1])); - } - break; - - case "alias": - case "nested": - case "reverse_nested": - case "children": - break; - default: - throw new SqlParseException("date range err or not define field " + kv.toString()); - } - } - } - return dateHistogram; + for (int i = 1; i < ranges.size(); i++) { + dateRange.addRange(ranges.get(i - 1), ranges.get(i)); } - private String gettAggNameFromParamsOrAlias(MethodField field) { - String alias = field.getAlias(); - for (KVValue kv : field.getParams()) { - if (kv.key != null && kv.key.equals("alias")) { - alias = kv.value.toString(); + return dateRange; + } + + /** + * @param field + * @return + * @throws SqlParseException + */ + private DateHistogramAggregationBuilder dateHistogram(MethodField field) + throws SqlParseException { + String alias = gettAggNameFromParamsOrAlias(field); + DateHistogramAggregationBuilder dateHistogram = + AggregationBuilders.dateHistogram(alias).format(TIME_FARMAT); + String value; + for (KVValue kv : field.getParams()) { + if (kv.value.toString().contains("doc[")) { + String script = kv.value + "; return " + kv.key; + dateHistogram.script(new Script(script)); + } else { + value = kv.value.toString(); + switch (kv.key.toLowerCase()) { + case "interval": + dateHistogram.dateHistogramInterval(new DateHistogramInterval(kv.value.toString())); + break; + case "fixed_interval": + dateHistogram.fixedInterval(new DateHistogramInterval(kv.value.toString())); + break; + case "field": + dateHistogram.field(value); + break; + case "format": + dateHistogram.format(value); + break; + case "time_zone": + dateHistogram.timeZone(ZoneOffset.of(value)); + break; + case "min_doc_count": + dateHistogram.minDocCount(Long.parseLong(value)); + break; + case "order": + dateHistogram.order( + "desc".equalsIgnoreCase(value) ? BucketOrder.key(false) : BucketOrder.key(true)); + break; + case "extended_bounds": + String[] bounds = value.split(":"); + if (bounds.length == 2) { + dateHistogram.extendedBounds(new LongBounds(bounds[0], bounds[1])); } + break; + + case "alias": + case "nested": + case "reverse_nested": + case "children": + break; + default: + throw new SqlParseException("date range err or not define field " + kv.toString()); } - return alias; + } } - - private HistogramAggregationBuilder histogram(MethodField field) throws SqlParseException { - String aggName = gettAggNameFromParamsOrAlias(field); - HistogramAggregationBuilder histogram = AggregationBuilders.histogram(aggName); - String value; - for (KVValue kv : field.getParams()) { - if (kv.value.toString().contains("doc[")) { - String script = kv.value + "; return " + kv.key; - histogram.script(new Script(script)); - } else { - value = kv.value.toString(); - switch (kv.key.toLowerCase()) { - case "interval": - histogram.interval(Long.parseLong(value)); - break; - case "field": - histogram.field(value); - break; - case "min_doc_count": - histogram.minDocCount(Long.parseLong(value)); - break; - case "extended_bounds": - String[] bounds = value.split(":"); - if (bounds.length == 2) { - histogram.extendedBounds(Long.valueOf(bounds[0]), Long.valueOf(bounds[1])); - } - break; - case "alias": - case "nested": - case "reverse_nested": - case "children": - break; - case "order": - final BucketOrder order; - switch (value) { - case "key_desc": - order = BucketOrder.key(false); - break; - case "count_asc": - order = BucketOrder.count(true); - break; - case "count_desc": - order = BucketOrder.count(false); - break; - case "key_asc": - default: - order = BucketOrder.key(true); - break; - } - histogram.order(order); - break; - default: - throw new SqlParseException("histogram err or not define field " + kv.toString()); - } + return dateHistogram; + } + + private String gettAggNameFromParamsOrAlias(MethodField field) { + String alias = field.getAlias(); + for (KVValue kv : field.getParams()) { + if (kv.key != null && kv.key.equals("alias")) { + alias = kv.value.toString(); + } + } + return alias; + } + + private HistogramAggregationBuilder histogram(MethodField field) throws SqlParseException { + String aggName = gettAggNameFromParamsOrAlias(field); + HistogramAggregationBuilder histogram = AggregationBuilders.histogram(aggName); + String value; + for (KVValue kv : field.getParams()) { + if (kv.value.toString().contains("doc[")) { + String script = kv.value + "; return " + kv.key; + histogram.script(new Script(script)); + } else { + value = kv.value.toString(); + switch (kv.key.toLowerCase()) { + case "interval": + histogram.interval(Long.parseLong(value)); + break; + case "field": + histogram.field(value); + break; + case "min_doc_count": + histogram.minDocCount(Long.parseLong(value)); + break; + case "extended_bounds": + String[] bounds = value.split(":"); + if (bounds.length == 2) { + histogram.extendedBounds(Long.valueOf(bounds[0]), Long.valueOf(bounds[1])); + } + break; + case "alias": + case "nested": + case "reverse_nested": + case "children": + break; + case "order": + final BucketOrder order; + switch (value) { + case "key_desc": + order = BucketOrder.key(false); + break; + case "count_asc": + order = BucketOrder.count(true); + break; + case "count_desc": + order = BucketOrder.count(false); + break; + case "key_asc": + default: + order = BucketOrder.key(true); + break; } + histogram.order(order); + break; + default: + throw new SqlParseException("histogram err or not define field " + kv.toString()); } - return histogram; + } } + return histogram; + } - /** - * - * - * @param field - * @return - */ - private RangeAggregationBuilder rangeBuilder(MethodField field) { + /** + * @param field + * @return + */ + private RangeAggregationBuilder rangeBuilder(MethodField field) { - // ignore alias param - LinkedList params = field.getParams().stream().filter(kv -> !"alias".equals(kv.key)) - .collect(Collectors.toCollection(LinkedList::new)); + // ignore alias param + LinkedList params = + field.getParams().stream() + .filter(kv -> !"alias".equals(kv.key)) + .collect(Collectors.toCollection(LinkedList::new)); - String fieldName = params.poll().toString(); + String fieldName = params.poll().toString(); - double[] ds = Util.KV2DoubleArr(params); + double[] ds = Util.KV2DoubleArr(params); - RangeAggregationBuilder range = AggregationBuilders.range(field.getAlias()).field(fieldName); + RangeAggregationBuilder range = AggregationBuilders.range(field.getAlias()).field(fieldName); - for (int i = 1; i < ds.length; i++) { - range.addRange(ds[i - 1], ds[i]); - } - - return range; + for (int i = 1; i < ds.length; i++) { + range.addRange(ds[i - 1], ds[i]); } + return range; + } + + /** + * Create count aggregation. + * + * @param field The count function + * @return AggregationBuilder use to count result + */ + private ValuesSourceAggregationBuilder makeCountAgg(MethodField field) { + + // Cardinality is approximate DISTINCT. + if (SQLAggregateOption.DISTINCT.equals(field.getOption())) { + + if (field.getParams().size() == 1) { + return AggregationBuilders.cardinality(field.getAlias()) + .field(field.getParams().get(0).value.toString()); + } else { + Integer precision_threshold = (Integer) (field.getParams().get(1).value); + return AggregationBuilders.cardinality(field.getAlias()) + .precisionThreshold(precision_threshold) + .field(field.getParams().get(0).value.toString()); + } + } - /** - * Create count aggregation. - * - * @param field The count function - * @return AggregationBuilder use to count result - */ - private ValuesSourceAggregationBuilder makeCountAgg(MethodField field) { - - // Cardinality is approximate DISTINCT. - if (SQLAggregateOption.DISTINCT.equals(field.getOption())) { - - if (field.getParams().size() == 1) { - return AggregationBuilders.cardinality(field.getAlias()).field(field.getParams().get(0).value - .toString()); - } else { - Integer precision_threshold = (Integer) (field.getParams().get(1).value); - return AggregationBuilders.cardinality(field.getAlias()).precisionThreshold(precision_threshold) - .field(field.getParams().get(0).value.toString()); - } - - } - - String fieldName = field.getParams().get(0).value.toString(); + String fieldName = field.getParams().get(0).value.toString(); - // In case of count(*) we use '_index' as field parameter to count all documents - if ("*".equals(fieldName)) { - KVValue kvValue = new KVValue(null, "_index"); - field.getParams().set(0, kvValue); - return AggregationBuilders.count(field.getAlias()).field(kvValue.toString()); - } else { - return AggregationBuilders.count(field.getAlias()).field(fieldName); - } + // In case of count(*) we use '_index' as field parameter to count all documents + if ("*".equals(fieldName)) { + KVValue kvValue = new KVValue(null, "_index"); + field.getParams().set(0, kvValue); + return AggregationBuilders.count(field.getAlias()).field(kvValue.toString()); + } else { + return AggregationBuilders.count(field.getAlias()).field(fieldName); } - - /** - * TOPHITS - * - * @param field - * @return - */ - private AbstractAggregationBuilder makeTopHitsAgg(MethodField field) { - String alias = gettAggNameFromParamsOrAlias(field); - TopHitsAggregationBuilder topHits = AggregationBuilders.topHits(alias); - List params = field.getParams(); - String[] include = null; - String[] exclude = null; - for (KVValue kv : params) { - switch (kv.key) { - case "from": - topHits.from((int) kv.value); - break; - case "size": - topHits.size((int) kv.value); - break; - case "include": - include = kv.value.toString().split(","); - break; - case "exclude": - exclude = kv.value.toString().split(","); - break; - case "alias": - case "nested": - case "reverse_nested": - case "children": - break; - default: - topHits.sort(kv.key, SortOrder.valueOf(kv.value.toString().toUpperCase())); - break; - } - } - if (include != null || exclude != null) { - topHits.fetchSource(include, exclude); - } - return topHits; + } + + /** + * TOPHITS + * + * @param field + * @return + */ + private AbstractAggregationBuilder makeTopHitsAgg(MethodField field) { + String alias = gettAggNameFromParamsOrAlias(field); + TopHitsAggregationBuilder topHits = AggregationBuilders.topHits(alias); + List params = field.getParams(); + String[] include = null; + String[] exclude = null; + for (KVValue kv : params) { + switch (kv.key) { + case "from": + topHits.from((int) kv.value); + break; + case "size": + topHits.size((int) kv.value); + break; + case "include": + include = kv.value.toString().split(","); + break; + case "exclude": + exclude = kv.value.toString().split(","); + break; + case "alias": + case "nested": + case "reverse_nested": + case "children": + break; + default: + topHits.sort(kv.key, SortOrder.valueOf(kv.value.toString().toUpperCase())); + break; + } } - - public Map getGroupMap() { - return this.groupMap; + if (include != null || exclude != null) { + topHits.fetchSource(include, exclude); } - - /** - * Wrap the Metric Aggregation with Filter Aggregation if necessary. - * The Filter Aggregation condition is constructed from the nested condition in where clause. - */ - private AggregationBuilder wrapWithFilterAgg(NestedType nestedType, ValuesSourceAggregationBuilder builder) - throws SqlParseException { - if (where != null && where.getWheres() != null) { - List nestedConditionList = where.getWheres().stream() - .filter(condition -> condition instanceof Condition) - .map(condition -> (Condition) condition) - .filter(condition -> condition.isNestedComplex() - || nestedType.path.equalsIgnoreCase(condition.getNestedPath())) - // ignore the OR condition on nested field. - .filter(condition -> CONN.AND.equals(condition.getConn())) - .collect(Collectors.toList()); - if (!nestedConditionList.isEmpty()) { - Where filterWhere = new Where(where.getConn()); - nestedConditionList.forEach(condition -> { - if (condition.isNestedComplex()) { - ((Where) condition.getValue()).getWheres().forEach(filterWhere::addWhere); - } else { - // Since the filter condition is used inside Nested Aggregation,remove the nested attribute. - condition.setNested(false); - condition.setNestedPath(""); - filterWhere.addWhere(condition); - } - }); - FilterAggregationBuilder filterAgg = AggregationBuilders.filter( - nestedType.getFilterAggName(), - QueryMaker.explain(filterWhere)); - nestedType.addBucketPath(Path.getAggPath(filterAgg.getName())); - return filterAgg.subAggregation(builder); - } - } - return builder; + return topHits; + } + + public Map getGroupMap() { + return this.groupMap; + } + + /** + * Wrap the Metric Aggregation with Filter Aggregation if necessary. The Filter Aggregation + * condition is constructed from the nested condition in where clause. + */ + private AggregationBuilder wrapWithFilterAgg( + NestedType nestedType, ValuesSourceAggregationBuilder builder) throws SqlParseException { + if (where != null && where.getWheres() != null) { + List nestedConditionList = + where.getWheres().stream() + .filter(condition -> condition instanceof Condition) + .map(condition -> (Condition) condition) + .filter( + condition -> + condition.isNestedComplex() + || nestedType.path.equalsIgnoreCase(condition.getNestedPath())) + // ignore the OR condition on nested field. + .filter(condition -> CONN.AND.equals(condition.getConn())) + .collect(Collectors.toList()); + if (!nestedConditionList.isEmpty()) { + Where filterWhere = new Where(where.getConn()); + nestedConditionList.forEach( + condition -> { + if (condition.isNestedComplex()) { + ((Where) condition.getValue()).getWheres().forEach(filterWhere::addWhere); + } else { + // Since the filter condition is used inside Nested Aggregation,remove the nested + // attribute. + condition.setNested(false); + condition.setNestedPath(""); + filterWhere.addWhere(condition); + } + }); + FilterAggregationBuilder filterAgg = + AggregationBuilders.filter( + nestedType.getFilterAggName(), QueryMaker.explain(filterWhere)); + nestedType.addBucketPath(Path.getAggPath(filterAgg.getName())); + return filterAgg.subAggregation(builder); + } } - - /** - * The groupMap is used when parsing order by to find out the corresponding field in aggregation. - * There are two cases. - * 1) using alias in order by, e.g. SELECT COUNT(*) as c FROM T GROUP BY age ORDER BY c - * 2) using full name in order by, e.g. SELECT COUNT(*) as c FROM T GROUP BY age ORDER BY COUNT(*) - * Then, the groupMap should support these two cases by maintain the mapping of - * {alias, value} and {full_name, value} - */ - private void extendGroupMap(Field field, KVValue value) { - groupMap.put(field.toString(), value); - if (!StringUtils.isEmpty(field.getAlias())) { - groupMap.putIfAbsent(field.getAlias(), value); - } + return builder; + } + + /** + * The groupMap is used when parsing order by to find out the corresponding field in aggregation. + * There are two cases. + * + *

    + *
  1. using alias in order by, e.g. SELECT COUNT(*) as c FROM T GROUP BY age ORDER BY c + *
  2. using full name in order by, e.g. SELECT COUNT(*) as c FROM T GROUP BY age ORDER BY + * COUNT(*) + *
+ * + * Then, the groupMap should support these two cases by maintain the mapping of {alias, value} and + * {full_name, value} + */ + private void extendGroupMap(Field field, KVValue value) { + groupMap.put(field.toString(), value); + if (!StringUtils.isEmpty(field.getAlias())) { + groupMap.putIfAbsent(field.getAlias(), value); } + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/BindingTupleQueryPlanner.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/BindingTupleQueryPlanner.java index 01a0e78484..a8fb7cc53c 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/BindingTupleQueryPlanner.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/BindingTupleQueryPlanner.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.core; import com.alibaba.druid.sql.ast.expr.SQLQueryExpr; @@ -17,63 +16,63 @@ import org.opensearch.sql.legacy.query.planner.physical.PhysicalOperator; import org.opensearch.sql.legacy.query.planner.physical.node.scroll.PhysicalScroll; -/** - * The definition of QueryPlanner which return the {@link BindingTuple} as result. - */ +/** The definition of QueryPlanner which return the {@link BindingTuple} as result. */ public class BindingTupleQueryPlanner { - private PhysicalOperator physicalOperator; - @Getter - private List columnNodes; - - public BindingTupleQueryPlanner(Client client, SQLQueryExpr sqlExpr, ColumnTypeProvider columnTypeProvider) { - SQLToOperatorConverter converter = new SQLToOperatorConverter(client, columnTypeProvider); - sqlExpr.accept(converter); - this.physicalOperator = converter.getPhysicalOperator(); - this.columnNodes = converter.getColumnNodes(); - } + private PhysicalOperator physicalOperator; + @Getter private List columnNodes; - /** - * Execute the QueryPlanner. - * @return list of {@link BindingTuple}. - */ - public List execute() { - PhysicalOperator op = physicalOperator; - List tuples = new ArrayList<>(); - try { - op.open(null); - } catch (Exception e) { - throw new RuntimeException(e); - } + public BindingTupleQueryPlanner( + Client client, SQLQueryExpr sqlExpr, ColumnTypeProvider columnTypeProvider) { + SQLToOperatorConverter converter = new SQLToOperatorConverter(client, columnTypeProvider); + sqlExpr.accept(converter); + this.physicalOperator = converter.getPhysicalOperator(); + this.columnNodes = converter.getColumnNodes(); + } - while (op.hasNext()) { - tuples.add(op.next().data()); - } - return tuples; + /** + * Execute the QueryPlanner. + * + * @return list of {@link BindingTuple}. + */ + public List execute() { + PhysicalOperator op = physicalOperator; + List tuples = new ArrayList<>(); + try { + op.open(null); + } catch (Exception e) { + throw new RuntimeException(e); } - /** - * Explain the physical execution plan. - * @return execution plan. - */ - public String explain() { - Explanation explanation = new Explanation(); - physicalOperator.accept(explanation); - return explanation.explain(); + while (op.hasNext()) { + tuples.add(op.next().data()); } + return tuples; + } + + /** + * Explain the physical execution plan. + * + * @return execution plan. + */ + public String explain() { + Explanation explanation = new Explanation(); + physicalOperator.accept(explanation); + return explanation.explain(); + } - private static class Explanation implements PlanNode.Visitor { - private String explain; + private static class Explanation implements PlanNode.Visitor { + private String explain; - public String explain() { - return explain; - } + public String explain() { + return explain; + } - @Override - public boolean visit(PlanNode planNode) { - if (planNode instanceof PhysicalScroll) { - explain = planNode.toString(); - } - return true; - } + @Override + public boolean visit(PlanNode planNode) { + if (planNode instanceof PhysicalScroll) { + explain = planNode.toString(); + } + return true; } + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/ColumnNode.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/ColumnNode.java index 753d5ac001..9dd969fb83 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/ColumnNode.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/ColumnNode.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.core; import com.google.common.base.Strings; @@ -14,20 +13,18 @@ import org.opensearch.sql.legacy.executor.format.Schema; import org.opensearch.sql.legacy.expression.core.Expression; -/** - * The definition of column node. - */ +/** The definition of column node. */ @Builder @Setter @Getter @ToString public class ColumnNode { - private String name; - private String alias; - private Schema.Type type; - private Expression expr; + private String name; + private String alias; + private Schema.Type type; + private Expression expr; - public String columnName() { - return Strings.isNullOrEmpty(alias) ? name : alias; - } + public String columnName() { + return Strings.isNullOrEmpty(alias) ? name : alias; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/Config.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/Config.java index 6e04c674cb..304a16756b 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/Config.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/core/Config.java @@ -3,156 +3,134 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.core; import org.opensearch.sql.legacy.query.planner.resource.blocksize.AdaptiveBlockSize; import org.opensearch.sql.legacy.query.planner.resource.blocksize.BlockSize; import org.opensearch.sql.legacy.query.planner.resource.blocksize.BlockSize.FixedBlockSize; -/** - * Query planner configuration - */ +/** Query planner configuration */ public class Config { - public static final int DEFAULT_BLOCK_SIZE = 10000; - public static final int DEFAULT_SCROLL_PAGE_SIZE = 10000; - public static final int DEFAULT_CIRCUIT_BREAK_LIMIT = 85; - public static final double[] DEFAULT_BACK_OFF_RETRY_INTERVALS = {4, 8 + 4, 16 + 4}; - public static final int DEFAULT_TIME_OUT = 60; - - /** - * Block size for join algorithm - */ - private BlockSize blockSize = new FixedBlockSize(DEFAULT_BLOCK_SIZE); - - /** - * Page size for scroll on each index - */ - private Integer[] scrollPageSizes = {DEFAULT_SCROLL_PAGE_SIZE, DEFAULT_SCROLL_PAGE_SIZE}; - - /** - * Circuit breaker trigger limit (percentage) - */ - private Integer circuitBreakLimit = DEFAULT_CIRCUIT_BREAK_LIMIT; - - /** - * Intervals for back off retry - */ - private double[] backOffRetryIntervals = DEFAULT_BACK_OFF_RETRY_INTERVALS; - - /** - * Total number of rows in final result specified by LIMIT - */ - private int totalLimit; - - /** - * Number of rows fetched from each table specified by JOIN_TABLES_LIMIT hint - */ - private int tableLimit1; - private int tableLimit2; - - /** - * Push down column values in ON of first table to query against second table - */ - private boolean isUseTermsFilterOptimization = false; - - /** - * Total time out (seconds) for the execution - */ - private int timeout = DEFAULT_TIME_OUT; - - - public BlockSize blockSize() { - return blockSize; - } + public static final int DEFAULT_BLOCK_SIZE = 10000; + public static final int DEFAULT_SCROLL_PAGE_SIZE = 10000; + public static final int DEFAULT_CIRCUIT_BREAK_LIMIT = 85; + public static final double[] DEFAULT_BACK_OFF_RETRY_INTERVALS = {4, 8 + 4, 16 + 4}; + public static final int DEFAULT_TIME_OUT = 60; - public void configureBlockSize(Object[] params) { - if (params.length > 0) { - Integer size = (Integer) params[0]; - if (size > 0) { - blockSize = new FixedBlockSize(size); - } else { - blockSize = new AdaptiveBlockSize(0); - } - } - } + /** Block size for join algorithm */ + private BlockSize blockSize = new FixedBlockSize(DEFAULT_BLOCK_SIZE); - public Integer[] scrollPageSize() { - return scrollPageSizes; - } + /** Page size for scroll on each index */ + private Integer[] scrollPageSizes = {DEFAULT_SCROLL_PAGE_SIZE, DEFAULT_SCROLL_PAGE_SIZE}; - public void configureScrollPageSize(Object[] params) { - if (params.length == 1) { - scrollPageSizes = new Integer[]{ - (Integer) params[0], - (Integer) params[0] - }; - } else if (params.length >= 2) { - scrollPageSizes = (Integer[]) params; - } - } + /** Circuit breaker trigger limit (percentage) */ + private Integer circuitBreakLimit = DEFAULT_CIRCUIT_BREAK_LIMIT; - public int circuitBreakLimit() { - return circuitBreakLimit; - } + /** Intervals for back off retry */ + private double[] backOffRetryIntervals = DEFAULT_BACK_OFF_RETRY_INTERVALS; - public void configureCircuitBreakLimit(Object[] params) { - if (params.length > 0) { - circuitBreakLimit = (Integer) params[0]; - } - } + /** Total number of rows in final result specified by LIMIT */ + private int totalLimit; - public double[] backOffRetryIntervals() { - return backOffRetryIntervals; - } + /** Number of rows fetched from each table specified by JOIN_TABLES_LIMIT hint */ + private int tableLimit1; - public void configureBackOffRetryIntervals(Object[] params) { - backOffRetryIntervals = new double[params.length]; - for (int i = 0; i < params.length; i++) { - backOffRetryIntervals[i] = (Integer) params[i]; //Only support integer interval for now - } - } + private int tableLimit2; - public void configureLimit(Integer totalLimit, Integer tableLimit1, Integer tableLimit2) { - if (totalLimit != null) { - this.totalLimit = totalLimit; - } - if (tableLimit1 != null) { - this.tableLimit1 = tableLimit1; - } - if (tableLimit2 != null) { - this.tableLimit2 = tableLimit2; - } - } + /** Push down column values in ON of first table to query against second table */ + private boolean isUseTermsFilterOptimization = false; - public int totalLimit() { - return totalLimit; - } + /** Total time out (seconds) for the execution */ + private int timeout = DEFAULT_TIME_OUT; - public int tableLimit1() { - return tableLimit1; + public BlockSize blockSize() { + return blockSize; + } + + public void configureBlockSize(Object[] params) { + if (params.length > 0) { + Integer size = (Integer) params[0]; + if (size > 0) { + blockSize = new FixedBlockSize(size); + } else { + blockSize = new AdaptiveBlockSize(0); + } } + } + + public Integer[] scrollPageSize() { + return scrollPageSizes; + } - public int tableLimit2() { - return tableLimit2; + public void configureScrollPageSize(Object[] params) { + if (params.length == 1) { + scrollPageSizes = new Integer[] {(Integer) params[0], (Integer) params[0]}; + } else if (params.length >= 2) { + scrollPageSizes = (Integer[]) params; } + } + + public int circuitBreakLimit() { + return circuitBreakLimit; + } - public void configureTermsFilterOptimization(boolean isUseTermFiltersOptimization) { - this.isUseTermsFilterOptimization = isUseTermFiltersOptimization; + public void configureCircuitBreakLimit(Object[] params) { + if (params.length > 0) { + circuitBreakLimit = (Integer) params[0]; } + } - public boolean isUseTermsFilterOptimization() { - return isUseTermsFilterOptimization; + public double[] backOffRetryIntervals() { + return backOffRetryIntervals; + } + + public void configureBackOffRetryIntervals(Object[] params) { + backOffRetryIntervals = new double[params.length]; + for (int i = 0; i < params.length; i++) { + backOffRetryIntervals[i] = (Integer) params[i]; // Only support integer interval for now } + } - public void configureTimeOut(Object[] params) { - if (params.length > 0) { - timeout = (Integer) params[0]; - } + public void configureLimit(Integer totalLimit, Integer tableLimit1, Integer tableLimit2) { + if (totalLimit != null) { + this.totalLimit = totalLimit; + } + if (tableLimit1 != null) { + this.tableLimit1 = tableLimit1; + } + if (tableLimit2 != null) { + this.tableLimit2 = tableLimit2; } + } - public int timeout() { - return timeout; + public int totalLimit() { + return totalLimit; + } + + public int tableLimit1() { + return tableLimit1; + } + + public int tableLimit2() { + return tableLimit2; + } + + public void configureTermsFilterOptimization(boolean isUseTermFiltersOptimization) { + this.isUseTermsFilterOptimization = isUseTermFiltersOptimization; + } + + public boolean isUseTermsFilterOptimization() { + return isUseTermsFilterOptimization; + } + + public void configureTimeOut(Object[] params) { + if (params.length > 0) { + timeout = (Integer) params[0]; } + } + + public int timeout() { + return timeout; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/estimation/Cost.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/estimation/Cost.java index efaf7057b6..86f155d626 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/estimation/Cost.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/estimation/Cost.java @@ -3,22 +3,20 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.physical.estimation; public class Cost implements Comparable { - public static final Cost INFINITY = new Cost(); + public static final Cost INFINITY = new Cost(); - private long inputSize; + private long inputSize; - private long time; + private long time; - public Cost() { - } + public Cost() {} - @Override - public int compareTo(Cost o) { - return 0; - } + @Override + public int compareTo(Cost o) { + return 0; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/BatchPhysicalOperator.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/BatchPhysicalOperator.java index 3b4eb2b48e..19ee573652 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/BatchPhysicalOperator.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/BatchPhysicalOperator.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.physical.node; import static org.opensearch.sql.legacy.query.planner.core.ExecuteParams.ExecuteParamType.RESOURCE_MANAGER; @@ -19,78 +18,74 @@ import org.opensearch.sql.legacy.query.planner.resource.ResourceManager; /** - * Abstraction for physical operators that load large volume of data and generally prefetch for efficiency. + * Abstraction for physical operators that load large volume of data and generally prefetch for + * efficiency. * * @param */ public abstract class BatchPhysicalOperator implements PhysicalOperator { - protected static final Logger LOG = LogManager.getLogger(); + protected static final Logger LOG = LogManager.getLogger(); - /** - * Resource monitor to avoid consuming too much resource - */ - private ResourceManager resourceMgr; + /** Resource monitor to avoid consuming too much resource */ + private ResourceManager resourceMgr; - /** - * Current batch of data - */ - private Iterator> curBatch; + /** Current batch of data */ + private Iterator> curBatch; - @Override - public void open(ExecuteParams params) throws Exception { - //PhysicalOperator.super.open(params); // Child needs to call this super.open() and open its next node too - resourceMgr = params.get(RESOURCE_MANAGER); - } + @Override + public void open(ExecuteParams params) throws Exception { + // Child needs to call this super.open() and open its next node too + // PhysicalOperator.super.open(params); + resourceMgr = params.get(RESOURCE_MANAGER); + } - @Override - public boolean hasNext() { - if (isNoMoreDataInCurrentBatch()) { - LOG.debug("{} No more data in current batch, pre-fetching next batch", this); - Collection> nextBatch = prefetchSafely(); + @Override + public boolean hasNext() { + if (isNoMoreDataInCurrentBatch()) { + LOG.debug("{} No more data in current batch, pre-fetching next batch", this); + Collection> nextBatch = prefetchSafely(); - LOG.debug("{} Pre-fetched {} rows", this, nextBatch.size()); - if (LOG.isTraceEnabled()) { - nextBatch.forEach(row -> LOG.trace("Row pre-fetched: {}", row)); - } + LOG.debug("{} Pre-fetched {} rows", this, nextBatch.size()); + if (LOG.isTraceEnabled()) { + nextBatch.forEach(row -> LOG.trace("Row pre-fetched: {}", row)); + } - curBatch = nextBatch.iterator(); - } - return curBatch.hasNext(); + curBatch = nextBatch.iterator(); } - - @Override - public Row next() { - return curBatch.next(); - } - - /** - * Prefetch next batch safely by checking resource monitor - */ - private Collection> prefetchSafely() { - Objects.requireNonNull(resourceMgr, "ResourceManager is not set so unable to do sanity check"); - - boolean isHealthy = resourceMgr.isHealthy(); - boolean isTimeout = resourceMgr.isTimeout(); - if (isHealthy && !isTimeout) { - try { - return prefetch(); - } catch (Exception e) { - throw new IllegalStateException("Failed to prefetch next batch", e); - } - } - throw new IllegalStateException("Exit due to " + (isHealthy ? "time out" : "insufficient resource")); + return curBatch.hasNext(); + } + + @Override + public Row next() { + return curBatch.next(); + } + + /** Prefetch next batch safely by checking resource monitor */ + private Collection> prefetchSafely() { + Objects.requireNonNull(resourceMgr, "ResourceManager is not set so unable to do sanity check"); + + boolean isHealthy = resourceMgr.isHealthy(); + boolean isTimeout = resourceMgr.isTimeout(); + if (isHealthy && !isTimeout) { + try { + return prefetch(); + } catch (Exception e) { + throw new IllegalStateException("Failed to prefetch next batch", e); + } } - - /** - * Prefetch next batch if current is exhausted. - * - * @return next batch - */ - protected abstract Collection> prefetch() throws Exception; - - private boolean isNoMoreDataInCurrentBatch() { - return curBatch == null || !curBatch.hasNext(); - } - + throw new IllegalStateException( + "Exit due to " + (isHealthy ? "time out" : "insufficient resource")); + } + + /** + * Prefetch next batch if current is exhausted. + * + * @return next batch + */ + protected abstract Collection> prefetch() throws Exception; + + private boolean isNoMoreDataInCurrentBatch() { + return curBatch == null || !curBatch.hasNext(); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/join/BlockHashJoin.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/join/BlockHashJoin.java index 19c0ae41d2..90bf9923d3 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/join/BlockHashJoin.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/join/BlockHashJoin.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.physical.node.join; import static com.alibaba.druid.sql.ast.statement.SQLJoinTableSource.JoinType; @@ -25,91 +24,87 @@ import org.opensearch.sql.legacy.query.planner.physical.estimation.Cost; import org.opensearch.sql.legacy.query.planner.resource.blocksize.BlockSize; -/** - * Block-based Hash Join implementation - */ +/** Block-based Hash Join implementation */ public class BlockHashJoin extends JoinAlgorithm { - /** - * Use terms filter optimization or not - */ - private final boolean isUseTermsFilterOptimization; + /** Use terms filter optimization or not */ + private final boolean isUseTermsFilterOptimization; - public BlockHashJoin(PhysicalOperator left, - PhysicalOperator right, - JoinType type, - JoinCondition condition, - BlockSize blockSize, - boolean isUseTermsFilterOptimization) { - super(left, right, type, condition, blockSize); + public BlockHashJoin( + PhysicalOperator left, + PhysicalOperator right, + JoinType type, + JoinCondition condition, + BlockSize blockSize, + boolean isUseTermsFilterOptimization) { + super(left, right, type, condition, blockSize); - this.isUseTermsFilterOptimization = isUseTermsFilterOptimization; - } + this.isUseTermsFilterOptimization = isUseTermsFilterOptimization; + } - @Override - public Cost estimate() { - return new Cost(); - } + @Override + public Cost estimate() { + return new Cost(); + } - @Override - protected void reopenRight() throws Exception { - Objects.requireNonNull(params, "Execute params is not set so unable to add extra filter"); + @Override + protected void reopenRight() throws Exception { + Objects.requireNonNull(params, "Execute params is not set so unable to add extra filter"); - if (isUseTermsFilterOptimization) { - params.add(ExecuteParams.ExecuteParamType.EXTRA_QUERY_FILTER, queryForPushedDownOnConds()); - } - right.open(params); + if (isUseTermsFilterOptimization) { + params.add(ExecuteParams.ExecuteParamType.EXTRA_QUERY_FILTER, queryForPushedDownOnConds()); } - - @Override - protected List> probe() { - List> combinedRows = new ArrayList<>(); - int totalSize = 0; - - /* Return if already found enough matched rows to give ResourceMgr a chance to check resource usage */ - while (right.hasNext() && totalSize < hashTable.size()) { - Row rightRow = right.next(); - Collection> matchedLeftRows = hashTable.match(rightRow); - - if (!matchedLeftRows.isEmpty()) { - combinedRows.add(new CombinedRow<>(rightRow, matchedLeftRows)); - totalSize += matchedLeftRows.size(); - } - } - return combinedRows; + right.open(params); + } + + @Override + protected List> probe() { + List> combinedRows = new ArrayList<>(); + int totalSize = 0; + + /* Return if already found enough matched rows to give ResourceMgr a chance to check resource usage */ + while (right.hasNext() && totalSize < hashTable.size()) { + Row rightRow = right.next(); + Collection> matchedLeftRows = hashTable.match(rightRow); + + if (!matchedLeftRows.isEmpty()) { + combinedRows.add(new CombinedRow<>(rightRow, matchedLeftRows)); + totalSize += matchedLeftRows.size(); + } } - - /** - * Build query for pushed down conditions in ON - */ - private BoolQueryBuilder queryForPushedDownOnConds() { - BoolQueryBuilder orQuery = boolQuery(); - Map>[] rightNameToLeftValuesGroup = hashTable.rightFieldWithLeftValues(); - - for (Map> rightNameToLeftValues : rightNameToLeftValuesGroup) { - if (LOG.isTraceEnabled()) { - rightNameToLeftValues.forEach((rightName, leftValues) -> - LOG.trace("Right name to left values mapping: {} => {}", rightName, leftValues)); - } - - BoolQueryBuilder andQuery = boolQuery(); - rightNameToLeftValues.forEach( - (rightName, leftValues) -> andQuery.must(termsQuery(rightName, leftValues)) - ); - - if (LOG.isTraceEnabled()) { - LOG.trace("Terms filter optimization: {}", Strings.toString(XContentType.JSON, andQuery)); - } - orQuery.should(andQuery); - } - return orQuery; + return combinedRows; + } + + /** Build query for pushed down conditions in ON */ + private BoolQueryBuilder queryForPushedDownOnConds() { + BoolQueryBuilder orQuery = boolQuery(); + Map>[] rightNameToLeftValuesGroup = + hashTable.rightFieldWithLeftValues(); + + for (Map> rightNameToLeftValues : rightNameToLeftValuesGroup) { + if (LOG.isTraceEnabled()) { + rightNameToLeftValues.forEach( + (rightName, leftValues) -> + LOG.trace("Right name to left values mapping: {} => {}", rightName, leftValues)); + } + + BoolQueryBuilder andQuery = boolQuery(); + rightNameToLeftValues.forEach( + (rightName, leftValues) -> andQuery.must(termsQuery(rightName, leftValues))); + + if (LOG.isTraceEnabled()) { + LOG.trace("Terms filter optimization: {}", Strings.toString(XContentType.JSON, andQuery)); + } + orQuery.should(andQuery); } + return orQuery; + } - /********************************************* - * Getters for Explain - *********************************************/ + /********************************************* + * Getters for Explain + *********************************************/ - public boolean isUseTermsFilterOptimization() { - return isUseTermsFilterOptimization; - } + public boolean isUseTermsFilterOptimization() { + return isUseTermsFilterOptimization; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/join/CombinedRow.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/join/CombinedRow.java index e83bbb7d0e..b1fb43441e 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/join/CombinedRow.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/join/CombinedRow.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.physical.node.join; import java.util.ArrayList; @@ -19,28 +18,28 @@ */ public class CombinedRow { - private Row rightRow; - private Collection> leftRows; + private Row rightRow; + private Collection> leftRows; - public CombinedRow(Row rightRow, Collection> leftRows) { - this.rightRow = rightRow; - this.leftRows = leftRows; - } + public CombinedRow(Row rightRow, Collection> leftRows) { + this.rightRow = rightRow; + this.leftRows = leftRows; + } - public List> combine() { - List> combinedRows = new ArrayList<>(); - for (Row leftRow : leftRows) { - combinedRows.add(leftRow.combine(rightRow)); - } - return combinedRows; + public List> combine() { + List> combinedRows = new ArrayList<>(); + for (Row leftRow : leftRows) { + combinedRows.add(leftRow.combine(rightRow)); } + return combinedRows; + } - public Collection> leftMatchedRows() { - return Collections.unmodifiableCollection(leftRows); - } + public Collection> leftMatchedRows() { + return Collections.unmodifiableCollection(leftRows); + } - @Override - public String toString() { - return "CombinedRow{rightRow=" + rightRow + ", leftRows=" + leftRows + '}'; - } + @Override + public String toString() { + return "CombinedRow{rightRow=" + rightRow + ", leftRows=" + leftRows + '}'; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/join/DefaultHashTable.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/join/DefaultHashTable.java index 733d7a78ab..23e79d2c31 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/join/DefaultHashTable.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/join/DefaultHashTable.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.physical.node.join; import static java.util.Collections.emptyList; @@ -22,102 +21,98 @@ import org.opensearch.sql.legacy.query.planner.physical.Row.RowKey; /** - * Hash table implementation. - * In the case of no join condition, hash table degrades to linked list with all rows in block paired to RowKey.NULL + * Hash table implementation. In the case of no join condition, hash table degrades to linked list + * with all rows in block paired to RowKey.NULL * * @param Row data type */ public class DefaultHashTable implements HashTable { - private static final Logger LOG = LogManager.getLogger(); - - /** - * Hash table implementation - */ - private final Multimap> table = ArrayListMultimap.create(); - - /** - * Left join conditions to generate key to build hash table by left rows from block - */ - private final String[] leftJoinFields; - - /** - * Right join conditions to generate key to probe hash table by right rows - */ - private final String[] rightJoinFields; - - - public DefaultHashTable(String[] leftJoinFields, String[] rightJoinFields) { - this.leftJoinFields = leftJoinFields; - this.rightJoinFields = rightJoinFields; + private static final Logger LOG = LogManager.getLogger(); + + /** Hash table implementation */ + private final Multimap> table = ArrayListMultimap.create(); + + /** Left join conditions to generate key to build hash table by left rows from block */ + private final String[] leftJoinFields; + + /** Right join conditions to generate key to probe hash table by right rows */ + private final String[] rightJoinFields; + + public DefaultHashTable(String[] leftJoinFields, String[] rightJoinFields) { + this.leftJoinFields = leftJoinFields; + this.rightJoinFields = rightJoinFields; + } + + /** + * Add row in block to hash table by left conditions in ON. For the duplicate key, append them to + * the list in value (MultiMap) + */ + @Override + public void add(Row row) { + RowKey key = row.key(leftJoinFields); + if (key == RowKey.NULL) { + LOG.debug( + "Skip rows with NULL column value during build: row={}, conditions={}", + row, + leftJoinFields); + } else { + table.put(key, row); } - - /** - * Add row in block to hash table by left conditions in ON. - * For the duplicate key, append them to the list in value (MultiMap) - */ - @Override - public void add(Row row) { - RowKey key = row.key(leftJoinFields); - if (key == RowKey.NULL) { - LOG.debug("Skip rows with NULL column value during build: row={}, conditions={}", row, leftJoinFields); - } else { - table.put(key, row); - } + } + + /** Probe hash table to match right rows by values of right conditions */ + @Override + public Collection> match(Row row) { + RowKey key = row.key(rightJoinFields); + if (key == RowKey.NULL) { + LOG.debug( + "Skip rows with NULL column value during probing: row={}, conditions={}", + row, + rightJoinFields); + return emptyList(); } - - /** - * Probe hash table to match right rows by values of right conditions - */ - @Override - public Collection> match(Row row) { - RowKey key = row.key(rightJoinFields); - if (key == RowKey.NULL) { - LOG.debug("Skip rows with NULL column value during probing: row={}, conditions={}", row, rightJoinFields); - return emptyList(); - } - return table.get(key); // Multimap returns empty list rather null. + return table.get(key); // Multimap returns empty list rather null. + } + + /** Right joined field name with according column value list to push down */ + @SuppressWarnings("unchecked") + @Override + public Map>[] rightFieldWithLeftValues() { + Map> result = + new HashMap<>(); // Eliminate potential duplicate in values + for (RowKey key : table.keySet()) { + Object[] keys = key.keys(); + for (int i = 0; i < keys.length; i++) { + result + .computeIfAbsent(rightJoinFields[i], (k -> new HashSet<>())) + .add(lowercaseIfStr(keys[i])); // Terms stored in lower case in OpenSearch + } } - /** - * Right joined field name with according column value list to push down - */ - @SuppressWarnings("unchecked") - @Override - public Map>[] rightFieldWithLeftValues() { - Map> result = new HashMap<>(); // Eliminate potential duplicate in values - for (RowKey key : table.keySet()) { - Object[] keys = key.keys(); - for (int i = 0; i < keys.length; i++) { - result.computeIfAbsent(rightJoinFields[i], (k -> new HashSet<>())). - add(lowercaseIfStr(keys[i])); // Terms stored in lower case in OpenSearch - } - } - - // Convert value of Map from Guava's Set to JDK list which is expected by OpenSearch writer - for (Entry> entry : result.entrySet()) { - entry.setValue(new ArrayList<>(entry.getValue())); - } - return new Map[]{result}; + // Convert value of Map from Guava's Set to JDK list which is expected by OpenSearch writer + for (Entry> entry : result.entrySet()) { + entry.setValue(new ArrayList<>(entry.getValue())); } - - @Override - public int size() { - return table.size(); - } - - @Override - public boolean isEmpty() { - return table.isEmpty(); - } - - @Override - public void clear() { - table.clear(); - } - - private Object lowercaseIfStr(Object key) { - return key instanceof String ? ((String) key).toLowerCase() : key; - } - + return new Map[] {result}; + } + + @Override + public int size() { + return table.size(); + } + + @Override + public boolean isEmpty() { + return table.isEmpty(); + } + + @Override + public void clear() { + table.clear(); + } + + private Object lowercaseIfStr(Object key) { + return key instanceof String ? ((String) key).toLowerCase() : key; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/BindingTupleRow.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/BindingTupleRow.java index 9e3a190e30..41f500fed1 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/BindingTupleRow.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/physical/node/scroll/BindingTupleRow.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.physical.node.scroll; import java.util.Map; @@ -13,25 +12,25 @@ @RequiredArgsConstructor public class BindingTupleRow implements Row { - private final BindingTuple bindingTuple; - - @Override - public RowKey key(String[] colNames) { - return null; - } - - @Override - public Row combine(Row otherRow) { - throw new RuntimeException("unsupported operation"); - } - - @Override - public void retain(Map colNameAlias) { - // do nothing - } - - @Override - public BindingTuple data() { - return bindingTuple; - } + private final BindingTuple bindingTuple; + + @Override + public RowKey key(String[] colNames) { + return null; + } + + @Override + public Row combine(Row otherRow) { + throw new RuntimeException("unsupported operation"); + } + + @Override + public void retain(Map colNameAlias) { + // do nothing + } + + @Override + public BindingTuple data() { + return bindingTuple; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/resource/blocksize/BlockSize.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/resource/blocksize/BlockSize.java index d68b16b8bb..6e5a2703f4 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/resource/blocksize/BlockSize.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/planner/resource/blocksize/BlockSize.java @@ -3,42 +3,35 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.query.planner.resource.blocksize; -/** - * Block size calculating logic. - */ +/** Block size calculating logic. */ public interface BlockSize { - /** - * Get block size configured or dynamically. Integer should be sufficient for single block size. - * - * @return block size. - */ - int size(); - + /** + * Get block size configured or dynamically. Integer should be sufficient for single block size. + * + * @return block size. + */ + int size(); - /** - * Default implementation with fixed block size - */ - class FixedBlockSize implements BlockSize { + /** Default implementation with fixed block size */ + class FixedBlockSize implements BlockSize { - private int blockSize; + private int blockSize; - public FixedBlockSize(int blockSize) { - this.blockSize = blockSize; - } - - @Override - public int size() { - return blockSize; - } + public FixedBlockSize(int blockSize) { + this.blockSize = blockSize; + } - @Override - public String toString() { - return "FixedBlockSize with " + "size=" + blockSize; - } + @Override + public int size() { + return blockSize; } + @Override + public String toString() { + return "FixedBlockSize with size=" + blockSize; + } + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/identifier/AnonymizeSensitiveDataRule.java b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/identifier/AnonymizeSensitiveDataRule.java index 2768b269bf..c4f3ee5a10 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/identifier/AnonymizeSensitiveDataRule.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/identifier/AnonymizeSensitiveDataRule.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.rewriter.identifier; import com.alibaba.druid.sql.ast.expr.SQLBooleanExpr; @@ -17,53 +16,53 @@ import org.opensearch.sql.legacy.rewriter.RewriteRule; /** - * Rewrite rule to anonymize sensitive data in logging queries. - * This rule replace the content of specific nodes (that might involve index data) in AST - * to anonymous content. + * Rewrite rule to anonymize sensitive data in logging queries. This rule replace the content of + * specific nodes (that might involve index data) in AST to anonymous content. */ -public class AnonymizeSensitiveDataRule extends MySqlASTVisitorAdapter implements RewriteRule { +public class AnonymizeSensitiveDataRule extends MySqlASTVisitorAdapter + implements RewriteRule { - @Override - public boolean visit(SQLIdentifierExpr identifierExpr) { - if (identifierExpr.getParent() instanceof SQLExprTableSource) { - identifierExpr.setName("table"); - } else { - identifierExpr.setName("identifier"); - } - return true; + @Override + public boolean visit(SQLIdentifierExpr identifierExpr) { + if (identifierExpr.getParent() instanceof SQLExprTableSource) { + identifierExpr.setName("table"); + } else { + identifierExpr.setName("identifier"); } + return true; + } - @Override - public boolean visit(SQLIntegerExpr integerExpr) { - integerExpr.setNumber(0); - return true; - } + @Override + public boolean visit(SQLIntegerExpr integerExpr) { + integerExpr.setNumber(0); + return true; + } - @Override - public boolean visit(SQLNumberExpr numberExpr) { - numberExpr.setNumber(0); - return true; - } + @Override + public boolean visit(SQLNumberExpr numberExpr) { + numberExpr.setNumber(0); + return true; + } - @Override - public boolean visit(SQLCharExpr charExpr) { - charExpr.setText("string_literal"); - return true; - } + @Override + public boolean visit(SQLCharExpr charExpr) { + charExpr.setText("string_literal"); + return true; + } - @Override - public boolean visit(SQLBooleanExpr booleanExpr) { - booleanExpr.setValue(false); - return true; - } + @Override + public boolean visit(SQLBooleanExpr booleanExpr) { + booleanExpr.setValue(false); + return true; + } - @Override - public boolean match(SQLQueryExpr expr) { - return true; - } + @Override + public boolean match(SQLQueryExpr expr) { + return true; + } - @Override - public void rewrite(SQLQueryExpr expr) { - expr.accept(this); - } + @Override + public void rewrite(SQLQueryExpr expr) { + expr.accept(this); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/spatial/BoundingBoxFilterParams.java b/legacy/src/main/java/org/opensearch/sql/legacy/spatial/BoundingBoxFilterParams.java index df9f4c88b2..fb62f60ae7 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/spatial/BoundingBoxFilterParams.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/spatial/BoundingBoxFilterParams.java @@ -3,26 +3,23 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.spatial; -/** - * Created by Eliran on 1/8/2015. - */ +/** Created by Eliran on 1/8/2015. */ public class BoundingBoxFilterParams { - private Point topLeft; - private Point bottomRight; + private Point topLeft; + private Point bottomRight; - public BoundingBoxFilterParams(Point topLeft, Point bottomRight) { - this.topLeft = topLeft; - this.bottomRight = bottomRight; - } + public BoundingBoxFilterParams(Point topLeft, Point bottomRight) { + this.topLeft = topLeft; + this.bottomRight = bottomRight; + } - public Point getTopLeft() { - return topLeft; - } + public Point getTopLeft() { + return topLeft; + } - public Point getBottomRight() { - return bottomRight; - } + public Point getBottomRight() { + return bottomRight; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/spatial/CellFilterParams.java b/legacy/src/main/java/org/opensearch/sql/legacy/spatial/CellFilterParams.java index fc3dc35f07..6c50c17467 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/spatial/CellFilterParams.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/spatial/CellFilterParams.java @@ -3,36 +3,33 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.spatial; -/** - * Created by Eliran on 15/8/2015. - */ +/** Created by Eliran on 15/8/2015. */ public class CellFilterParams { - private Point geohashPoint; - private int precision; - private boolean neighbors; - - public CellFilterParams(Point geohashPoint, int precision, boolean neighbors) { - this.geohashPoint = geohashPoint; - this.precision = precision; - this.neighbors = neighbors; - } - - public CellFilterParams(Point geohashPoint, int precision) { - this(geohashPoint, precision, false); - } - - public Point getGeohashPoint() { - return geohashPoint; - } - - public int getPrecision() { - return precision; - } - - public boolean isNeighbors() { - return neighbors; - } + private Point geohashPoint; + private int precision; + private boolean neighbors; + + public CellFilterParams(Point geohashPoint, int precision, boolean neighbors) { + this.geohashPoint = geohashPoint; + this.precision = precision; + this.neighbors = neighbors; + } + + public CellFilterParams(Point geohashPoint, int precision) { + this(geohashPoint, precision, false); + } + + public Point getGeohashPoint() { + return geohashPoint; + } + + public int getPrecision() { + return precision; + } + + public boolean isNeighbors() { + return neighbors; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/spatial/DistanceFilterParams.java b/legacy/src/main/java/org/opensearch/sql/legacy/spatial/DistanceFilterParams.java index 1141da08ca..8c419de58d 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/spatial/DistanceFilterParams.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/spatial/DistanceFilterParams.java @@ -3,26 +3,23 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.spatial; -/** - * Created by Eliran on 1/8/2015. - */ +/** Created by Eliran on 1/8/2015. */ public class DistanceFilterParams { - private String distance; - private Point from; + private String distance; + private Point from; - public DistanceFilterParams(String distance, Point from) { - this.distance = distance; - this.from = from; - } + public DistanceFilterParams(String distance, Point from) { + this.distance = distance; + this.from = from; + } - public String getDistance() { - return distance; - } + public String getDistance() { + return distance; + } - public Point getFrom() { - return from; - } + public Point getFrom() { + return from; + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/antlr/semantic/types/BaseTypeTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/antlr/semantic/types/BaseTypeTest.java index a8ddfd43e8..0269c6b01c 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/antlr/semantic/types/BaseTypeTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/antlr/semantic/types/BaseTypeTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.antlr.semantic.types; import static org.junit.Assert.assertEquals; @@ -30,78 +29,75 @@ import org.opensearch.sql.legacy.antlr.semantic.types.base.OpenSearchDataType; import org.opensearch.sql.legacy.antlr.semantic.types.base.OpenSearchIndex; -/** - * Test base type compatibility - */ +/** Test base type compatibility */ public class BaseTypeTest { - @Test - public void unknownTypeNameShouldReturnUnknown() { - assertEquals(UNKNOWN, OpenSearchDataType.typeOf("this_is_a_new_es_type_we_arent_aware")); - } - - @Test - public void typeOfShouldIgnoreCase() { - assertEquals(INTEGER, OpenSearchDataType.typeOf("Integer")); - } - - @Test - public void sameBaseTypeShouldBeCompatible() { - assertTrue(INTEGER.isCompatible(INTEGER)); - assertTrue(BOOLEAN.isCompatible(BOOLEAN)); - } - - @Test - public void parentBaseTypeShouldBeCompatibleWithSubBaseType() { - assertTrue(NUMBER.isCompatible(DOUBLE)); - assertTrue(DOUBLE.isCompatible(FLOAT)); - assertTrue(FLOAT.isCompatible(INTEGER)); - assertTrue(INTEGER.isCompatible(SHORT)); - assertTrue(INTEGER.isCompatible(LONG)); - assertTrue(STRING.isCompatible(TEXT)); - assertTrue(STRING.isCompatible(KEYWORD)); - assertTrue(DATE.isCompatible(STRING)); - } - - @Test - public void ancestorBaseTypeShouldBeCompatibleWithSubBaseType() { - assertTrue(NUMBER.isCompatible(LONG)); - assertTrue(NUMBER.isCompatible(DOUBLE)); - assertTrue(DOUBLE.isCompatible(INTEGER)); - assertTrue(INTEGER.isCompatible(SHORT)); - assertTrue(INTEGER.isCompatible(LONG)); - } - - @Ignore("Two way compatibility is not necessary") - @Test - public void subBaseTypeShouldBeCompatibleWithParentBaseType() { - assertTrue(KEYWORD.isCompatible(STRING)); - } - - @Test - public void nonRelatedBaseTypeShouldNotBeCompatible() { - assertFalse(SHORT.isCompatible(TEXT)); - assertFalse(DATE.isCompatible(BOOLEAN)); - } - - @Test - public void unknownBaseTypeShouldBeCompatibleWithAnyBaseType() { - assertTrue(UNKNOWN.isCompatible(INTEGER)); - assertTrue(UNKNOWN.isCompatible(KEYWORD)); - assertTrue(UNKNOWN.isCompatible(BOOLEAN)); - } - - @Test - public void anyBaseTypeShouldBeCompatibleWithUnknownBaseType() { - assertTrue(LONG.isCompatible(UNKNOWN)); - assertTrue(TEXT.isCompatible(UNKNOWN)); - assertTrue(DATE.isCompatible(UNKNOWN)); - } - - @Test - public void nestedIndexTypeShouldBeCompatibleWithNestedDataType() { - assertTrue(NESTED.isCompatible(new OpenSearchIndex("test", NESTED_FIELD))); - assertTrue(OPENSEARCH_TYPE.isCompatible(new OpenSearchIndex("test", NESTED_FIELD))); - } - + @Test + public void unknownTypeNameShouldReturnUnknown() { + assertEquals(UNKNOWN, OpenSearchDataType.typeOf("this_is_a_new_es_type_we_arent_aware")); + } + + @Test + public void typeOfShouldIgnoreCase() { + assertEquals(INTEGER, OpenSearchDataType.typeOf("Integer")); + } + + @Test + public void sameBaseTypeShouldBeCompatible() { + assertTrue(INTEGER.isCompatible(INTEGER)); + assertTrue(BOOLEAN.isCompatible(BOOLEAN)); + } + + @Test + public void parentBaseTypeShouldBeCompatibleWithSubBaseType() { + assertTrue(NUMBER.isCompatible(DOUBLE)); + assertTrue(DOUBLE.isCompatible(FLOAT)); + assertTrue(FLOAT.isCompatible(INTEGER)); + assertTrue(INTEGER.isCompatible(SHORT)); + assertTrue(INTEGER.isCompatible(LONG)); + assertTrue(STRING.isCompatible(TEXT)); + assertTrue(STRING.isCompatible(KEYWORD)); + assertTrue(DATE.isCompatible(STRING)); + } + + @Test + public void ancestorBaseTypeShouldBeCompatibleWithSubBaseType() { + assertTrue(NUMBER.isCompatible(LONG)); + assertTrue(NUMBER.isCompatible(DOUBLE)); + assertTrue(DOUBLE.isCompatible(INTEGER)); + assertTrue(INTEGER.isCompatible(SHORT)); + assertTrue(INTEGER.isCompatible(LONG)); + } + + @Ignore("Two way compatibility is not necessary") + @Test + public void subBaseTypeShouldBeCompatibleWithParentBaseType() { + assertTrue(KEYWORD.isCompatible(STRING)); + } + + @Test + public void nonRelatedBaseTypeShouldNotBeCompatible() { + assertFalse(SHORT.isCompatible(TEXT)); + assertFalse(DATE.isCompatible(BOOLEAN)); + } + + @Test + public void unknownBaseTypeShouldBeCompatibleWithAnyBaseType() { + assertTrue(UNKNOWN.isCompatible(INTEGER)); + assertTrue(UNKNOWN.isCompatible(KEYWORD)); + assertTrue(UNKNOWN.isCompatible(BOOLEAN)); + } + + @Test + public void anyBaseTypeShouldBeCompatibleWithUnknownBaseType() { + assertTrue(LONG.isCompatible(UNKNOWN)); + assertTrue(TEXT.isCompatible(UNKNOWN)); + assertTrue(DATE.isCompatible(UNKNOWN)); + } + + @Test + public void nestedIndexTypeShouldBeCompatibleWithNestedDataType() { + assertTrue(NESTED.isCompatible(new OpenSearchIndex("test", NESTED_FIELD))); + assertTrue(OPENSEARCH_TYPE.isCompatible(new OpenSearchIndex("test", NESTED_FIELD))); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/antlr/visitor/AntlrSqlParseTreeVisitorTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/antlr/visitor/AntlrSqlParseTreeVisitorTest.java index c4e7a7e725..be4b5a5197 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/antlr/visitor/AntlrSqlParseTreeVisitorTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/antlr/visitor/AntlrSqlParseTreeVisitorTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.antlr.visitor; import static java.util.Collections.emptyList; @@ -25,95 +24,97 @@ import org.opensearch.sql.legacy.antlr.semantic.visitor.TypeChecker; import org.opensearch.sql.legacy.exception.SqlFeatureNotImplementedException; -/** - * Test cases for AntlrSqlParseTreeVisitor - */ +/** Test cases for AntlrSqlParseTreeVisitor */ public class AntlrSqlParseTreeVisitorTest { - private TypeChecker analyzer = new TypeChecker(new SemanticContext()) { + private TypeChecker analyzer = + new TypeChecker(new SemanticContext()) { @Override public Type visitIndexName(String indexName) { - return null; // avoid querying mapping on null LocalClusterState + return null; // avoid querying mapping on null LocalClusterState } @Override public Type visitFieldName(String fieldName) { - switch (fieldName) { - case "age": return INTEGER; - case "birthday": return DATE; - default: return UNKNOWN; - } + switch (fieldName) { + case "age": + return INTEGER; + case "birthday": + return DATE; + default: + return UNKNOWN; + } } - }; - - @Rule - public ExpectedException exceptionRule = ExpectedException.none(); - - @Test - public void selectNumberShouldReturnNumberAsQueryVisitingResult() { - Type result = visit("SELECT age FROM test"); - Assert.assertSame(result, INTEGER); - } - - @Test - public void selectNumberAndDateShouldReturnProductOfThemAsQueryVisitingResult() { - Type result = visit("SELECT age, birthday FROM test"); - Assert.assertTrue(result instanceof Product ); - Assert.assertTrue(result.isCompatible(new Product(Arrays.asList(INTEGER, DATE)))); - } - - @Test - public void selectStarShouldReturnEmptyProductAsQueryVisitingResult() { - Type result = visit("SELECT * FROM test"); - Assert.assertTrue(result instanceof Product); - Assert.assertTrue(result.isCompatible(new Product(emptyList()))); - } - - @Test - public void visitSelectNestedFunctionShouldThrowException() { - exceptionRule.expect(SqlFeatureNotImplementedException.class); - exceptionRule.expectMessage("Nested function calls like [abs(log(age))] are not supported yet"); - visit("SELECT abs(log(age)) FROM test"); - } - - @Test - public void visitWhereNestedFunctionShouldThrowException() { - exceptionRule.expect(SqlFeatureNotImplementedException.class); - exceptionRule.expectMessage("Nested function calls like [abs(log(age))] are not supported yet"); - visit("SELECT age FROM test WHERE abs(log(age)) = 1"); - } - - @Test - public void visitMathConstantAsNestedFunctionShouldPass() { - visit("SELECT abs(pi()) FROM test"); - } - - @Test - public void visitSupportedNestedFunctionShouldPass() { - visit("SELECT sum(nested(name.balance)) FROM test"); - } - - @Test - public void visitFunctionAsAggregatorShouldThrowException() { - exceptionRule.expect(SqlFeatureNotImplementedException.class); - exceptionRule.expectMessage("Aggregation calls with function aggregator like [max(abs(age))] are not supported yet"); - visit("SELECT max(abs(age)) FROM test"); - } - - @Test - public void visitUnsupportedOperatorShouldThrowException() { - exceptionRule.expect(SqlFeatureNotImplementedException.class); - exceptionRule.expectMessage("Operator [DIV] is not supported yet"); - visit("SELECT balance DIV age FROM test"); - } - - private ParseTree createParseTree(String sql) { - return new OpenSearchLegacySqlAnalyzer(new SqlAnalysisConfig(true, true, 1000)).analyzeSyntax(sql); - } - - private Type visit(String sql) { - ParseTree parseTree = createParseTree(sql); - return parseTree.accept(new AntlrSqlParseTreeVisitor<>(analyzer)); - } - + }; + + @Rule public ExpectedException exceptionRule = ExpectedException.none(); + + @Test + public void selectNumberShouldReturnNumberAsQueryVisitingResult() { + Type result = visit("SELECT age FROM test"); + Assert.assertSame(result, INTEGER); + } + + @Test + public void selectNumberAndDateShouldReturnProductOfThemAsQueryVisitingResult() { + Type result = visit("SELECT age, birthday FROM test"); + Assert.assertTrue(result instanceof Product); + Assert.assertTrue(result.isCompatible(new Product(Arrays.asList(INTEGER, DATE)))); + } + + @Test + public void selectStarShouldReturnEmptyProductAsQueryVisitingResult() { + Type result = visit("SELECT * FROM test"); + Assert.assertTrue(result instanceof Product); + Assert.assertTrue(result.isCompatible(new Product(emptyList()))); + } + + @Test + public void visitSelectNestedFunctionShouldThrowException() { + exceptionRule.expect(SqlFeatureNotImplementedException.class); + exceptionRule.expectMessage("Nested function calls like [abs(log(age))] are not supported yet"); + visit("SELECT abs(log(age)) FROM test"); + } + + @Test + public void visitWhereNestedFunctionShouldThrowException() { + exceptionRule.expect(SqlFeatureNotImplementedException.class); + exceptionRule.expectMessage("Nested function calls like [abs(log(age))] are not supported yet"); + visit("SELECT age FROM test WHERE abs(log(age)) = 1"); + } + + @Test + public void visitMathConstantAsNestedFunctionShouldPass() { + visit("SELECT abs(pi()) FROM test"); + } + + @Test + public void visitSupportedNestedFunctionShouldPass() { + visit("SELECT sum(nested(name.balance)) FROM test"); + } + + @Test + public void visitFunctionAsAggregatorShouldThrowException() { + exceptionRule.expect(SqlFeatureNotImplementedException.class); + exceptionRule.expectMessage( + "Aggregation calls with function aggregator like [max(abs(age))] are not supported yet"); + visit("SELECT max(abs(age)) FROM test"); + } + + @Test + public void visitUnsupportedOperatorShouldThrowException() { + exceptionRule.expect(SqlFeatureNotImplementedException.class); + exceptionRule.expectMessage("Operator [DIV] is not supported yet"); + visit("SELECT balance DIV age FROM test"); + } + + private ParseTree createParseTree(String sql) { + return new OpenSearchLegacySqlAnalyzer(new SqlAnalysisConfig(true, true, 1000)) + .analyzeSyntax(sql); + } + + private Type visit(String sql) { + ParseTree parseTree = createParseTree(sql); + return parseTree.accept(new AntlrSqlParseTreeVisitor<>(analyzer)); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/executor/AsyncRestExecutorTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/executor/AsyncRestExecutorTest.java index b26e171ce7..9be2517c4a 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/executor/AsyncRestExecutorTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/executor/AsyncRestExecutorTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor; import static java.util.Collections.emptyList; @@ -34,77 +33,69 @@ import org.opensearch.sql.opensearch.setting.OpenSearchSettings; import org.opensearch.threadpool.ThreadPool; -/** - * Test AsyncRestExecutor behavior. - */ +/** Test AsyncRestExecutor behavior. */ @RunWith(MockitoJUnitRunner.Silent.class) public class AsyncRestExecutorTest { - private static final boolean NON_BLOCKING = false; - - @Mock - private RestExecutor executor; + private static final boolean NON_BLOCKING = false; - @Mock - private Client client; + @Mock private RestExecutor executor; - private Map params = emptyMap(); + @Mock private Client client; - @Mock - private QueryAction action; + private Map params = emptyMap(); - @Mock - private RestChannel channel; + @Mock private QueryAction action; - @Mock - private ClusterSettings clusterSettings; + @Mock private RestChannel channel; - @Before - public void setUp() { - when(client.threadPool()).thenReturn(mock(ThreadPool.class)); - when(action.getSqlRequest()).thenReturn(SqlRequest.NULL); - when(clusterSettings.get(ClusterName.CLUSTER_NAME_SETTING)).thenReturn(ClusterName.DEFAULT); + @Mock private ClusterSettings clusterSettings; - OpenSearchSettings settings = spy(new OpenSearchSettings(clusterSettings)); - doReturn(emptyList()).when(settings).getSettings(); - LocalClusterState.state().setPluginSettings(settings); - } + @Before + public void setUp() { + when(client.threadPool()).thenReturn(mock(ThreadPool.class)); + when(action.getSqlRequest()).thenReturn(SqlRequest.NULL); + when(clusterSettings.get(ClusterName.CLUSTER_NAME_SETTING)).thenReturn(ClusterName.DEFAULT); - @Test - public void executeBlockingQuery() throws Exception { - Thread.currentThread().setName(TRANSPORT_WORKER_THREAD_NAME_PREFIX); - execute(); - verifyRunInWorkerThread(); - } + OpenSearchSettings settings = spy(new OpenSearchSettings(clusterSettings)); + doReturn(emptyList()).when(settings).getSettings(); + LocalClusterState.state().setPluginSettings(settings); + } - @Test - public void executeBlockingQueryButNotInTransport() throws Exception { - execute(); - verifyRunInCurrentThread(); - } + @Test + public void executeBlockingQuery() throws Exception { + Thread.currentThread().setName(TRANSPORT_WORKER_THREAD_NAME_PREFIX); + execute(); + verifyRunInWorkerThread(); + } - @Test - public void executeNonBlockingQuery() throws Exception { - execute(anyAction -> NON_BLOCKING); - verifyRunInCurrentThread(); - } + @Test + public void executeBlockingQueryButNotInTransport() throws Exception { + execute(); + verifyRunInCurrentThread(); + } - private void execute() throws Exception { - AsyncRestExecutor asyncExecutor = new AsyncRestExecutor(executor); - asyncExecutor.execute(client, params, action, channel); - } + @Test + public void executeNonBlockingQuery() throws Exception { + execute(anyAction -> NON_BLOCKING); + verifyRunInCurrentThread(); + } - private void execute(Predicate isBlocking) throws Exception { - AsyncRestExecutor asyncExecutor = new AsyncRestExecutor(executor, isBlocking); - asyncExecutor.execute(client, params, action, channel); - } + private void execute() throws Exception { + AsyncRestExecutor asyncExecutor = new AsyncRestExecutor(executor); + asyncExecutor.execute(client, params, action, channel); + } - private void verifyRunInCurrentThread() { - verify(client, never()).threadPool(); - } + private void execute(Predicate isBlocking) throws Exception { + AsyncRestExecutor asyncExecutor = new AsyncRestExecutor(executor, isBlocking); + asyncExecutor.execute(client, params, action, channel); + } - private void verifyRunInWorkerThread() { - verify(client, times(1)).threadPool(); - } + private void verifyRunInCurrentThread() { + verify(client, never()).threadPool(); + } + private void verifyRunInWorkerThread() { + verify(client, times(1)).threadPool(); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/executor/csv/CSVResultTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/executor/csv/CSVResultTest.java index 1a24045881..c877095d8f 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/executor/csv/CSVResultTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/executor/csv/CSVResultTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.csv; import static org.junit.Assert.assertEquals; @@ -13,25 +12,21 @@ import java.util.stream.Collectors; import org.junit.Test; -/** - * Unit tests for {@link CSVResult} - */ +/** Unit tests for {@link CSVResult} */ public class CSVResultTest { - private static final String SEPARATOR = ","; + private static final String SEPARATOR = ","; - @Test - public void getHeadersShouldReturnHeadersSanitized() { - CSVResult csv = csv(headers("name", "=age"), lines(line("John", "30"))); - assertEquals( - headers("name", "'=age"), - csv.getHeaders() - ); - } + @Test + public void getHeadersShouldReturnHeadersSanitized() { + CSVResult csv = csv(headers("name", "=age"), lines(line("John", "30"))); + assertEquals(headers("name", "'=age"), csv.getHeaders()); + } - @Test - public void getLinesShouldReturnLinesSanitized() { - CSVResult csv = csv( + @Test + public void getLinesShouldReturnLinesSanitized() { + CSVResult csv = + csv( headers("name", "city"), lines( line("John", "Seattle"), @@ -39,53 +34,42 @@ public void getLinesShouldReturnLinesSanitized() { line("John", "+Seattle"), line("-John", "Seattle"), line("@John", "Seattle"), - line("John", "Seattle=") - ) - ); - - assertEquals( - line( - "John,Seattle", - "John,'=Seattle", - "John,'+Seattle", - "'-John,Seattle", - "'@John,Seattle", - "John,Seattle=" - ), - csv.getLines() - ); - } - - @Test - public void getHeadersShouldReturnHeadersQuotedIfRequired() { - CSVResult csv = csv(headers("na,me", ",,age"), lines(line("John", "30"))); - assertEquals( - headers("\"na,me\"", "\",,age\""), - csv.getHeaders() - ); - } - - @Test - public void getLinesShouldReturnLinesQuotedIfRequired() { - CSVResult csv = csv(headers("name", "age"), lines(line("John,Smith", "30,,,"))); - assertEquals( - line("\"John,Smith\",\"30,,,\""), - csv.getLines() - ); - } - - @Test - public void getHeadersShouldReturnHeadersBothSanitizedAndQuotedIfRequired() { - CSVResult csv = csv(headers("na,+me", ",,,=age", "=city,"), lines(line("John", "30", "Seattle"))); - assertEquals( - headers("\"na,+me\"", "\",,,=age\"", "\"'=city,\""), - csv.getHeaders() - ); - } - - @Test - public void getLinesShouldReturnLinesBothSanitizedAndQuotedIfRequired() { - CSVResult csv = csv( + line("John", "Seattle="))); + + assertEquals( + line( + "John,Seattle", + "John,'=Seattle", + "John,'+Seattle", + "'-John,Seattle", + "'@John,Seattle", + "John,Seattle="), + csv.getLines()); + } + + @Test + public void getHeadersShouldReturnHeadersQuotedIfRequired() { + CSVResult csv = csv(headers("na,me", ",,age"), lines(line("John", "30"))); + assertEquals(headers("\"na,me\"", "\",,age\""), csv.getHeaders()); + } + + @Test + public void getLinesShouldReturnLinesQuotedIfRequired() { + CSVResult csv = csv(headers("name", "age"), lines(line("John,Smith", "30,,,"))); + assertEquals(line("\"John,Smith\",\"30,,,\""), csv.getLines()); + } + + @Test + public void getHeadersShouldReturnHeadersBothSanitizedAndQuotedIfRequired() { + CSVResult csv = + csv(headers("na,+me", ",,,=age", "=city,"), lines(line("John", "30", "Seattle"))); + assertEquals(headers("\"na,+me\"", "\",,,=age\"", "\"'=city,\""), csv.getHeaders()); + } + + @Test + public void getLinesShouldReturnLinesBothSanitizedAndQuotedIfRequired() { + CSVResult csv = + csv( headers("name", "city"), lines( line("John", "Seattle"), @@ -93,38 +77,33 @@ public void getLinesShouldReturnLinesBothSanitizedAndQuotedIfRequired() { line("John", "+Sea,ttle"), line(",-John", "Seattle"), line(",,,@John", "Seattle"), - line("John", "Seattle=") - ) - ); - - assertEquals( - line( - "John,Seattle", - "John,'=Seattle", - "John,\"'+Sea,ttle\"", - "\",-John\",Seattle", - "\",,,@John\",Seattle", - "John,Seattle=" - ), - csv.getLines() - ); - } - - private CSVResult csv(List headers, List> lines) { - return new CSVResult(SEPARATOR, headers, lines); - } - - private List headers(String... headers) { - return Arrays.stream(headers).collect(Collectors.toList()); - } - - private List line(String... line) { - return Arrays.stream(line).collect(Collectors.toList()); - } - - @SafeVarargs - private final List> lines(List... lines) { - return Arrays.stream(lines).collect(Collectors.toList()); - } - + line("John", "Seattle="))); + + assertEquals( + line( + "John,Seattle", + "John,'=Seattle", + "John,\"'+Sea,ttle\"", + "\",-John\",Seattle", + "\",,,@John\",Seattle", + "John,Seattle="), + csv.getLines()); + } + + private CSVResult csv(List headers, List> lines) { + return new CSVResult(SEPARATOR, headers, lines); + } + + private List headers(String... headers) { + return Arrays.stream(headers).collect(Collectors.toList()); + } + + private List line(String... line) { + return Arrays.stream(line).collect(Collectors.toList()); + } + + @SafeVarargs + private final List> lines(List... lines) { + return Arrays.stream(lines).collect(Collectors.toList()); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/executor/format/DateFieldFormatterTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/executor/format/DateFieldFormatterTest.java index 5807ee2c44..1c2d1bae62 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/executor/format/DateFieldFormatterTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/executor/format/DateFieldFormatterTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.executor.format; import static org.junit.Assert.assertEquals; @@ -18,685 +17,634 @@ public class DateFieldFormatterTest { - @Test - public void testOpenSearchDashboardsSampleDataEcommerceOrderDateField() - { - String columnName = "order_date"; - String dateFormat = "date_optional_time"; - String originalDateValue = "2020-02-24T09:28:48+00:00"; - String expectedDateValue = "2020-02-24 09:28:48.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testOpenSearchDashboardsSampleDataFlightsTimestampField() - { - String columnName = "timestamp"; - String dateFormat = "date_optional_time"; - String originalDateValue = "2020-02-03T00:00:00"; - String expectedDateValue = "2020-02-03 00:00:00.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testOpenSearchDashboardsSampleDataFlightsTimestampFieldNoTime() - { - String columnName = "timestamp"; - String dateFormat = "date_optional_time"; - String originalDateValue = "2020-02-03T"; - String expectedDateValue = "2020-02-03 00:00:00.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testOpenSearchDashboardsSampleDataLogsUtcDateField() - { - String columnName = "utc_date"; - String dateFormat = "date_optional_time"; - String originalDateValue = "2020-02-02T00:39:02.912Z"; - String expectedDateValue = "2020-02-02 00:39:02.912"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testEpochMillis() - { - String columnName = "date_field"; - String dateFormat = "epoch_millis"; - String originalDateValue = "727430805000"; - String expectedDateValue = "1993-01-19 08:06:45.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testEpochSecond() - { - String columnName = "date_field"; - String dateFormat = "epoch_second"; - String originalDateValue = "727430805"; - String expectedDateValue = "1993-01-19 08:06:45.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testDateOptionalTimeDateOnly() - { - String columnName = "date_field"; - String dateFormat = "date_optional_time"; - String originalDateValue = "1993-01-19"; - String expectedDateValue = "1993-01-19 00:00:00.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testDateOptionalTimeDateAndTime() - { - String columnName = "date_field"; - String dateFormat = "date_optional_time"; - String originalDateValue = "1993-01-19T00:06:45.123-0800"; - String expectedDateValue = "1993-01-19 08:06:45.123"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testBasicDate() - { - String columnName = "date_field"; - String dateFormat = "basic_date"; - String originalDateValue = "19930119"; - String expectedDateValue = "1993-01-19 00:00:00.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testBasicDateTime() - { - String columnName = "date_field"; - String dateFormat = "basic_date_time"; - String originalDateValue = "19930119T120645.123-0800"; - String expectedDateValue = "1993-01-19 20:06:45.123"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testBasicDateTimeNoMillis() - { - String columnName = "date_field"; - String dateFormat = "basic_date_time_no_millis"; - String originalDateValue = "19930119T120645-0800"; - String expectedDateValue = "1993-01-19 20:06:45.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testBasicOrdinalDate() - { - String columnName = "date_field"; - String dateFormat = "basic_ordinal_date"; - String originalDateValue = "1993019"; - String expectedDateValue = "1993-01-19 00:00:00.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testBasicOrdinalDateTime() - { - String columnName = "date_field"; - String dateFormat = "basic_ordinal_date_time"; - String originalDateValue = "1993019T120645.123-0800"; - String expectedDateValue = "1993-01-19 20:06:45.123"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testBasicOrdinalDateTimeNoMillis() - { - String columnName = "date_field"; - String dateFormat = "basic_ordinal_date_time_no_millis"; - String originalDateValue = "1993019T120645-0800"; - String expectedDateValue = "1993-01-19 20:06:45.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testBasicTime() - { - String columnName = "date_field"; - String dateFormat = "basic_time"; - String originalDateValue = "120645.123-0800"; - String expectedDateValue = "1970-01-01 20:06:45.123"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testBasicTimeNoMillis() - { - String columnName = "date_field"; - String dateFormat = "basic_time_no_millis"; - String originalDateValue = "120645-0800"; - String expectedDateValue = "1970-01-01 20:06:45.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testBasicTTime() - { - String columnName = "date_field"; - String dateFormat = "basic_t_time"; - String originalDateValue = "T120645.123-0800"; - String expectedDateValue = "1970-01-01 20:06:45.123"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testBasicTTimeNoMillis() - { - String columnName = "date_field"; - String dateFormat = "basic_t_time_no_millis"; - String originalDateValue = "T120645-0800"; - String expectedDateValue = "1970-01-01 20:06:45.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testBasicWeekDate() - { - String columnName = "date_field"; - String dateFormat = "basic_week_date"; - String originalDateValue = "1993W042"; - String expectedDateValue = "1993-01-19 00:00:00.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testBasicWeekDateTime() - { - String columnName = "date_field"; - String dateFormat = "basic_week_date_time"; - String originalDateValue = "1993W042T120645.123-0800"; - String expectedDateValue = "1993-01-19 20:06:45.123"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testBasicWeekDateTimeNoMillis() - { - String columnName = "date_field"; - String dateFormat = "basic_week_date_time_no_millis"; - String originalDateValue = "1993W042T120645-0800"; - String expectedDateValue = "1993-01-19 20:06:45.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testDate() - { - String columnName = "date_field"; - String dateFormat = "date"; - String originalDateValue = "1993-01-19"; - String expectedDateValue = "1993-01-19 00:00:00.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testDateHour() - { - String columnName = "date_field"; - String dateFormat = "date_hour"; - String originalDateValue = "1993-01-19T12"; - String expectedDateValue = "1993-01-19 12:00:00.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testDateHourMinute() - { - String columnName = "date_field"; - String dateFormat = "date_hour_minute"; - String originalDateValue = "1993-01-19T12:06"; - String expectedDateValue = "1993-01-19 12:06:00.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testDateHourMinuteSecond() - { - String columnName = "date_field"; - String dateFormat = "date_hour_minute_second"; - String originalDateValue = "1993-01-19T12:06:45"; - String expectedDateValue = "1993-01-19 12:06:45.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testDateHourMinuteSecondFraction() - { - String columnName = "date_field"; - String dateFormat = "date_hour_minute_second_fraction"; - String originalDateValue = "1993-01-19T12:06:45.123"; - String expectedDateValue = "1993-01-19 12:06:45.123"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testDateHourMinuteSecondMillis() - { - String columnName = "date_field"; - String dateFormat = "date_hour_minute_second_millis"; - String originalDateValue = "1993-01-19T12:06:45.123"; - String expectedDateValue = "1993-01-19 12:06:45.123"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testDateTime() - { - String columnName = "date_field"; - String dateFormat = "date_time"; - String originalDateValue = "1993-01-19T12:06:45.123-0800"; - String expectedDateValue = "1993-01-19 20:06:45.123"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testDateTimeNoMillis() - { - String columnName = "date_field"; - String dateFormat = "date_time_no_millis"; - String originalDateValue = "1993-01-19T12:06:45-0800"; - String expectedDateValue = "1993-01-19 20:06:45.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testHour() - { - String columnName = "date_field"; - String dateFormat = "hour"; - String originalDateValue = "12"; - String expectedDateValue = "1970-01-01 12:00:00.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testHourMinute() - { - String columnName = "date_field"; - String dateFormat = "hour_minute"; - String originalDateValue = "12:06"; - String expectedDateValue = "1970-01-01 12:06:00.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testHourMinuteSecond() - { - String columnName = "date_field"; - String dateFormat = "hour_minute_second"; - String originalDateValue = "12:06:45"; - String expectedDateValue = "1970-01-01 12:06:45.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testHourMinuteSecondFraction() - { - String columnName = "date_field"; - String dateFormat = "hour_minute_second_fraction"; - String originalDateValue = "12:06:45.123"; - String expectedDateValue = "1970-01-01 12:06:45.123"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testHourMinuteSecondMillis() - { - String columnName = "date_field"; - String dateFormat = "hour_minute_second_millis"; - String originalDateValue = "12:06:45.123"; - String expectedDateValue = "1970-01-01 12:06:45.123"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testOrdinalDate() - { - String columnName = "date_field"; - String dateFormat = "ordinal_date"; - String originalDateValue = "1993-019"; - String expectedDateValue = "1993-01-19 00:00:00.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testOrdinalDateTime() - { - String columnName = "date_field"; - String dateFormat = "ordinal_date_time"; - String originalDateValue = "1993-019T12:06:45.123-0800"; - String expectedDateValue = "1993-01-19 20:06:45.123"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testOrdinalDateTimeNoMillis() - { - String columnName = "date_field"; - String dateFormat = "ordinal_date_time_no_millis"; - String originalDateValue = "1993-019T12:06:45-0800"; - String expectedDateValue = "1993-01-19 20:06:45.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testTime() - { - String columnName = "date_field"; - String dateFormat = "time"; - String originalDateValue = "12:06:45.123-0800"; - String expectedDateValue = "1970-01-01 20:06:45.123"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testTimeNoMillis() - { - String columnName = "date_field"; - String dateFormat = "time_no_millis"; - String originalDateValue = "12:06:45-0800"; - String expectedDateValue = "1970-01-01 20:06:45.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testTTime() - { - String columnName = "date_field"; - String dateFormat = "t_time"; - String originalDateValue = "T12:06:45.123-0800"; - String expectedDateValue = "1970-01-01 20:06:45.123"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testTTimeNoMillis() - { - String columnName = "date_field"; - String dateFormat = "t_time_no_millis"; - String originalDateValue = "T12:06:45-0800"; - String expectedDateValue = "1970-01-01 20:06:45.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testWeekDate() - { - String columnName = "date_field"; - String dateFormat = "week_date"; - String originalDateValue = "1993-W04-2"; - String expectedDateValue = "1993-01-19 00:00:00.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testWeekDateTime() - { - String columnName = "date_field"; - String dateFormat = "week_date_time"; - String originalDateValue = "1993-W04-2T12:06:45.123-0800"; - String expectedDateValue = "1993-01-19 20:06:45.123"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testWeekDateTimeNoMillis() - { - String columnName = "date_field"; - String dateFormat = "week_date_time_no_millis"; - String originalDateValue = "1993-W04-2T12:06:45-0800"; - String expectedDateValue = "1993-01-19 20:06:45.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testWeekyear() - { - String columnName = "date_field"; - String dateFormat = "week_year"; - String originalDateValue = "1993"; - String expectedDateValue = "1993-01-01 00:00:00.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testWeekyearWeek() - { - String columnName = "date_field"; - String dateFormat = "weekyear_week"; - String originalDateValue = "1993-W04"; - String expectedDateValue = "1993-01-17 00:00:00.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testWeekyearWeekDay() - { - String columnName = "date_field"; - String dateFormat = "weekyear_week_day"; - String originalDateValue = "1993-W04-2"; - String expectedDateValue = "1993-01-19 00:00:00.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testYear() - { - String columnName = "date_field"; - String dateFormat = "year"; - String originalDateValue = "1993"; - String expectedDateValue = "1993-01-01 00:00:00.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testYearMonth() - { - String columnName = "date_field"; - String dateFormat = "year_month"; - String originalDateValue = "1993-01"; - String expectedDateValue = "1993-01-01 00:00:00.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testYearMonthDay() - { - String columnName = "date_field"; - String dateFormat = "year_month_day"; - String originalDateValue = "1993-01-19"; - String expectedDateValue = "1993-01-19 00:00:00.000"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testCustomFormat() - { - String columnName = "date_field"; - String dateFormat = "EEE, MMM d, ''yy"; - - String originalDateValue = "Tue, Jan 19, '93"; - String expectedDateValue = "1993-01-19 00:00:00.000"; - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testMultipleFormats() - { - String columnName = "date_field"; - String dateFormat = "date_optional_time||epoch_millis"; - - String originalDateValue = "1993-01-19"; - String expectedDateValue = "1993-01-19 00:00:00.000"; - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - - originalDateValue = "727401600000"; - expectedDateValue = "1993-01-19 00:00:00.000"; - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testMultipleCustomFormats() - { - String columnName = "date_field"; - String dateFormat = "EEE, MMM d, ''yy||yyMMddHHmmssZ"; - - String originalDateValue = "Tue, Jan 19, '93"; - String expectedDateValue = "1993-01-19 00:00:00.000"; - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - - originalDateValue = "930119000000-0000"; - expectedDateValue = "1993-01-19 00:00:00.000"; - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testNamedAndCustomFormats() - { - String columnName = "date_field"; - String dateFormat = "EEE, MMM d, ''yy||hour_minute_second"; - - String originalDateValue = "Tue, Jan 19, '93"; - String expectedDateValue = "1993-01-19 00:00:00.000"; - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - - originalDateValue = "12:06:45"; - expectedDateValue = "1970-01-01 12:06:45.000"; - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testIncorrectFormat() - { - String columnName = "date_field"; - String dateFormat = "date_optional_time"; - String originalDateValue = "1581724085"; - // Invalid format for date value; should return original value - String expectedDateValue = "1581724085"; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testNullDateData() - { - String columnName = "date_field"; - String dateFormat = "date_optional_time"; - String originalDateValue = null; - // Nulls should be preserved - String expectedDateValue = null; - - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - @Test - public void testStrictDateOptionalTimeOrEpochMillsShouldPass() - { - String columnName = "date_field"; - String dateFormat = "strict_date_optional_time||epoch_millis"; - - String originalDateValue = "2015-01-01"; - String expectedDateValue = "2015-01-01 00:00:00.000"; - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - - originalDateValue = "2015-01-01T12:10:30Z"; - expectedDateValue = "2015-01-01 12:10:30.000"; - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - - originalDateValue = "1420070400001"; - expectedDateValue = "2015-01-01 00:00:00.001"; - verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); - } - - private void verifyFormatting(String columnName, String dateFormatProperty, String originalDateValue, String expectedDateValue) - { - List columns = buildColumnList(columnName); - Map> dateFieldFormatMap = buildDateFieldFormatMap(columnName, dateFormatProperty); - - Map rowSource = new HashMap<>(); - rowSource.put(columnName, originalDateValue); - - DateFieldFormatter dateFieldFormatter = new DateFieldFormatter(dateFieldFormatMap, columns, new HashMap<>()); - executeFormattingAndCompare(dateFieldFormatter, rowSource, columnName, expectedDateValue); - } - - private void executeFormattingAndCompare( - DateFieldFormatter formatter, - Map rowSource, - String columnToCheck, - String expectedDateValue) { - formatter.applyJDBCDateFormat(rowSource); - assertEquals(expectedDateValue, rowSource.get(columnToCheck)); - } - - private List buildColumnList(String columnName) { - return ImmutableList.builder() - .add(new Schema.Column(columnName, null, Schema.Type.DATE)) - .build(); - } - - private Map> buildDateFieldFormatMap(String columnName, String dateFormatProperty) { - return ImmutableMap.>builder() - .put(columnName, Arrays.asList(dateFormatProperty.split("\\|\\|"))) - .build(); - } + @Test + public void testOpenSearchDashboardsSampleDataEcommerceOrderDateField() { + String columnName = "order_date"; + String dateFormat = "date_optional_time"; + String originalDateValue = "2020-02-24T09:28:48+00:00"; + String expectedDateValue = "2020-02-24 09:28:48.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testOpenSearchDashboardsSampleDataFlightsTimestampField() { + String columnName = "timestamp"; + String dateFormat = "date_optional_time"; + String originalDateValue = "2020-02-03T00:00:00"; + String expectedDateValue = "2020-02-03 00:00:00.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testOpenSearchDashboardsSampleDataFlightsTimestampFieldNoTime() { + String columnName = "timestamp"; + String dateFormat = "date_optional_time"; + String originalDateValue = "2020-02-03T"; + String expectedDateValue = "2020-02-03 00:00:00.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testOpenSearchDashboardsSampleDataLogsUtcDateField() { + String columnName = "utc_date"; + String dateFormat = "date_optional_time"; + String originalDateValue = "2020-02-02T00:39:02.912Z"; + String expectedDateValue = "2020-02-02 00:39:02.912"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testEpochMillis() { + String columnName = "date_field"; + String dateFormat = "epoch_millis"; + String originalDateValue = "727430805000"; + String expectedDateValue = "1993-01-19 08:06:45.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testEpochSecond() { + String columnName = "date_field"; + String dateFormat = "epoch_second"; + String originalDateValue = "727430805"; + String expectedDateValue = "1993-01-19 08:06:45.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testDateOptionalTimeDateOnly() { + String columnName = "date_field"; + String dateFormat = "date_optional_time"; + String originalDateValue = "1993-01-19"; + String expectedDateValue = "1993-01-19 00:00:00.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testDateOptionalTimeDateAndTime() { + String columnName = "date_field"; + String dateFormat = "date_optional_time"; + String originalDateValue = "1993-01-19T00:06:45.123-0800"; + String expectedDateValue = "1993-01-19 08:06:45.123"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testBasicDate() { + String columnName = "date_field"; + String dateFormat = "basic_date"; + String originalDateValue = "19930119"; + String expectedDateValue = "1993-01-19 00:00:00.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testBasicDateTime() { + String columnName = "date_field"; + String dateFormat = "basic_date_time"; + String originalDateValue = "19930119T120645.123-0800"; + String expectedDateValue = "1993-01-19 20:06:45.123"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testBasicDateTimeNoMillis() { + String columnName = "date_field"; + String dateFormat = "basic_date_time_no_millis"; + String originalDateValue = "19930119T120645-0800"; + String expectedDateValue = "1993-01-19 20:06:45.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testBasicOrdinalDate() { + String columnName = "date_field"; + String dateFormat = "basic_ordinal_date"; + String originalDateValue = "1993019"; + String expectedDateValue = "1993-01-19 00:00:00.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testBasicOrdinalDateTime() { + String columnName = "date_field"; + String dateFormat = "basic_ordinal_date_time"; + String originalDateValue = "1993019T120645.123-0800"; + String expectedDateValue = "1993-01-19 20:06:45.123"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testBasicOrdinalDateTimeNoMillis() { + String columnName = "date_field"; + String dateFormat = "basic_ordinal_date_time_no_millis"; + String originalDateValue = "1993019T120645-0800"; + String expectedDateValue = "1993-01-19 20:06:45.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testBasicTime() { + String columnName = "date_field"; + String dateFormat = "basic_time"; + String originalDateValue = "120645.123-0800"; + String expectedDateValue = "1970-01-01 20:06:45.123"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testBasicTimeNoMillis() { + String columnName = "date_field"; + String dateFormat = "basic_time_no_millis"; + String originalDateValue = "120645-0800"; + String expectedDateValue = "1970-01-01 20:06:45.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testBasicTTime() { + String columnName = "date_field"; + String dateFormat = "basic_t_time"; + String originalDateValue = "T120645.123-0800"; + String expectedDateValue = "1970-01-01 20:06:45.123"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testBasicTTimeNoMillis() { + String columnName = "date_field"; + String dateFormat = "basic_t_time_no_millis"; + String originalDateValue = "T120645-0800"; + String expectedDateValue = "1970-01-01 20:06:45.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testBasicWeekDate() { + String columnName = "date_field"; + String dateFormat = "basic_week_date"; + String originalDateValue = "1993W042"; + String expectedDateValue = "1993-01-19 00:00:00.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testBasicWeekDateTime() { + String columnName = "date_field"; + String dateFormat = "basic_week_date_time"; + String originalDateValue = "1993W042T120645.123-0800"; + String expectedDateValue = "1993-01-19 20:06:45.123"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testBasicWeekDateTimeNoMillis() { + String columnName = "date_field"; + String dateFormat = "basic_week_date_time_no_millis"; + String originalDateValue = "1993W042T120645-0800"; + String expectedDateValue = "1993-01-19 20:06:45.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testDate() { + String columnName = "date_field"; + String dateFormat = "date"; + String originalDateValue = "1993-01-19"; + String expectedDateValue = "1993-01-19 00:00:00.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testDateHour() { + String columnName = "date_field"; + String dateFormat = "date_hour"; + String originalDateValue = "1993-01-19T12"; + String expectedDateValue = "1993-01-19 12:00:00.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testDateHourMinute() { + String columnName = "date_field"; + String dateFormat = "date_hour_minute"; + String originalDateValue = "1993-01-19T12:06"; + String expectedDateValue = "1993-01-19 12:06:00.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testDateHourMinuteSecond() { + String columnName = "date_field"; + String dateFormat = "date_hour_minute_second"; + String originalDateValue = "1993-01-19T12:06:45"; + String expectedDateValue = "1993-01-19 12:06:45.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testDateHourMinuteSecondFraction() { + String columnName = "date_field"; + String dateFormat = "date_hour_minute_second_fraction"; + String originalDateValue = "1993-01-19T12:06:45.123"; + String expectedDateValue = "1993-01-19 12:06:45.123"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testDateHourMinuteSecondMillis() { + String columnName = "date_field"; + String dateFormat = "date_hour_minute_second_millis"; + String originalDateValue = "1993-01-19T12:06:45.123"; + String expectedDateValue = "1993-01-19 12:06:45.123"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testDateTime() { + String columnName = "date_field"; + String dateFormat = "date_time"; + String originalDateValue = "1993-01-19T12:06:45.123-0800"; + String expectedDateValue = "1993-01-19 20:06:45.123"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testDateTimeNoMillis() { + String columnName = "date_field"; + String dateFormat = "date_time_no_millis"; + String originalDateValue = "1993-01-19T12:06:45-0800"; + String expectedDateValue = "1993-01-19 20:06:45.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testHour() { + String columnName = "date_field"; + String dateFormat = "hour"; + String originalDateValue = "12"; + String expectedDateValue = "1970-01-01 12:00:00.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testHourMinute() { + String columnName = "date_field"; + String dateFormat = "hour_minute"; + String originalDateValue = "12:06"; + String expectedDateValue = "1970-01-01 12:06:00.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testHourMinuteSecond() { + String columnName = "date_field"; + String dateFormat = "hour_minute_second"; + String originalDateValue = "12:06:45"; + String expectedDateValue = "1970-01-01 12:06:45.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testHourMinuteSecondFraction() { + String columnName = "date_field"; + String dateFormat = "hour_minute_second_fraction"; + String originalDateValue = "12:06:45.123"; + String expectedDateValue = "1970-01-01 12:06:45.123"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testHourMinuteSecondMillis() { + String columnName = "date_field"; + String dateFormat = "hour_minute_second_millis"; + String originalDateValue = "12:06:45.123"; + String expectedDateValue = "1970-01-01 12:06:45.123"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testOrdinalDate() { + String columnName = "date_field"; + String dateFormat = "ordinal_date"; + String originalDateValue = "1993-019"; + String expectedDateValue = "1993-01-19 00:00:00.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testOrdinalDateTime() { + String columnName = "date_field"; + String dateFormat = "ordinal_date_time"; + String originalDateValue = "1993-019T12:06:45.123-0800"; + String expectedDateValue = "1993-01-19 20:06:45.123"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testOrdinalDateTimeNoMillis() { + String columnName = "date_field"; + String dateFormat = "ordinal_date_time_no_millis"; + String originalDateValue = "1993-019T12:06:45-0800"; + String expectedDateValue = "1993-01-19 20:06:45.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testTime() { + String columnName = "date_field"; + String dateFormat = "time"; + String originalDateValue = "12:06:45.123-0800"; + String expectedDateValue = "1970-01-01 20:06:45.123"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testTimeNoMillis() { + String columnName = "date_field"; + String dateFormat = "time_no_millis"; + String originalDateValue = "12:06:45-0800"; + String expectedDateValue = "1970-01-01 20:06:45.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testTTime() { + String columnName = "date_field"; + String dateFormat = "t_time"; + String originalDateValue = "T12:06:45.123-0800"; + String expectedDateValue = "1970-01-01 20:06:45.123"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testTTimeNoMillis() { + String columnName = "date_field"; + String dateFormat = "t_time_no_millis"; + String originalDateValue = "T12:06:45-0800"; + String expectedDateValue = "1970-01-01 20:06:45.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testWeekDate() { + String columnName = "date_field"; + String dateFormat = "week_date"; + String originalDateValue = "1993-W04-2"; + String expectedDateValue = "1993-01-19 00:00:00.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testWeekDateTime() { + String columnName = "date_field"; + String dateFormat = "week_date_time"; + String originalDateValue = "1993-W04-2T12:06:45.123-0800"; + String expectedDateValue = "1993-01-19 20:06:45.123"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testWeekDateTimeNoMillis() { + String columnName = "date_field"; + String dateFormat = "week_date_time_no_millis"; + String originalDateValue = "1993-W04-2T12:06:45-0800"; + String expectedDateValue = "1993-01-19 20:06:45.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testWeekyear() { + String columnName = "date_field"; + String dateFormat = "week_year"; + String originalDateValue = "1993"; + String expectedDateValue = "1993-01-01 00:00:00.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testWeekyearWeek() { + String columnName = "date_field"; + String dateFormat = "weekyear_week"; + String originalDateValue = "1993-W04"; + String expectedDateValue = "1993-01-17 00:00:00.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testWeekyearWeekDay() { + String columnName = "date_field"; + String dateFormat = "weekyear_week_day"; + String originalDateValue = "1993-W04-2"; + String expectedDateValue = "1993-01-19 00:00:00.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testYear() { + String columnName = "date_field"; + String dateFormat = "year"; + String originalDateValue = "1993"; + String expectedDateValue = "1993-01-01 00:00:00.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testYearMonth() { + String columnName = "date_field"; + String dateFormat = "year_month"; + String originalDateValue = "1993-01"; + String expectedDateValue = "1993-01-01 00:00:00.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testYearMonthDay() { + String columnName = "date_field"; + String dateFormat = "year_month_day"; + String originalDateValue = "1993-01-19"; + String expectedDateValue = "1993-01-19 00:00:00.000"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testCustomFormat() { + String columnName = "date_field"; + String dateFormat = "EEE, MMM d, ''yy"; + + String originalDateValue = "Tue, Jan 19, '93"; + String expectedDateValue = "1993-01-19 00:00:00.000"; + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testMultipleFormats() { + String columnName = "date_field"; + String dateFormat = "date_optional_time||epoch_millis"; + + String originalDateValue = "1993-01-19"; + String expectedDateValue = "1993-01-19 00:00:00.000"; + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + + originalDateValue = "727401600000"; + expectedDateValue = "1993-01-19 00:00:00.000"; + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testMultipleCustomFormats() { + String columnName = "date_field"; + String dateFormat = "EEE, MMM d, ''yy||yyMMddHHmmssZ"; + + String originalDateValue = "Tue, Jan 19, '93"; + String expectedDateValue = "1993-01-19 00:00:00.000"; + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + + originalDateValue = "930119000000-0000"; + expectedDateValue = "1993-01-19 00:00:00.000"; + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testNamedAndCustomFormats() { + String columnName = "date_field"; + String dateFormat = "EEE, MMM d, ''yy||hour_minute_second"; + + String originalDateValue = "Tue, Jan 19, '93"; + String expectedDateValue = "1993-01-19 00:00:00.000"; + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + + originalDateValue = "12:06:45"; + expectedDateValue = "1970-01-01 12:06:45.000"; + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testIncorrectFormat() { + String columnName = "date_field"; + String dateFormat = "date_optional_time"; + String originalDateValue = "1581724085"; + // Invalid format for date value; should return original value + String expectedDateValue = "1581724085"; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testNullDateData() { + String columnName = "date_field"; + String dateFormat = "date_optional_time"; + String originalDateValue = null; + // Nulls should be preserved + String expectedDateValue = null; + + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + @Test + public void testStrictDateOptionalTimeOrEpochMillsShouldPass() { + String columnName = "date_field"; + String dateFormat = "strict_date_optional_time||epoch_millis"; + + String originalDateValue = "2015-01-01"; + String expectedDateValue = "2015-01-01 00:00:00.000"; + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + + originalDateValue = "2015-01-01T12:10:30Z"; + expectedDateValue = "2015-01-01 12:10:30.000"; + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + + originalDateValue = "1420070400001"; + expectedDateValue = "2015-01-01 00:00:00.001"; + verifyFormatting(columnName, dateFormat, originalDateValue, expectedDateValue); + } + + private void verifyFormatting( + String columnName, + String dateFormatProperty, + String originalDateValue, + String expectedDateValue) { + List columns = buildColumnList(columnName); + Map> dateFieldFormatMap = + buildDateFieldFormatMap(columnName, dateFormatProperty); + + Map rowSource = new HashMap<>(); + rowSource.put(columnName, originalDateValue); + + DateFieldFormatter dateFieldFormatter = + new DateFieldFormatter(dateFieldFormatMap, columns, new HashMap<>()); + executeFormattingAndCompare(dateFieldFormatter, rowSource, columnName, expectedDateValue); + } + + private void executeFormattingAndCompare( + DateFieldFormatter formatter, + Map rowSource, + String columnToCheck, + String expectedDateValue) { + formatter.applyJDBCDateFormat(rowSource); + assertEquals(expectedDateValue, rowSource.get(columnToCheck)); + } + + private List buildColumnList(String columnName) { + return ImmutableList.builder() + .add(new Schema.Column(columnName, null, Schema.Type.DATE)) + .build(); + } + + private Map> buildDateFieldFormatMap( + String columnName, String dateFormatProperty) { + return ImmutableMap.>builder() + .put(columnName, Arrays.asList(dateFormatProperty.split("\\|\\|"))) + .build(); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/AggregationOptionTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/AggregationOptionTest.java index e5f44eacf0..526642e8ea 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/AggregationOptionTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/AggregationOptionTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest; import com.alibaba.druid.sql.ast.expr.SQLAggregateOption; @@ -17,55 +16,53 @@ import org.opensearch.sql.legacy.parser.SqlParser; import org.opensearch.sql.legacy.util.SqlParserUtils; -/** - * Unit test class for feature of aggregation options: DISTINCT, ALL, UNIQUE, DEDUPLICATION - */ +/** Unit test class for feature of aggregation options: DISTINCT, ALL, UNIQUE, DEDUPLICATION */ public class AggregationOptionTest { - @Test - public void selectDistinctFieldsShouldHaveAggregationOption() { - List fields = getSelectFields("SELECT DISTINCT gender, city FROM accounts"); - for (Field field: fields) { - Assert.assertEquals(field.getOption(), SQLAggregateOption.DISTINCT); - } + @Test + public void selectDistinctFieldsShouldHaveAggregationOption() { + List fields = getSelectFields("SELECT DISTINCT gender, city FROM accounts"); + for (Field field : fields) { + Assert.assertEquals(field.getOption(), SQLAggregateOption.DISTINCT); } + } - @Test - public void selectWithoutDistinctFieldsShouldNotHaveAggregationOption() { - List fields = getSelectFields("SELECT gender, city FROM accounts"); - for (Field field: fields) { - Assert.assertNull(field.getOption()); - } + @Test + public void selectWithoutDistinctFieldsShouldNotHaveAggregationOption() { + List fields = getSelectFields("SELECT gender, city FROM accounts"); + for (Field field : fields) { + Assert.assertNull(field.getOption()); } + } - @Test - public void selectDistinctWithoutGroupByShouldHaveGroupByItems() { - List> groupBys = getGroupBys("SELECT DISTINCT gender, city FROM accounts"); - Assert.assertFalse(groupBys.isEmpty()); - } + @Test + public void selectDistinctWithoutGroupByShouldHaveGroupByItems() { + List> groupBys = getGroupBys("SELECT DISTINCT gender, city FROM accounts"); + Assert.assertFalse(groupBys.isEmpty()); + } - @Test - public void selectWithoutDistinctWithoutGroupByShouldNotHaveGroupByItems() { - List> groupBys = getGroupBys("SELECT gender, city FROM accounts"); - Assert.assertTrue(groupBys.isEmpty()); - } + @Test + public void selectWithoutDistinctWithoutGroupByShouldNotHaveGroupByItems() { + List> groupBys = getGroupBys("SELECT gender, city FROM accounts"); + Assert.assertTrue(groupBys.isEmpty()); + } - private List> getGroupBys(String query) { - return getSelect(query).getGroupBys(); - } + private List> getGroupBys(String query) { + return getSelect(query).getGroupBys(); + } - private List getSelectFields(String query) { - return getSelect(query).getFields(); - } + private List getSelectFields(String query) { + return getSelect(query).getFields(); + } - private Select getSelect(String query) { - SQLQueryExpr queryExpr = SqlParserUtils.parse(query); - Select select = null; - try { - select = new SqlParser().parseSelect(queryExpr); - } catch (SqlParseException e) { - e.printStackTrace(); - } - return select; + private Select getSelect(String query) { + SQLQueryExpr queryExpr = SqlParserUtils.parse(query); + Select select = null; + try { + select = new SqlParser().parseSelect(queryExpr); + } catch (SqlParseException e) { + e.printStackTrace(); } + return select; + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/DateFormatTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/DateFormatTest.java index 89ac8b4563..3bb7b4a2b6 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/DateFormatTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/DateFormatTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest; import static org.hamcrest.MatcherAssert.assertThat; @@ -40,196 +39,238 @@ public class DateFormatTest { - private static final String SELECT_CNT_FROM_DATE = "SELECT COUNT(*) AS c FROM dates "; - - @Test - public void simpleFormatCondition() { - List q = query(SELECT_CNT_FROM_DATE + "WHERE date_format(creationDate, 'YYYY') < '2018'"); - - assertThat(q, hasQueryWithValue("fieldName", equalTo("creationDate"))); - assertThat(q, hasQueryWithValueGetter(MatcherUtils.featureValueOf("has format", equalTo("YYYY"), f->((RangeQueryBuilder)f).format()))); - } - - @Test - public void equalCondition() { - List q = query(SELECT_CNT_FROM_DATE + "WHERE date_format(creationDate, 'YYYY-MM-dd') = '2018-04-02'"); - - assertThat(q, hasQueryWithValueGetter(MatcherUtils.featureValueOf("has format", equalTo("YYYY-MM-dd"), f->((RangeQueryBuilder)f).format()))); - - // Equality query for date_format is created with a rangeQuery where the 'from' and 'to' values are equal to the value we are equating to - assertThat(q, hasQueryWithValue("from", equalTo(BytesRefs.toBytesRef("2018-04-02")))); // converting string to bytes ref as RangeQueryBuilder stores it this way - assertThat(q, hasQueryWithValue("to", equalTo(BytesRefs.toBytesRef("2018-04-02")))); - } - - @Test - public void orderByTest() { - String query = "SELECT agent, ip, date_format(utc_time, 'dd-MM-YYYY') date " + - "FROM opensearch_dashboards_sample_data_logs " + - "ORDER BY date_format(utc_time, 'dd-MM-YYYY') desc, ip"; - - Select select = getSelect(query); - - List orderBys = select.getOrderBys(); - assertThat(orderBys.size(), equalTo(2)); - - Order formula = orderBys.get(0); - - assertThat(formula.isScript(), is(true)); - assertThat(formula.getType(), is("DESC")); - assertThat(formula.getName(), containsString("DateTimeFormatter.ofPattern")); - - Order ip = orderBys.get(1); - - assertThat(ip.isScript(), is(false)); - assertThat(ip.getName(), is("ip")); - assertThat(ip.getType(), is("ASC")); - } - - @Test - public void groupByWithDescOrder() throws SqlParseException { - String query = "SELECT date_format(utc_time, 'dd-MM-YYYY'), count(*) " + - "FROM opensearch_dashboards_sample_data_logs " + - "GROUP BY date_format(utc_time, 'dd-MM-YYYY') " + - "ORDER BY date_format(utc_time, 'dd-MM-YYYY') DESC"; - - JSONObject aggregation = getAggregation(query); - assertThat(aggregation.getInt("size"), is(getSelect(query).getRowCount())); - assertThat(aggregation.getJSONObject("order").getString("_key"), is("desc")); - } - - @Test - public void groupByWithAscOrder() throws SqlParseException { - String query = "SELECT date_format(utc_time, 'dd-MM-YYYY'), count(*) " + - "FROM opensearch_dashboards_sample_data_logs " + - "GROUP BY date_format(utc_time, 'dd-MM-YYYY') " + - "ORDER BY date_format(utc_time, 'dd-MM-YYYY')"; - - JSONObject aggregation = getAggregation(query); - - assertThat(aggregation.getJSONObject("order").getString("_key"), is("asc")); - } - - @Test - @Ignore("https://github.com/opendistro-for-elasticsearch/sql/issues/158") - public void groupByWithAndAlias() throws SqlParseException { - String query = "SELECT date_format(utc_time, 'dd-MM-YYYY') x, count(*) " + - "FROM opensearch_dashboards_sample_data_logs " + - "GROUP BY x " + - "ORDER BY x"; - - JSONObject aggregation = getAggregation(query); - assertThat(aggregation.getJSONObject("order").getString("_key"), is("asc")); - } - - public JSONObject getAggregation(String query) throws SqlParseException { - Select select = getSelect(query); - - Client client = mock(Client.class); - AggregationQueryAction queryAction = new AggregationQueryAction(client, select); - - String elasticDsl = queryAction.explain().explain(); - JSONObject elasticQuery = new JSONObject(elasticDsl); - - JSONObject aggregations = elasticQuery.getJSONObject("aggregations"); - String dateFormatAggregationKey = getScriptAggregationKey(aggregations, "date_format"); - - return aggregations.getJSONObject(dateFormatAggregationKey).getJSONObject("terms"); - } - - public static String getScriptAggregationKey(JSONObject aggregation, String prefix) { - return aggregation.keySet() - .stream() - .filter(x -> x.startsWith(prefix)) - .findFirst() - .orElseThrow(()-> new RuntimeException("Can't find key" + prefix + " in aggregation " + aggregation)); - } - - @Test - public void notEqualCondition() { - List q = query(SELECT_CNT_FROM_DATE + "WHERE date_format(creationDate, 'YYYY-MM-dd') <> '2018-04-02'"); - - assertThat(q, hasNotQueryWithValue("from", equalTo(BytesRefs.toBytesRef("2018-04-02")))); - assertThat(q, hasNotQueryWithValue("to", equalTo(BytesRefs.toBytesRef("2018-04-02")))); - } - - @Test - public void greaterThanCondition() { - List q = query(SELECT_CNT_FROM_DATE + "WHERE date_format(creationDate, 'YYYY-MM-dd') > '2018-04-02'"); - - assertThat(q, hasQueryWithValue("from", equalTo(BytesRefs.toBytesRef("2018-04-02")))); - assertThat(q, hasQueryWithValue("includeLower", equalTo(false))); - assertThat(q, hasQueryWithValue("includeUpper", equalTo(true))); - } - - @Test - public void greaterThanOrEqualToCondition() { - List q = query(SELECT_CNT_FROM_DATE + "WHERE date_format(creationDate, 'YYYY-MM-dd') >= '2018-04-02'"); - - assertThat(q, hasQueryWithValue("from", equalTo(BytesRefs.toBytesRef("2018-04-02")))); - assertThat(q, hasQueryWithValue("to", equalTo(null))); - assertThat(q, hasQueryWithValue("includeLower", equalTo(true))); - assertThat(q, hasQueryWithValue("includeUpper", equalTo(true))); + private static final String SELECT_CNT_FROM_DATE = "SELECT COUNT(*) AS c FROM dates "; + + @Test + public void simpleFormatCondition() { + List q = + query(SELECT_CNT_FROM_DATE + "WHERE date_format(creationDate, 'YYYY') < '2018'"); + + assertThat(q, hasQueryWithValue("fieldName", equalTo("creationDate"))); + assertThat( + q, + hasQueryWithValueGetter( + MatcherUtils.featureValueOf( + "has format", equalTo("YYYY"), f -> ((RangeQueryBuilder) f).format()))); + } + + @Test + public void equalCondition() { + List q = + query( + SELECT_CNT_FROM_DATE + "WHERE date_format(creationDate, 'YYYY-MM-dd') = '2018-04-02'"); + + assertThat( + q, + hasQueryWithValueGetter( + MatcherUtils.featureValueOf( + "has format", equalTo("YYYY-MM-dd"), f -> ((RangeQueryBuilder) f).format()))); + + // Equality query for date_format is created with a rangeQuery where the 'from' and 'to' values + // are equal to the value we are equating to + assertThat( + q, + hasQueryWithValue( + "from", + equalTo( + BytesRefs.toBytesRef( + "2018-04-02")))); // converting string to bytes ref as RangeQueryBuilder stores + // it this way + assertThat(q, hasQueryWithValue("to", equalTo(BytesRefs.toBytesRef("2018-04-02")))); + } + + @Test + public void orderByTest() { + String query = + "SELECT agent, ip, date_format(utc_time, 'dd-MM-YYYY') date " + + "FROM opensearch_dashboards_sample_data_logs " + + "ORDER BY date_format(utc_time, 'dd-MM-YYYY') desc, ip"; + + Select select = getSelect(query); + + List orderBys = select.getOrderBys(); + assertThat(orderBys.size(), equalTo(2)); + + Order formula = orderBys.get(0); + + assertThat(formula.isScript(), is(true)); + assertThat(formula.getType(), is("DESC")); + assertThat(formula.getName(), containsString("DateTimeFormatter.ofPattern")); + + Order ip = orderBys.get(1); + + assertThat(ip.isScript(), is(false)); + assertThat(ip.getName(), is("ip")); + assertThat(ip.getType(), is("ASC")); + } + + @Test + public void groupByWithDescOrder() throws SqlParseException { + String query = + "SELECT date_format(utc_time, 'dd-MM-YYYY'), count(*) " + + "FROM opensearch_dashboards_sample_data_logs " + + "GROUP BY date_format(utc_time, 'dd-MM-YYYY') " + + "ORDER BY date_format(utc_time, 'dd-MM-YYYY') DESC"; + + JSONObject aggregation = getAggregation(query); + assertThat(aggregation.getInt("size"), is(getSelect(query).getRowCount())); + assertThat(aggregation.getJSONObject("order").getString("_key"), is("desc")); + } + + @Test + public void groupByWithAscOrder() throws SqlParseException { + String query = + "SELECT date_format(utc_time, 'dd-MM-YYYY'), count(*) " + + "FROM opensearch_dashboards_sample_data_logs " + + "GROUP BY date_format(utc_time, 'dd-MM-YYYY') " + + "ORDER BY date_format(utc_time, 'dd-MM-YYYY')"; + + JSONObject aggregation = getAggregation(query); + + assertThat(aggregation.getJSONObject("order").getString("_key"), is("asc")); + } + + @Test + @Ignore("https://github.com/opendistro-for-elasticsearch/sql/issues/158") + public void groupByWithAndAlias() throws SqlParseException { + String query = + "SELECT date_format(utc_time, 'dd-MM-YYYY') x, count(*) " + + "FROM opensearch_dashboards_sample_data_logs " + + "GROUP BY x " + + "ORDER BY x"; + + JSONObject aggregation = getAggregation(query); + assertThat(aggregation.getJSONObject("order").getString("_key"), is("asc")); + } + + public JSONObject getAggregation(String query) throws SqlParseException { + Select select = getSelect(query); + + Client client = mock(Client.class); + AggregationQueryAction queryAction = new AggregationQueryAction(client, select); + + String elasticDsl = queryAction.explain().explain(); + JSONObject elasticQuery = new JSONObject(elasticDsl); + + JSONObject aggregations = elasticQuery.getJSONObject("aggregations"); + String dateFormatAggregationKey = getScriptAggregationKey(aggregations, "date_format"); + + return aggregations.getJSONObject(dateFormatAggregationKey).getJSONObject("terms"); + } + + public static String getScriptAggregationKey(JSONObject aggregation, String prefix) { + return aggregation.keySet().stream() + .filter(x -> x.startsWith(prefix)) + .findFirst() + .orElseThrow( + () -> + new RuntimeException("Can't find key" + prefix + " in aggregation " + aggregation)); + } + + @Test + public void notEqualCondition() { + List q = + query( + SELECT_CNT_FROM_DATE + "WHERE date_format(creationDate, 'YYYY-MM-dd') <> '2018-04-02'"); + + assertThat(q, hasNotQueryWithValue("from", equalTo(BytesRefs.toBytesRef("2018-04-02")))); + assertThat(q, hasNotQueryWithValue("to", equalTo(BytesRefs.toBytesRef("2018-04-02")))); + } + + @Test + public void greaterThanCondition() { + List q = + query( + SELECT_CNT_FROM_DATE + "WHERE date_format(creationDate, 'YYYY-MM-dd') > '2018-04-02'"); + + assertThat(q, hasQueryWithValue("from", equalTo(BytesRefs.toBytesRef("2018-04-02")))); + assertThat(q, hasQueryWithValue("includeLower", equalTo(false))); + assertThat(q, hasQueryWithValue("includeUpper", equalTo(true))); + } + + @Test + public void greaterThanOrEqualToCondition() { + List q = + query( + SELECT_CNT_FROM_DATE + "WHERE date_format(creationDate, 'YYYY-MM-dd') >= '2018-04-02'"); + + assertThat(q, hasQueryWithValue("from", equalTo(BytesRefs.toBytesRef("2018-04-02")))); + assertThat(q, hasQueryWithValue("to", equalTo(null))); + assertThat(q, hasQueryWithValue("includeLower", equalTo(true))); + assertThat(q, hasQueryWithValue("includeUpper", equalTo(true))); + } + + @Test + public void timeZoneCondition() { + List q = + query( + SELECT_CNT_FROM_DATE + + "WHERE date_format(creationDate, 'YYYY-MM-dd', 'America/Phoenix') >" + + " '2018-04-02'"); + + // Used hasProperty here as getter followed convention for obtaining ID and Feature Matcher was + // having issues with generic type to obtain value + assertThat(q, hasQueryWithValue("timeZone", hasProperty("id", equalTo("America/Phoenix")))); + } + + private List query(String sql) { + return translate(parseSql(sql)); + } + + private List translate(SQLQueryExpr expr) { + try { + Select select = new SqlParser().parseSelect(expr); + QueryBuilder whereQuery = QueryMaker.explain(select.getWhere(), select.isQuery); + return ((BoolQueryBuilder) whereQuery).filter(); + } catch (SqlParseException e) { + throw new ParserException("Illegal sql expr: " + expr.toString()); } + } - @Test - public void timeZoneCondition() { - List q = query(SELECT_CNT_FROM_DATE + "WHERE date_format(creationDate, 'YYYY-MM-dd', 'America/Phoenix') > '2018-04-02'"); - - // Used hasProperty here as getter followed convention for obtaining ID and Feature Matcher was having issues with generic type to obtain value - assertThat(q, hasQueryWithValue("timeZone", hasProperty("id", equalTo("America/Phoenix")))); - } - - private List query(String sql) { - return translate(parseSql(sql)); - } - - private List translate(SQLQueryExpr expr) { - try { - Select select = new SqlParser().parseSelect(expr); - QueryBuilder whereQuery = QueryMaker.explain(select.getWhere(), select.isQuery); - return ((BoolQueryBuilder) whereQuery).filter(); - } catch (SqlParseException e) { - throw new ParserException("Illegal sql expr: " + expr.toString()); - } - } - - private SQLQueryExpr parseSql(String sql) { - ElasticSqlExprParser parser = new ElasticSqlExprParser(sql); - SQLExpr expr = parser.expr(); - if (parser.getLexer().token() != Token.EOF) { - throw new ParserException("Illegal sql: " + sql); - } - return (SQLQueryExpr) expr; + private SQLQueryExpr parseSql(String sql) { + ElasticSqlExprParser parser = new ElasticSqlExprParser(sql); + SQLExpr expr = parser.expr(); + if (parser.getLexer().token() != Token.EOF) { + throw new ParserException("Illegal sql: " + sql); } - - private Select getSelect(String query) { - try { - Select select = new SqlParser().parseSelect(parseSql(query)); - if (select.getRowCount() == null){ - select.setRowCount(Select.DEFAULT_LIMIT); - } - return select; - } catch (SqlParseException e) { - throw new RuntimeException(e); - } - } - - private Matcher> hasQueryWithValue(String name, Matcher matcher) { - return hasItem( - hasFieldWithValue("mustClauses", "has mustClauses", - hasItem(hasFieldWithValue(name, "has " + name, matcher)))); - } - - private Matcher> hasNotQueryWithValue(String name, Matcher matcher) { - return hasItem( - hasFieldWithValue("mustClauses", "has mustClauses", - hasItem(hasFieldWithValue("mustNotClauses", "has mustNotClauses", - hasItem(hasFieldWithValue(name, "has " + name, matcher)))))); - } - - private Matcher> hasQueryWithValueGetter(Matcher matcher) { - return hasItem( - hasFieldWithValue("mustClauses", "has mustClauses", - hasItem(matcher))); + return (SQLQueryExpr) expr; + } + + private Select getSelect(String query) { + try { + Select select = new SqlParser().parseSelect(parseSql(query)); + if (select.getRowCount() == null) { + select.setRowCount(Select.DEFAULT_LIMIT); + } + return select; + } catch (SqlParseException e) { + throw new RuntimeException(e); } + } + + private Matcher> hasQueryWithValue( + String name, Matcher matcher) { + return hasItem( + hasFieldWithValue( + "mustClauses", + "has mustClauses", + hasItem(hasFieldWithValue(name, "has " + name, matcher)))); + } + + private Matcher> hasNotQueryWithValue( + String name, Matcher matcher) { + return hasItem( + hasFieldWithValue( + "mustClauses", + "has mustClauses", + hasItem( + hasFieldWithValue( + "mustNotClauses", + "has mustNotClauses", + hasItem(hasFieldWithValue(name, "has " + name, matcher)))))); + } + + private Matcher> hasQueryWithValueGetter(Matcher matcher) { + return hasItem(hasFieldWithValue("mustClauses", "has mustClauses", hasItem(matcher))); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/DateFunctionsTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/DateFunctionsTest.java index 771b0ce1bf..cf1be90665 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/DateFunctionsTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/DateFunctionsTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest; import static org.junit.Assert.assertTrue; @@ -20,181 +19,132 @@ public class DateFunctionsTest { - private static SqlParser parser; - - @BeforeClass - public static void init() { parser = new SqlParser(); } - - /** - * The following unit tests will only cover a subset of the available date functions as the painless script is - * generated from the same template. More thorough testing will be done in integration tests since output will - * differ for each function. - */ - - @Test - public void yearInSelect() { - String query = "SELECT YEAR(creationDate) " + - "FROM dates"; - ScriptField scriptField = getScriptFieldFromQuery(query); - assertTrue( - scriptContainsString( - scriptField, - "doc['creationDate'].value.year")); - } - - @Test - public void yearInWhere() { - String query = "SELECT * " + - "FROM dates " + - "WHERE YEAR(creationDate) > 2012"; - ScriptFilter scriptFilter = getScriptFilterFromQuery(query, parser); - assertTrue( - scriptContainsString( - scriptFilter, - "doc['creationDate'].value.year")); - assertTrue( - scriptHasPattern( - scriptFilter, - "year_\\d+ > 2012")); - } - - @Test - public void weekOfYearInSelect() { - String query = "SELECT WEEK_OF_YEAR(creationDate) " + - "FROM dates"; - ScriptField scriptField = getScriptFieldFromQuery(query); - assertTrue( - scriptContainsString( - scriptField, - "doc['creationDate'].value.get(WeekFields.ISO.weekOfWeekBasedYear())")); - } - - @Test - public void weekOfYearInWhere() { - String query = "SELECT * " + - "FROM dates " + - "WHERE WEEK_OF_YEAR(creationDate) > 15"; - ScriptFilter scriptFilter = getScriptFilterFromQuery(query, parser); - assertTrue( - scriptContainsString( - scriptFilter, - "doc['creationDate'].value.get(WeekFields.ISO.weekOfWeekBasedYear())")); - assertTrue( - scriptHasPattern( - scriptFilter, - "weekOfWeekyear_\\d+ > 15")); - } - - @Test - public void dayOfMonth() { - String query = "SELECT DAY_OF_MONTH(creationDate) " + - "FROM dates"; - ScriptField scriptField = getScriptFieldFromQuery(query); - assertTrue( - scriptContainsString( - scriptField, - "doc['creationDate'].value.dayOfMonth")); - } - - @Test - public void hourOfDay() { - String query = "SELECT HOUR_OF_DAY(creationDate) " + - "FROM dates"; - ScriptField scriptField = getScriptFieldFromQuery(query); - assertTrue( - scriptContainsString( - scriptField, - "doc['creationDate'].value.hour")); - } - - @Test - public void secondOfMinute() { - String query = "SELECT SECOND_OF_MINUTE(creationDate) " + - "FROM dates"; - ScriptField scriptField = getScriptFieldFromQuery(query); - assertTrue( - scriptContainsString( - scriptField, - "doc['creationDate'].value.second")); - } - - @Test - public void month() { - String query = "SELECT MONTH(creationDate) FROM dates"; - ScriptField scriptField = getScriptFieldFromQuery(query); - assertTrue( - scriptContainsString( - scriptField, - "doc['creationDate'].value.monthValue")); - } - - @Test - public void dayofmonth() { - String query = "SELECT DAY_OF_MONTH(creationDate) FROM dates"; - ScriptField scriptField = getScriptFieldFromQuery(query); - assertTrue( - scriptContainsString( - scriptField, - "doc['creationDate'].value.dayOfMonth")); - } - - @Test - public void date() { - String query = "SELECT DATE(creationDate) FROM dates"; - ScriptField scriptField = getScriptFieldFromQuery(query); - assertTrue( - scriptContainsString( - scriptField, - "LocalDate.parse(doc['creationDate'].value.toString(),DateTimeFormatter.ISO_DATE_TIME)")); - } - - @Test - public void monthname() { - String query = "SELECT MONTHNAME(creationDate) FROM dates"; - ScriptField scriptField = getScriptFieldFromQuery(query); - assertTrue( - scriptContainsString( - scriptField, - "doc['creationDate'].value.month")); - } - - @Test - public void timestamp() { - String query = "SELECT TIMESTAMP(creationDate) FROM dates"; - ScriptField scriptField = getScriptFieldFromQuery(query); - assertTrue( - scriptContainsString( - scriptField, - "DateTimeFormatter.ofPattern('yyyy-MM-dd HH:mm:ss')")); - } - - @Test - public void maketime() { - String query = "SELECT MAKETIME(1, 1, 1) FROM dates"; - ScriptField scriptField = getScriptFieldFromQuery(query); - assertTrue( - scriptContainsString( - scriptField, - "LocalTime.of(1, 1, 1).format(DateTimeFormatter.ofPattern('HH:mm:ss'))")); - } - - @Test - public void now() { - String query = "SELECT NOW() FROM dates"; - ScriptField scriptField = getScriptFieldFromQuery(query); - assertTrue( - scriptContainsString( - scriptField, - "System.currentTimeMillis()")); - } - - @Test - public void curdate() { - String query = "SELECT CURDATE() FROM dates"; - ScriptField scriptField = getScriptFieldFromQuery(query); - assertTrue( - scriptContainsString( - scriptField, - "System.currentTimeMillis()")); - } + private static SqlParser parser; + + @BeforeClass + public static void init() { + parser = new SqlParser(); + } + + /** + * The following unit tests will only cover a subset of the available date functions as the + * painless script is generated from the same template. More thorough testing will be done in + * integration tests since output will differ for each function. + */ + @Test + public void yearInSelect() { + String query = "SELECT YEAR(creationDate) " + "FROM dates"; + ScriptField scriptField = getScriptFieldFromQuery(query); + assertTrue(scriptContainsString(scriptField, "doc['creationDate'].value.year")); + } + + @Test + public void yearInWhere() { + String query = "SELECT * " + "FROM dates " + "WHERE YEAR(creationDate) > 2012"; + ScriptFilter scriptFilter = getScriptFilterFromQuery(query, parser); + assertTrue(scriptContainsString(scriptFilter, "doc['creationDate'].value.year")); + assertTrue(scriptHasPattern(scriptFilter, "year_\\d+ > 2012")); + } + + @Test + public void weekOfYearInSelect() { + String query = "SELECT WEEK_OF_YEAR(creationDate) " + "FROM dates"; + ScriptField scriptField = getScriptFieldFromQuery(query); + assertTrue( + scriptContainsString( + scriptField, "doc['creationDate'].value.get(WeekFields.ISO.weekOfWeekBasedYear())")); + } + + @Test + public void weekOfYearInWhere() { + String query = "SELECT * " + "FROM dates " + "WHERE WEEK_OF_YEAR(creationDate) > 15"; + ScriptFilter scriptFilter = getScriptFilterFromQuery(query, parser); + assertTrue( + scriptContainsString( + scriptFilter, "doc['creationDate'].value.get(WeekFields.ISO.weekOfWeekBasedYear())")); + assertTrue(scriptHasPattern(scriptFilter, "weekOfWeekyear_\\d+ > 15")); + } + + @Test + public void dayOfMonth() { + String query = "SELECT DAY_OF_MONTH(creationDate) " + "FROM dates"; + ScriptField scriptField = getScriptFieldFromQuery(query); + assertTrue(scriptContainsString(scriptField, "doc['creationDate'].value.dayOfMonth")); + } + + @Test + public void hourOfDay() { + String query = "SELECT HOUR_OF_DAY(creationDate) " + "FROM dates"; + ScriptField scriptField = getScriptFieldFromQuery(query); + assertTrue(scriptContainsString(scriptField, "doc['creationDate'].value.hour")); + } + + @Test + public void secondOfMinute() { + String query = "SELECT SECOND_OF_MINUTE(creationDate) " + "FROM dates"; + ScriptField scriptField = getScriptFieldFromQuery(query); + assertTrue(scriptContainsString(scriptField, "doc['creationDate'].value.second")); + } + + @Test + public void month() { + String query = "SELECT MONTH(creationDate) FROM dates"; + ScriptField scriptField = getScriptFieldFromQuery(query); + assertTrue(scriptContainsString(scriptField, "doc['creationDate'].value.monthValue")); + } + + @Test + public void dayofmonth() { + String query = "SELECT DAY_OF_MONTH(creationDate) FROM dates"; + ScriptField scriptField = getScriptFieldFromQuery(query); + assertTrue(scriptContainsString(scriptField, "doc['creationDate'].value.dayOfMonth")); + } + + @Test + public void date() { + String query = "SELECT DATE(creationDate) FROM dates"; + ScriptField scriptField = getScriptFieldFromQuery(query); + assertTrue( + scriptContainsString( + scriptField, + "LocalDate.parse(doc['creationDate'].value.toString(),DateTimeFormatter.ISO_DATE_TIME)")); + } + + @Test + public void monthname() { + String query = "SELECT MONTHNAME(creationDate) FROM dates"; + ScriptField scriptField = getScriptFieldFromQuery(query); + assertTrue(scriptContainsString(scriptField, "doc['creationDate'].value.month")); + } + + @Test + public void timestamp() { + String query = "SELECT TIMESTAMP(creationDate) FROM dates"; + ScriptField scriptField = getScriptFieldFromQuery(query); + assertTrue( + scriptContainsString(scriptField, "DateTimeFormatter.ofPattern('yyyy-MM-dd HH:mm:ss')")); + } + + @Test + public void maketime() { + String query = "SELECT MAKETIME(1, 1, 1) FROM dates"; + ScriptField scriptField = getScriptFieldFromQuery(query); + assertTrue( + scriptContainsString( + scriptField, "LocalTime.of(1, 1, 1).format(DateTimeFormatter.ofPattern('HH:mm:ss'))")); + } + + @Test + public void now() { + String query = "SELECT NOW() FROM dates"; + ScriptField scriptField = getScriptFieldFromQuery(query); + assertTrue(scriptContainsString(scriptField, "System.currentTimeMillis()")); + } + + @Test + public void curdate() { + String query = "SELECT CURDATE() FROM dates"; + ScriptField scriptField = getScriptFieldFromQuery(query); + assertTrue(scriptContainsString(scriptField, "System.currentTimeMillis()")); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/cursor/DefaultCursorTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/cursor/DefaultCursorTest.java index cfb70dc83c..0eb198ad84 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/cursor/DefaultCursorTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/cursor/DefaultCursorTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.cursor; import static org.hamcrest.Matchers.emptyOrNullString; @@ -19,40 +18,39 @@ public class DefaultCursorTest { - @Test - public void checkCursorType() { - DefaultCursor cursor = new DefaultCursor(); - assertEquals(cursor.getType(), CursorType.DEFAULT); - } - - - @Test - public void cursorShouldStartWithCursorTypeID() { - DefaultCursor cursor = new DefaultCursor(); - cursor.setRowsLeft(50); - cursor.setScrollId("dbdskbcdjksbcjkdsbcjk+//"); - cursor.setIndexPattern("myIndex"); - cursor.setFetchSize(500); - cursor.setFieldAliasMap(Collections.emptyMap()); - cursor.setColumns(new ArrayList<>()); - assertThat(cursor.generateCursorId(), startsWith(cursor.getType().getId()+ ":") ); - } - - @Test - public void nullCursorWhenRowLeftIsLessThanEqualZero() { - DefaultCursor cursor = new DefaultCursor(); - assertThat(cursor.generateCursorId(), emptyOrNullString()); - - cursor.setRowsLeft(-10); - assertThat(cursor.generateCursorId(), emptyOrNullString()); - } - - @Test - public void nullCursorWhenScrollIDIsNullOrEmpty() { - DefaultCursor cursor = new DefaultCursor(); - assertThat(cursor.generateCursorId(), emptyOrNullString()); - - cursor.setScrollId(""); - assertThat(cursor.generateCursorId(), emptyOrNullString()); - } + @Test + public void checkCursorType() { + DefaultCursor cursor = new DefaultCursor(); + assertEquals(cursor.getType(), CursorType.DEFAULT); + } + + @Test + public void cursorShouldStartWithCursorTypeID() { + DefaultCursor cursor = new DefaultCursor(); + cursor.setRowsLeft(50); + cursor.setScrollId("dbdskbcdjksbcjkdsbcjk+//"); + cursor.setIndexPattern("myIndex"); + cursor.setFetchSize(500); + cursor.setFieldAliasMap(Collections.emptyMap()); + cursor.setColumns(new ArrayList<>()); + assertThat(cursor.generateCursorId(), startsWith(cursor.getType().getId() + ":")); + } + + @Test + public void nullCursorWhenRowLeftIsLessThanEqualZero() { + DefaultCursor cursor = new DefaultCursor(); + assertThat(cursor.generateCursorId(), emptyOrNullString()); + + cursor.setRowsLeft(-10); + assertThat(cursor.generateCursorId(), emptyOrNullString()); + } + + @Test + public void nullCursorWhenScrollIDIsNullOrEmpty() { + DefaultCursor cursor = new DefaultCursor(); + assertThat(cursor.generateCursorId(), emptyOrNullString()); + + cursor.setScrollId(""); + assertThat(cursor.generateCursorId(), emptyOrNullString()); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/domain/ColumnTypeProviderTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/domain/ColumnTypeProviderTest.java index 205c63ad1d..6599d576b3 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/domain/ColumnTypeProviderTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/domain/ColumnTypeProviderTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.domain; import static org.junit.Assert.assertEquals; @@ -18,28 +17,29 @@ import org.opensearch.sql.legacy.executor.format.Schema; public class ColumnTypeProviderTest { - @Test - public void singleESDataTypeShouldReturnCorrectSchemaType() { - assertEquals(Schema.Type.LONG, new ColumnTypeProvider(OpenSearchDataType.LONG).get(0)); - } - - @Test - public void productTypeShouldReturnCorrectSchemaType() { - ColumnTypeProvider columnTypeProvider = - new ColumnTypeProvider(new Product(ImmutableList.of(OpenSearchDataType.LONG, OpenSearchDataType.SHORT))); - assertEquals(Schema.Type.LONG, columnTypeProvider.get(0)); - assertEquals(Schema.Type.SHORT, columnTypeProvider.get(1)); - } - - @Test - public void unSupportedTypeShouldReturnDefaultSchemaType() { - ColumnTypeProvider columnTypeProvider = new ColumnTypeProvider(SetOperator.UNION); - assertEquals(COLUMN_DEFAULT_TYPE, columnTypeProvider.get(0)); - } - - @Test - public void providerWithoutColumnTypeShouldReturnDefaultSchemaType() { - ColumnTypeProvider columnTypeProvider = new ColumnTypeProvider(); - assertEquals(COLUMN_DEFAULT_TYPE, columnTypeProvider.get(0)); - } + @Test + public void singleESDataTypeShouldReturnCorrectSchemaType() { + assertEquals(Schema.Type.LONG, new ColumnTypeProvider(OpenSearchDataType.LONG).get(0)); + } + + @Test + public void productTypeShouldReturnCorrectSchemaType() { + ColumnTypeProvider columnTypeProvider = + new ColumnTypeProvider( + new Product(ImmutableList.of(OpenSearchDataType.LONG, OpenSearchDataType.SHORT))); + assertEquals(Schema.Type.LONG, columnTypeProvider.get(0)); + assertEquals(Schema.Type.SHORT, columnTypeProvider.get(1)); + } + + @Test + public void unSupportedTypeShouldReturnDefaultSchemaType() { + ColumnTypeProvider columnTypeProvider = new ColumnTypeProvider(SetOperator.UNION); + assertEquals(COLUMN_DEFAULT_TYPE, columnTypeProvider.get(0)); + } + + @Test + public void providerWithoutColumnTypeShouldReturnDefaultSchemaType() { + ColumnTypeProvider columnTypeProvider = new ColumnTypeProvider(); + assertEquals(COLUMN_DEFAULT_TYPE, columnTypeProvider.get(0)); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/executor/DeleteResultSetTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/executor/DeleteResultSetTest.java index 31388e79e3..533c2b2989 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/executor/DeleteResultSetTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/executor/DeleteResultSetTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.executor; import static org.hamcrest.MatcherAssert.assertThat; @@ -23,53 +22,52 @@ import org.opensearch.sql.legacy.executor.format.DeleteResultSet; import org.opensearch.sql.legacy.executor.format.Schema; - public class DeleteResultSetTest { - @Mock - NodeClient client; + @Mock NodeClient client; - @Mock - Delete deleteQuery; - - @Test - public void testDeleteResponseForJdbcFormat() throws IOException { + @Mock Delete deleteQuery; - String jsonDeleteResponse = "{\n" + - " \"took\" : 73,\n" + - " \"timed_out\" : false,\n" + - " \"total\" : 1,\n" + - " \"updated\" : 0,\n" + - " \"created\" : 0,\n" + - " \"deleted\" : 10,\n" + - " \"batches\" : 1,\n" + - " \"version_conflicts\" : 0,\n" + - " \"noops\" : 0,\n" + - " \"retries\" : {\n" + - " \"bulk\" : 0,\n" + - " \"search\" : 0\n" + - " },\n" + - " \"throttled_millis\" : 0,\n" + - " \"requests_per_second\" : -1.0,\n" + - " \"throttled_until_millis\" : 0,\n" + - " \"failures\" : [ ]\n" + - "}\n"; + @Test + public void testDeleteResponseForJdbcFormat() throws IOException { - XContentType xContentType = XContentType.JSON; - XContentParser parser = xContentType.xContent().createParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.THROW_UNSUPPORTED_OPERATION, - jsonDeleteResponse - ); + String jsonDeleteResponse = + "{\n" + + " \"took\" : 73,\n" + + " \"timed_out\" : false,\n" + + " \"total\" : 1,\n" + + " \"updated\" : 0,\n" + + " \"created\" : 0,\n" + + " \"deleted\" : 10,\n" + + " \"batches\" : 1,\n" + + " \"version_conflicts\" : 0,\n" + + " \"noops\" : 0,\n" + + " \"retries\" : {\n" + + " \"bulk\" : 0,\n" + + " \"search\" : 0\n" + + " },\n" + + " \"throttled_millis\" : 0,\n" + + " \"requests_per_second\" : -1.0,\n" + + " \"throttled_until_millis\" : 0,\n" + + " \"failures\" : [ ]\n" + + "}\n"; - BulkByScrollResponse deleteResponse = BulkByScrollResponse.fromXContent(parser); - DeleteResultSet deleteResultSet = new DeleteResultSet(client, deleteQuery, deleteResponse); - Schema schema = deleteResultSet.getSchema(); - DataRows dataRows = deleteResultSet.getDataRows(); + XContentType xContentType = XContentType.JSON; + XContentParser parser = + xContentType + .xContent() + .createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + jsonDeleteResponse); - assertThat(schema.getHeaders().size(), equalTo(1)); - assertThat(dataRows.getSize(), equalTo(1L)); - assertThat(dataRows.iterator().next().getData(DeleteResultSet.DELETED), equalTo(10L)); - } + BulkByScrollResponse deleteResponse = BulkByScrollResponse.fromXContent(parser); + DeleteResultSet deleteResultSet = new DeleteResultSet(client, deleteQuery, deleteResponse); + Schema schema = deleteResultSet.getSchema(); + DataRows dataRows = deleteResultSet.getDataRows(); + assertThat(schema.getHeaders().size(), equalTo(1)); + assertThat(dataRows.getSize(), equalTo(1L)); + assertThat(dataRows.iterator().next().getData(DeleteResultSet.DELETED), equalTo(10L)); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/executor/format/BindingTupleResultSetTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/executor/format/BindingTupleResultSetTest.java index d76aa84a5d..fa385fa14b 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/executor/format/BindingTupleResultSetTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/executor/format/BindingTupleResultSetTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.executor.format; import static org.hamcrest.MatcherAssert.assertThat; @@ -27,53 +26,65 @@ public class BindingTupleResultSetTest { - @Test - public void buildDataRowsFromBindingTupleShouldPass() { - assertThat(row( + @Test + public void buildDataRowsFromBindingTupleShouldPass() { + assertThat( + row( Arrays.asList( ColumnNode.builder().name("age").type(Schema.Type.INTEGER).build(), ColumnNode.builder().name("gender").type(Schema.Type.TEXT).build()), - Arrays.asList(BindingTuple.from(ImmutableMap.of("age", 31, "gender", "m")), + Arrays.asList( + BindingTuple.from(ImmutableMap.of("age", 31, "gender", "m")), BindingTuple.from(ImmutableMap.of("age", 31, "gender", "f")), BindingTuple.from(ImmutableMap.of("age", 39, "gender", "m")), BindingTuple.from(ImmutableMap.of("age", 39, "gender", "f")))), - containsInAnyOrder(rowContents(allOf(hasEntry("age", 31), hasEntry("gender", (Object) "m"))), - rowContents(allOf(hasEntry("age", 31), hasEntry("gender", (Object) "f"))), - rowContents(allOf(hasEntry("age", 39), hasEntry("gender", (Object) "m"))), - rowContents(allOf(hasEntry("age", 39), hasEntry("gender", (Object) "f"))))); - } + containsInAnyOrder( + rowContents(allOf(hasEntry("age", 31), hasEntry("gender", (Object) "m"))), + rowContents(allOf(hasEntry("age", 31), hasEntry("gender", (Object) "f"))), + rowContents(allOf(hasEntry("age", 39), hasEntry("gender", (Object) "m"))), + rowContents(allOf(hasEntry("age", 39), hasEntry("gender", (Object) "f"))))); + } - @Test - public void buildDataRowsFromBindingTupleIncludeLongValueShouldPass() { - assertThat(row( + @Test + public void buildDataRowsFromBindingTupleIncludeLongValueShouldPass() { + assertThat( + row( Arrays.asList( ColumnNode.builder().name("longValue").type(Schema.Type.LONG).build(), ColumnNode.builder().name("gender").type(Schema.Type.TEXT).build()), Arrays.asList( BindingTuple.from(ImmutableMap.of("longValue", Long.MAX_VALUE, "gender", "m")), BindingTuple.from(ImmutableMap.of("longValue", Long.MIN_VALUE, "gender", "f")))), - containsInAnyOrder( - rowContents(allOf(hasEntry("longValue", Long.MAX_VALUE), hasEntry("gender", (Object) "m"))), - rowContents(allOf(hasEntry("longValue", Long.MIN_VALUE), hasEntry("gender", (Object) "f"))))); - } + containsInAnyOrder( + rowContents( + allOf(hasEntry("longValue", Long.MAX_VALUE), hasEntry("gender", (Object) "m"))), + rowContents( + allOf(hasEntry("longValue", Long.MIN_VALUE), hasEntry("gender", (Object) "f"))))); + } - @Test - public void buildDataRowsFromBindingTupleIncludeDateShouldPass() { - assertThat(row( + @Test + public void buildDataRowsFromBindingTupleIncludeDateShouldPass() { + assertThat( + row( Arrays.asList( ColumnNode.builder().alias("dateValue").type(Schema.Type.DATE).build(), ColumnNode.builder().alias("gender").type(Schema.Type.TEXT).build()), Arrays.asList( BindingTuple.from(ImmutableMap.of("dateValue", 1529712000000L, "gender", "m")))), - containsInAnyOrder( - rowContents(allOf(hasEntry("dateValue", "2018-06-23 00:00:00.000"), hasEntry("gender", (Object) "m"))))); - } + containsInAnyOrder( + rowContents( + allOf( + hasEntry("dateValue", "2018-06-23 00:00:00.000"), + hasEntry("gender", (Object) "m"))))); + } - private static Matcher rowContents(Matcher> matcher) { - return featureValueOf("DataRows.Row", matcher, DataRows.Row::getContents); - } + private static Matcher rowContents(Matcher> matcher) { + return featureValueOf("DataRows.Row", matcher, DataRows.Row::getContents); + } - private List row(List columnNodes, List bindingTupleList) { - return ImmutableList.copyOf(BindingTupleResultSet.buildDataRows(columnNodes, bindingTupleList).iterator()); - } + private List row( + List columnNodes, List bindingTupleList) { + return ImmutableList.copyOf( + BindingTupleResultSet.buildDataRows(columnNodes, bindingTupleList).iterator()); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/executor/format/CSVResultsExtractorTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/executor/format/CSVResultsExtractorTest.java index b3afff2ce1..be6029f9af 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/executor/format/CSVResultsExtractorTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/executor/format/CSVResultsExtractorTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.executor.format; import static org.hamcrest.MatcherAssert.assertThat; @@ -19,21 +18,25 @@ import org.opensearch.sql.legacy.expression.domain.BindingTuple; public class CSVResultsExtractorTest { - private final CSVResultsExtractor csvResultsExtractor = new CSVResultsExtractor(false, false); - - @Test - public void extractResultsFromBindingTupleListShouldPass() throws CsvExtractorException { - CSVResult csvResult = csv(Arrays.asList(BindingTuple.from(ImmutableMap.of("age", 31, "gender", "m")), - BindingTuple.from(ImmutableMap.of("age", 31, "gender", "f")), - BindingTuple.from(ImmutableMap.of("age", 39, "gender", "m")), - BindingTuple.from(ImmutableMap.of("age", 39, "gender", "f"))), - Arrays.asList("age", "gender")); - - assertThat(csvResult.getHeaders(), contains("age", "gender")); - assertThat(csvResult.getLines(), contains("31,m", "31,f", "39,m", "39,f")); - } - - private CSVResult csv(List bindingTupleList, List fieldNames) throws CsvExtractorException { - return csvResultsExtractor.extractResults(bindingTupleList, false, ",", fieldNames); - } + private final CSVResultsExtractor csvResultsExtractor = new CSVResultsExtractor(false, false); + + @Test + public void extractResultsFromBindingTupleListShouldPass() throws CsvExtractorException { + CSVResult csvResult = + csv( + Arrays.asList( + BindingTuple.from(ImmutableMap.of("age", 31, "gender", "m")), + BindingTuple.from(ImmutableMap.of("age", 31, "gender", "f")), + BindingTuple.from(ImmutableMap.of("age", 39, "gender", "m")), + BindingTuple.from(ImmutableMap.of("age", 39, "gender", "f"))), + Arrays.asList("age", "gender")); + + assertThat(csvResult.getHeaders(), contains("age", "gender")); + assertThat(csvResult.getLines(), contains("31,m", "31,f", "39,m", "39,f")); + } + + private CSVResult csv(List bindingTupleList, List fieldNames) + throws CsvExtractorException { + return csvResultsExtractor.extractResults(bindingTupleList, false, ",", fieldNames); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/expression/core/BinaryExpressionTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/expression/core/BinaryExpressionTest.java index 2f802f4f91..37a0666ad3 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/expression/core/BinaryExpressionTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/expression/core/BinaryExpressionTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.expression.core; import static org.junit.Assert.assertEquals; @@ -21,69 +20,65 @@ @RunWith(MockitoJUnitRunner.class) public class BinaryExpressionTest extends ExpressionTest { - @Rule - public ExpectedException exceptionRule = ExpectedException.none(); - - @Test - public void addIntegerValueShouldPass() { - assertEquals(2, - apply(ScalarOperation.ADD, ref("intValue"), ref("intValue"))); - } - - @Test - public void multipleAddIntegerValueShouldPass() { - assertEquals(3, - apply(ScalarOperation.ADD, ref("intValue"), - of(ScalarOperation.ADD, ref("intValue"), ref("intValue")))); - } - - @Test - public void addDoubleValueShouldPass() { - assertEquals(4d, - apply(ScalarOperation.ADD, ref("doubleValue"), ref("doubleValue"))); - } - - @Test - public void addDoubleAndIntegerShouldPass() { - assertEquals(3d, - apply(ScalarOperation.ADD, ref("doubleValue"), ref("intValue"))); - } - - @Test - public void divideIntegerValueShouldPass() { - assertEquals(0, - apply(ScalarOperation.DIVIDE, ref("intValue"), ref("intValue2"))); - } - - @Test - public void divideIntegerAndDoubleShouldPass() { - assertEquals(0.5d, - apply(ScalarOperation.DIVIDE, ref("intValue"), ref("doubleValue"))); - } - - @Test - public void subtractIntAndDoubleShouldPass() { - assertEquals(-1d, - apply(ScalarOperation.SUBTRACT, ref("intValue"), ref("doubleValue"))); - } - - @Test - public void multiplyIntAndDoubleShouldPass() { - assertEquals(2d, - apply(ScalarOperation.MULTIPLY, ref("intValue"), ref("doubleValue"))); - } - - @Test - public void modulesIntAndDoubleShouldPass() { - assertEquals(1d, - apply(ScalarOperation.MODULES, ref("intValue"), ref("doubleValue"))); - } - - @Test - public void addIntAndStringShouldPass() { - exceptionRule.expect(RuntimeException.class); - exceptionRule.expectMessage("unexpected operation type: ADD(INTEGER_VALUE, STRING_VALUE)"); - - assertEquals(2, apply(ScalarOperation.ADD, literal(integerValue(1)), literal(stringValue("stringValue")))); - } + @Rule public ExpectedException exceptionRule = ExpectedException.none(); + + @Test + public void addIntegerValueShouldPass() { + assertEquals(2, apply(ScalarOperation.ADD, ref("intValue"), ref("intValue"))); + } + + @Test + public void multipleAddIntegerValueShouldPass() { + assertEquals( + 3, + apply( + ScalarOperation.ADD, + ref("intValue"), + of(ScalarOperation.ADD, ref("intValue"), ref("intValue")))); + } + + @Test + public void addDoubleValueShouldPass() { + assertEquals(4d, apply(ScalarOperation.ADD, ref("doubleValue"), ref("doubleValue"))); + } + + @Test + public void addDoubleAndIntegerShouldPass() { + assertEquals(3d, apply(ScalarOperation.ADD, ref("doubleValue"), ref("intValue"))); + } + + @Test + public void divideIntegerValueShouldPass() { + assertEquals(0, apply(ScalarOperation.DIVIDE, ref("intValue"), ref("intValue2"))); + } + + @Test + public void divideIntegerAndDoubleShouldPass() { + assertEquals(0.5d, apply(ScalarOperation.DIVIDE, ref("intValue"), ref("doubleValue"))); + } + + @Test + public void subtractIntAndDoubleShouldPass() { + assertEquals(-1d, apply(ScalarOperation.SUBTRACT, ref("intValue"), ref("doubleValue"))); + } + + @Test + public void multiplyIntAndDoubleShouldPass() { + assertEquals(2d, apply(ScalarOperation.MULTIPLY, ref("intValue"), ref("doubleValue"))); + } + + @Test + public void modulesIntAndDoubleShouldPass() { + assertEquals(1d, apply(ScalarOperation.MODULES, ref("intValue"), ref("doubleValue"))); + } + + @Test + public void addIntAndStringShouldPass() { + exceptionRule.expect(RuntimeException.class); + exceptionRule.expectMessage("unexpected operation type: ADD(INTEGER_VALUE, STRING_VALUE)"); + + assertEquals( + 2, + apply(ScalarOperation.ADD, literal(integerValue(1)), literal(stringValue("stringValue")))); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/expression/core/CompoundExpressionTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/expression/core/CompoundExpressionTest.java index 2e75ee0c8b..3315024a13 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/expression/core/CompoundExpressionTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/expression/core/CompoundExpressionTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.expression.core; import static org.junit.Assert.assertEquals; @@ -16,10 +15,12 @@ public class CompoundExpressionTest extends ExpressionTest { - @Test - public void absAndAddShouldPass() { - assertEquals(2.0d, apply(ScalarOperation.ABS, of(ScalarOperation.ADD, - literal(doubleValue(-1.0d)), - literal(integerValue(-1))))); - } + @Test + public void absAndAddShouldPass() { + assertEquals( + 2.0d, + apply( + ScalarOperation.ABS, + of(ScalarOperation.ADD, literal(doubleValue(-1.0d)), literal(integerValue(-1))))); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/metrics/BasicCounterTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/metrics/BasicCounterTest.java index ebe61109a7..34dc170a37 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/metrics/BasicCounterTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/metrics/BasicCounterTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.metrics; import static org.hamcrest.MatcherAssert.assertThat; @@ -14,22 +13,21 @@ public class BasicCounterTest { - @Test - public void increment() { - BasicCounter counter = new BasicCounter(); - for (int i=0; i<5; ++i) { - counter.increment(); - } - - assertThat(counter.getValue(), equalTo(5L)); + @Test + public void increment() { + BasicCounter counter = new BasicCounter(); + for (int i = 0; i < 5; ++i) { + counter.increment(); } - @Test - public void incrementN() { - BasicCounter counter = new BasicCounter(); - counter.add(5); + assertThat(counter.getValue(), equalTo(5L)); + } - assertThat(counter.getValue(), equalTo(5L)); - } + @Test + public void incrementN() { + BasicCounter counter = new BasicCounter(); + counter.add(5); + assertThat(counter.getValue(), equalTo(5L)); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/parser/BucketPathTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/parser/BucketPathTest.java index 067143716d..c26740a04c 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/parser/BucketPathTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/parser/BucketPathTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.parser; import static org.junit.Assert.assertEquals; @@ -16,46 +15,45 @@ public class BucketPathTest { - @Rule - public ExpectedException exceptionRule = ExpectedException.none(); + @Rule public ExpectedException exceptionRule = ExpectedException.none(); - private final Path agg1 = Path.getAggPath("projects@NESTED"); - private final Path agg2 = Path.getAggPath("projects@FILTERED"); - private final Path metric = Path.getMetricPath("c"); + private final Path agg1 = Path.getAggPath("projects@NESTED"); + private final Path agg2 = Path.getAggPath("projects@FILTERED"); + private final Path metric = Path.getMetricPath("c"); - @Test - public void bucketPath() { - BucketPath bucketPath = new BucketPath(); - bucketPath.add(metric); - bucketPath.add(agg2); - bucketPath.add(agg1); + @Test + public void bucketPath() { + BucketPath bucketPath = new BucketPath(); + bucketPath.add(metric); + bucketPath.add(agg2); + bucketPath.add(agg1); - assertEquals("projects@NESTED>projects@FILTERED.c", bucketPath.getBucketPath()); - } + assertEquals("projects@NESTED>projects@FILTERED.c", bucketPath.getBucketPath()); + } - @Test - public void bucketPathEmpty() { - BucketPath bucketPath = new BucketPath(); + @Test + public void bucketPathEmpty() { + BucketPath bucketPath = new BucketPath(); - assertEquals("", bucketPath.getBucketPath()); - } + assertEquals("", bucketPath.getBucketPath()); + } - @Test - public void theLastMustBeMetric() { - BucketPath bucketPath = new BucketPath(); + @Test + public void theLastMustBeMetric() { + BucketPath bucketPath = new BucketPath(); - exceptionRule.expect(AssertionError.class); - exceptionRule.expectMessage("The last path in the bucket path must be Metric"); - bucketPath.add(agg1); - } + exceptionRule.expect(AssertionError.class); + exceptionRule.expectMessage("The last path in the bucket path must be Metric"); + bucketPath.add(agg1); + } - @Test - public void allTheOtherMustBeAgg() { - BucketPath bucketPath = new BucketPath(); + @Test + public void allTheOtherMustBeAgg() { + BucketPath bucketPath = new BucketPath(); - exceptionRule.expect(AssertionError.class); - exceptionRule.expectMessage("All the other path in the bucket path must be Agg"); - bucketPath.add(metric); - bucketPath.add(metric); - } + exceptionRule.expect(AssertionError.class); + exceptionRule.expectMessage("All the other path in the bucket path must be Agg"); + bucketPath.add(metric); + bucketPath.add(metric); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/BindingTupleQueryPlannerExecuteTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/BindingTupleQueryPlannerExecuteTest.java index 9f6fcbcc6d..1260b551fb 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/BindingTupleQueryPlannerExecuteTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/BindingTupleQueryPlannerExecuteTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.planner; import static org.hamcrest.MatcherAssert.assertThat; @@ -35,79 +34,78 @@ @RunWith(MockitoJUnitRunner.class) public class BindingTupleQueryPlannerExecuteTest { - @Mock - private Client client; - @Mock - private SearchResponse aggResponse; - @Mock - private ColumnTypeProvider columnTypeProvider; + @Mock private Client client; + @Mock private SearchResponse aggResponse; + @Mock private ColumnTypeProvider columnTypeProvider; - @Before - public void init() { - MockitoAnnotations.initMocks(this); + @Before + public void init() { + MockitoAnnotations.initMocks(this); - ActionFuture mockFuture = mock(ActionFuture.class); - when(client.execute(any(), any())).thenReturn(mockFuture); - when(mockFuture.actionGet()).thenAnswer(invocationOnMock -> aggResponse); - } + ActionFuture mockFuture = mock(ActionFuture.class); + when(client.execute(any(), any())).thenReturn(mockFuture); + when(mockFuture.actionGet()).thenAnswer(invocationOnMock -> aggResponse); + } - @Test - public void testAggregationShouldPass() { - assertThat(query("SELECT gender, MAX(age) as max, MIN(age) as min FROM accounts GROUP BY gender", - mockSearchAggregation()), - containsInAnyOrder( - BindingTuple.from(ImmutableMap.of("gender", "m", "max", 20d, "min", 10d)), - BindingTuple.from(ImmutableMap.of("gender", "f", "max", 40d, "min", 20d)))); - } + @Test + public void testAggregationShouldPass() { + assertThat( + query( + "SELECT gender, MAX(age) as max, MIN(age) as min FROM accounts GROUP BY gender", + mockSearchAggregation()), + containsInAnyOrder( + BindingTuple.from(ImmutableMap.of("gender", "m", "max", 20d, "min", 10d)), + BindingTuple.from(ImmutableMap.of("gender", "f", "max", 40d, "min", 20d)))); + } + protected List query(String sql, MockSearchAggregation mockAgg) { + doAnswer(mockAgg).when(aggResponse).getAggregations(); - protected List query(String sql, MockSearchAggregation mockAgg) { - doAnswer(mockAgg).when(aggResponse).getAggregations(); + BindingTupleQueryPlanner queryPlanner = + new BindingTupleQueryPlanner(client, SqlParserUtils.parse(sql), columnTypeProvider); + return queryPlanner.execute(); + } - BindingTupleQueryPlanner queryPlanner = - new BindingTupleQueryPlanner(client, SqlParserUtils.parse(sql), columnTypeProvider); - return queryPlanner.execute(); - } + private MockSearchAggregation mockSearchAggregation() { + return new MockSearchAggregation( + "{\n" + + " \"sterms#gender\": {\n" + + " \"buckets\": [\n" + + " {\n" + + " \"key\": \"m\",\n" + + " \"doc_count\": 507,\n" + + " \"min#min\": {\n" + + " \"value\": 10\n" + + " },\n" + + " \"max#max\": {\n" + + " \"value\": 20\n" + + " }\n" + + " },\n" + + " {\n" + + " \"key\": \"f\",\n" + + " \"doc_count\": 493,\n" + + " \"min#min\": {\n" + + " \"value\": 20\n" + + " },\n" + + " \"max#max\": {\n" + + " \"value\": 40\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + "}"); + } - private MockSearchAggregation mockSearchAggregation() { - return new MockSearchAggregation("{\n" - + " \"sterms#gender\": {\n" - + " \"buckets\": [\n" - + " {\n" - + " \"key\": \"m\",\n" - + " \"doc_count\": 507,\n" - + " \"min#min\": {\n" - + " \"value\": 10\n" - + " },\n" - + " \"max#max\": {\n" - + " \"value\": 20\n" - + " }\n" - + " },\n" - + " {\n" - + " \"key\": \"f\",\n" - + " \"doc_count\": 493,\n" - + " \"min#min\": {\n" - + " \"value\": 20\n" - + " },\n" - + " \"max#max\": {\n" - + " \"value\": 40\n" - + " }\n" - + " }\n" - + " ]\n" - + " }\n" - + "}"); - } - - protected static class MockSearchAggregation implements Answer { - private final Aggregations aggregation; + protected static class MockSearchAggregation implements Answer { + private final Aggregations aggregation; - public MockSearchAggregation(String agg) { - aggregation = AggregationUtils.fromJson(agg); - } + public MockSearchAggregation(String agg) { + aggregation = AggregationUtils.fromJson(agg); + } - @Override - public Aggregations answer(InvocationOnMock invocationOnMock) throws Throwable { - return aggregation; - } + @Override + public Aggregations answer(InvocationOnMock invocationOnMock) throws Throwable { + return aggregation; } + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/query/DefaultQueryActionTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/query/DefaultQueryActionTest.java index 57530692d4..11e14e9b48 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/query/DefaultQueryActionTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/query/DefaultQueryActionTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.query; import static org.hamcrest.Matchers.equalTo; @@ -42,228 +41,225 @@ public class DefaultQueryActionTest { - private DefaultQueryAction queryAction; + private DefaultQueryAction queryAction; - private Client mockClient; + private Client mockClient; - private Select mockSelect; + private Select mockSelect; - private SearchRequestBuilder mockRequestBuilder; + private SearchRequestBuilder mockRequestBuilder; - @Before - public void initDefaultQueryAction() { + @Before + public void initDefaultQueryAction() { - mockClient = mock(Client.class); - mockSelect = mock(Select.class); - mockRequestBuilder = mock(SearchRequestBuilder.class); + mockClient = mock(Client.class); + mockSelect = mock(Select.class); + mockRequestBuilder = mock(SearchRequestBuilder.class); - List fields = new LinkedList<>(); - fields.add(new Field("balance", "bbb")); + List fields = new LinkedList<>(); + fields.add(new Field("balance", "bbb")); - doReturn(fields).when(mockSelect).getFields(); - doReturn(null).when(mockRequestBuilder).setFetchSource(any(String[].class), any(String[].class)); - doReturn(null).when(mockRequestBuilder).addScriptField(anyString(), any(Script.class)); + doReturn(fields).when(mockSelect).getFields(); + doReturn(null) + .when(mockRequestBuilder) + .setFetchSource(any(String[].class), any(String[].class)); + doReturn(null).when(mockRequestBuilder).addScriptField(anyString(), any(Script.class)); - queryAction = new DefaultQueryAction(mockClient, mockSelect); - queryAction.initialize(mockRequestBuilder); - } + queryAction = new DefaultQueryAction(mockClient, mockSelect); + queryAction.initialize(mockRequestBuilder); + } - @After - public void cleanup() { - LocalClusterState.state(null); - } + @After + public void cleanup() { + LocalClusterState.state(null); + } - @Test - public void scriptFieldWithTwoParams() throws SqlParseException { + @Test + public void scriptFieldWithTwoParams() throws SqlParseException { - List fields = new LinkedList<>(); - fields.add(createScriptField("script1", "doc['balance'] * 2", - false, true, false)); + List fields = new LinkedList<>(); + fields.add(createScriptField("script1", "doc['balance'] * 2", false, true, false)); - queryAction.setFields(fields); + queryAction.setFields(fields); - final Optional> fieldNames = queryAction.getFieldNames(); - Assert.assertTrue("Field names have not been set", fieldNames.isPresent()); - Assert.assertThat(fieldNames.get().size(), equalTo(1)); - Assert.assertThat(fieldNames.get().get(0), equalTo("script1")); + final Optional> fieldNames = queryAction.getFieldNames(); + Assert.assertTrue("Field names have not been set", fieldNames.isPresent()); + Assert.assertThat(fieldNames.get().size(), equalTo(1)); + Assert.assertThat(fieldNames.get().get(0), equalTo("script1")); - Mockito.verify(mockRequestBuilder).addScriptField(eq("script1"), any(Script.class)); - } + Mockito.verify(mockRequestBuilder).addScriptField(eq("script1"), any(Script.class)); + } - @Test - public void scriptFieldWithThreeParams() throws SqlParseException { + @Test + public void scriptFieldWithThreeParams() throws SqlParseException { - List fields = new LinkedList<>(); - fields.add(createScriptField("script1", "doc['balance'] * 2", - true, true, false)); + List fields = new LinkedList<>(); + fields.add(createScriptField("script1", "doc['balance'] * 2", true, true, false)); - queryAction.setFields(fields); + queryAction.setFields(fields); - final Optional> fieldNames = queryAction.getFieldNames(); - Assert.assertTrue("Field names have not been set", fieldNames.isPresent()); - Assert.assertThat(fieldNames.get().size(), equalTo(1)); - Assert.assertThat(fieldNames.get().get(0), equalTo("script1")); + final Optional> fieldNames = queryAction.getFieldNames(); + Assert.assertTrue("Field names have not been set", fieldNames.isPresent()); + Assert.assertThat(fieldNames.get().size(), equalTo(1)); + Assert.assertThat(fieldNames.get().get(0), equalTo("script1")); - Mockito.verify(mockRequestBuilder).addScriptField(eq("script1"), any(Script.class)); - } + Mockito.verify(mockRequestBuilder).addScriptField(eq("script1"), any(Script.class)); + } - @Test(expected = SqlParseException.class) - public void scriptFieldWithLessThanTwoParams() throws SqlParseException { + @Test(expected = SqlParseException.class) + public void scriptFieldWithLessThanTwoParams() throws SqlParseException { - List fields = new LinkedList<>(); - fields.add(createScriptField("script1", "doc['balance'] * 2", - false, false, false)); + List fields = new LinkedList<>(); + fields.add(createScriptField("script1", "doc['balance'] * 2", false, false, false)); - queryAction.setFields(fields); - } + queryAction.setFields(fields); + } - @Test - public void scriptFieldWithMoreThanThreeParams() throws SqlParseException { + @Test + public void scriptFieldWithMoreThanThreeParams() throws SqlParseException { - List fields = new LinkedList<>(); - fields.add(createScriptField("script1", "doc['balance'] * 2", - false, true, true)); - - queryAction.setFields(fields); - } - - @Test - public void testIfScrollShouldBeOpenWithDifferentFormats() { - int settingFetchSize = 500; - TimeValue timeValue = new TimeValue(120000); - int limit = 2300; - mockLocalClusterStateAndInitializeMetrics(timeValue); - - doReturn(limit).when(mockSelect).getRowCount(); - doReturn(mockRequestBuilder).when(mockRequestBuilder).setSize(settingFetchSize); - SqlRequest mockSqlRequest = mock(SqlRequest.class); - doReturn(settingFetchSize).when(mockSqlRequest).fetchSize(); - queryAction.setSqlRequest(mockSqlRequest); - - Format[] formats = new Format[] {Format.CSV, Format.RAW, Format.JSON, Format.TABLE}; - for (Format format : formats) { - queryAction.setFormat(format); - queryAction.checkAndSetScroll(); - } - - Mockito.verify(mockRequestBuilder, times(4)).setSize(limit); - Mockito.verify(mockRequestBuilder, never()).setScroll(any(TimeValue.class)); - - queryAction.setFormat(Format.JDBC); - queryAction.checkAndSetScroll(); - Mockito.verify(mockRequestBuilder).setSize(settingFetchSize); - Mockito.verify(mockRequestBuilder).setScroll(timeValue); - - } + List fields = new LinkedList<>(); + fields.add(createScriptField("script1", "doc['balance'] * 2", false, true, true)); - @Test - public void testIfScrollShouldBeOpen() { - int settingFetchSize = 500; - TimeValue timeValue = new TimeValue(120000); - int limit = 2300; + queryAction.setFields(fields); + } - doReturn(limit).when(mockSelect).getRowCount(); - doReturn(mockRequestBuilder).when(mockRequestBuilder).setSize(settingFetchSize); - SqlRequest mockSqlRequest = mock(SqlRequest.class); - doReturn(settingFetchSize).when(mockSqlRequest).fetchSize(); - queryAction.setSqlRequest(mockSqlRequest); - queryAction.setFormat(Format.JDBC); + @Test + public void testIfScrollShouldBeOpenWithDifferentFormats() { + int settingFetchSize = 500; + TimeValue timeValue = new TimeValue(120000); + int limit = 2300; + mockLocalClusterStateAndInitializeMetrics(timeValue); - mockLocalClusterStateAndInitializeMetrics(timeValue); - queryAction.checkAndSetScroll(); - Mockito.verify(mockRequestBuilder).setSize(settingFetchSize); - Mockito.verify(mockRequestBuilder).setScroll(timeValue); + doReturn(limit).when(mockSelect).getRowCount(); + doReturn(mockRequestBuilder).when(mockRequestBuilder).setSize(settingFetchSize); + SqlRequest mockSqlRequest = mock(SqlRequest.class); + doReturn(settingFetchSize).when(mockSqlRequest).fetchSize(); + queryAction.setSqlRequest(mockSqlRequest); + Format[] formats = new Format[] {Format.CSV, Format.RAW, Format.JSON, Format.TABLE}; + for (Format format : formats) { + queryAction.setFormat(format); + queryAction.checkAndSetScroll(); } - @Test - public void testIfScrollShouldBeOpenWithDifferentFetchSize() { - TimeValue timeValue = new TimeValue(120000); - int limit = 2300; - mockLocalClusterStateAndInitializeMetrics(timeValue); - - doReturn(limit).when(mockSelect).getRowCount(); - SqlRequest mockSqlRequest = mock(SqlRequest.class); - queryAction.setSqlRequest(mockSqlRequest); - queryAction.setFormat(Format.JDBC); - - int[] fetchSizes = new int[] {0, -10}; - for (int fetch : fetchSizes) { - doReturn(fetch).when(mockSqlRequest).fetchSize(); - queryAction.checkAndSetScroll(); - } - Mockito.verify(mockRequestBuilder, times(2)).setSize(limit); - Mockito.verify(mockRequestBuilder, never()).setScroll(timeValue); - - int userFetchSize = 20; - doReturn(userFetchSize).when(mockSqlRequest).fetchSize(); - doReturn(mockRequestBuilder).when(mockRequestBuilder).setSize(userFetchSize); - queryAction.checkAndSetScroll(); - Mockito.verify(mockRequestBuilder).setSize(20); - Mockito.verify(mockRequestBuilder).setScroll(timeValue); + Mockito.verify(mockRequestBuilder, times(4)).setSize(limit); + Mockito.verify(mockRequestBuilder, never()).setScroll(any(TimeValue.class)); + + queryAction.setFormat(Format.JDBC); + queryAction.checkAndSetScroll(); + Mockito.verify(mockRequestBuilder).setSize(settingFetchSize); + Mockito.verify(mockRequestBuilder).setScroll(timeValue); + } + + @Test + public void testIfScrollShouldBeOpen() { + int settingFetchSize = 500; + TimeValue timeValue = new TimeValue(120000); + int limit = 2300; + + doReturn(limit).when(mockSelect).getRowCount(); + doReturn(mockRequestBuilder).when(mockRequestBuilder).setSize(settingFetchSize); + SqlRequest mockSqlRequest = mock(SqlRequest.class); + doReturn(settingFetchSize).when(mockSqlRequest).fetchSize(); + queryAction.setSqlRequest(mockSqlRequest); + queryAction.setFormat(Format.JDBC); + + mockLocalClusterStateAndInitializeMetrics(timeValue); + queryAction.checkAndSetScroll(); + Mockito.verify(mockRequestBuilder).setSize(settingFetchSize); + Mockito.verify(mockRequestBuilder).setScroll(timeValue); + } + + @Test + public void testIfScrollShouldBeOpenWithDifferentFetchSize() { + TimeValue timeValue = new TimeValue(120000); + int limit = 2300; + mockLocalClusterStateAndInitializeMetrics(timeValue); + + doReturn(limit).when(mockSelect).getRowCount(); + SqlRequest mockSqlRequest = mock(SqlRequest.class); + queryAction.setSqlRequest(mockSqlRequest); + queryAction.setFormat(Format.JDBC); + + int[] fetchSizes = new int[] {0, -10}; + for (int fetch : fetchSizes) { + doReturn(fetch).when(mockSqlRequest).fetchSize(); + queryAction.checkAndSetScroll(); } - - - @Test - public void testIfScrollShouldBeOpenWithDifferentValidFetchSizeAndLimit() { - TimeValue timeValue = new TimeValue(120000); - mockLocalClusterStateAndInitializeMetrics(timeValue); - - int limit = 2300; - doReturn(limit).when(mockSelect).getRowCount(); - SqlRequest mockSqlRequest = mock(SqlRequest.class); - - /** fetchSize <= LIMIT - open scroll*/ - int userFetchSize = 1500; - doReturn(userFetchSize).when(mockSqlRequest).fetchSize(); - doReturn(mockRequestBuilder).when(mockRequestBuilder).setSize(userFetchSize); - queryAction.setSqlRequest(mockSqlRequest); - queryAction.setFormat(Format.JDBC); - - queryAction.checkAndSetScroll(); - Mockito.verify(mockRequestBuilder).setSize(userFetchSize); - Mockito.verify(mockRequestBuilder).setScroll(timeValue); - - /** fetchSize > LIMIT - no scroll */ - userFetchSize = 5000; - doReturn(userFetchSize).when(mockSqlRequest).fetchSize(); - mockRequestBuilder = mock(SearchRequestBuilder.class); - queryAction.initialize(mockRequestBuilder); - queryAction.checkAndSetScroll(); - Mockito.verify(mockRequestBuilder).setSize(limit); - Mockito.verify(mockRequestBuilder, never()).setScroll(timeValue); + Mockito.verify(mockRequestBuilder, times(2)).setSize(limit); + Mockito.verify(mockRequestBuilder, never()).setScroll(timeValue); + + int userFetchSize = 20; + doReturn(userFetchSize).when(mockSqlRequest).fetchSize(); + doReturn(mockRequestBuilder).when(mockRequestBuilder).setSize(userFetchSize); + queryAction.checkAndSetScroll(); + Mockito.verify(mockRequestBuilder).setSize(20); + Mockito.verify(mockRequestBuilder).setScroll(timeValue); + } + + @Test + public void testIfScrollShouldBeOpenWithDifferentValidFetchSizeAndLimit() { + TimeValue timeValue = new TimeValue(120000); + mockLocalClusterStateAndInitializeMetrics(timeValue); + + int limit = 2300; + doReturn(limit).when(mockSelect).getRowCount(); + SqlRequest mockSqlRequest = mock(SqlRequest.class); + + /** fetchSize <= LIMIT - open scroll */ + int userFetchSize = 1500; + doReturn(userFetchSize).when(mockSqlRequest).fetchSize(); + doReturn(mockRequestBuilder).when(mockRequestBuilder).setSize(userFetchSize); + queryAction.setSqlRequest(mockSqlRequest); + queryAction.setFormat(Format.JDBC); + + queryAction.checkAndSetScroll(); + Mockito.verify(mockRequestBuilder).setSize(userFetchSize); + Mockito.verify(mockRequestBuilder).setScroll(timeValue); + + /** fetchSize > LIMIT - no scroll */ + userFetchSize = 5000; + doReturn(userFetchSize).when(mockSqlRequest).fetchSize(); + mockRequestBuilder = mock(SearchRequestBuilder.class); + queryAction.initialize(mockRequestBuilder); + queryAction.checkAndSetScroll(); + Mockito.verify(mockRequestBuilder).setSize(limit); + Mockito.verify(mockRequestBuilder, never()).setScroll(timeValue); + } + + private void mockLocalClusterStateAndInitializeMetrics(TimeValue time) { + LocalClusterState mockLocalClusterState = mock(LocalClusterState.class); + LocalClusterState.state(mockLocalClusterState); + doReturn(time).when(mockLocalClusterState).getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE); + doReturn(3600L) + .when(mockLocalClusterState) + .getSettingValue(Settings.Key.METRICS_ROLLING_WINDOW); + doReturn(2L).when(mockLocalClusterState).getSettingValue(Settings.Key.METRICS_ROLLING_INTERVAL); + + Metrics.getInstance().registerDefaultMetrics(); + } + + private Field createScriptField( + final String name, + final String script, + final boolean addScriptLanguage, + final boolean addScriptParam, + final boolean addRedundantParam) { + + final List params = new ArrayList<>(); + + params.add(new KVValue("alias", name)); + if (addScriptLanguage) { + params.add(new KVValue("painless")); } - - private void mockLocalClusterStateAndInitializeMetrics(TimeValue time) { - LocalClusterState mockLocalClusterState = mock(LocalClusterState.class); - LocalClusterState.state(mockLocalClusterState); - doReturn(time).when(mockLocalClusterState).getSettingValue( - Settings.Key.SQL_CURSOR_KEEP_ALIVE); - doReturn(3600L).when(mockLocalClusterState).getSettingValue( - Settings.Key.METRICS_ROLLING_WINDOW); - doReturn(2L).when(mockLocalClusterState).getSettingValue( - Settings.Key.METRICS_ROLLING_INTERVAL); - - Metrics.getInstance().registerDefaultMetrics(); - + if (addScriptParam) { + params.add(new KVValue(script)); } - - private Field createScriptField(final String name, final String script, final boolean addScriptLanguage, - final boolean addScriptParam, final boolean addRedundantParam) { - - final List params = new ArrayList<>(); - - params.add(new KVValue("alias", name)); - if (addScriptLanguage) { - params.add(new KVValue("painless")); - } - if (addScriptParam) { - params.add(new KVValue(script)); - } - if (addRedundantParam) { - params.add(new KVValue("Fail the test")); - } - - return new MethodField("script", params, null, null); + if (addRedundantParam) { + params.add(new KVValue("Fail the test")); } + + return new MethodField("script", params, null, null); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/inline/AliasInliningTests.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/inline/AliasInliningTests.java index 0c16a3264a..168725ed11 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/inline/AliasInliningTests.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/inline/AliasInliningTests.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.rewriter.inline; import static org.hamcrest.MatcherAssert.assertThat; @@ -29,101 +28,111 @@ public class AliasInliningTests { - private static final String TEST_MAPPING_FILE = "mappings/semantics.json"; - @Before - public void setUp() throws IOException { - URL url = Resources.getResource(TEST_MAPPING_FILE); - String mappings = Resources.toString(url, Charsets.UTF_8); - mockLocalClusterState(mappings); - } - - @Test - public void orderByAliasedFieldTest() throws SqlParseException { - String originalQuery = "SELECT utc_time date " + - "FROM opensearch_dashboards_sample_data_logs " + - "ORDER BY date DESC"; - String originalDsl = parseAsSimpleQuery(originalQuery); - - String rewrittenQuery = - "SELECT utc_time date " + - "FROM opensearch_dashboards_sample_data_logs " + - "ORDER BY utc_time DESC"; - - String rewrittenDsl = parseAsSimpleQuery(rewrittenQuery); - - assertThat(originalDsl, equalTo(rewrittenDsl)); - } - - @Test - public void orderByAliasedScriptedField() throws SqlParseException { - String originalDsl = parseAsSimpleQuery("SELECT date_format(birthday, 'dd-MM-YYYY') date " + - "FROM bank " + - "ORDER BY date"); - String rewrittenQuery = "SELECT date_format(birthday, 'dd-MM-YYYY') date " + - "FROM bank " + - "ORDER BY date_format(birthday, 'dd-MM-YYYY')"; - - String rewrittenDsl = parseAsSimpleQuery(rewrittenQuery); - assertThat(originalDsl, equalTo(rewrittenDsl)); - } - - @Test - public void groupByAliasedFieldTest() throws SqlParseException { - String originalQuery = "SELECT utc_time date " + - "FROM opensearch_dashboards_sample_data_logs " + - "GROUP BY date"; - - String originalDsl = parseAsAggregationQuery(originalQuery); - - String rewrittenQuery = "SELECT utc_time date " + - "FROM opensearch_dashboards_sample_data_logs " + - "GROUP BY utc_time DESC"; - - String rewrittenDsl = parseAsAggregationQuery(rewrittenQuery); - - assertThat(originalDsl, equalTo(rewrittenDsl)); - } - - @Test - public void groupAndSortBySameExprAlias() throws SqlParseException { - String query = "SELECT date_format(timestamp, 'yyyy-MM') opensearch-table.timestamp_tg, COUNT(*) count, COUNT(DistanceKilometers) opensearch-table.DistanceKilometers_count\n" + - "FROM opensearch_dashboards_sample_data_flights\n" + - "GROUP BY date_format(timestamp, 'yyyy-MM')\n" + - "ORDER BY date_format(timestamp, 'yyyy-MM') DESC\n" + - "LIMIT 2500"; - String dsl = parseAsAggregationQuery(query); - - JSONObject parseQuery = new JSONObject(dsl); - - assertThat(parseQuery.query("/aggregations/opensearch-table.timestamp_tg/terms/script"), notNullValue()); - - } - - @Test - public void groupByAndSortAliased() throws SqlParseException { - String dsl = parseAsAggregationQuery( - "SELECT date_format(utc_time, 'dd-MM-YYYY') date " + - "FROM opensearch_dashboards_sample_data_logs " + - "GROUP BY date " + - "ORDER BY date DESC"); - - JSONObject parsedQuery = new JSONObject(dsl); - - JSONObject query = (JSONObject)parsedQuery.query("/aggregations/date/terms/script"); - - assertThat(query, notNullValue()); - } - - private String parseAsSimpleQuery(String originalQuery) throws SqlParseException { - SqlRequest sqlRequest = new SqlRequest(originalQuery, new JSONObject()); - DefaultQueryAction defaultQueryAction = new DefaultQueryAction(mock(Client.class), - new SqlParser().parseSelect(parse(originalQuery))); - defaultQueryAction.setSqlRequest(sqlRequest); - return defaultQueryAction.explain().explain(); - } - - private String parseAsAggregationQuery(String originalQuery) throws SqlParseException { - return new AggregationQueryAction(mock(Client.class), - new SqlParser().parseSelect(parse(originalQuery))).explain().explain(); - } + private static final String TEST_MAPPING_FILE = "mappings/semantics.json"; + + @Before + public void setUp() throws IOException { + URL url = Resources.getResource(TEST_MAPPING_FILE); + String mappings = Resources.toString(url, Charsets.UTF_8); + mockLocalClusterState(mappings); + } + + @Test + public void orderByAliasedFieldTest() throws SqlParseException { + String originalQuery = + "SELECT utc_time date " + + "FROM opensearch_dashboards_sample_data_logs " + + "ORDER BY date DESC"; + String originalDsl = parseAsSimpleQuery(originalQuery); + + String rewrittenQuery = + "SELECT utc_time date " + + "FROM opensearch_dashboards_sample_data_logs " + + "ORDER BY utc_time DESC"; + + String rewrittenDsl = parseAsSimpleQuery(rewrittenQuery); + + assertThat(originalDsl, equalTo(rewrittenDsl)); + } + + @Test + public void orderByAliasedScriptedField() throws SqlParseException { + String originalDsl = + parseAsSimpleQuery( + "SELECT date_format(birthday, 'dd-MM-YYYY') date " + "FROM bank " + "ORDER BY date"); + String rewrittenQuery = + "SELECT date_format(birthday, 'dd-MM-YYYY') date " + + "FROM bank " + + "ORDER BY date_format(birthday, 'dd-MM-YYYY')"; + + String rewrittenDsl = parseAsSimpleQuery(rewrittenQuery); + assertThat(originalDsl, equalTo(rewrittenDsl)); + } + + @Test + public void groupByAliasedFieldTest() throws SqlParseException { + String originalQuery = + "SELECT utc_time date " + "FROM opensearch_dashboards_sample_data_logs " + "GROUP BY date"; + + String originalDsl = parseAsAggregationQuery(originalQuery); + + String rewrittenQuery = + "SELECT utc_time date " + + "FROM opensearch_dashboards_sample_data_logs " + + "GROUP BY utc_time DESC"; + + String rewrittenDsl = parseAsAggregationQuery(rewrittenQuery); + + assertThat(originalDsl, equalTo(rewrittenDsl)); + } + + @Test + public void groupAndSortBySameExprAlias() throws SqlParseException { + String query = + "SELECT date_format(timestamp, 'yyyy-MM') opensearch-table.timestamp_tg, COUNT(*) count," + + " COUNT(DistanceKilometers) opensearch-table.DistanceKilometers_count\n" + + "FROM opensearch_dashboards_sample_data_flights\n" + + "GROUP BY date_format(timestamp, 'yyyy-MM')\n" + + "ORDER BY date_format(timestamp, 'yyyy-MM') DESC\n" + + "LIMIT 2500"; + String dsl = parseAsAggregationQuery(query); + + JSONObject parseQuery = new JSONObject(dsl); + + assertThat( + parseQuery.query("/aggregations/opensearch-table.timestamp_tg/terms/script"), + notNullValue()); + } + + @Test + public void groupByAndSortAliased() throws SqlParseException { + String dsl = + parseAsAggregationQuery( + "SELECT date_format(utc_time, 'dd-MM-YYYY') date " + + "FROM opensearch_dashboards_sample_data_logs " + + "GROUP BY date " + + "ORDER BY date DESC"); + + JSONObject parsedQuery = new JSONObject(dsl); + + JSONObject query = (JSONObject) parsedQuery.query("/aggregations/date/terms/script"); + + assertThat(query, notNullValue()); + } + + private String parseAsSimpleQuery(String originalQuery) throws SqlParseException { + SqlRequest sqlRequest = new SqlRequest(originalQuery, new JSONObject()); + DefaultQueryAction defaultQueryAction = + new DefaultQueryAction( + mock(Client.class), new SqlParser().parseSelect(parse(originalQuery))); + defaultQueryAction.setSqlRequest(sqlRequest); + return defaultQueryAction.explain().explain(); + } + + private String parseAsAggregationQuery(String originalQuery) throws SqlParseException { + return new AggregationQueryAction( + mock(Client.class), new SqlParser().parseSelect(parse(originalQuery))) + .explain() + .explain(); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/utils/BackticksUnquoterTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/utils/BackticksUnquoterTest.java index b0c6b8a2d8..c7e7f22d5c 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/utils/BackticksUnquoterTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/utils/BackticksUnquoterTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.unittest.utils; import static org.hamcrest.MatcherAssert.assertThat; @@ -15,28 +14,29 @@ import org.opensearch.sql.legacy.utils.StringUtils; /** - * To test the functionality of {@link StringUtils#unquoteSingleField} - * and {@link StringUtils#unquoteFullColumn(String, String)} + * To test the functionality of {@link StringUtils#unquoteSingleField} and {@link + * StringUtils#unquoteFullColumn(String, String)} */ public class BackticksUnquoterTest { - @Test - public void assertNotQuotedStringShouldKeepTheSame() { - assertThat(unquoteSingleField("identifier"), equalTo("identifier")); - assertThat(unquoteFullColumn("identifier"), equalTo("identifier")); - } - - @Test - public void assertStringWithOneBackTickShouldKeepTheSame() { - assertThat(unquoteSingleField("`identifier"), equalTo("`identifier")); - assertThat(unquoteFullColumn("`identifier"), equalTo("`identifier")); - } - - @Test - public void assertBackticksQuotedStringShouldBeUnquoted() { - assertThat("identifier", equalTo(unquoteSingleField("`identifier`"))); - - assertThat("identifier1.identifier2", equalTo(unquoteFullColumn("`identifier1`.`identifier2`"))); - assertThat("identifier1.identifier2", equalTo(unquoteFullColumn("`identifier1`.identifier2"))); - } + @Test + public void assertNotQuotedStringShouldKeepTheSame() { + assertThat(unquoteSingleField("identifier"), equalTo("identifier")); + assertThat(unquoteFullColumn("identifier"), equalTo("identifier")); + } + + @Test + public void assertStringWithOneBackTickShouldKeepTheSame() { + assertThat(unquoteSingleField("`identifier"), equalTo("`identifier")); + assertThat(unquoteFullColumn("`identifier"), equalTo("`identifier")); + } + + @Test + public void assertBackticksQuotedStringShouldBeUnquoted() { + assertThat("identifier", equalTo(unquoteSingleField("`identifier`"))); + + assertThat( + "identifier1.identifier2", equalTo(unquoteFullColumn("`identifier1`.`identifier2`"))); + assertThat("identifier1.identifier2", equalTo(unquoteFullColumn("`identifier1`.identifier2"))); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/util/AggregationUtils.java b/legacy/src/test/java/org/opensearch/sql/legacy/util/AggregationUtils.java index 58fa8793ff..85da1d990f 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/util/AggregationUtils.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/util/AggregationUtils.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.util; import com.fasterxml.jackson.core.JsonFactory; @@ -41,42 +40,52 @@ import org.opensearch.search.aggregations.pipeline.PercentilesBucketPipelineAggregationBuilder; public class AggregationUtils { - private final static List entryList = - new ImmutableMap.Builder>().put( - MinAggregationBuilder.NAME, (p, c) -> ParsedMin.fromXContent(p, (String) c)) - .put(MaxAggregationBuilder.NAME, (p, c) -> ParsedMax.fromXContent(p, (String) c)) - .put(SumAggregationBuilder.NAME, (p, c) -> ParsedSum.fromXContent(p, (String) c)) - .put(AvgAggregationBuilder.NAME, (p, c) -> ParsedAvg.fromXContent(p, (String) c)) - .put(StringTerms.NAME, (p, c) -> ParsedStringTerms.fromXContent(p, (String) c)) - .put(LongTerms.NAME, (p, c) -> ParsedLongTerms.fromXContent(p, (String) c)) - .put(DoubleTerms.NAME, (p, c) -> ParsedDoubleTerms.fromXContent(p, (String) c)) - .put(ValueCountAggregationBuilder.NAME, (p, c) -> ParsedValueCount.fromXContent(p, (String) c)) - .put(PercentilesBucketPipelineAggregationBuilder.NAME, - (p, c) -> ParsedPercentilesBucket.fromXContent(p, (String) c)) - .put(DateHistogramAggregationBuilder.NAME, (p, c) -> ParsedDateHistogram.fromXContent(p, (String) c)) - .build() - .entrySet() - .stream() - .map(entry -> new NamedXContentRegistry.Entry(Aggregation.class, new ParseField(entry.getKey()), - entry.getValue())) - .collect(Collectors.toList()); - private final static NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry(entryList); + private static final List entryList = + new ImmutableMap.Builder>() + .put(MinAggregationBuilder.NAME, (p, c) -> ParsedMin.fromXContent(p, (String) c)) + .put(MaxAggregationBuilder.NAME, (p, c) -> ParsedMax.fromXContent(p, (String) c)) + .put(SumAggregationBuilder.NAME, (p, c) -> ParsedSum.fromXContent(p, (String) c)) + .put(AvgAggregationBuilder.NAME, (p, c) -> ParsedAvg.fromXContent(p, (String) c)) + .put(StringTerms.NAME, (p, c) -> ParsedStringTerms.fromXContent(p, (String) c)) + .put(LongTerms.NAME, (p, c) -> ParsedLongTerms.fromXContent(p, (String) c)) + .put(DoubleTerms.NAME, (p, c) -> ParsedDoubleTerms.fromXContent(p, (String) c)) + .put( + ValueCountAggregationBuilder.NAME, + (p, c) -> ParsedValueCount.fromXContent(p, (String) c)) + .put( + PercentilesBucketPipelineAggregationBuilder.NAME, + (p, c) -> ParsedPercentilesBucket.fromXContent(p, (String) c)) + .put( + DateHistogramAggregationBuilder.NAME, + (p, c) -> ParsedDateHistogram.fromXContent(p, (String) c)) + .build() + .entrySet() + .stream() + .map( + entry -> + new NamedXContentRegistry.Entry( + Aggregation.class, new ParseField(entry.getKey()), entry.getValue())) + .collect(Collectors.toList()); + private static final NamedXContentRegistry namedXContentRegistry = + new NamedXContentRegistry(entryList); - /** - * Populate {@link Aggregations} from JSON string. - * @param json json string - * @return {@link Aggregations} - */ - public static Aggregations fromJson(String json) { - try { - XContentParser xContentParser = new JsonXContentParser( - namedXContentRegistry, - LoggingDeprecationHandler.INSTANCE, - new JsonFactory().createParser(json)); - xContentParser.nextToken(); - return Aggregations.fromXContent(xContentParser); - } catch (IOException e) { - throw new RuntimeException(e); - } + /** + * Populate {@link Aggregations} from JSON string. + * + * @param json json string + * @return {@link Aggregations} + */ + public static Aggregations fromJson(String json) { + try { + XContentParser xContentParser = + new JsonXContentParser( + namedXContentRegistry, + LoggingDeprecationHandler.INSTANCE, + new JsonFactory().createParser(json)); + xContentParser.nextToken(); + return Aggregations.fromXContent(xContentParser); + } catch (IOException e) { + throw new RuntimeException(e); } + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/util/CheckScriptContents.java b/legacy/src/test/java/org/opensearch/sql/legacy/util/CheckScriptContents.java index 2396ca5924..7578720624 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/util/CheckScriptContents.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/util/CheckScriptContents.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.util; import static java.util.Collections.emptyList; @@ -13,7 +12,6 @@ import static org.mockito.Mockito.any; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; import static org.opensearch.search.builder.SearchSourceBuilder.ScriptField; @@ -59,205 +57,211 @@ public class CheckScriptContents { - private static SQLExpr queryToExpr(String query) { - return new ElasticSqlExprParser(query).expr(); - } + private static SQLExpr queryToExpr(String query) { + return new ElasticSqlExprParser(query).expr(); + } - public static ScriptField getScriptFieldFromQuery(String query) { - try { - Client mockClient = mock(Client.class); - stubMockClient(mockClient); - QueryAction queryAction = OpenSearchActionFactory.create(mockClient, query); - SqlElasticRequestBuilder requestBuilder = queryAction.explain(); + public static ScriptField getScriptFieldFromQuery(String query) { + try { + Client mockClient = mock(Client.class); + stubMockClient(mockClient); + QueryAction queryAction = OpenSearchActionFactory.create(mockClient, query); + SqlElasticRequestBuilder requestBuilder = queryAction.explain(); - SearchRequestBuilder request = (SearchRequestBuilder) requestBuilder.getBuilder(); - List scriptFields = request.request().source().scriptFields(); + SearchRequestBuilder request = (SearchRequestBuilder) requestBuilder.getBuilder(); + List scriptFields = request.request().source().scriptFields(); - assertTrue(scriptFields.size() == 1); + assertTrue(scriptFields.size() == 1); - return scriptFields.get(0); + return scriptFields.get(0); - } catch (SQLFeatureNotSupportedException | SqlParseException | SQLFeatureDisabledException e) { - throw new ParserException("Unable to parse query: " + query, e); - } + } catch (SQLFeatureNotSupportedException | SqlParseException | SQLFeatureDisabledException e) { + throw new ParserException("Unable to parse query: " + query, e); } + } - public static ScriptFilter getScriptFilterFromQuery(String query, SqlParser parser) { - try { - Select select = parser.parseSelect((SQLQueryExpr) queryToExpr(query)); - Where where = select.getWhere(); - - assertTrue(where.getWheres().size() == 1); - assertTrue(((Condition) (where.getWheres().get(0))).getValue() instanceof ScriptFilter); + public static ScriptFilter getScriptFilterFromQuery(String query, SqlParser parser) { + try { + Select select = parser.parseSelect((SQLQueryExpr) queryToExpr(query)); + Where where = select.getWhere(); - return (ScriptFilter) (((Condition) (where.getWheres().get(0))).getValue()); + assertTrue(where.getWheres().size() == 1); + assertTrue(((Condition) (where.getWheres().get(0))).getValue() instanceof ScriptFilter); - } catch (SqlParseException e) { - throw new ParserException("Unable to parse query: " + query); - } - } + return (ScriptFilter) (((Condition) (where.getWheres().get(0))).getValue()); - public static boolean scriptContainsString(ScriptField scriptField, String string) { - return scriptField.script().getIdOrCode().contains(string); + } catch (SqlParseException e) { + throw new ParserException("Unable to parse query: " + query); } + } - public static boolean scriptContainsString(ScriptFilter scriptFilter, String string) { - return scriptFilter.getScript().contains(string); - } + public static boolean scriptContainsString(ScriptField scriptField, String string) { + return scriptField.script().getIdOrCode().contains(string); + } - public static boolean scriptHasPattern(ScriptField scriptField, String regex) { - Pattern pattern = Pattern.compile(regex); - Matcher matcher = pattern.matcher(scriptField.script().getIdOrCode()); - return matcher.find(); - } + public static boolean scriptContainsString(ScriptFilter scriptFilter, String string) { + return scriptFilter.getScript().contains(string); + } - public static boolean scriptHasPattern(ScriptFilter scriptFilter, String regex) { - Pattern pattern = Pattern.compile(regex); - Matcher matcher = pattern.matcher(scriptFilter.getScript()); - return matcher.find(); - } + public static boolean scriptHasPattern(ScriptField scriptField, String regex) { + Pattern pattern = Pattern.compile(regex); + Matcher matcher = pattern.matcher(scriptField.script().getIdOrCode()); + return matcher.find(); + } - public static void stubMockClient(Client mockClient) { - String mappings = "{\n" + - " \"opensearch-sql_test_index_bank\": {\n" + - " \"mappings\": {\n" + - " \"account\": {\n" + - " \"properties\": {\n" + - " \"account_number\": {\n" + - " \"type\": \"long\"\n" + - " },\n" + - " \"address\": {\n" + - " \"type\": \"text\"\n" + - " },\n" + - " \"age\": {\n" + - " \"type\": \"integer\"\n" + - " },\n" + - " \"balance\": {\n" + - " \"type\": \"long\"\n" + - " },\n" + - " \"birthdate\": {\n" + - " \"type\": \"date\"\n" + - " },\n" + - " \"city\": {\n" + - " \"type\": \"keyword\"\n" + - " },\n" + - " \"email\": {\n" + - " \"type\": \"text\"\n" + - " },\n" + - " \"employer\": {\n" + - " \"type\": \"text\",\n" + - " \"fields\": {\n" + - " \"keyword\": {\n" + - " \"type\": \"keyword\",\n" + - " \"ignore_above\": 256\n" + - " }\n" + - " }\n" + - " },\n" + - " \"firstname\": {\n" + - " \"type\": \"text\"\n" + - " },\n" + - " \"gender\": {\n" + - " \"type\": \"text\"\n" + - " },\n" + - " \"lastname\": {\n" + - " \"type\": \"keyword\"\n" + - " },\n" + - " \"male\": {\n" + - " \"type\": \"boolean\"\n" + - " },\n" + - " \"state\": {\n" + - " \"type\": \"text\",\n" + - " \"fields\": {\n" + - " \"raw\": {\n" + - " \"type\": \"keyword\",\n" + - " \"ignore_above\": 256\n" + - " }\n" + - " }\n" + - " }\n" + - " }\n" + - " }\n" + - " },\n" + - // ==== All required by IndexMetaData.fromXContent() ==== - " \"settings\": {\n" + - " \"index\": {\n" + - " \"number_of_shards\": 5,\n" + - " \"number_of_replicas\": 0,\n" + - " \"version\": {\n" + - " \"created\": \"6050399\"\n" + - " }\n" + - " }\n" + - " },\n" + - " \"mapping_version\": \"1\",\n" + - " \"settings_version\": \"1\"\n" + - //======================================================= - " }\n" + - "}"; + public static boolean scriptHasPattern(ScriptFilter scriptFilter, String regex) { + Pattern pattern = Pattern.compile(regex); + Matcher matcher = pattern.matcher(scriptFilter.getScript()); + return matcher.find(); + } - AdminClient mockAdminClient = mock(AdminClient.class); - when(mockClient.admin()).thenReturn(mockAdminClient); + public static void stubMockClient(Client mockClient) { + String mappings = + "{\n" + + " \"opensearch-sql_test_index_bank\": {\n" + + " \"mappings\": {\n" + + " \"account\": {\n" + + " \"properties\": {\n" + + " \"account_number\": {\n" + + " \"type\": \"long\"\n" + + " },\n" + + " \"address\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"age\": {\n" + + " \"type\": \"integer\"\n" + + " },\n" + + " \"balance\": {\n" + + " \"type\": \"long\"\n" + + " },\n" + + " \"birthdate\": {\n" + + " \"type\": \"date\"\n" + + " },\n" + + " \"city\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"email\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"employer\": {\n" + + " \"type\": \"text\",\n" + + " \"fields\": {\n" + + " \"keyword\": {\n" + + " \"type\": \"keyword\",\n" + + " \"ignore_above\": 256\n" + + " }\n" + + " }\n" + + " },\n" + + " \"firstname\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"gender\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"lastname\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"male\": {\n" + + " \"type\": \"boolean\"\n" + + " },\n" + + " \"state\": {\n" + + " \"type\": \"text\",\n" + + " \"fields\": {\n" + + " \"raw\": {\n" + + " \"type\": \"keyword\",\n" + + " \"ignore_above\": 256\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + + // ==== All required by IndexMetaData.fromXContent() ==== + " \"settings\": {\n" + + " \"index\": {\n" + + " \"number_of_shards\": 5,\n" + + " \"number_of_replicas\": 0,\n" + + " \"version\": {\n" + + " \"created\": \"6050399\"\n" + + " }\n" + + " }\n" + + " },\n" + + " \"mapping_version\": \"1\",\n" + + " \"settings_version\": \"1\",\n" + + " \"aliases_version\": \"1\"\n" + + + // ======================================================= + " }\n" + + "}"; - IndicesAdminClient mockIndexClient = mock(IndicesAdminClient.class); - when(mockAdminClient.indices()).thenReturn(mockIndexClient); + AdminClient mockAdminClient = mock(AdminClient.class); + when(mockClient.admin()).thenReturn(mockAdminClient); - ActionFuture mockActionResp = mock(ActionFuture.class); - when(mockIndexClient.getFieldMappings(any(GetFieldMappingsRequest.class))).thenReturn(mockActionResp); - mockLocalClusterState(mappings); - } + IndicesAdminClient mockIndexClient = mock(IndicesAdminClient.class); + when(mockAdminClient.indices()).thenReturn(mockIndexClient); - public static XContentParser createParser(String mappings) throws IOException { - return XContentType.JSON.xContent().createParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.THROW_UNSUPPORTED_OPERATION, - mappings - ); - } + ActionFuture mockActionResp = mock(ActionFuture.class); + when(mockIndexClient.getFieldMappings(any(GetFieldMappingsRequest.class))) + .thenReturn(mockActionResp); + mockLocalClusterState(mappings); + } - public static void mockLocalClusterState(String mappings) { - LocalClusterState.state().setClusterService(mockClusterService(mappings)); - LocalClusterState.state().setResolver(mockIndexNameExpressionResolver()); - LocalClusterState.state().setPluginSettings(mockPluginSettings()); - } + public static XContentParser createParser(String mappings) throws IOException { + return XContentType.JSON + .xContent() + .createParser( + NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, mappings); + } - public static ClusterService mockClusterService(String mappings) { - ClusterService mockService = mock(ClusterService.class); - ClusterState mockState = mock(ClusterState.class); - Metadata mockMetaData = mock(Metadata.class); + public static void mockLocalClusterState(String mappings) { + LocalClusterState.state().setClusterService(mockClusterService(mappings)); + LocalClusterState.state().setResolver(mockIndexNameExpressionResolver()); + LocalClusterState.state().setPluginSettings(mockPluginSettings()); + } - when(mockService.state()).thenReturn(mockState); - when(mockState.metadata()).thenReturn(mockMetaData); - try { - when(mockMetaData.findMappings(any(), any())).thenReturn( - Map.of(TestsConstants.TEST_INDEX_BANK, IndexMetadata.fromXContent( - createParser(mappings)).mapping())); - } - catch (IOException e) { - throw new IllegalStateException(e); - } - return mockService; - } + public static ClusterService mockClusterService(String mappings) { + ClusterService mockService = mock(ClusterService.class); + ClusterState mockState = mock(ClusterState.class); + Metadata mockMetaData = mock(Metadata.class); - public static IndexNameExpressionResolver mockIndexNameExpressionResolver() { - IndexNameExpressionResolver mockResolver = mock(IndexNameExpressionResolver.class); - when(mockResolver.concreteIndexNames(any(), any(), anyBoolean(), anyString())).thenAnswer( - (Answer) invocation -> { - // Return index expression directly without resolving - Object indexExprs = invocation.getArguments()[3]; - if (indexExprs instanceof String) { - return new String[]{ (String) indexExprs }; - } - return (String[]) indexExprs; - } - ); - return mockResolver; + when(mockService.state()).thenReturn(mockState); + when(mockState.metadata()).thenReturn(mockMetaData); + try { + when(mockMetaData.findMappings(any(), any())) + .thenReturn( + Map.of( + TestsConstants.TEST_INDEX_BANK, + IndexMetadata.fromXContent(createParser(mappings)).mapping())); + } catch (IOException e) { + throw new IllegalStateException(e); } + return mockService; + } - public static OpenSearchSettings mockPluginSettings() { - OpenSearchSettings settings = mock(OpenSearchSettings.class); + public static IndexNameExpressionResolver mockIndexNameExpressionResolver() { + IndexNameExpressionResolver mockResolver = mock(IndexNameExpressionResolver.class); + when(mockResolver.concreteIndexNames(any(), any(), anyBoolean(), anyString())) + .thenAnswer( + (Answer) + invocation -> { + // Return index expression directly without resolving + Object indexExprs = invocation.getArguments()[3]; + if (indexExprs instanceof String) { + return new String[] {(String) indexExprs}; + } + return (String[]) indexExprs; + }); + return mockResolver; + } - // Force return empty list to avoid ClusterSettings be invoked which is a final class and hard to mock. - // In this case, default value in Setting will be returned all the time. - doReturn(emptyList()).when(settings).getSettings(); - return settings; - } + public static OpenSearchSettings mockPluginSettings() { + OpenSearchSettings settings = mock(OpenSearchSettings.class); + // Force return empty list to avoid ClusterSettings be invoked which is a final class and hard + // to mock. + // In this case, default value in Setting will be returned all the time. + doReturn(emptyList()).when(settings).getSettings(); + return settings; + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/type/OpenSearchTextType.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/type/OpenSearchTextType.java index 67b7296834..706d49afda 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/type/OpenSearchTextType.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/type/OpenSearchTextType.java @@ -15,8 +15,8 @@ import org.opensearch.sql.data.type.ExprType; /** - * The type of a text value. See - * doc + * The type of a text value. See doc */ public class OpenSearchTextType extends OpenSearchDataType { @@ -24,8 +24,7 @@ public class OpenSearchTextType extends OpenSearchDataType { // text could have fields // a read-only collection - @EqualsAndHashCode.Exclude - Map fields = ImmutableMap.of(); + @EqualsAndHashCode.Exclude Map fields = ImmutableMap.of(); private OpenSearchTextType() { super(MappingType.Text); @@ -34,6 +33,7 @@ private OpenSearchTextType() { /** * Constructs a Text Type using the passed in fields argument. + * * @param fields The fields to be used to construct the text type. * @return A new OpenSeachTextTypeObject */ @@ -67,7 +67,7 @@ protected OpenSearchDataType cloneEmpty() { } /** - * Text field doesn't have doc value (exception thrown even when you call "get") + * Text field doesn't have doc value (exception thrown even when you call "get")
* Limitation: assume inner field name is always "keyword". */ public static String convertTextToKeyword(String fieldName, ExprType fieldType) { diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactoryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactoryTest.java index 827606a961..bfc06b94c0 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactoryTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactoryTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.opensearch.data.value; import static org.junit.jupiter.api.Assertions.assertAll; @@ -88,8 +87,8 @@ class OpenSearchExprValueFactoryTest { .put("timeNoMillisOrTimeV", OpenSearchDateType.of("time_no_millis || time")) .put("dateOrOrdinalDateV", OpenSearchDateType.of("date || ordinal_date")) .put("customFormatV", OpenSearchDateType.of("yyyy-MM-dd-HH-mm-ss")) - .put("customAndEpochMillisV", - OpenSearchDateType.of("yyyy-MM-dd-HH-mm-ss || epoch_millis")) + .put( + "customAndEpochMillisV", OpenSearchDateType.of("yyyy-MM-dd-HH-mm-ss || epoch_millis")) .put("incompleteFormatV", OpenSearchDateType.of("year")) .put("boolV", OpenSearchDataType.of(BOOLEAN)) .put("structV", OpenSearchDataType.of(STRUCT)) @@ -98,20 +97,22 @@ class OpenSearchExprValueFactoryTest { .put("arrayV", OpenSearchDataType.of(ARRAY)) .put("arrayV.info", OpenSearchDataType.of(STRING)) .put("arrayV.author", OpenSearchDataType.of(STRING)) - .put("deepNestedV", OpenSearchDataType.of( - OpenSearchDataType.of(OpenSearchDataType.MappingType.Nested)) - ) - .put("deepNestedV.year", OpenSearchDataType.of( - OpenSearchDataType.of(OpenSearchDataType.MappingType.Nested)) - ) + .put( + "deepNestedV", + OpenSearchDataType.of(OpenSearchDataType.of(OpenSearchDataType.MappingType.Nested))) + .put( + "deepNestedV.year", + OpenSearchDataType.of(OpenSearchDataType.of(OpenSearchDataType.MappingType.Nested))) .put("deepNestedV.year.timeV", OpenSearchDateType.of(TIME)) - .put("nestedV", OpenSearchDataType.of( - OpenSearchDataType.of(OpenSearchDataType.MappingType.Nested)) - ) + .put( + "nestedV", + OpenSearchDataType.of(OpenSearchDataType.of(OpenSearchDataType.MappingType.Nested))) .put("nestedV.count", OpenSearchDataType.of(INTEGER)) .put("textV", OpenSearchDataType.of(OpenSearchDataType.MappingType.Text)) - .put("textKeywordV", OpenSearchTextType.of(Map.of("words", - OpenSearchDataType.of(OpenSearchDataType.MappingType.Keyword)))) + .put( + "textKeywordV", + OpenSearchTextType.of( + Map.of("words", OpenSearchDataType.of(OpenSearchDataType.MappingType.Keyword)))) .put("ipV", OpenSearchDataType.of(OpenSearchDataType.MappingType.Ip)) .put("geoV", OpenSearchDataType.of(OpenSearchDataType.MappingType.GeoPoint)) .put("binaryV", OpenSearchDataType.of(OpenSearchDataType.MappingType.Binary)) @@ -124,9 +125,8 @@ class OpenSearchExprValueFactoryTest { public void constructNullValue() { assertAll( () -> assertEquals(nullValue(), tupleValue("{\"intV\":null}").get("intV")), - () -> assertEquals(nullValue(), constructFromObject("intV", null)), - () -> assertTrue(new OpenSearchJsonContent(null).isNull()) - ); + () -> assertEquals(nullValue(), constructFromObject("intV", null)), + () -> assertTrue(new OpenSearchJsonContent(null).isNull())); } @Test @@ -136,8 +136,7 @@ public void iterateArrayValue() throws JsonProcessingException { assertAll( () -> assertEquals("zz", arrayIt.next().stringValue()), () -> assertEquals("bb", arrayIt.next().stringValue()), - () -> assertFalse(arrayIt.hasNext()) - ); + () -> assertFalse(arrayIt.hasNext())); } @Test @@ -146,8 +145,7 @@ public void iterateArrayValueWithOneElement() throws JsonProcessingException { var arrayIt = new OpenSearchJsonContent(mapper.readTree("[\"zz\"]")).array(); assertAll( () -> assertEquals("zz", arrayIt.next().stringValue()), - () -> assertFalse(arrayIt.hasNext()) - ); + () -> assertFalse(arrayIt.hasNext())); } @Test @@ -160,8 +158,7 @@ public void constructByte() { assertAll( () -> assertEquals(byteValue((byte) 1), tupleValue("{\"byteV\":1}").get("byteV")), () -> assertEquals(byteValue((byte) 1), constructFromObject("byteV", 1)), - () -> assertEquals(byteValue((byte) 1), constructFromObject("byteV", "1.0")) - ); + () -> assertEquals(byteValue((byte) 1), constructFromObject("byteV", "1.0"))); } @Test @@ -169,8 +166,7 @@ public void constructShort() { assertAll( () -> assertEquals(shortValue((short) 1), tupleValue("{\"shortV\":1}").get("shortV")), () -> assertEquals(shortValue((short) 1), constructFromObject("shortV", 1)), - () -> assertEquals(shortValue((short) 1), constructFromObject("shortV", "1.0")) - ); + () -> assertEquals(shortValue((short) 1), constructFromObject("shortV", "1.0"))); } @Test @@ -178,8 +174,7 @@ public void constructInteger() { assertAll( () -> assertEquals(integerValue(1), tupleValue("{\"intV\":1}").get("intV")), () -> assertEquals(integerValue(1), constructFromObject("intV", 1)), - () -> assertEquals(integerValue(1), constructFromObject("intV", "1.0")) - ); + () -> assertEquals(integerValue(1), constructFromObject("intV", "1.0"))); } @Test @@ -192,33 +187,29 @@ public void constructLong() { assertAll( () -> assertEquals(longValue(1L), tupleValue("{\"longV\":1}").get("longV")), () -> assertEquals(longValue(1L), constructFromObject("longV", 1L)), - () -> assertEquals(longValue(1L), constructFromObject("longV", "1.0")) - ); + () -> assertEquals(longValue(1L), constructFromObject("longV", "1.0"))); } @Test public void constructFloat() { assertAll( () -> assertEquals(floatValue(1f), tupleValue("{\"floatV\":1.0}").get("floatV")), - () -> assertEquals(floatValue(1f), constructFromObject("floatV", 1f)) - ); + () -> assertEquals(floatValue(1f), constructFromObject("floatV", 1f))); } @Test public void constructDouble() { assertAll( () -> assertEquals(doubleValue(1d), tupleValue("{\"doubleV\":1.0}").get("doubleV")), - () -> assertEquals(doubleValue(1d), constructFromObject("doubleV", 1d)) - ); + () -> assertEquals(doubleValue(1d), constructFromObject("doubleV", 1d))); } @Test public void constructString() { assertAll( - () -> assertEquals(stringValue("text"), - tupleValue("{\"stringV\":\"text\"}").get("stringV")), - () -> assertEquals(stringValue("text"), constructFromObject("stringV", "text")) - ); + () -> + assertEquals(stringValue("text"), tupleValue("{\"stringV\":\"text\"}").get("stringV")), + () -> assertEquals(stringValue("text"), constructFromObject("stringV", "text"))); } @Test @@ -228,23 +219,25 @@ public void constructBoolean() { () -> assertEquals(booleanValue(true), constructFromObject("boolV", true)), () -> assertEquals(booleanValue(true), constructFromObject("boolV", "true")), () -> assertEquals(booleanValue(true), constructFromObject("boolV", 1)), - () -> assertEquals(booleanValue(false), constructFromObject("boolV", 0)) - ); + () -> assertEquals(booleanValue(false), constructFromObject("boolV", 0))); } @Test public void constructText() { assertAll( - () -> assertEquals(new OpenSearchExprTextValue("text"), - tupleValue("{\"textV\":\"text\"}").get("textV")), - () -> assertEquals(new OpenSearchExprTextValue("text"), - constructFromObject("textV", "text")), - - () -> assertEquals(new OpenSearchExprTextValue("text"), - tupleValue("{\"textKeywordV\":\"text\"}").get("textKeywordV")), - () -> assertEquals(new OpenSearchExprTextValue("text"), - constructFromObject("textKeywordV", "text")) - ); + () -> + assertEquals( + new OpenSearchExprTextValue("text"), + tupleValue("{\"textV\":\"text\"}").get("textV")), + () -> + assertEquals(new OpenSearchExprTextValue("text"), constructFromObject("textV", "text")), + () -> + assertEquals( + new OpenSearchExprTextValue("text"), + tupleValue("{\"textKeywordV\":\"text\"}").get("textKeywordV")), + () -> + assertEquals( + new OpenSearchExprTextValue("text"), constructFromObject("textKeywordV", "text"))); } @Test @@ -252,95 +245,122 @@ public void constructDates() { ExprValue dateStringV = constructFromObject("dateStringV", "1984-04-12"); assertAll( () -> assertEquals(new ExprDateValue("1984-04-12"), dateStringV), - () -> assertEquals(new ExprDateValue( - LocalDate.ofInstant(Instant.ofEpochMilli(450576000000L), UTC_ZONE_ID)), - constructFromObject("dateV", 450576000000L)), - () -> assertEquals(new ExprDateValue("1984-04-12"), - constructFromObject("dateOrOrdinalDateV", "1984-103")), - () -> assertEquals(new ExprDateValue("2015-01-01"), - tupleValue("{\"dateV\":\"2015-01-01\"}").get("dateV")) - ); + () -> + assertEquals( + new ExprDateValue( + LocalDate.ofInstant(Instant.ofEpochMilli(450576000000L), UTC_ZONE_ID)), + constructFromObject("dateV", 450576000000L)), + () -> + assertEquals( + new ExprDateValue("1984-04-12"), + constructFromObject("dateOrOrdinalDateV", "1984-103")), + () -> + assertEquals( + new ExprDateValue("2015-01-01"), + tupleValue("{\"dateV\":\"2015-01-01\"}").get("dateV"))); } @Test public void constructTimes() { - ExprValue timeStringV = constructFromObject("timeStringV","12:10:30.000Z"); + ExprValue timeStringV = constructFromObject("timeStringV", "12:10:30.000Z"); assertAll( () -> assertTrue(timeStringV.isDateTime()), () -> assertTrue(timeStringV instanceof ExprTimeValue), () -> assertEquals(new ExprTimeValue("12:10:30"), timeStringV), - () -> assertEquals(new ExprTimeValue(LocalTime.from( - Instant.ofEpochMilli(1420070400001L).atZone(UTC_ZONE_ID))), - constructFromObject("timeV", 1420070400001L)), - () -> assertEquals(new ExprTimeValue("09:07:42.000"), - constructFromObject("timeNoMillisOrTimeV", "09:07:42.000Z")), - () -> assertEquals(new ExprTimeValue("09:07:42"), - tupleValue("{\"timeV\":\"09:07:42\"}").get("timeV")) - ); + () -> + assertEquals( + new ExprTimeValue( + LocalTime.from(Instant.ofEpochMilli(1420070400001L).atZone(UTC_ZONE_ID))), + constructFromObject("timeV", 1420070400001L)), + () -> + assertEquals( + new ExprTimeValue("09:07:42.000"), + constructFromObject("timeNoMillisOrTimeV", "09:07:42.000Z")), + () -> + assertEquals( + new ExprTimeValue("09:07:42"), + tupleValue("{\"timeV\":\"09:07:42\"}").get("timeV"))); } @Test public void constructDatetime() { assertAll( - () -> assertEquals( - new ExprTimestampValue("2015-01-01 00:00:00"), - tupleValue("{\"timestampV\":\"2015-01-01\"}").get("timestampV")), - () -> assertEquals( - new ExprTimestampValue("2015-01-01 12:10:30"), - tupleValue("{\"timestampV\":\"2015-01-01T12:10:30Z\"}").get("timestampV")), - () -> assertEquals( - new ExprTimestampValue("2015-01-01 12:10:30"), - tupleValue("{\"timestampV\":\"2015-01-01T12:10:30\"}").get("timestampV")), - () -> assertEquals( - new ExprTimestampValue("2015-01-01 12:10:30"), - tupleValue("{\"timestampV\":\"2015-01-01 12:10:30\"}").get("timestampV")), - () -> assertEquals( - new ExprTimestampValue(Instant.ofEpochMilli(1420070400001L)), - constructFromObject("timestampV", 1420070400001L)), - () -> assertEquals( - new ExprTimestampValue(Instant.ofEpochMilli(1420070400001L)), - constructFromObject("timestampV", Instant.ofEpochMilli(1420070400001L))), - () -> assertEquals( - new ExprTimestampValue(Instant.ofEpochMilli(1420070400001L)), - constructFromObject("epochMillisV", "1420070400001")), - () -> assertEquals( - new ExprTimestampValue(Instant.ofEpochMilli(1420070400001L)), - constructFromObject("epochMillisV", 1420070400001L)), - () -> assertEquals( - new ExprTimestampValue(Instant.ofEpochSecond(142704001L)), - constructFromObject("epochSecondV", 142704001L)), - () -> assertEquals( - new ExprTimeValue("10:20:30"), - tupleValue("{ \"timeCustomV\" : 102030 }").get("timeCustomV")), - () -> assertEquals( - new ExprDateValue("1961-04-12"), - tupleValue("{ \"dateCustomV\" : 19610412 }").get("dateCustomV")), - () -> assertEquals( - new ExprTimestampValue("1984-05-10 20:30:40"), - tupleValue("{ \"dateTimeCustomV\" : 19840510203040 }").get("dateTimeCustomV")), - () -> assertEquals( - new ExprTimestampValue("2015-01-01 12:10:30"), - constructFromObject("timestampV", "2015-01-01 12:10:30")), - () -> assertEquals( - new ExprDatetimeValue("2015-01-01 12:10:30"), - constructFromObject("datetimeV", "2015-01-01 12:10:30")), - () -> assertEquals( - new ExprDatetimeValue("2015-01-01 12:10:30"), - constructFromObject("datetimeDefaultV", "2015-01-01 12:10:30")), - () -> assertEquals( - new ExprTimestampValue(Instant.ofEpochMilli(1420070400001L)), - constructFromObject("dateOrEpochMillisV", "1420070400001")), + () -> + assertEquals( + new ExprTimestampValue("2015-01-01 00:00:00"), + tupleValue("{\"timestampV\":\"2015-01-01\"}").get("timestampV")), + () -> + assertEquals( + new ExprTimestampValue("2015-01-01 12:10:30"), + tupleValue("{\"timestampV\":\"2015-01-01T12:10:30Z\"}").get("timestampV")), + () -> + assertEquals( + new ExprTimestampValue("2015-01-01 12:10:30"), + tupleValue("{\"timestampV\":\"2015-01-01T12:10:30\"}").get("timestampV")), + () -> + assertEquals( + new ExprTimestampValue("2015-01-01 12:10:30"), + tupleValue("{\"timestampV\":\"2015-01-01 12:10:30\"}").get("timestampV")), + () -> + assertEquals( + new ExprTimestampValue(Instant.ofEpochMilli(1420070400001L)), + constructFromObject("timestampV", 1420070400001L)), + () -> + assertEquals( + new ExprTimestampValue(Instant.ofEpochMilli(1420070400001L)), + constructFromObject("timestampV", Instant.ofEpochMilli(1420070400001L))), + () -> + assertEquals( + new ExprTimestampValue(Instant.ofEpochMilli(1420070400001L)), + constructFromObject("epochMillisV", "1420070400001")), + () -> + assertEquals( + new ExprTimestampValue(Instant.ofEpochMilli(1420070400001L)), + constructFromObject("epochMillisV", 1420070400001L)), + () -> + assertEquals( + new ExprTimestampValue(Instant.ofEpochSecond(142704001L)), + constructFromObject("epochSecondV", 142704001L)), + () -> + assertEquals( + new ExprTimeValue("10:20:30"), + tupleValue("{ \"timeCustomV\" : 102030 }").get("timeCustomV")), + () -> + assertEquals( + new ExprDateValue("1961-04-12"), + tupleValue("{ \"dateCustomV\" : 19610412 }").get("dateCustomV")), + () -> + assertEquals( + new ExprTimestampValue("1984-05-10 20:30:40"), + tupleValue("{ \"dateTimeCustomV\" : 19840510203040 }").get("dateTimeCustomV")), + () -> + assertEquals( + new ExprTimestampValue("2015-01-01 12:10:30"), + constructFromObject("timestampV", "2015-01-01 12:10:30")), + () -> + assertEquals( + new ExprDatetimeValue("2015-01-01 12:10:30"), + constructFromObject("datetimeV", "2015-01-01 12:10:30")), + () -> + assertEquals( + new ExprDatetimeValue("2015-01-01 12:10:30"), + constructFromObject("datetimeDefaultV", "2015-01-01 12:10:30")), + () -> + assertEquals( + new ExprTimestampValue(Instant.ofEpochMilli(1420070400001L)), + constructFromObject("dateOrEpochMillisV", "1420070400001")), // case: timestamp-formatted field, but it only gets a time: should match a time - () -> assertEquals( - new ExprTimeValue("19:36:22"), - tupleValue("{\"timestampV\":\"19:36:22\"}").get("timestampV")), + () -> + assertEquals( + new ExprTimeValue("19:36:22"), + tupleValue("{\"timestampV\":\"19:36:22\"}").get("timestampV")), // case: timestamp-formatted field, but it only gets a date: should match a date - () -> assertEquals( - new ExprDateValue("2011-03-03"), - tupleValue("{\"timestampV\":\"2011-03-03\"}").get("timestampV")) - ); + () -> + assertEquals( + new ExprDateValue("2011-03-03"), + tupleValue("{\"timestampV\":\"2011-03-03\"}").get("timestampV"))); } @Test @@ -350,11 +370,11 @@ public void constructDatetime_fromCustomFormat() { constructFromObject("customFormatV", "2015-01-01-12-10-30")); IllegalArgumentException exception = - assertThrows(IllegalArgumentException.class, + assertThrows( + IllegalArgumentException.class, () -> constructFromObject("customFormatV", "2015-01-01 12-10-30")); assertEquals( - "Construct TIMESTAMP from \"2015-01-01 12-10-30\" failed, " - + "unsupported format.", + "Construct TIMESTAMP from \"2015-01-01 12-10-30\" failed, unsupported format.", exception.getMessage()); assertEquals( @@ -369,91 +389,87 @@ public void constructDatetime_fromCustomFormat() { @Test public void constructDatetimeFromUnsupportedFormat_ThrowIllegalArgumentException() { IllegalArgumentException exception = - assertThrows(IllegalArgumentException.class, + assertThrows( + IllegalArgumentException.class, () -> constructFromObject("timestampV", "2015-01-01 12:10")); assertEquals( - "Construct TIMESTAMP from \"2015-01-01 12:10\" failed, " - + "unsupported format.", + "Construct TIMESTAMP from \"2015-01-01 12:10\" failed, unsupported format.", exception.getMessage()); // fail with missing seconds exception = - assertThrows(IllegalArgumentException.class, + assertThrows( + IllegalArgumentException.class, () -> constructFromObject("dateOrEpochMillisV", "2015-01-01 12:10")); assertEquals( - "Construct TIMESTAMP from \"2015-01-01 12:10\" failed, " - + "unsupported format.", + "Construct TIMESTAMP from \"2015-01-01 12:10\" failed, unsupported format.", exception.getMessage()); } @Test public void constructTimeFromUnsupportedFormat_ThrowIllegalArgumentException() { - IllegalArgumentException exception = assertThrows( - IllegalArgumentException.class, () -> constructFromObject("timeV", "2015-01-01")); + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, () -> constructFromObject("timeV", "2015-01-01")); assertEquals( - "Construct TIME from \"2015-01-01\" failed, " - + "unsupported format.", - exception.getMessage()); + "Construct TIME from \"2015-01-01\" failed, unsupported format.", exception.getMessage()); - exception = assertThrows( - IllegalArgumentException.class, () -> constructFromObject("timeStringV", "10:10")); + exception = + assertThrows( + IllegalArgumentException.class, () -> constructFromObject("timeStringV", "10:10")); assertEquals( - "Construct TIME from \"10:10\" failed, " - + "unsupported format.", - exception.getMessage()); + "Construct TIME from \"10:10\" failed, unsupported format.", exception.getMessage()); } @Test public void constructDateFromUnsupportedFormat_ThrowIllegalArgumentException() { - IllegalArgumentException exception = assertThrows( - IllegalArgumentException.class, () -> constructFromObject("dateV", "12:10:10")); + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, () -> constructFromObject("dateV", "12:10:10")); assertEquals( - "Construct DATE from \"12:10:10\" failed, " - + "unsupported format.", - exception.getMessage()); + "Construct DATE from \"12:10:10\" failed, unsupported format.", exception.getMessage()); - exception = assertThrows( - IllegalArgumentException.class, () -> constructFromObject("dateStringV", "abc")); - assertEquals( - "Construct DATE from \"abc\" failed, " - + "unsupported format.", - exception.getMessage()); + exception = + assertThrows( + IllegalArgumentException.class, () -> constructFromObject("dateStringV", "abc")); + assertEquals("Construct DATE from \"abc\" failed, unsupported format.", exception.getMessage()); } @Test public void constructDateFromIncompleteFormat() { - assertEquals( - new ExprDateValue("1984-01-01"), - constructFromObject("incompleteFormatV", "1984")); + assertEquals(new ExprDateValue("1984-01-01"), constructFromObject("incompleteFormatV", "1984")); } @Test public void constructArray() { assertEquals( - new ExprCollectionValue(List.of(new ExprTupleValue( - new LinkedHashMap() { - { - put("info", stringValue("zz")); - put("author", stringValue("au")); - } - }))), + new ExprCollectionValue( + List.of( + new ExprTupleValue( + new LinkedHashMap() { + { + put("info", stringValue("zz")); + put("author", stringValue("au")); + } + }))), tupleValue("{\"arrayV\":[{\"info\":\"zz\",\"author\":\"au\"}]}").get("arrayV")); assertEquals( - new ExprCollectionValue(List.of(new ExprTupleValue( - new LinkedHashMap() { - { - put("info", stringValue("zz")); - put("author", stringValue("au")); - } - }))), - constructFromObject("arrayV", List.of( - ImmutableMap.of("info", "zz", "author", "au")))); + new ExprCollectionValue( + List.of( + new ExprTupleValue( + new LinkedHashMap() { + { + put("info", stringValue("zz")); + put("author", stringValue("au")); + } + }))), + constructFromObject("arrayV", List.of(ImmutableMap.of("info", "zz", "author", "au")))); } @Test public void constructArrayOfStrings() { - assertEquals(new ExprCollectionValue( - List.of(stringValue("zz"), stringValue("au"))), + assertEquals( + new ExprCollectionValue(List.of(stringValue("zz"), stringValue("au"))), constructFromObject("arrayV", List.of("zz", "au"))); } @@ -461,100 +477,71 @@ public void constructArrayOfStrings() { public void constructNestedArraysOfStrings() { assertEquals( new ExprCollectionValue( - List.of( - collectionValue( - List.of("zz", "au") - ), - collectionValue( - List.of("ss") - ) - ) - ), - tupleValueWithArraySupport( - "{\"stringV\":[" - + "[\"zz\", \"au\"]," - + "[\"ss\"]" - + "]}" - ).get("stringV")); + List.of(collectionValue(List.of("zz", "au")), collectionValue(List.of("ss")))), + tupleValueWithArraySupport("{\"stringV\":[ [\"zz\", \"au\"], [\"ss\"] ]}").get("stringV")); } @Test public void constructNestedArraysOfStringsReturnsFirstIndex() { assertEquals( - stringValue("zz"), - tupleValue( - "{\"stringV\":[" - + "[\"zz\", \"au\"]," - + "[\"ss\"]" - + "]}" - ).get("stringV")); + stringValue("zz"), tupleValue("{\"stringV\":[[\"zz\", \"au\"],[\"ss\"]]}").get("stringV")); } @Test public void constructMultiNestedArraysOfStringsReturnsFirstIndex() { assertEquals( stringValue("z"), - tupleValue( - "{\"stringV\":" - + "[\"z\"," - + "[\"s\"]," - + "[\"zz\", \"au\"]" - + "]}" - ).get("stringV")); + tupleValue("{\"stringV\":[\"z\",[\"s\"],[\"zz\", \"au\"]]}").get("stringV")); } @Test public void constructArrayOfInts() { - assertEquals(new ExprCollectionValue( - List.of(integerValue(1), integerValue(2))), + assertEquals( + new ExprCollectionValue(List.of(integerValue(1), integerValue(2))), constructFromObject("arrayV", List.of(1, 2))); } @Test public void constructArrayOfShorts() { // Shorts are treated same as integer - assertEquals(new ExprCollectionValue( - List.of(shortValue((short)3), shortValue((short)4))), + assertEquals( + new ExprCollectionValue(List.of(shortValue((short) 3), shortValue((short) 4))), constructFromObject("arrayV", List.of(3, 4))); } @Test public void constructArrayOfLongs() { - assertEquals(new ExprCollectionValue( - List.of(longValue(123456789L), longValue(987654321L))), + assertEquals( + new ExprCollectionValue(List.of(longValue(123456789L), longValue(987654321L))), constructFromObject("arrayV", List.of(123456789L, 987654321L))); } @Test public void constructArrayOfFloats() { - assertEquals(new ExprCollectionValue( - List.of(floatValue(3.14f), floatValue(4.13f))), + assertEquals( + new ExprCollectionValue(List.of(floatValue(3.14f), floatValue(4.13f))), constructFromObject("arrayV", List.of(3.14f, 4.13f))); } @Test public void constructArrayOfDoubles() { - assertEquals(new ExprCollectionValue( - List.of(doubleValue(9.1928374756D), doubleValue(4.987654321D))), + assertEquals( + new ExprCollectionValue(List.of(doubleValue(9.1928374756D), doubleValue(4.987654321D))), constructFromObject("arrayV", List.of(9.1928374756D, 4.987654321D))); } @Test public void constructArrayOfBooleans() { - assertEquals(new ExprCollectionValue( - List.of(booleanValue(true), booleanValue(false))), + assertEquals( + new ExprCollectionValue(List.of(booleanValue(true), booleanValue(false))), constructFromObject("arrayV", List.of(true, false))); } @Test public void constructNestedObjectArrayNode() { - assertEquals(collectionValue( - List.of( - Map.of("count", 1), - Map.of("count", 2) - )), - tupleValueWithArraySupport("{\"nestedV\":[{\"count\":1},{\"count\":2}]}") - .get("nestedV")); + assertEquals( + collectionValue(List.of(Map.of("count", 1), Map.of("count", 2))), + tupleValueWithArraySupport("{\"nestedV\":[{\"count\":1},{\"count\":2}]}").get("nestedV")); } @Test @@ -562,84 +549,70 @@ public void constructNestedObjectArrayOfObjectArraysNode() { assertEquals( collectionValue( List.of( - Map.of("year", + Map.of( + "year", List.of( Map.of("timeV", new ExprTimeValue("09:07:42")), - Map.of("timeV", new ExprTimeValue("09:07:42")) - ) - ), - Map.of("year", + Map.of("timeV", new ExprTimeValue("09:07:42")))), + Map.of( + "year", List.of( Map.of("timeV", new ExprTimeValue("09:07:42")), - Map.of("timeV", new ExprTimeValue("09:07:42")) - ) - ) - ) - ), + Map.of("timeV", new ExprTimeValue("09:07:42")))))), tupleValueWithArraySupport( - "{\"deepNestedV\":" - + "[" - + "{\"year\":" - + "[" - + "{\"timeV\":\"09:07:42\"}," - + "{\"timeV\":\"09:07:42\"}" - + "]" - + "}," - + "{\"year\":" - + "[" - + "{\"timeV\":\"09:07:42\"}," - + "{\"timeV\":\"09:07:42\"}" - + "]" - + "}" - + "]" - + "}") + "{\"deepNestedV\":" + + " [" + + " {\"year\":" + + " [" + + " {\"timeV\":\"09:07:42\"}," + + " {\"timeV\":\"09:07:42\"}" + + " ]" + + " }," + + " {\"year\":" + + " [" + + " {\"timeV\":\"09:07:42\"}," + + " {\"timeV\":\"09:07:42\"}" + + " ]" + + " }" + + " ]" + + "}") .get("deepNestedV")); } @Test public void constructNestedArrayNode() { - assertEquals(collectionValue( - List.of( - 1969, - 2011 - )), - tupleValueWithArraySupport("{\"nestedV\":[1969,2011]}") - .get("nestedV")); + assertEquals( + collectionValue(List.of(1969, 2011)), + tupleValueWithArraySupport("{\"nestedV\":[1969,2011]}").get("nestedV")); } @Test public void constructNestedObjectNode() { - assertEquals(collectionValue( - List.of( - Map.of("count", 1969) - )), - tupleValue("{\"nestedV\":{\"count\":1969}}") - .get("nestedV")); + assertEquals( + collectionValue(List.of(Map.of("count", 1969))), + tupleValue("{\"nestedV\":{\"count\":1969}}").get("nestedV")); } @Test public void constructArrayOfGeoPoints() { - assertEquals(new ExprCollectionValue( + assertEquals( + new ExprCollectionValue( List.of( new OpenSearchExprGeoPointValue(42.60355556, -97.25263889), - new OpenSearchExprGeoPointValue(-33.6123556, 66.287449)) - ), + new OpenSearchExprGeoPointValue(-33.6123556, 66.287449))), tupleValueWithArraySupport( - "{\"geoV\":[" - + "{\"lat\":42.60355556,\"lon\":-97.25263889}," - + "{\"lat\":-33.6123556,\"lon\":66.287449}" - + "]}" - ).get("geoV") - ); + "{\"geoV\":[" + + "{\"lat\":42.60355556,\"lon\":-97.25263889}," + + "{\"lat\":-33.6123556,\"lon\":66.287449}" + + "]}") + .get("geoV")); } @Test public void constructArrayOfIPsReturnsFirstIndex() { assertEquals( new OpenSearchExprIpValue("192.168.0.1"), - tupleValue("{\"ipV\":[\"192.168.0.1\",\"192.168.0.2\"]}") - .get("ipV") - ); + tupleValue("{\"ipV\":[\"192.168.0.1\",\"192.168.0.2\"]}").get("ipV")); } @Test @@ -647,8 +620,7 @@ public void constructBinaryArrayReturnsFirstIndex() { assertEquals( new OpenSearchExprBinaryValue("U29tZSBiaWsdfsdfgYmxvYg=="), tupleValue("{\"binaryV\":[\"U29tZSBiaWsdfsdfgYmxvYg==\",\"U987yuhjjiy8jhk9vY+98jjdf\"]}") - .get("binaryV") - ); + .get("binaryV")); } @Test @@ -656,26 +628,21 @@ public void constructArrayOfCustomEpochMillisReturnsFirstIndex() { assertEquals( new ExprDatetimeValue("2015-01-01 12:10:30"), tupleValue("{\"customAndEpochMillisV\":[\"2015-01-01 12:10:30\",\"1999-11-09 01:09:44\"]}") - .get("customAndEpochMillisV") - ); + .get("customAndEpochMillisV")); } @Test public void constructArrayOfDateStringsReturnsFirstIndex() { assertEquals( new ExprDateValue("1984-04-12"), - tupleValue("{\"dateStringV\":[\"1984-04-12\",\"2033-05-03\"]}") - .get("dateStringV") - ); + tupleValue("{\"dateStringV\":[\"1984-04-12\",\"2033-05-03\"]}").get("dateStringV")); } @Test public void constructArrayOfTimeStringsReturnsFirstIndex() { assertEquals( new ExprTimeValue("12:10:30"), - tupleValue("{\"timeStringV\":[\"12:10:30.000Z\",\"18:33:55.000Z\"]}") - .get("timeStringV") - ); + tupleValue("{\"timeStringV\":[\"12:10:30.000Z\",\"18:33:55.000Z\"]}").get("timeStringV")); } @Test @@ -683,8 +650,7 @@ public void constructArrayOfEpochMillis() { assertEquals( new ExprTimestampValue(Instant.ofEpochMilli(1420070400001L)), tupleValue("{\"dateOrEpochMillisV\":[\"1420070400001\",\"1454251113333\"]}") - .get("dateOrEpochMillisV") - ); + .get("dateOrEpochMillisV")); } @Test @@ -711,54 +677,64 @@ public void constructStruct() { @Test public void constructIP() { - assertEquals(new OpenSearchExprIpValue("192.168.0.1"), + assertEquals( + new OpenSearchExprIpValue("192.168.0.1"), tupleValue("{\"ipV\":\"192.168.0.1\"}").get("ipV")); } @Test public void constructGeoPoint() { - assertEquals(new OpenSearchExprGeoPointValue(42.60355556, -97.25263889), + assertEquals( + new OpenSearchExprGeoPointValue(42.60355556, -97.25263889), tupleValue("{\"geoV\":{\"lat\":42.60355556,\"lon\":-97.25263889}}").get("geoV")); - assertEquals(new OpenSearchExprGeoPointValue(42.60355556, -97.25263889), + assertEquals( + new OpenSearchExprGeoPointValue(42.60355556, -97.25263889), tupleValue("{\"geoV\":{\"lat\":\"42.60355556\",\"lon\":\"-97.25263889\"}}").get("geoV")); - assertEquals(new OpenSearchExprGeoPointValue(42.60355556, -97.25263889), + assertEquals( + new OpenSearchExprGeoPointValue(42.60355556, -97.25263889), constructFromObject("geoV", "42.60355556,-97.25263889")); } @Test public void constructGeoPointFromUnsupportedFormatShouldThrowException() { IllegalStateException exception = - assertThrows(IllegalStateException.class, + assertThrows( + IllegalStateException.class, () -> tupleValue("{\"geoV\":[42.60355556,-97.25263889]}").get("geoV")); - assertEquals("geo point must in format of {\"lat\": number, \"lon\": number}", - exception.getMessage()); + assertEquals( + "geo point must in format of {\"lat\": number, \"lon\": number}", exception.getMessage()); exception = - assertThrows(IllegalStateException.class, + assertThrows( + IllegalStateException.class, () -> tupleValue("{\"geoV\":{\"lon\":-97.25263889}}").get("geoV")); - assertEquals("geo point must in format of {\"lat\": number, \"lon\": number}", - exception.getMessage()); + assertEquals( + "geo point must in format of {\"lat\": number, \"lon\": number}", exception.getMessage()); exception = - assertThrows(IllegalStateException.class, + assertThrows( + IllegalStateException.class, () -> tupleValue("{\"geoV\":{\"lat\":-97.25263889}}").get("geoV")); - assertEquals("geo point must in format of {\"lat\": number, \"lon\": number}", - exception.getMessage()); + assertEquals( + "geo point must in format of {\"lat\": number, \"lon\": number}", exception.getMessage()); exception = - assertThrows(IllegalStateException.class, + assertThrows( + IllegalStateException.class, () -> tupleValue("{\"geoV\":{\"lat\":true,\"lon\":-97.25263889}}").get("geoV")); assertEquals("latitude must be number value, but got value: true", exception.getMessage()); exception = - assertThrows(IllegalStateException.class, + assertThrows( + IllegalStateException.class, () -> tupleValue("{\"geoV\":{\"lat\":42.60355556,\"lon\":false}}").get("geoV")); assertEquals("longitude must be number value, but got value: false", exception.getMessage()); } @Test public void constructBinary() { - assertEquals(new OpenSearchExprBinaryValue("U29tZSBiaW5hcnkgYmxvYg=="), + assertEquals( + new OpenSearchExprBinaryValue("U29tZSBiaW5hcnkgYmxvYg=="), tupleValue("{\"binaryV\":\"U29tZSBiaW5hcnkgYmxvYg==\"}").get("binaryV")); } @@ -769,14 +745,16 @@ public void constructBinary() { @Test public void constructFromOpenSearchArrayReturnFirstElement() { assertEquals(integerValue(1), tupleValue("{\"intV\":[1, 2, 3]}").get("intV")); - assertEquals(new ExprTupleValue( - new LinkedHashMap() { - { - put("id", integerValue(1)); - put("state", stringValue("WA")); - } - }), tupleValue("{\"structV\":[{\"id\":1,\"state\":\"WA\"},{\"id\":2,\"state\":\"CA\"}]}}") - .get("structV")); + assertEquals( + new ExprTupleValue( + new LinkedHashMap() { + { + put("id", integerValue(1)); + put("state", stringValue("WA")); + } + }), + tupleValue("{\"structV\":[{\"id\":1,\"state\":\"WA\"},{\"id\":2,\"state\":\"CA\"}]}}") + .get("structV")); } @Test @@ -799,19 +777,13 @@ public void constructUnsupportedTypeThrowException() { new OpenSearchExprValueFactory(Map.of("type", new TestType())); IllegalStateException exception = assertThrows( - IllegalStateException.class, - () -> exprValueFactory.construct("{\"type\":1}", false) - ); + IllegalStateException.class, () -> exprValueFactory.construct("{\"type\":1}", false)); assertEquals("Unsupported type: TEST_TYPE for value: 1.", exception.getMessage()); exception = assertThrows( - IllegalStateException.class, - () -> exprValueFactory.construct("type", 1, false) - ); - assertEquals( - "Unsupported type: TEST_TYPE for value: 1.", - exception.getMessage()); + IllegalStateException.class, () -> exprValueFactory.construct("type", 1, false)); + assertEquals("Unsupported type: TEST_TYPE for value: 1.", exception.getMessage()); } @Test @@ -820,21 +792,21 @@ public void constructUnsupportedTypeThrowException() { public void factoryMappingsAreExtendableWithoutOverWrite() throws NoSuchFieldException, IllegalAccessException { var factory = new OpenSearchExprValueFactory(Map.of("value", OpenSearchDataType.of(INTEGER))); - factory.extendTypeMapping(Map.of( - "value", OpenSearchDataType.of(DOUBLE), - "agg", OpenSearchDataType.of(DATE))); + factory.extendTypeMapping( + Map.of( + "value", OpenSearchDataType.of(DOUBLE), + "agg", OpenSearchDataType.of(DATE))); // extract private field for testing purposes var field = factory.getClass().getDeclaredField("typeMapping"); field.setAccessible(true); @SuppressWarnings("unchecked") - var mapping = (Map)field.get(factory); + var mapping = (Map) field.get(factory); assertAll( () -> assertEquals(2, mapping.size()), () -> assertTrue(mapping.containsKey("value")), () -> assertTrue(mapping.containsKey("agg")), () -> assertEquals(OpenSearchDataType.of(INTEGER), mapping.get("value")), - () -> assertEquals(OpenSearchDataType.of(DATE), mapping.get("agg")) - ); + () -> assertEquals(OpenSearchDataType.of(DATE), mapping.get("agg"))); } public Map tupleValue(String jsonString) {