Skip to content

Commit

Permalink
FIX: field function name letter case preserved in select with group by (
Browse files Browse the repository at this point in the history
opendistro-for-elasticsearch#381)

* FIX: Method field name

* REF: parser logic

* RMV: remove unused function

* STY: unused imports

* STY: unused import
  • Loading branch information
chenqi0805 authored Mar 19, 2020
1 parent b056078 commit ad8ad3a
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,18 @@ public class SQLAggregationParser {
public void parse(MySqlSelectQueryBlock queryBlock) {
context = new Context(constructSQLExprAliasMapFromSelect(queryBlock));

//1. extract raw names of selectItems
List<String> selectItemNames = extractSelectItemNames(queryBlock.getSelectList());

//2. rewrite all the function name to lower case.
rewriteFunctionNameToLowerCase(queryBlock);

//2. find all GroupKeyExpr from GroupBy expression.
findAllGroupKeyExprFromGroupByAndSelect(queryBlock);
findAllAggregationExprFromSelect(queryBlock);

//3. parse the select list to expression
parseExprInSelectList(queryBlock, new SQLExprToExpressionConverter(context));
parseExprInSelectList(queryBlock, selectItemNames, new SQLExprToExpressionConverter(context));
}

public List<SQLSelectItem> selectItemList() {
Expand Down Expand Up @@ -149,12 +155,14 @@ public boolean visit(SQLAggregateExpr expr) {
}));
}

private void parseExprInSelectList(MySqlSelectQueryBlock queryBlock, SQLExprToExpressionConverter exprConverter) {
private void parseExprInSelectList(
MySqlSelectQueryBlock queryBlock, List<String> selectItemNames,
SQLExprToExpressionConverter exprConverter) {
List<SQLSelectItem> selectItems = queryBlock.getSelectList();
for (int i = 0; i < selectItems.size(); i++) {
Expression expression = exprConverter.convert(selectItems.get(i).getExpr());
ColumnNode columnNode = ColumnNode.builder()
.name(nameOfSelectItem(selectItems.get(i)))
.name(selectItemNames.get(i))
.alias(selectItems.get(i).getAlias())
.type(columnTypeProvider.get(i))
.expr(expression)
Expand All @@ -163,6 +171,24 @@ private void parseExprInSelectList(MySqlSelectQueryBlock queryBlock, SQLExprToEx
}
}

private List<String> extractSelectItemNames(List<SQLSelectItem> selectItems) {
List<String> selectItemNames = new ArrayList<>();
for (SQLSelectItem selectItem: selectItems){
selectItemNames.add(nameOfSelectItem(selectItem));
}
return selectItemNames;
}

private void rewriteFunctionNameToLowerCase(MySqlSelectQueryBlock query) {
query.accept(new MySqlASTVisitorAdapter() {
@Override
public boolean visit(SQLMethodInvokeExpr x) {
x.setMethodName(x.getMethodName().toLowerCase());
return true;
}
});
}

private String nameOfSelectItem(SQLSelectItem selectItem) {
return Strings.isNullOrEmpty(selectItem.getAlias()) ? Context
.nameOfExpr(selectItem.getExpr()) : selectItem.getAlias();
Expand Down Expand Up @@ -237,7 +263,7 @@ public static String nameOfExpr(SQLExpr expr) {
((SQLAggregateExpr) expr).getArguments().get(0));
} else if (expr instanceof SQLMethodInvokeExpr) {
exprName = String.format("%s(%s)", ((SQLMethodInvokeExpr) expr).getMethodName(),
nameOfExpr(((SQLMethodInvokeExpr) expr).getParameters().get(0)));
nameOfExpr(((SQLMethodInvokeExpr) expr).getParameters().get(0)));
} else if (expr instanceof SQLIdentifierExpr) {
exprName = ((SQLIdentifierExpr) expr).getName();
} else if (expr instanceof SQLCastExpr) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

package com.amazon.opendistroforelasticsearch.sql.query.planner.converter;

import com.alibaba.druid.sql.ast.expr.SQLMethodInvokeExpr;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlASTVisitorAdapter;
import com.amazon.opendistroforelasticsearch.sql.domain.ColumnTypeProvider;
Expand Down Expand Up @@ -54,16 +53,13 @@ public SQLToOperatorConverter(Client client, ColumnTypeProvider columnTypeProvid

@Override
public boolean visit(MySqlSelectQueryBlock query) {
//1. rewrite all the function name to lower case.
rewriteFunctionNameToLowerCase(query);

//2. parse the aggregation
//1. parse the aggregation
aggregationParser.parse(query);


//3. construct the PhysicalOperator
physicalOperator = project(
scroll(query));
//2. construct the PhysicalOperator
physicalOperator = project(scroll(query));
return false;
}

Expand All @@ -76,16 +72,6 @@ public List<ColumnNode> getColumnNodes() {
return aggregationParser.getColumnNodes();
}

private void rewriteFunctionNameToLowerCase(MySqlSelectQueryBlock query) {
query.accept(new MySqlASTVisitorAdapter() {
@Override
public boolean visit(SQLMethodInvokeExpr x) {
x.setMethodName(x.getMethodName().toLowerCase());
return true;
}
});
}

private PhysicalOperator<BindingTuple> project(PhysicalOperator<BindingTuple> input) {
return new PhysicalProject(input, aggregationParser.getColumnNodes());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,16 @@ public void stringOperatorNameCaseInsensitiveTest() {

@Test
public void dateFunctionNameCaseInsensitiveTest() {
assertEquals(
assertTrue(
executeQuery("SELECT DATE_FORMAT(insert_time, 'yyyy-MM-dd', 'UTC') FROM elasticsearch-sql_test_index_online " +
"WHERE date_FORMAT(insert_time, 'yyyy-MM-dd', 'UTC') > '2014-01-01' " +
"GROUP BY DAte_format(insert_time, 'yyyy-MM-dd', 'UTC') " +
"ORDER BY date_forMAT(insert_time, 'yyyy-MM-dd', 'UTC')", "jdbc"),
executeQuery("SELECT date_format(insert_time, 'yyyy-MM-dd', 'UTC') FROM elasticsearch-sql_test_index_online " +
"WHERE date_format(insert_time, 'yyyy-MM-dd', 'UTC') > '2014-01-01' " +
"GROUP BY date_format(insert_time, 'yyyy-MM-dd', 'UTC') " +
"ORDER BY date_format(insert_time, 'yyyy-MM-dd', 'UTC')", "jdbc")
"ORDER BY date_forMAT(insert_time, 'yyyy-MM-dd', 'UTC')", "jdbc").equalsIgnoreCase(
executeQuery("SELECT date_format(insert_time, 'yyyy-MM-dd', 'UTC') FROM elasticsearch-sql_test_index_online " +
"WHERE date_format(insert_time, 'yyyy-MM-dd', 'UTC') > '2014-01-01' " +
"GROUP BY date_format(insert_time, 'yyyy-MM-dd', 'UTC') " +
"ORDER BY date_format(insert_time, 'yyyy-MM-dd', 'UTC')", "jdbc")
)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ public void testLengthWithTextFieldReturnsInt() {
verifySchema(response, schema("length(firstname)", null, "integer"));
}

@Test
public void testLengthWithGroupByExpr() {
JSONObject response = executeJdbcRequest("SELECT Length(firstname) FROM " + TestsConstants.TEST_INDEX_ACCOUNT +
" GROUP BY LENGTH(firstname) LIMIT 5");

verifySchema(response, schema("Length(firstname)", null, "integer"));
}

/*
trigFunctions
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
package com.amazon.opendistroforelasticsearch.sql.unittest.planner.converter;

import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.expr.SQLAggregateExpr;
import com.alibaba.druid.sql.ast.expr.SQLMethodInvokeExpr;
import com.alibaba.druid.sql.ast.expr.SQLQueryExpr;
import com.alibaba.druid.sql.ast.statement.SQLSelectItem;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlASTVisitorAdapter;
import com.alibaba.druid.util.JdbcConstants;
import com.amazon.opendistroforelasticsearch.sql.domain.ColumnTypeProvider;
import com.amazon.opendistroforelasticsearch.sql.expression.core.Expression;
Expand All @@ -34,6 +37,7 @@
import org.junit.runner.RunWith;
import org.mockito.runners.MockitoJUnitRunner;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

Expand Down Expand Up @@ -108,6 +112,30 @@ public void parseAggWithExpressionShouldPass() {
.ref("MIN_1")))));
}

@Test
public void parseWithRawSelectFuncnameShouldPass() {
String sql = "SELECT LOG(FlightDelayMin) " +
"FROM kibana_sample_data_flights " +
"GROUP BY log(FlightDelayMin)";
SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider());
parser.parse(mYSqlSelectQueryBlock(sql));
List<SQLSelectItem> sqlSelectItems = parser.selectItemList();
List<ColumnNode> columnNodes = parser.getColumnNodes();

assertThat(sqlSelectItems, containsInAnyOrder(group("log(FlightDelayMin)", "log(FlightDelayMin)")));

assertThat(
columnNodes,
containsInAnyOrder(
columnNode(
"LOG(FlightDelayMin)",
null,
ExpressionFactory.ref("log(FlightDelayMin)")
)
)
);
}

@Test
public void functionOverFiledShouldPass() {
String sql = "SELECT dayOfWeek, max(FlightDelayMin) + MIN(FlightDelayMin) as sub " +
Expand Down

0 comments on commit ad8ad3a

Please sign in to comment.