From 29c57004dec233f577a8af32f70333aea35ecc33 Mon Sep 17 00:00:00 2001 From: Chen Dai <46505291+dai-chen@users.noreply.github.com> Date: Wed, 16 Dec 2020 09:45:15 -0800 Subject: [PATCH] Use name and alias in JDBC format (#932) * Rename getName to getNameOrAlias * Use original name as name in JDBC format * Support alias in CLI * Use local CLI for doctest * Add UT * Fix IT * Fix IT * Fix UT * Update javadoc --- .../sql/analysis/Analyzer.java | 6 ++--- .../ExpressionReferenceOptimizer.java | 3 ++- .../sql/expression/NamedExpression.java | 8 +++---- .../planner/physical/AggregationOperator.java | 2 +- .../sql/planner/physical/ProjectOperator.java | 2 +- .../analysis/NamedExpressionAnalyzerTest.java | 2 +- .../sql/expression/NamedExpressionTest.java | 6 ++--- .../planner/physical/ProjectOperatorTest.java | 4 ++-- doctest/bootstrap.sh | 3 +-- .../aggregation/AggregationQueryBuilder.java | 2 +- .../dsl/BucketAggregationBuilder.java | 2 +- .../sql/legacy/AggregationExpressionIT.java | 11 +++++---- .../sql/legacy/AggregationIT.java | 19 ++++++++++----- .../sql/legacy/PrettyFormatResponseIT.java | 2 ++ .../sql/legacy/SQLFunctionsIT.java | 14 +++++------ .../sql/sql/JdbcFormatIT.java | 2 +- .../sql/protocol/response/QueryResult.java | 11 +++++++-- .../protocol/response/QueryResultTest.java | 18 +++++++++++++-- .../format/CsvResponseFormatterTest.java | 8 +++---- .../SimpleJsonResponseFormatterTest.java | 19 +++++++++++++-- sql-cli/src/odfe_sql_cli/formatter.py | 2 +- sql-cli/tests/test_formatter.py | 23 +++++++++++++++++++ 22 files changed, 119 insertions(+), 50 deletions(-) diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/Analyzer.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/Analyzer.java index 961b6bdd16..b8e418a43a 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/Analyzer.java @@ -194,7 +194,7 @@ public LogicalPlan visitAggregation(Aggregation node, AnalysisContext context) { for (UnresolvedExpression expr : node.getAggExprList()) { NamedExpression aggExpr = namedExpressionAnalyzer.analyze(expr, context); aggregatorBuilder - .add(new NamedAggregator(aggExpr.getName(), (Aggregator) aggExpr.getDelegated())); + .add(new NamedAggregator(aggExpr.getNameOrAlias(), (Aggregator) aggExpr.getDelegated())); } ImmutableList aggregators = aggregatorBuilder.build(); @@ -210,7 +210,7 @@ public LogicalPlan visitAggregation(Aggregation node, AnalysisContext context) { aggregators.forEach(aggregator -> newEnv.define(new Symbol(Namespace.FIELD_NAME, aggregator.getName()), aggregator.type())); groupBys.forEach(group -> newEnv.define(new Symbol(Namespace.FIELD_NAME, - group.getName()), group.type())); + group.getNameOrAlias()), group.type())); return new LogicalAggregation(child, aggregators, groupBys); } @@ -291,7 +291,7 @@ public LogicalPlan visitProject(Project node, AnalysisContext context) { context.push(); TypeEnvironment newEnv = context.peek(); namedExpressions.forEach(expr -> newEnv.define(new Symbol(Namespace.FIELD_NAME, - expr.getName()), expr.type())); + expr.getNameOrAlias()), expr.type())); return new LogicalProject(child, namedExpressions); } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionReferenceOptimizer.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionReferenceOptimizer.java index b999f50f15..b98c7be53e 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionReferenceOptimizer.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/analysis/ExpressionReferenceOptimizer.java @@ -136,7 +136,8 @@ public Void visitAggregation(LogicalAggregation plan, Void context) { new ReferenceExpression(namedAggregator.getName(), namedAggregator.type()))); // Create the mapping for all the group by. plan.getGroupByList().forEach(groupBy -> expressionMap - .put(groupBy.getDelegated(), new ReferenceExpression(groupBy.getName(), groupBy.type()))); + .put(groupBy.getDelegated(), + new ReferenceExpression(groupBy.getNameOrAlias(), groupBy.type()))); return null; } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/NamedExpression.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/NamedExpression.java index 17fd1225eb..8153239ebd 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/NamedExpression.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/expression/NamedExpression.java @@ -24,7 +24,6 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; -import lombok.ToString; /** * Named expression that represents expression with name. @@ -33,6 +32,7 @@ */ @AllArgsConstructor @EqualsAndHashCode +@Getter @RequiredArgsConstructor public class NamedExpression implements Expression { @@ -44,13 +44,11 @@ public class NamedExpression implements Expression { /** * Expression that being named. */ - @Getter private final Expression delegated; /** * Optional alias. */ - @Getter private String alias; @Override @@ -67,7 +65,7 @@ public ExprType type() { * Get expression name using name or its alias (if it's present). * @return expression name */ - public String getName() { + public String getNameOrAlias() { return Strings.isNullOrEmpty(alias) ? name : alias; } @@ -78,7 +76,7 @@ public T accept(ExpressionNodeVisitor visitor, C context) { @Override public String toString() { - return getName(); + return getNameOrAlias(); } } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/AggregationOperator.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/AggregationOperator.java index 3bcf8301f5..ec38e3911f 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/AggregationOperator.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/AggregationOperator.java @@ -172,7 +172,7 @@ public GroupKey(ExprValue value) { public LinkedHashMap groupKeyMap() { LinkedHashMap map = new LinkedHashMap<>(); for (int i = 0; i < groupByExprList.size(); i++) { - map.put(groupByExprList.get(i).getName(), groupByValueList.get(i)); + map.put(groupByExprList.get(i).getNameOrAlias(), groupByValueList.get(i)); } return map; } diff --git a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/ProjectOperator.java b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/ProjectOperator.java index 32a6906298..c0fde16367 100644 --- a/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/ProjectOperator.java +++ b/core/src/main/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/ProjectOperator.java @@ -62,7 +62,7 @@ public ExprValue next() { ImmutableMap.Builder mapBuilder = new Builder<>(); for (NamedExpression expr : projectList) { ExprValue exprValue = expr.valueOf(inputValue.bindingTuples()); - mapBuilder.put(expr.getName(), exprValue); + mapBuilder.put(expr.getNameOrAlias(), exprValue); } return ExprTupleValue.fromExprValueMap(mapBuilder.build()); } diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/NamedExpressionAnalyzerTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/NamedExpressionAnalyzerTest.java index 4386ec7501..1ef80e8248 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/NamedExpressionAnalyzerTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/analysis/NamedExpressionAnalyzerTest.java @@ -41,6 +41,6 @@ void visit_named_seleteitem() { new NamedExpressionAnalyzer(expressionAnalyzer); NamedExpression analyze = analyzer.analyze(alias, analysisContext); - assertEquals("integer_value", analyze.getName()); + assertEquals("integer_value", analyze.getNameOrAlias()); } } \ No newline at end of file diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/NamedExpressionTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/NamedExpressionTest.java index c8330d2fba..dbfca07b76 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/NamedExpressionTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/expression/NamedExpressionTest.java @@ -30,7 +30,7 @@ void name_an_expression() { LiteralExpression delegated = DSL.literal(10); NamedExpression namedExpression = DSL.named("10", delegated); - assertEquals("10", namedExpression.getName()); + assertEquals("10", namedExpression.getNameOrAlias()); assertEquals(delegated.type(), namedExpression.type()); assertEquals(delegated.valueOf(valueEnv()), namedExpression.valueOf(valueEnv())); } @@ -39,7 +39,7 @@ void name_an_expression() { void name_an_expression_with_alias() { LiteralExpression delegated = DSL.literal(10); NamedExpression namedExpression = DSL.named("10", delegated, "ten"); - assertEquals("ten", namedExpression.getName()); + assertEquals("ten", namedExpression.getNameOrAlias()); } @Test @@ -48,7 +48,7 @@ void name_an_named_expression() { Expression expression = DSL.named("10", delegated, "ten"); NamedExpression namedExpression = DSL.named(expression); - assertEquals("ten", namedExpression.getName()); + assertEquals("ten", namedExpression.getNameOrAlias()); } } \ No newline at end of file diff --git a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/ProjectOperatorTest.java b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/ProjectOperatorTest.java index 5baf541f62..873c0f2734 100644 --- a/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/ProjectOperatorTest.java +++ b/core/src/test/java/com/amazon/opendistroforelasticsearch/sql/planner/physical/ProjectOperatorTest.java @@ -104,11 +104,11 @@ public void project_keep_missing_value() { public void project_schema() { PhysicalPlan project = project(inputPlan, DSL.named("response", DSL.ref("response", INTEGER)), - DSL.named("action", DSL.ref("action", STRING))); + DSL.named("action", DSL.ref("action", STRING), "act")); assertThat(project.schema().getColumns(), contains( new ExecutionEngine.Schema.Column("response", null, INTEGER), - new ExecutionEngine.Schema.Column("action", null, STRING) + new ExecutionEngine.Schema.Column("action", "act", STRING) )); } } diff --git a/doctest/bootstrap.sh b/doctest/bootstrap.sh index 55d841c957..29f6105386 100755 --- a/doctest/bootstrap.sh +++ b/doctest/bootstrap.sh @@ -21,5 +21,4 @@ fi $DIR/.venv/bin/pip install -U pip setuptools wheel $DIR/.venv/bin/pip install -r $DIR/requirements.txt -# Temporary fix, add odfe-sql-cli dependency into requirements.txt once we have released cli to PyPI -$DIR/.venv/bin/pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple odfe-sql-cli==0.0.2 +$DIR/.venv/bin/pip install -e ../sql-cli diff --git a/elasticsearch/src/main/java/com/amazon/opendistroforelasticsearch/sql/elasticsearch/storage/script/aggregation/AggregationQueryBuilder.java b/elasticsearch/src/main/java/com/amazon/opendistroforelasticsearch/sql/elasticsearch/storage/script/aggregation/AggregationQueryBuilder.java index b2d16fdf4f..d0081ae27d 100644 --- a/elasticsearch/src/main/java/com/amazon/opendistroforelasticsearch/sql/elasticsearch/storage/script/aggregation/AggregationQueryBuilder.java +++ b/elasticsearch/src/main/java/com/amazon/opendistroforelasticsearch/sql/elasticsearch/storage/script/aggregation/AggregationQueryBuilder.java @@ -101,7 +101,7 @@ public Map buildTypeMapping( List groupByList) { ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); namedAggregatorList.forEach(agg -> builder.put(agg.getName(), agg.type())); - groupByList.forEach(group -> builder.put(group.getName(), group.type())); + groupByList.forEach(group -> builder.put(group.getNameOrAlias(), group.type())); return builder.build(); } diff --git a/elasticsearch/src/main/java/com/amazon/opendistroforelasticsearch/sql/elasticsearch/storage/script/aggregation/dsl/BucketAggregationBuilder.java b/elasticsearch/src/main/java/com/amazon/opendistroforelasticsearch/sql/elasticsearch/storage/script/aggregation/dsl/BucketAggregationBuilder.java index 3449f1bc9f..158a685a31 100644 --- a/elasticsearch/src/main/java/com/amazon/opendistroforelasticsearch/sql/elasticsearch/storage/script/aggregation/dsl/BucketAggregationBuilder.java +++ b/elasticsearch/src/main/java/com/amazon/opendistroforelasticsearch/sql/elasticsearch/storage/script/aggregation/dsl/BucketAggregationBuilder.java @@ -47,7 +47,7 @@ public List> build( new ImmutableList.Builder<>(); for (Pair groupPair : groupList) { TermsValuesSourceBuilder valuesSourceBuilder = - new TermsValuesSourceBuilder(groupPair.getLeft().getName()) + new TermsValuesSourceBuilder(groupPair.getLeft().getNameOrAlias()) .missingBucket(true) .order(groupPair.getRight()); resultBuilder diff --git a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/AggregationExpressionIT.java b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/AggregationExpressionIT.java index 028b51b8c7..5d37b0ba0b 100644 --- a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/AggregationExpressionIT.java +++ b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/AggregationExpressionIT.java @@ -83,6 +83,7 @@ public void noGroupKeyAvgOnIntegerShouldPass() { @Test public void hasGroupKeyAvgOnIntegerShouldPass() { + Assume.assumeTrue(isNewQueryEngineEabled()); JSONObject response = executeJdbcRequest(String.format( "SELECT gender, AVG(age) as avg " + "FROM %s " + @@ -91,7 +92,7 @@ public void hasGroupKeyAvgOnIntegerShouldPass() { verifySchema(response, schema("gender", null, "text"), - schema("avg", "avg", "double")); + schema("AVG(age)", "avg", "double")); verifyDataRows(response, rows("m", 34.25), rows("f", 33.666666666666664d)); @@ -181,6 +182,8 @@ public void AddLiteralOnGroupKeyShouldPass() { @Test public void logWithAddLiteralOnGroupKeyShouldPass() { + Assume.assumeTrue(isNewQueryEngineEabled()); + JSONObject response = executeJdbcRequest(String.format( "SELECT gender, Log(age+10) as logAge, max(balance) as max " + "FROM %s " + @@ -191,8 +194,8 @@ public void logWithAddLiteralOnGroupKeyShouldPass() { verifySchema(response, schema("gender", null, "text"), - schema("logAge", "logAge", "double"), - schema("max", "max", "long")); + schema("Log(age+10)", "logAge", "double"), + schema("max(balance)", "max", "long")); verifyDataRows(response, rows("m", 3.4011973816621555d, 49568), rows("m", 3.4339872044851463d, 49433)); @@ -264,7 +267,7 @@ public void aggregateCastStatementShouldNotReturnZero() { "SELECT SUM(CAST(male AS INT)) AS male_sum FROM %s", Index.BANK.getName())); - verifySchema(response, schema("male_sum", "male_sum", "integer")); + verifySchema(response, schema("SUM(CAST(male AS INT))", "male_sum", "integer")); verifyDataRows(response, rows(4)); } diff --git a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/AggregationIT.java b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/AggregationIT.java index 921409f544..caae2f957d 100644 --- a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/AggregationIT.java +++ b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/AggregationIT.java @@ -45,6 +45,7 @@ import org.json.JSONArray; import org.json.JSONObject; import org.junit.Assert; +import org.junit.Assume; import org.junit.Ignore; import org.junit.Test; @@ -470,10 +471,12 @@ public void orderByAscTest() { @Test public void orderByAliasAscTest() { + Assume.assumeTrue(isNewQueryEngineEabled()); + JSONObject response = executeJdbcRequest(String.format("SELECT COUNT(*) as count FROM %s " + "GROUP BY gender ORDER BY count", TEST_INDEX_ACCOUNT)); - verifySchema(response, schema("count", "count", "integer")); + verifySchema(response, schema("COUNT(*)", "count", "integer")); verifyDataRowsInOrder(response, rows(493), rows(507)); @@ -492,10 +495,12 @@ public void orderByDescTest() throws IOException { @Test public void orderByAliasDescTest() throws IOException { + Assume.assumeTrue(isNewQueryEngineEabled()); + JSONObject response = executeJdbcRequest(String.format("SELECT COUNT(*) as count FROM %s " + "GROUP BY gender ORDER BY count DESC", TEST_INDEX_ACCOUNT)); - verifySchema(response, schema("count", "count", "integer")); + verifySchema(response, schema("COUNT(*)", "count", "integer")); verifyDataRowsInOrder(response, rows(507), rows(493)); @@ -503,13 +508,15 @@ public void orderByAliasDescTest() throws IOException { @Test public void orderByGroupFieldWithAlias() throws IOException { + Assume.assumeTrue(isNewQueryEngineEabled()); + // ORDER BY field name JSONObject response = executeJdbcRequest(String.format("SELECT gender as g, COUNT(*) as count " + "FROM %s GROUP BY gender ORDER BY gender", TEST_INDEX_ACCOUNT)); verifySchema(response, - schema("g", "g", "text"), - schema("count", "count", "integer")); + schema("gender", "g", "text"), + schema("COUNT(*)", "count", "integer")); verifyDataRowsInOrder(response, rows("f", 493), rows("m", 507)); @@ -519,8 +526,8 @@ public void orderByGroupFieldWithAlias() throws IOException { + "FROM %s GROUP BY gender ORDER BY g", TEST_INDEX_ACCOUNT)); verifySchema(response, - schema("g", "g", "text"), - schema("count", "count", "integer")); + schema("gender", "g", "text"), + schema("COUNT(*)", "count", "integer")); verifyDataRowsInOrder(response, rows("f", 493), rows("m", 507)); diff --git a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/PrettyFormatResponseIT.java b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/PrettyFormatResponseIT.java index 2fbf70dfeb..7ade383945 100644 --- a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/PrettyFormatResponseIT.java +++ b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/PrettyFormatResponseIT.java @@ -357,6 +357,8 @@ public void aggregationFunctionInSelectCaseCheck() throws IOException { @Test public void aggregationFunctionInSelectWithAlias() throws IOException { + Assume.assumeFalse(isNewQueryEngineEabled()); + JSONObject response = executeQuery( String.format(Locale.ROOT, "SELECT COUNT(*) AS total FROM %s GROUP BY age", TestsConstants.TEST_INDEX_ACCOUNT)); diff --git a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/SQLFunctionsIT.java b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/SQLFunctionsIT.java index b95d07d11a..85c785a5f3 100644 --- a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/SQLFunctionsIT.java +++ b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/legacy/SQLFunctionsIT.java @@ -227,7 +227,7 @@ public void castIntFieldToFloatWithoutAliasJdbcFormatTest() { " ORDER BY balance DESC LIMIT 1"); verifySchema(response, - schema("cast_balance", null, "float")); + schema("CAST(balance AS FLOAT)", "cast_balance", "float")); verifyDataRows(response, rows(49989.0)); @@ -242,7 +242,7 @@ public void castIntFieldToFloatWithAliasJdbcFormatTest() { "FROM " + TestsConstants.TEST_INDEX_ACCOUNT + " ORDER BY jdbc_float_alias LIMIT 1"); verifySchema(response, - schema("jdbc_float_alias", null, "float")); + schema("CAST(balance AS FLOAT)", "jdbc_float_alias", "float")); verifyDataRows(response, rows(1011.0)); @@ -394,10 +394,10 @@ public void castBoolFieldToNumericValueInSelectClause() { verifySchema(response, schema("male", "boolean"), - schema("cast_int", "integer"), - schema("cast_long", "long"), - schema("cast_float", "float"), - schema("cast_double", "double") + schema("CAST(male AS INT)", "cast_int", "integer"), + schema("CAST(male AS LONG)", "cast_long", "long"), + schema("CAST(male AS FLOAT)", "cast_float", "float"), + schema("CAST(male AS DOUBLE)", "cast_double", "double") ); verifyDataRows(response, rows(true, 1, 1, 1.0, 1.0), @@ -419,7 +419,7 @@ public void castBoolFieldToNumericValueWithGroupByAlias() { ); verifySchema(response, - schema("cast_int", "cast_int", "integer"), + schema("CAST(male AS INT)", "cast_int", "integer"), schema("COUNT(*)", "integer") ); verifyDataRows(response, diff --git a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/JdbcFormatIT.java b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/JdbcFormatIT.java index 51cf961ca5..836cb5636c 100644 --- a/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/JdbcFormatIT.java +++ b/integ-test/src/test/java/com/amazon/opendistroforelasticsearch/sql/sql/JdbcFormatIT.java @@ -52,7 +52,7 @@ public void testAliasInSchema() { JSONObject response = new JSONObject(executeQuery( "SELECT account_number AS acc FROM " + TEST_INDEX_BANK, "jdbc")); - verifySchema(response, schema("acc", "acc", "long")); + verifySchema(response, schema("account_number", "acc", "long")); } } diff --git a/protocol/src/main/java/com/amazon/opendistroforelasticsearch/sql/protocol/response/QueryResult.java b/protocol/src/main/java/com/amazon/opendistroforelasticsearch/sql/protocol/response/QueryResult.java index 83a09366b9..5deb7e0f56 100644 --- a/protocol/src/main/java/com/amazon/opendistroforelasticsearch/sql/protocol/response/QueryResult.java +++ b/protocol/src/main/java/com/amazon/opendistroforelasticsearch/sql/protocol/response/QueryResult.java @@ -19,6 +19,7 @@ import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue; import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils; import com.amazon.opendistroforelasticsearch.sql.executor.ExecutionEngine; +import com.amazon.opendistroforelasticsearch.sql.executor.ExecutionEngine.Schema.Column; import java.util.Collection; import java.util.Iterator; import java.util.LinkedHashMap; @@ -53,11 +54,13 @@ public int size() { /** * Parse column name from results. * - * @return mapping from column names to its expression type + * @return mapping from column names to its expression type. + * note that column name could be original name or its alias if any. */ public Map columnNameTypes() { Map colNameTypes = new LinkedHashMap<>(); - schema.getColumns().forEach(column -> colNameTypes.put(column.getName(), + schema.getColumns().forEach(column -> colNameTypes.put( + getColumnName(column), column.getExprType().typeName().toLowerCase())); return colNameTypes; } @@ -72,6 +75,10 @@ public Iterator iterator() { .iterator(); } + private String getColumnName(Column column) { + return (column.getAlias() != null) ? column.getAlias() : column.getName(); + } + private Object[] convertExprValuesToValues(Collection exprValues) { return exprValues .stream() diff --git a/protocol/src/test/java/com/amazon/opendistroforelasticsearch/sql/protocol/response/QueryResultTest.java b/protocol/src/test/java/com/amazon/opendistroforelasticsearch/sql/protocol/response/QueryResultTest.java index 785b8949af..5373650b25 100644 --- a/protocol/src/test/java/com/amazon/opendistroforelasticsearch/sql/protocol/response/QueryResultTest.java +++ b/protocol/src/test/java/com/amazon/opendistroforelasticsearch/sql/protocol/response/QueryResultTest.java @@ -33,8 +33,8 @@ class QueryResultTest { private ExecutionEngine.Schema schema = new ExecutionEngine.Schema(ImmutableList.of( - new ExecutionEngine.Schema.Column("name", "name", STRING), - new ExecutionEngine.Schema.Column("age", "age", INTEGER))); + new ExecutionEngine.Schema.Column("name", null, STRING), + new ExecutionEngine.Schema.Column("age", null, INTEGER))); @Test @@ -63,6 +63,20 @@ void columnNameTypes() { ); } + @Test + void columnNameTypesWithAlias() { + ExecutionEngine.Schema schema = new ExecutionEngine.Schema(ImmutableList.of( + new ExecutionEngine.Schema.Column("name", "n", STRING))); + QueryResult response = new QueryResult( + schema, + Collections.singletonList(tupleValue(ImmutableMap.of("n", "John")))); + + assertEquals( + ImmutableMap.of("n", "string"), + response.columnNameTypes() + ); + } + @Test void columnNameTypesFromEmptyExprValues() { QueryResult response = new QueryResult( diff --git a/protocol/src/test/java/com/amazon/opendistroforelasticsearch/sql/protocol/response/format/CsvResponseFormatterTest.java b/protocol/src/test/java/com/amazon/opendistroforelasticsearch/sql/protocol/response/format/CsvResponseFormatterTest.java index fdae598121..f720a3e775 100644 --- a/protocol/src/test/java/com/amazon/opendistroforelasticsearch/sql/protocol/response/format/CsvResponseFormatterTest.java +++ b/protocol/src/test/java/com/amazon/opendistroforelasticsearch/sql/protocol/response/format/CsvResponseFormatterTest.java @@ -53,10 +53,10 @@ void formatResponse() { @Test void sanitizeHeaders() { ExecutionEngine.Schema schema = new ExecutionEngine.Schema(ImmutableList.of( - new ExecutionEngine.Schema.Column("=firstname", "firstname", STRING), - new ExecutionEngine.Schema.Column("+lastname", "lastname", STRING), - new ExecutionEngine.Schema.Column("-city", "city", STRING), - new ExecutionEngine.Schema.Column("@age", "age", INTEGER))); + new ExecutionEngine.Schema.Column("=firstname", null, STRING), + new ExecutionEngine.Schema.Column("+lastname", null, STRING), + new ExecutionEngine.Schema.Column("-city", null, STRING), + new ExecutionEngine.Schema.Column("@age", null, INTEGER))); QueryResult response = new QueryResult(schema, Arrays.asList( tupleValue(ImmutableMap.of( "=firstname", "John", "+lastname", "Smith", "-city", "Seattle", "@age", 20)))); diff --git a/protocol/src/test/java/com/amazon/opendistroforelasticsearch/sql/protocol/response/format/SimpleJsonResponseFormatterTest.java b/protocol/src/test/java/com/amazon/opendistroforelasticsearch/sql/protocol/response/format/SimpleJsonResponseFormatterTest.java index f7d8d6e710..6c72ed8d38 100644 --- a/protocol/src/test/java/com/amazon/opendistroforelasticsearch/sql/protocol/response/format/SimpleJsonResponseFormatterTest.java +++ b/protocol/src/test/java/com/amazon/opendistroforelasticsearch/sql/protocol/response/format/SimpleJsonResponseFormatterTest.java @@ -36,8 +36,8 @@ class SimpleJsonResponseFormatterTest { private final ExecutionEngine.Schema schema = new ExecutionEngine.Schema(ImmutableList.of( - new ExecutionEngine.Schema.Column("firstname", "name", STRING), - new ExecutionEngine.Schema.Column("age", "age", INTEGER))); + new ExecutionEngine.Schema.Column("firstname", null, STRING), + new ExecutionEngine.Schema.Column("age", null, INTEGER))); @Test void formatResponse() { @@ -92,6 +92,21 @@ void formatResponsePretty() { formatter.format(response)); } + @Test + void formatResponseSchemaWithAlias() { + ExecutionEngine.Schema schema = new ExecutionEngine.Schema(ImmutableList.of( + new ExecutionEngine.Schema.Column("firstname", "name", STRING))); + QueryResult response = + new QueryResult( + schema, + ImmutableList.of(tupleValue(ImmutableMap.of("name", "John", "age", 20)))); + SimpleJsonResponseFormatter formatter = new SimpleJsonResponseFormatter(COMPACT); + assertEquals( + "{\"schema\":[{\"name\":\"name\",\"type\":\"string\"}]," + + "\"datarows\":[[\"John\",20]],\"total\":1,\"size\":1}", + formatter.format(response)); + } + @Test void formatResponseWithMissingValue() { QueryResult response = diff --git a/sql-cli/src/odfe_sql_cli/formatter.py b/sql-cli/src/odfe_sql_cli/formatter.py index acfb401fe2..de8248ea13 100644 --- a/sql-cli/src/odfe_sql_cli/formatter.py +++ b/sql-cli/src/odfe_sql_cli/formatter.py @@ -74,7 +74,7 @@ def format_output(self, data): # get header and type as lists, for future usage for i in schema: - fields.append(i["name"]) + fields.append(i.get("alias", i["name"])) types.append(i["type"]) output = formatter.format_output(datarows, fields, **self.output_kwargs) diff --git a/sql-cli/tests/test_formatter.py b/sql-cli/tests/test_formatter.py index b0f85c34a5..3131eac91b 100644 --- a/sql-cli/tests/test_formatter.py +++ b/sql-cli/tests/test_formatter.py @@ -79,6 +79,29 @@ def test_format_output(self): ] assert list(results) == expected + def test_format_alias_output(self): + settings = OutputSettings(table_format="psql") + formatter = Formatter(settings) + data = { + "schema": [{"name": "name", "alias": "n", "type": "text"}], + "total": 1, + "datarows": [["Tim"]], + "size": 1, + "status": 200, + } + + results = formatter.format_output(data) + + expected = [ + "fetched rows / total rows = 1/1", + "+-----+", + "| n |", + "|-----|", + "| Tim |", + "+-----+", + ] + assert list(results) == expected + def test_format_array_output(self): settings = OutputSettings(table_format="psql") formatter = Formatter(settings)