Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

FIX: field function name letter case preserved in select with group by #381

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,18 @@ public class SQLAggregationParser {
public void parse(MySqlSelectQueryBlock queryBlock) {
context = new Context(constructSQLExprAliasMapFromSelect(queryBlock));

//1. extract raw names of selectItems
List<String> selectItemNames = extractSelectItemNames(queryBlock.getSelectList());

//2. rewrite all the function name to lower case.
rewriteFunctionNameToLowerCase(queryBlock);

//2. find all GroupKeyExpr from GroupBy expression.
findAllGroupKeyExprFromGroupByAndSelect(queryBlock);
findAllAggregationExprFromSelect(queryBlock);

//3. parse the select list to expression
parseExprInSelectList(queryBlock, new SQLExprToExpressionConverter(context));
parseExprInSelectList(queryBlock, selectItemNames, new SQLExprToExpressionConverter(context));
}

public List<SQLSelectItem> selectItemList() {
Expand Down Expand Up @@ -149,12 +155,14 @@ public boolean visit(SQLAggregateExpr expr) {
}));
}

private void parseExprInSelectList(MySqlSelectQueryBlock queryBlock, SQLExprToExpressionConverter exprConverter) {
private void parseExprInSelectList(
MySqlSelectQueryBlock queryBlock, List<String> selectItemNames,
SQLExprToExpressionConverter exprConverter) {
List<SQLSelectItem> selectItems = queryBlock.getSelectList();
for (int i = 0; i < selectItems.size(); i++) {
Expression expression = exprConverter.convert(selectItems.get(i).getExpr());
ColumnNode columnNode = ColumnNode.builder()
.name(nameOfSelectItem(selectItems.get(i)))
.name(selectItemNames.get(i))
.alias(selectItems.get(i).getAlias())
.type(columnTypeProvider.get(i))
.expr(expression)
Expand All @@ -163,6 +171,24 @@ private void parseExprInSelectList(MySqlSelectQueryBlock queryBlock, SQLExprToEx
}
}

private List<String> extractSelectItemNames(List<SQLSelectItem> selectItems) {
List<String> selectItemNames = new ArrayList<>();
for (SQLSelectItem selectItem: selectItems){
selectItemNames.add(nameOfSelectItem(selectItem));
}
return selectItemNames;
}

private void rewriteFunctionNameToLowerCase(MySqlSelectQueryBlock query) {
query.accept(new MySqlASTVisitorAdapter() {
@Override
public boolean visit(SQLMethodInvokeExpr x) {
x.setMethodName(x.getMethodName().toLowerCase());
return true;
}
});
}

private String nameOfSelectItem(SQLSelectItem selectItem) {
return Strings.isNullOrEmpty(selectItem.getAlias()) ? Context
.nameOfExpr(selectItem.getExpr()) : selectItem.getAlias();
Expand Down Expand Up @@ -237,7 +263,7 @@ public static String nameOfExpr(SQLExpr expr) {
((SQLAggregateExpr) expr).getArguments().get(0));
} else if (expr instanceof SQLMethodInvokeExpr) {
exprName = String.format("%s(%s)", ((SQLMethodInvokeExpr) expr).getMethodName(),
nameOfExpr(((SQLMethodInvokeExpr) expr).getParameters().get(0)));
nameOfExpr(((SQLMethodInvokeExpr) expr).getParameters().get(0)));
} else if (expr instanceof SQLIdentifierExpr) {
exprName = ((SQLIdentifierExpr) expr).getName();
} else if (expr instanceof SQLCastExpr) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

package com.amazon.opendistroforelasticsearch.sql.query.planner.converter;

import com.alibaba.druid.sql.ast.expr.SQLMethodInvokeExpr;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlASTVisitorAdapter;
import com.amazon.opendistroforelasticsearch.sql.domain.ColumnTypeProvider;
Expand Down Expand Up @@ -54,16 +53,13 @@ public SQLToOperatorConverter(Client client, ColumnTypeProvider columnTypeProvid

@Override
public boolean visit(MySqlSelectQueryBlock query) {
//1. rewrite all the function name to lower case.
rewriteFunctionNameToLowerCase(query);

//2. parse the aggregation
//1. parse the aggregation
aggregationParser.parse(query);


//3. construct the PhysicalOperator
physicalOperator = project(
scroll(query));
//2. construct the PhysicalOperator
physicalOperator = project(scroll(query));
return false;
}

Expand All @@ -76,16 +72,6 @@ public List<ColumnNode> getColumnNodes() {
return aggregationParser.getColumnNodes();
}

private void rewriteFunctionNameToLowerCase(MySqlSelectQueryBlock query) {
query.accept(new MySqlASTVisitorAdapter() {
@Override
public boolean visit(SQLMethodInvokeExpr x) {
x.setMethodName(x.getMethodName().toLowerCase());
return true;
}
});
}

private PhysicalOperator<BindingTuple> project(PhysicalOperator<BindingTuple> input) {
return new PhysicalProject(input, aggregationParser.getColumnNodes());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,16 @@ public void stringOperatorNameCaseInsensitiveTest() {

@Test
public void dateFunctionNameCaseInsensitiveTest() {
assertEquals(
assertTrue(
penghuo marked this conversation as resolved.
Show resolved Hide resolved
executeQuery("SELECT DATE_FORMAT(insert_time, 'yyyy-MM-dd', 'UTC') FROM elasticsearch-sql_test_index_online " +
"WHERE date_FORMAT(insert_time, 'yyyy-MM-dd', 'UTC') > '2014-01-01' " +
"GROUP BY DAte_format(insert_time, 'yyyy-MM-dd', 'UTC') " +
"ORDER BY date_forMAT(insert_time, 'yyyy-MM-dd', 'UTC')", "jdbc"),
executeQuery("SELECT date_format(insert_time, 'yyyy-MM-dd', 'UTC') FROM elasticsearch-sql_test_index_online " +
"WHERE date_format(insert_time, 'yyyy-MM-dd', 'UTC') > '2014-01-01' " +
"GROUP BY date_format(insert_time, 'yyyy-MM-dd', 'UTC') " +
"ORDER BY date_format(insert_time, 'yyyy-MM-dd', 'UTC')", "jdbc")
"ORDER BY date_forMAT(insert_time, 'yyyy-MM-dd', 'UTC')", "jdbc").equalsIgnoreCase(
executeQuery("SELECT date_format(insert_time, 'yyyy-MM-dd', 'UTC') FROM elasticsearch-sql_test_index_online " +
"WHERE date_format(insert_time, 'yyyy-MM-dd', 'UTC') > '2014-01-01' " +
"GROUP BY date_format(insert_time, 'yyyy-MM-dd', 'UTC') " +
"ORDER BY date_format(insert_time, 'yyyy-MM-dd', 'UTC')", "jdbc")
)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ public void testLengthWithTextFieldReturnsInt() {
verifySchema(response, schema("length(firstname)", null, "integer"));
}

@Test
public void testLengthWithGroupByExpr() {
JSONObject response = executeJdbcRequest("SELECT Length(firstname) FROM " + TestsConstants.TEST_INDEX_ACCOUNT +
" GROUP BY LENGTH(firstname) LIMIT 5");

verifySchema(response, schema("Length(firstname)", null, "integer"));
}

/*
trigFunctions
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
package com.amazon.opendistroforelasticsearch.sql.unittest.planner.converter;

import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.expr.SQLAggregateExpr;
import com.alibaba.druid.sql.ast.expr.SQLMethodInvokeExpr;
import com.alibaba.druid.sql.ast.expr.SQLQueryExpr;
import com.alibaba.druid.sql.ast.statement.SQLSelectItem;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlASTVisitorAdapter;
import com.alibaba.druid.util.JdbcConstants;
import com.amazon.opendistroforelasticsearch.sql.domain.ColumnTypeProvider;
import com.amazon.opendistroforelasticsearch.sql.expression.core.Expression;
Expand All @@ -34,6 +37,7 @@
import org.junit.runner.RunWith;
import org.mockito.runners.MockitoJUnitRunner;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

Expand Down Expand Up @@ -108,6 +112,30 @@ public void parseAggWithExpressionShouldPass() {
.ref("MIN_1")))));
}

@Test
public void parseWithRawSelectFuncnameShouldPass() {
String sql = "SELECT LOG(FlightDelayMin) " +
"FROM kibana_sample_data_flights " +
"GROUP BY log(FlightDelayMin)";
SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider());
parser.parse(mYSqlSelectQueryBlock(sql));
List<SQLSelectItem> sqlSelectItems = parser.selectItemList();
List<ColumnNode> columnNodes = parser.getColumnNodes();

assertThat(sqlSelectItems, containsInAnyOrder(group("log(FlightDelayMin)", "log(FlightDelayMin)")));

assertThat(
columnNodes,
containsInAnyOrder(
columnNode(
"LOG(FlightDelayMin)",
null,
ExpressionFactory.ref("log(FlightDelayMin)")
)
)
);
}

@Test
public void functionOverFiledShouldPass() {
String sql = "SELECT dayOfWeek, max(FlightDelayMin) + MIN(FlightDelayMin) as sub " +
Expand Down