Skip to content

Commit

Permalink
Use name and alias in JDBC format (opendistro-for-elasticsearch#932)
Browse files Browse the repository at this point in the history
* Rename getName to getNameOrAlias

* Use original name as name in JDBC format

* Support alias in CLI

* Use local CLI for doctest

* Add UT

* Fix IT

* Fix IT

* Fix UT

* Update javadoc
  • Loading branch information
dai-chen authored and penghuo committed Dec 16, 2020
1 parent dc53a74 commit 29c5700
Show file tree
Hide file tree
Showing 22 changed files with 119 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ public LogicalPlan visitAggregation(Aggregation node, AnalysisContext context) {
for (UnresolvedExpression expr : node.getAggExprList()) {
NamedExpression aggExpr = namedExpressionAnalyzer.analyze(expr, context);
aggregatorBuilder
.add(new NamedAggregator(aggExpr.getName(), (Aggregator) aggExpr.getDelegated()));
.add(new NamedAggregator(aggExpr.getNameOrAlias(), (Aggregator) aggExpr.getDelegated()));
}
ImmutableList<NamedAggregator> aggregators = aggregatorBuilder.build();

Expand All @@ -210,7 +210,7 @@ public LogicalPlan visitAggregation(Aggregation node, AnalysisContext context) {
aggregators.forEach(aggregator -> newEnv.define(new Symbol(Namespace.FIELD_NAME,
aggregator.getName()), aggregator.type()));
groupBys.forEach(group -> newEnv.define(new Symbol(Namespace.FIELD_NAME,
group.getName()), group.type()));
group.getNameOrAlias()), group.type()));
return new LogicalAggregation(child, aggregators, groupBys);
}

Expand Down Expand Up @@ -291,7 +291,7 @@ public LogicalPlan visitProject(Project node, AnalysisContext context) {
context.push();
TypeEnvironment newEnv = context.peek();
namedExpressions.forEach(expr -> newEnv.define(new Symbol(Namespace.FIELD_NAME,
expr.getName()), expr.type()));
expr.getNameOrAlias()), expr.type()));
return new LogicalProject(child, namedExpressions);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ public Void visitAggregation(LogicalAggregation plan, Void context) {
new ReferenceExpression(namedAggregator.getName(), namedAggregator.type())));
// Create the mapping for all the group by.
plan.getGroupByList().forEach(groupBy -> expressionMap
.put(groupBy.getDelegated(), new ReferenceExpression(groupBy.getName(), groupBy.type())));
.put(groupBy.getDelegated(),
new ReferenceExpression(groupBy.getNameOrAlias(), groupBy.type())));
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;

/**
* Named expression that represents expression with name.
Expand All @@ -33,6 +32,7 @@
*/
@AllArgsConstructor
@EqualsAndHashCode
@Getter
@RequiredArgsConstructor
public class NamedExpression implements Expression {

Expand All @@ -44,13 +44,11 @@ public class NamedExpression implements Expression {
/**
* Expression that being named.
*/
@Getter
private final Expression delegated;

/**
* Optional alias.
*/
@Getter
private String alias;

@Override
Expand All @@ -67,7 +65,7 @@ public ExprType type() {
* Get expression name using name or its alias (if it's present).
* @return expression name
*/
public String getName() {
public String getNameOrAlias() {
return Strings.isNullOrEmpty(alias) ? name : alias;
}

Expand All @@ -78,7 +76,7 @@ public <T, C> T accept(ExpressionNodeVisitor<T, C> visitor, C context) {

@Override
public String toString() {
return getName();
return getNameOrAlias();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ public GroupKey(ExprValue value) {
public LinkedHashMap<String, ExprValue> groupKeyMap() {
LinkedHashMap<String, ExprValue> map = new LinkedHashMap<>();
for (int i = 0; i < groupByExprList.size(); i++) {
map.put(groupByExprList.get(i).getName(), groupByValueList.get(i));
map.put(groupByExprList.get(i).getNameOrAlias(), groupByValueList.get(i));
}
return map;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public ExprValue next() {
ImmutableMap.Builder<String, ExprValue> mapBuilder = new Builder<>();
for (NamedExpression expr : projectList) {
ExprValue exprValue = expr.valueOf(inputValue.bindingTuples());
mapBuilder.put(expr.getName(), exprValue);
mapBuilder.put(expr.getNameOrAlias(), exprValue);
}
return ExprTupleValue.fromExprValueMap(mapBuilder.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@ void visit_named_seleteitem() {
new NamedExpressionAnalyzer(expressionAnalyzer);

NamedExpression analyze = analyzer.analyze(alias, analysisContext);
assertEquals("integer_value", analyze.getName());
assertEquals("integer_value", analyze.getNameOrAlias());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void name_an_expression() {
LiteralExpression delegated = DSL.literal(10);
NamedExpression namedExpression = DSL.named("10", delegated);

assertEquals("10", namedExpression.getName());
assertEquals("10", namedExpression.getNameOrAlias());
assertEquals(delegated.type(), namedExpression.type());
assertEquals(delegated.valueOf(valueEnv()), namedExpression.valueOf(valueEnv()));
}
Expand All @@ -39,7 +39,7 @@ void name_an_expression() {
void name_an_expression_with_alias() {
LiteralExpression delegated = DSL.literal(10);
NamedExpression namedExpression = DSL.named("10", delegated, "ten");
assertEquals("ten", namedExpression.getName());
assertEquals("ten", namedExpression.getNameOrAlias());
}

@Test
Expand All @@ -48,7 +48,7 @@ void name_an_named_expression() {
Expression expression = DSL.named("10", delegated, "ten");

NamedExpression namedExpression = DSL.named(expression);
assertEquals("ten", namedExpression.getName());
assertEquals("ten", namedExpression.getNameOrAlias());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ public void project_keep_missing_value() {
public void project_schema() {
PhysicalPlan project = project(inputPlan,
DSL.named("response", DSL.ref("response", INTEGER)),
DSL.named("action", DSL.ref("action", STRING)));
DSL.named("action", DSL.ref("action", STRING), "act"));

assertThat(project.schema().getColumns(), contains(
new ExecutionEngine.Schema.Column("response", null, INTEGER),
new ExecutionEngine.Schema.Column("action", null, STRING)
new ExecutionEngine.Schema.Column("action", "act", STRING)
));
}
}
3 changes: 1 addition & 2 deletions doctest/bootstrap.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,4 @@ fi

$DIR/.venv/bin/pip install -U pip setuptools wheel
$DIR/.venv/bin/pip install -r $DIR/requirements.txt
# Temporary fix, add odfe-sql-cli dependency into requirements.txt once we have released cli to PyPI
$DIR/.venv/bin/pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple odfe-sql-cli==0.0.2
$DIR/.venv/bin/pip install -e ../sql-cli
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ public Map<String, ExprType> buildTypeMapping(
List<NamedExpression> groupByList) {
ImmutableMap.Builder<String, ExprType> builder = new ImmutableMap.Builder<>();
namedAggregatorList.forEach(agg -> builder.put(agg.getName(), agg.type()));
groupByList.forEach(group -> builder.put(group.getName(), group.type()));
groupByList.forEach(group -> builder.put(group.getNameOrAlias(), group.type()));
return builder.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public List<CompositeValuesSourceBuilder<?>> build(
new ImmutableList.Builder<>();
for (Pair<NamedExpression, SortOrder> groupPair : groupList) {
TermsValuesSourceBuilder valuesSourceBuilder =
new TermsValuesSourceBuilder(groupPair.getLeft().getName())
new TermsValuesSourceBuilder(groupPair.getLeft().getNameOrAlias())
.missingBucket(true)
.order(groupPair.getRight());
resultBuilder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ public void noGroupKeyAvgOnIntegerShouldPass() {

@Test
public void hasGroupKeyAvgOnIntegerShouldPass() {
Assume.assumeTrue(isNewQueryEngineEabled());
JSONObject response = executeJdbcRequest(String.format(
"SELECT gender, AVG(age) as avg " +
"FROM %s " +
Expand All @@ -91,7 +92,7 @@ public void hasGroupKeyAvgOnIntegerShouldPass() {

verifySchema(response,
schema("gender", null, "text"),
schema("avg", "avg", "double"));
schema("AVG(age)", "avg", "double"));
verifyDataRows(response,
rows("m", 34.25),
rows("f", 33.666666666666664d));
Expand Down Expand Up @@ -181,6 +182,8 @@ public void AddLiteralOnGroupKeyShouldPass() {

@Test
public void logWithAddLiteralOnGroupKeyShouldPass() {
Assume.assumeTrue(isNewQueryEngineEabled());

JSONObject response = executeJdbcRequest(String.format(
"SELECT gender, Log(age+10) as logAge, max(balance) as max " +
"FROM %s " +
Expand All @@ -191,8 +194,8 @@ public void logWithAddLiteralOnGroupKeyShouldPass() {

verifySchema(response,
schema("gender", null, "text"),
schema("logAge", "logAge", "double"),
schema("max", "max", "long"));
schema("Log(age+10)", "logAge", "double"),
schema("max(balance)", "max", "long"));
verifyDataRows(response,
rows("m", 3.4011973816621555d, 49568),
rows("m", 3.4339872044851463d, 49433));
Expand Down Expand Up @@ -264,7 +267,7 @@ public void aggregateCastStatementShouldNotReturnZero() {
"SELECT SUM(CAST(male AS INT)) AS male_sum FROM %s",
Index.BANK.getName()));

verifySchema(response, schema("male_sum", "male_sum", "integer"));
verifySchema(response, schema("SUM(CAST(male AS INT))", "male_sum", "integer"));
verifyDataRows(response, rows(4));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.json.JSONArray;
import org.json.JSONObject;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.Ignore;
import org.junit.Test;

Expand Down Expand Up @@ -470,10 +471,12 @@ public void orderByAscTest() {

@Test
public void orderByAliasAscTest() {
Assume.assumeTrue(isNewQueryEngineEabled());

JSONObject response = executeJdbcRequest(String.format("SELECT COUNT(*) as count FROM %s " +
"GROUP BY gender ORDER BY count", TEST_INDEX_ACCOUNT));

verifySchema(response, schema("count", "count", "integer"));
verifySchema(response, schema("COUNT(*)", "count", "integer"));
verifyDataRowsInOrder(response,
rows(493),
rows(507));
Expand All @@ -492,24 +495,28 @@ public void orderByDescTest() throws IOException {

@Test
public void orderByAliasDescTest() throws IOException {
Assume.assumeTrue(isNewQueryEngineEabled());

JSONObject response = executeJdbcRequest(String.format("SELECT COUNT(*) as count FROM %s " +
"GROUP BY gender ORDER BY count DESC", TEST_INDEX_ACCOUNT));

verifySchema(response, schema("count", "count", "integer"));
verifySchema(response, schema("COUNT(*)", "count", "integer"));
verifyDataRowsInOrder(response,
rows(507),
rows(493));
}

@Test
public void orderByGroupFieldWithAlias() throws IOException {
Assume.assumeTrue(isNewQueryEngineEabled());

// ORDER BY field name
JSONObject response = executeJdbcRequest(String.format("SELECT gender as g, COUNT(*) as count "
+ "FROM %s GROUP BY gender ORDER BY gender", TEST_INDEX_ACCOUNT));

verifySchema(response,
schema("g", "g", "text"),
schema("count", "count", "integer"));
schema("gender", "g", "text"),
schema("COUNT(*)", "count", "integer"));
verifyDataRowsInOrder(response,
rows("f", 493),
rows("m", 507));
Expand All @@ -519,8 +526,8 @@ public void orderByGroupFieldWithAlias() throws IOException {
+ "FROM %s GROUP BY gender ORDER BY g", TEST_INDEX_ACCOUNT));

verifySchema(response,
schema("g", "g", "text"),
schema("count", "count", "integer"));
schema("gender", "g", "text"),
schema("COUNT(*)", "count", "integer"));
verifyDataRowsInOrder(response,
rows("f", 493),
rows("m", 507));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ public void aggregationFunctionInSelectCaseCheck() throws IOException {

@Test
public void aggregationFunctionInSelectWithAlias() throws IOException {
Assume.assumeFalse(isNewQueryEngineEabled());

JSONObject response = executeQuery(
String.format(Locale.ROOT, "SELECT COUNT(*) AS total FROM %s GROUP BY age",
TestsConstants.TEST_INDEX_ACCOUNT));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ public void castIntFieldToFloatWithoutAliasJdbcFormatTest() {
" ORDER BY balance DESC LIMIT 1");

verifySchema(response,
schema("cast_balance", null, "float"));
schema("CAST(balance AS FLOAT)", "cast_balance", "float"));

verifyDataRows(response,
rows(49989.0));
Expand All @@ -242,7 +242,7 @@ public void castIntFieldToFloatWithAliasJdbcFormatTest() {
"FROM " + TestsConstants.TEST_INDEX_ACCOUNT + " ORDER BY jdbc_float_alias LIMIT 1");

verifySchema(response,
schema("jdbc_float_alias", null, "float"));
schema("CAST(balance AS FLOAT)", "jdbc_float_alias", "float"));

verifyDataRows(response,
rows(1011.0));
Expand Down Expand Up @@ -394,10 +394,10 @@ public void castBoolFieldToNumericValueInSelectClause() {

verifySchema(response,
schema("male", "boolean"),
schema("cast_int", "integer"),
schema("cast_long", "long"),
schema("cast_float", "float"),
schema("cast_double", "double")
schema("CAST(male AS INT)", "cast_int", "integer"),
schema("CAST(male AS LONG)", "cast_long", "long"),
schema("CAST(male AS FLOAT)", "cast_float", "float"),
schema("CAST(male AS DOUBLE)", "cast_double", "double")
);
verifyDataRows(response,
rows(true, 1, 1, 1.0, 1.0),
Expand All @@ -419,7 +419,7 @@ public void castBoolFieldToNumericValueWithGroupByAlias() {
);

verifySchema(response,
schema("cast_int", "cast_int", "integer"),
schema("CAST(male AS INT)", "cast_int", "integer"),
schema("COUNT(*)", "integer")
);
verifyDataRows(response,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public void testAliasInSchema() {
JSONObject response = new JSONObject(executeQuery(
"SELECT account_number AS acc FROM " + TEST_INDEX_BANK, "jdbc"));

verifySchema(response, schema("acc", "acc", "long"));
verifySchema(response, schema("account_number", "acc", "long"));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils;
import com.amazon.opendistroforelasticsearch.sql.executor.ExecutionEngine;
import com.amazon.opendistroforelasticsearch.sql.executor.ExecutionEngine.Schema.Column;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
Expand Down Expand Up @@ -53,11 +54,13 @@ public int size() {
/**
* Parse column name from results.
*
* @return mapping from column names to its expression type
* @return mapping from column names to its expression type.
* note that column name could be original name or its alias if any.
*/
public Map<String, String> columnNameTypes() {
Map<String, String> colNameTypes = new LinkedHashMap<>();
schema.getColumns().forEach(column -> colNameTypes.put(column.getName(),
schema.getColumns().forEach(column -> colNameTypes.put(
getColumnName(column),
column.getExprType().typeName().toLowerCase()));
return colNameTypes;
}
Expand All @@ -72,6 +75,10 @@ public Iterator<Object[]> iterator() {
.iterator();
}

private String getColumnName(Column column) {
return (column.getAlias() != null) ? column.getAlias() : column.getName();
}

private Object[] convertExprValuesToValues(Collection<ExprValue> exprValues) {
return exprValues
.stream()
Expand Down
Loading

0 comments on commit 29c5700

Please sign in to comment.