From 24dd82e81ac310308323a26d58b5a4bdab9ed43b Mon Sep 17 00:00:00 2001 From: George Chen Date: Fri, 13 Mar 2020 11:45:58 -0500 Subject: [PATCH 1/5] FIX: Method field name --- .../converter/SQLAggregationParser.java | 24 ++-- .../converter/SQLToOperatorConverter.java | 21 +++- .../sql/esintgtest/JdbcTestIT.java | 13 ++- .../sql/esintgtest/TypeInformationIT.java | 8 ++ .../converter/SQLAggregationParserTest.java | 104 ++++++++++++++++-- 5 files changed, 141 insertions(+), 29 deletions(-) diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLAggregationParser.java b/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLAggregationParser.java index 42628fc2ac..ce8e6c0933 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLAggregationParser.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLAggregationParser.java @@ -55,7 +55,7 @@ public class SQLAggregationParser { @Getter private List columnNodes = new ArrayList<>(); - public void parse(MySqlSelectQueryBlock queryBlock) { + public void parse(MySqlSelectQueryBlock queryBlock, List selectMethodNames) { context = new Context(constructSQLExprAliasMapFromSelect(queryBlock)); //2. find all GroupKeyExpr from GroupBy expression. @@ -63,7 +63,7 @@ public void parse(MySqlSelectQueryBlock queryBlock) { findAllAggregationExprFromSelect(queryBlock); //3. parse the select list to expression - parseExprInSelectList(queryBlock, new SQLExprToExpressionConverter(context)); + parseExprInSelectList(queryBlock, selectMethodNames, new SQLExprToExpressionConverter(context)); } public List selectItemList() { @@ -149,12 +149,14 @@ public boolean visit(SQLAggregateExpr expr) { })); } - private void parseExprInSelectList(MySqlSelectQueryBlock queryBlock, SQLExprToExpressionConverter exprConverter) { + private void parseExprInSelectList( + MySqlSelectQueryBlock queryBlock, List selectMethodNames, + SQLExprToExpressionConverter exprConverter) { List selectItems = queryBlock.getSelectList(); for (int i = 0; i < selectItems.size(); i++) { Expression expression = exprConverter.convert(selectItems.get(i).getExpr()); ColumnNode columnNode = ColumnNode.builder() - .name(nameOfSelectItem(selectItems.get(i))) + .name(nameOfSelectItem(selectItems.get(i), selectMethodNames.get(i))) .alias(selectItems.get(i).getAlias()) .type(columnTypeProvider.get(i)) .expr(expression) @@ -163,9 +165,9 @@ private void parseExprInSelectList(MySqlSelectQueryBlock queryBlock, SQLExprToEx } } - private String nameOfSelectItem(SQLSelectItem selectItem) { + private String nameOfSelectItem(SQLSelectItem selectItem, String selectMethodName) { return Strings.isNullOrEmpty(selectItem.getAlias()) ? Context - .nameOfExpr(selectItem.getExpr()) : selectItem.getAlias(); + .nameOfExpr(selectItem.getExpr(), selectMethodName) : selectItem.getAlias(); } @RequiredArgsConstructor @@ -208,7 +210,7 @@ public class GroupKeyExpr { public GroupKeyExpr(SQLExpr expr) { this.expr = expr; - String exprName = nameOfExpr(expr).replace(".", "#"); + String exprName = nameOfExpr(expr, null).replace(".", "#"); if (expr instanceof SQLIdentifierExpr && selectSQLExprAliasMap.values().contains(((SQLIdentifierExpr) expr).getName())) { exprName = ((SQLIdentifierExpr) expr).getName(); @@ -230,14 +232,16 @@ public AggregationExpr(SQLAggregateExpr expr) { } } - public static String nameOfExpr(SQLExpr expr) { + public static String nameOfExpr(SQLExpr expr, String selectMethodName) { String exprName = expr.toString().toLowerCase(); if (expr instanceof SQLAggregateExpr) { exprName = String.format("%s(%s)", ((SQLAggregateExpr) expr).getMethodName(), ((SQLAggregateExpr) expr).getArguments().get(0)); } else if (expr instanceof SQLMethodInvokeExpr) { - exprName = String.format("%s(%s)", ((SQLMethodInvokeExpr) expr).getMethodName(), - nameOfExpr(((SQLMethodInvokeExpr) expr).getParameters().get(0))); + String funcName = ( + selectMethodName == null ? ((SQLMethodInvokeExpr) expr).getMethodName() : selectMethodName); + exprName = String.format("%s(%s)", funcName, + nameOfExpr(((SQLMethodInvokeExpr) expr).getParameters().get(0), null)); } else if (expr instanceof SQLIdentifierExpr) { exprName = ((SQLIdentifierExpr) expr).getName(); } else if (expr instanceof SQLCastExpr) { diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLToOperatorConverter.java b/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLToOperatorConverter.java index e492d76171..6b0b402e54 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLToOperatorConverter.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLToOperatorConverter.java @@ -15,7 +15,9 @@ package com.amazon.opendistroforelasticsearch.sql.query.planner.converter; +import com.alibaba.druid.sql.ast.SQLExpr; import com.alibaba.druid.sql.ast.expr.SQLMethodInvokeExpr; +import com.alibaba.druid.sql.ast.statement.SQLSelectItem; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock; import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlASTVisitorAdapter; import com.amazon.opendistroforelasticsearch.sql.domain.ColumnTypeProvider; @@ -33,6 +35,7 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.client.Client; +import java.util.ArrayList; import java.util.List; /** @@ -54,11 +57,14 @@ public SQLToOperatorConverter(Client client, ColumnTypeProvider columnTypeProvid @Override public boolean visit(MySqlSelectQueryBlock query) { + //1. extract function names in select + List selectMethodNames = extractSelectFunctionNames(query.getSelectList()); + //1. rewrite all the function name to lower case. rewriteFunctionNameToLowerCase(query); //2. parse the aggregation - aggregationParser.parse(query); + aggregationParser.parse(query, selectMethodNames); //3. construct the PhysicalOperator @@ -76,6 +82,19 @@ public List getColumnNodes() { return aggregationParser.getColumnNodes(); } + public List extractSelectFunctionNames(List selectItems) { + List methodNames = new ArrayList<>(); + for (SQLSelectItem selectItem: selectItems){ + SQLExpr selectItemExpr = selectItem.getExpr(); + if (selectItemExpr instanceof SQLMethodInvokeExpr) { + methodNames.add(((SQLMethodInvokeExpr) selectItemExpr).getMethodName()); + } else { + methodNames.add(null); + } + } + return methodNames; + } + private void rewriteFunctionNameToLowerCase(MySqlSelectQueryBlock query) { query.accept(new MySqlASTVisitorAdapter() { @Override diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/sql/esintgtest/JdbcTestIT.java b/src/test/java/com/amazon/opendistroforelasticsearch/sql/esintgtest/JdbcTestIT.java index 8ad7ea7fa0..3bdcf384fe 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/sql/esintgtest/JdbcTestIT.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/sql/esintgtest/JdbcTestIT.java @@ -118,15 +118,16 @@ public void stringOperatorNameCaseInsensitiveTest() { @Test public void dateFunctionNameCaseInsensitiveTest() { - assertEquals( + assertTrue( executeQuery("SELECT DATE_FORMAT(insert_time, 'yyyy-MM-dd', 'UTC') FROM elasticsearch-sql_test_index_online " + "WHERE date_FORMAT(insert_time, 'yyyy-MM-dd', 'UTC') > '2014-01-01' " + "GROUP BY DAte_format(insert_time, 'yyyy-MM-dd', 'UTC') " + - "ORDER BY date_forMAT(insert_time, 'yyyy-MM-dd', 'UTC')", "jdbc"), - executeQuery("SELECT date_format(insert_time, 'yyyy-MM-dd', 'UTC') FROM elasticsearch-sql_test_index_online " + - "WHERE date_format(insert_time, 'yyyy-MM-dd', 'UTC') > '2014-01-01' " + - "GROUP BY date_format(insert_time, 'yyyy-MM-dd', 'UTC') " + - "ORDER BY date_format(insert_time, 'yyyy-MM-dd', 'UTC')", "jdbc") + "ORDER BY date_forMAT(insert_time, 'yyyy-MM-dd', 'UTC')", "jdbc").equalsIgnoreCase( + executeQuery("SELECT date_format(insert_time, 'yyyy-MM-dd', 'UTC') FROM elasticsearch-sql_test_index_online " + + "WHERE date_format(insert_time, 'yyyy-MM-dd', 'UTC') > '2014-01-01' " + + "GROUP BY date_format(insert_time, 'yyyy-MM-dd', 'UTC') " + + "ORDER BY date_format(insert_time, 'yyyy-MM-dd', 'UTC')", "jdbc") + ) ); } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/sql/esintgtest/TypeInformationIT.java b/src/test/java/com/amazon/opendistroforelasticsearch/sql/esintgtest/TypeInformationIT.java index 81d6198d98..80adad1ea5 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/sql/esintgtest/TypeInformationIT.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/sql/esintgtest/TypeInformationIT.java @@ -89,6 +89,14 @@ public void testLengthWithTextFieldReturnsInt() { verifySchema(response, schema("length(firstname)", null, "integer")); } + @Test + public void testLengthWithGroupByExpr() { + JSONObject response = executeJdbcRequest("SELECT Length(firstname) FROM " + TestsConstants.TEST_INDEX_ACCOUNT + + " GROUP BY LENGTH(firstname) LIMIT 5"); + + verifySchema(response, schema("Length(firstname)", null, "integer")); + } + /* trigFunctions */ diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/sql/unittest/planner/converter/SQLAggregationParserTest.java b/src/test/java/com/amazon/opendistroforelasticsearch/sql/unittest/planner/converter/SQLAggregationParserTest.java index 06259dabc6..67fb1bfd01 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/sql/unittest/planner/converter/SQLAggregationParserTest.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/sql/unittest/planner/converter/SQLAggregationParserTest.java @@ -16,10 +16,13 @@ package com.amazon.opendistroforelasticsearch.sql.unittest.planner.converter; import com.alibaba.druid.sql.SQLUtils; +import com.alibaba.druid.sql.ast.SQLExpr; import com.alibaba.druid.sql.ast.expr.SQLAggregateExpr; +import com.alibaba.druid.sql.ast.expr.SQLMethodInvokeExpr; import com.alibaba.druid.sql.ast.expr.SQLQueryExpr; import com.alibaba.druid.sql.ast.statement.SQLSelectItem; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock; +import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlASTVisitorAdapter; import com.alibaba.druid.util.JdbcConstants; import com.amazon.opendistroforelasticsearch.sql.domain.ColumnTypeProvider; import com.amazon.opendistroforelasticsearch.sql.expression.core.Expression; @@ -34,6 +37,7 @@ import org.junit.runner.RunWith; import org.mockito.runners.MockitoJUnitRunner; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -55,7 +59,9 @@ public void parseAggWithoutExpressionShouldPass() { "FROM kibana_sample_data_flights " + "GROUP BY dayOfWeek"; SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); - parser.parse(mYSqlSelectQueryBlock(sql)); + MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); + List selectMethodNames = prepareRawMethodNames(query, true); + parser.parse(query, selectMethodNames); List sqlSelectItems = parser.selectItemList(); List columnNodes = parser.getColumnNodes(); @@ -75,7 +81,9 @@ public void parseAggWithFunctioniWithoutExpressionShouldPass() { "FROM kibana_sample_data_flights " + "GROUP BY dayOfWeek"; SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); - parser.parse(mYSqlSelectQueryBlock(sql)); + MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); + List selectMethodNames = prepareRawMethodNames(query, true); + parser.parse(query, selectMethodNames); List sqlSelectItems = parser.selectItemList(); List columnNodes = parser.getColumnNodes(); @@ -95,7 +103,9 @@ public void parseAggWithExpressionShouldPass() { "FROM kibana_sample_data_flights " + "GROUP BY dayOfWeek"; SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); - parser.parse(mYSqlSelectQueryBlock(sql)); + MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); + List selectMethodNames = prepareRawMethodNames(query, true); + parser.parse(query, selectMethodNames); List sqlSelectItems = parser.selectItemList(); List columnNodes = parser.getColumnNodes(); @@ -108,13 +118,48 @@ public void parseAggWithExpressionShouldPass() { .ref("MIN_1"))))); } + @Test + public void parseWithRawSelectFuncnameShouldPass() { + String sql = "SELECT LOG(FlightDelayMin) " + + "FROM kibana_sample_data_flights " + + "GROUP BY log(FlightDelayMin)"; + SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); + MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); + List selectMethodNames = prepareRawMethodNames(query, false); + query.accept(new MySqlASTVisitorAdapter() { + @Override + public boolean visit(SQLMethodInvokeExpr x) { + x.setMethodName(x.getMethodName().toLowerCase()); + return true; + } + }); + parser.parse(query, selectMethodNames); + List sqlSelectItems = parser.selectItemList(); + List columnNodes = parser.getColumnNodes(); + + assertThat(sqlSelectItems, containsInAnyOrder(group("log(FlightDelayMin)", "log(FlightDelayMin)"))); + + assertThat( + columnNodes, + containsInAnyOrder( + columnNode( + "LOG(FlightDelayMin)", + null, + ExpressionFactory.ref("log(FlightDelayMin)") + ) + ) + ); + } + @Test public void functionOverFiledShouldPass() { String sql = "SELECT dayOfWeek, max(FlightDelayMin) + MIN(FlightDelayMin) as sub " + "FROM kibana_sample_data_flights " + "GROUP BY dayOfWeek"; SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); - parser.parse(mYSqlSelectQueryBlock(sql)); + MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); + List selectMethodNames = prepareRawMethodNames(query, true); + parser.parse(query, selectMethodNames); List sqlSelectItems = parser.selectItemList(); List columnNodes = parser.getColumnNodes(); @@ -133,7 +178,9 @@ public void parseCompoundAggWithExpressionShouldPass() { "FROM kibana_sample_data_flights " + "GROUP BY ASCII(dayOfWeek)"; SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); - parser.parse(mYSqlSelectQueryBlock(sql)); + MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); + List selectMethodNames = prepareRawMethodNames(query, true); + parser.parse(query, selectMethodNames); List sqlSelectItems = parser.selectItemList(); List columnNodes = parser.getColumnNodes(); @@ -151,7 +198,9 @@ public void parseCompoundAggWithExpressionShouldPass() { public void parseSingleFunctionOverAggShouldPass() { String sql = "SELECT log(max(age)) FROM accounts"; SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); - parser.parse(mYSqlSelectQueryBlock(sql)); + MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); + List selectMethodNames = prepareRawMethodNames(query, true); + parser.parse(query, selectMethodNames); List sqlSelectItems = parser.selectItemList(); List columnNodes = parser.getColumnNodes(); @@ -164,7 +213,9 @@ public void parseSingleFunctionOverAggShouldPass() { public void parseFunctionGroupColumnOverShouldPass() { String sql = "SELECT CAST(balance AS FLOAT) FROM accounts GROUP BY balance"; SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); - parser.parse(mYSqlSelectQueryBlock(sql)); + MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); + List selectMethodNames = prepareRawMethodNames(query, true); + parser.parse(query, selectMethodNames); List sqlSelectItems = parser.selectItemList(); List columnNodes = parser.getColumnNodes(); @@ -177,7 +228,9 @@ public void parseFunctionGroupColumnOverShouldPass() { public void withoutAggregationShouldPass() { String sql = "SELECT age, gender FROM accounts GROUP BY age, gender"; SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); - parser.parse(mYSqlSelectQueryBlock(sql)); + MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); + List selectMethodNames = prepareRawMethodNames(query, true); + parser.parse(query, selectMethodNames); List sqlSelectItems = parser.selectItemList(); List columnNodes = parser.getColumnNodes(); @@ -193,7 +246,9 @@ public void withoutAggregationShouldPass() { public void groupKeyInSelectWithFunctionShouldPass() { String sql = "SELECT log(age), max(balance) FROM accounts GROUP BY age"; SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); - parser.parse(mYSqlSelectQueryBlock(sql)); + MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); + List selectMethodNames = prepareRawMethodNames(query, true); + parser.parse(query, selectMethodNames); List sqlSelectItems = parser.selectItemList(); List columnNodes = parser.getColumnNodes(); @@ -209,7 +264,9 @@ public void groupKeyInSelectWithFunctionShouldPass() { public void theDotInFieldNameShouldBeReplaceWithSharp() { String sql = "SELECT name.lastname, max(balance) FROM accounts GROUP BY name.lastname"; SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); - parser.parse(mYSqlSelectQueryBlock(sql)); + MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); + List selectMethodNames = prepareRawMethodNames(query, true); + parser.parse(query, selectMethodNames); List sqlSelectItems = parser.selectItemList(); List columnNodes = parser.getColumnNodes(); @@ -225,7 +282,9 @@ public void theDotInFieldNameShouldBeReplaceWithSharp() { public void noGroupKeyInSelectShouldPass() { String sql = "SELECT AVG(age) FROM t GROUP BY age"; SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); - parser.parse(mYSqlSelectQueryBlock(sql)); + MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); + List selectMethodNames = prepareRawMethodNames(query, true); + parser.parse(query, selectMethodNames); List sqlSelectItems = parser.selectItemList(); List columnNodes = parser.getColumnNodes(); @@ -247,7 +306,28 @@ public void aggregationWithNestedShouldThrowException() { + "FROM t " + "GROUP BY nested(projects.name.keyword, 'projects')"; SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); - parser.parse(mYSqlSelectQueryBlock(sql)); + MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); + List selectMethodNames = prepareRawMethodNames(query, true); + parser.parse(query, selectMethodNames); + } + + private List prepareRawMethodNames(MySqlSelectQueryBlock query, Boolean fillWithNull) { + List selectItems = query.getSelectList(); + if (fillWithNull) { + return Arrays.asList(new String[selectItems.size()]); + } + List selectMethodNames = new ArrayList<>(); + for (SQLSelectItem selectItem: selectItems){ + SQLExpr selectItemExpr = selectItem.getExpr(); + if (selectItemExpr instanceof SQLMethodInvokeExpr) { + selectMethodNames.add(((SQLMethodInvokeExpr) selectItemExpr).getMethodName()); + } else if (selectItemExpr instanceof SQLAggregateExpr) { + selectMethodNames.add(((SQLAggregateExpr) selectItemExpr).getMethodName()); + } else { + selectMethodNames.add(null); + } + } + return selectMethodNames; } private MySqlSelectQueryBlock mYSqlSelectQueryBlock(String sql) { From e59e2eb4a736ee682e1bad056144e0a17d3002f2 Mon Sep 17 00:00:00 2001 From: George Chen Date: Wed, 18 Mar 2020 10:42:48 -0500 Subject: [PATCH 2/5] REF: parser logic --- .../converter/SQLAggregationParser.java | 46 ++++++++--- .../converter/SQLToOperatorConverter.java | 24 +----- .../converter/SQLAggregationParserTest.java | 78 ++++--------------- 3 files changed, 51 insertions(+), 97 deletions(-) diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLAggregationParser.java b/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLAggregationParser.java index ce8e6c0933..e847fca83a 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLAggregationParser.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLAggregationParser.java @@ -55,15 +55,21 @@ public class SQLAggregationParser { @Getter private List columnNodes = new ArrayList<>(); - public void parse(MySqlSelectQueryBlock queryBlock, List selectMethodNames) { + public void parse(MySqlSelectQueryBlock queryBlock) { context = new Context(constructSQLExprAliasMapFromSelect(queryBlock)); + //1. extract raw names of selectItems + List selectItemNames = extractSelectItemNames(queryBlock.getSelectList()); + + //2. rewrite all the function name to lower case. + rewriteFunctionNameToLowerCase(queryBlock); + //2. find all GroupKeyExpr from GroupBy expression. findAllGroupKeyExprFromGroupByAndSelect(queryBlock); findAllAggregationExprFromSelect(queryBlock); //3. parse the select list to expression - parseExprInSelectList(queryBlock, selectMethodNames, new SQLExprToExpressionConverter(context)); + parseExprInSelectList(queryBlock, selectItemNames, new SQLExprToExpressionConverter(context)); } public List selectItemList() { @@ -150,13 +156,13 @@ public boolean visit(SQLAggregateExpr expr) { } private void parseExprInSelectList( - MySqlSelectQueryBlock queryBlock, List selectMethodNames, + MySqlSelectQueryBlock queryBlock, List selectItemNames, SQLExprToExpressionConverter exprConverter) { List selectItems = queryBlock.getSelectList(); for (int i = 0; i < selectItems.size(); i++) { Expression expression = exprConverter.convert(selectItems.get(i).getExpr()); ColumnNode columnNode = ColumnNode.builder() - .name(nameOfSelectItem(selectItems.get(i), selectMethodNames.get(i))) + .name(selectItemNames.get(i)) .alias(selectItems.get(i).getAlias()) .type(columnTypeProvider.get(i)) .expr(expression) @@ -165,9 +171,27 @@ private void parseExprInSelectList( } } - private String nameOfSelectItem(SQLSelectItem selectItem, String selectMethodName) { + private List extractSelectItemNames(List selectItems) { + List selectItemNames = new ArrayList<>(); + for (SQLSelectItem selectItem: selectItems){ + selectItemNames.add(nameOfSelectItem(selectItem)); + } + return selectItemNames; + } + + private void rewriteFunctionNameToLowerCase(MySqlSelectQueryBlock query) { + query.accept(new MySqlASTVisitorAdapter() { + @Override + public boolean visit(SQLMethodInvokeExpr x) { + x.setMethodName(x.getMethodName().toLowerCase()); + return true; + } + }); + } + + private String nameOfSelectItem(SQLSelectItem selectItem) { return Strings.isNullOrEmpty(selectItem.getAlias()) ? Context - .nameOfExpr(selectItem.getExpr(), selectMethodName) : selectItem.getAlias(); + .nameOfExpr(selectItem.getExpr()) : selectItem.getAlias(); } @RequiredArgsConstructor @@ -210,7 +234,7 @@ public class GroupKeyExpr { public GroupKeyExpr(SQLExpr expr) { this.expr = expr; - String exprName = nameOfExpr(expr, null).replace(".", "#"); + String exprName = nameOfExpr(expr).replace(".", "#"); if (expr instanceof SQLIdentifierExpr && selectSQLExprAliasMap.values().contains(((SQLIdentifierExpr) expr).getName())) { exprName = ((SQLIdentifierExpr) expr).getName(); @@ -232,16 +256,14 @@ public AggregationExpr(SQLAggregateExpr expr) { } } - public static String nameOfExpr(SQLExpr expr, String selectMethodName) { + public static String nameOfExpr(SQLExpr expr) { String exprName = expr.toString().toLowerCase(); if (expr instanceof SQLAggregateExpr) { exprName = String.format("%s(%s)", ((SQLAggregateExpr) expr).getMethodName(), ((SQLAggregateExpr) expr).getArguments().get(0)); } else if (expr instanceof SQLMethodInvokeExpr) { - String funcName = ( - selectMethodName == null ? ((SQLMethodInvokeExpr) expr).getMethodName() : selectMethodName); - exprName = String.format("%s(%s)", funcName, - nameOfExpr(((SQLMethodInvokeExpr) expr).getParameters().get(0), null)); + exprName = String.format("%s(%s)", ((SQLMethodInvokeExpr) expr).getMethodName(), + nameOfExpr(((SQLMethodInvokeExpr) expr).getParameters().get(0))); } else if (expr instanceof SQLIdentifierExpr) { exprName = ((SQLIdentifierExpr) expr).getName(); } else if (expr instanceof SQLCastExpr) { diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLToOperatorConverter.java b/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLToOperatorConverter.java index 6b0b402e54..58ac4cd5de 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLToOperatorConverter.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLToOperatorConverter.java @@ -57,19 +57,13 @@ public SQLToOperatorConverter(Client client, ColumnTypeProvider columnTypeProvid @Override public boolean visit(MySqlSelectQueryBlock query) { - //1. extract function names in select - List selectMethodNames = extractSelectFunctionNames(query.getSelectList()); - //1. rewrite all the function name to lower case. - rewriteFunctionNameToLowerCase(query); + //1. parse the aggregation + aggregationParser.parse(query); - //2. parse the aggregation - aggregationParser.parse(query, selectMethodNames); - - //3. construct the PhysicalOperator - physicalOperator = project( - scroll(query)); + //2. construct the PhysicalOperator + physicalOperator = project(scroll(query)); return false; } @@ -95,16 +89,6 @@ public List extractSelectFunctionNames(List selectItems) return methodNames; } - private void rewriteFunctionNameToLowerCase(MySqlSelectQueryBlock query) { - query.accept(new MySqlASTVisitorAdapter() { - @Override - public boolean visit(SQLMethodInvokeExpr x) { - x.setMethodName(x.getMethodName().toLowerCase()); - return true; - } - }); - } - private PhysicalOperator project(PhysicalOperator input) { return new PhysicalProject(input, aggregationParser.getColumnNodes()); } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/sql/unittest/planner/converter/SQLAggregationParserTest.java b/src/test/java/com/amazon/opendistroforelasticsearch/sql/unittest/planner/converter/SQLAggregationParserTest.java index 67fb1bfd01..0c825b5476 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/sql/unittest/planner/converter/SQLAggregationParserTest.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/sql/unittest/planner/converter/SQLAggregationParserTest.java @@ -59,9 +59,7 @@ public void parseAggWithoutExpressionShouldPass() { "FROM kibana_sample_data_flights " + "GROUP BY dayOfWeek"; SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); - MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); - List selectMethodNames = prepareRawMethodNames(query, true); - parser.parse(query, selectMethodNames); + parser.parse(mYSqlSelectQueryBlock(sql)); List sqlSelectItems = parser.selectItemList(); List columnNodes = parser.getColumnNodes(); @@ -81,9 +79,7 @@ public void parseAggWithFunctioniWithoutExpressionShouldPass() { "FROM kibana_sample_data_flights " + "GROUP BY dayOfWeek"; SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); - MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); - List selectMethodNames = prepareRawMethodNames(query, true); - parser.parse(query, selectMethodNames); + parser.parse(mYSqlSelectQueryBlock(sql)); List sqlSelectItems = parser.selectItemList(); List columnNodes = parser.getColumnNodes(); @@ -103,9 +99,7 @@ public void parseAggWithExpressionShouldPass() { "FROM kibana_sample_data_flights " + "GROUP BY dayOfWeek"; SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); - MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); - List selectMethodNames = prepareRawMethodNames(query, true); - parser.parse(query, selectMethodNames); + parser.parse(mYSqlSelectQueryBlock(sql)); List sqlSelectItems = parser.selectItemList(); List columnNodes = parser.getColumnNodes(); @@ -124,16 +118,7 @@ public void parseWithRawSelectFuncnameShouldPass() { "FROM kibana_sample_data_flights " + "GROUP BY log(FlightDelayMin)"; SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); - MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); - List selectMethodNames = prepareRawMethodNames(query, false); - query.accept(new MySqlASTVisitorAdapter() { - @Override - public boolean visit(SQLMethodInvokeExpr x) { - x.setMethodName(x.getMethodName().toLowerCase()); - return true; - } - }); - parser.parse(query, selectMethodNames); + parser.parse(mYSqlSelectQueryBlock(sql)); List sqlSelectItems = parser.selectItemList(); List columnNodes = parser.getColumnNodes(); @@ -157,9 +142,7 @@ public void functionOverFiledShouldPass() { "FROM kibana_sample_data_flights " + "GROUP BY dayOfWeek"; SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); - MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); - List selectMethodNames = prepareRawMethodNames(query, true); - parser.parse(query, selectMethodNames); + parser.parse(mYSqlSelectQueryBlock(sql)); List sqlSelectItems = parser.selectItemList(); List columnNodes = parser.getColumnNodes(); @@ -178,9 +161,7 @@ public void parseCompoundAggWithExpressionShouldPass() { "FROM kibana_sample_data_flights " + "GROUP BY ASCII(dayOfWeek)"; SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); - MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); - List selectMethodNames = prepareRawMethodNames(query, true); - parser.parse(query, selectMethodNames); + parser.parse(mYSqlSelectQueryBlock(sql)); List sqlSelectItems = parser.selectItemList(); List columnNodes = parser.getColumnNodes(); @@ -198,9 +179,7 @@ public void parseCompoundAggWithExpressionShouldPass() { public void parseSingleFunctionOverAggShouldPass() { String sql = "SELECT log(max(age)) FROM accounts"; SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); - MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); - List selectMethodNames = prepareRawMethodNames(query, true); - parser.parse(query, selectMethodNames); + parser.parse(mYSqlSelectQueryBlock(sql)); List sqlSelectItems = parser.selectItemList(); List columnNodes = parser.getColumnNodes(); @@ -213,9 +192,7 @@ public void parseSingleFunctionOverAggShouldPass() { public void parseFunctionGroupColumnOverShouldPass() { String sql = "SELECT CAST(balance AS FLOAT) FROM accounts GROUP BY balance"; SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); - MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); - List selectMethodNames = prepareRawMethodNames(query, true); - parser.parse(query, selectMethodNames); + parser.parse(mYSqlSelectQueryBlock(sql)); List sqlSelectItems = parser.selectItemList(); List columnNodes = parser.getColumnNodes(); @@ -228,9 +205,7 @@ public void parseFunctionGroupColumnOverShouldPass() { public void withoutAggregationShouldPass() { String sql = "SELECT age, gender FROM accounts GROUP BY age, gender"; SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); - MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); - List selectMethodNames = prepareRawMethodNames(query, true); - parser.parse(query, selectMethodNames); + parser.parse(mYSqlSelectQueryBlock(sql)); List sqlSelectItems = parser.selectItemList(); List columnNodes = parser.getColumnNodes(); @@ -246,9 +221,7 @@ public void withoutAggregationShouldPass() { public void groupKeyInSelectWithFunctionShouldPass() { String sql = "SELECT log(age), max(balance) FROM accounts GROUP BY age"; SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); - MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); - List selectMethodNames = prepareRawMethodNames(query, true); - parser.parse(query, selectMethodNames); + parser.parse(mYSqlSelectQueryBlock(sql)); List sqlSelectItems = parser.selectItemList(); List columnNodes = parser.getColumnNodes(); @@ -264,9 +237,7 @@ public void groupKeyInSelectWithFunctionShouldPass() { public void theDotInFieldNameShouldBeReplaceWithSharp() { String sql = "SELECT name.lastname, max(balance) FROM accounts GROUP BY name.lastname"; SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); - MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); - List selectMethodNames = prepareRawMethodNames(query, true); - parser.parse(query, selectMethodNames); + parser.parse(mYSqlSelectQueryBlock(sql)); List sqlSelectItems = parser.selectItemList(); List columnNodes = parser.getColumnNodes(); @@ -282,9 +253,7 @@ public void theDotInFieldNameShouldBeReplaceWithSharp() { public void noGroupKeyInSelectShouldPass() { String sql = "SELECT AVG(age) FROM t GROUP BY age"; SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); - MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); - List selectMethodNames = prepareRawMethodNames(query, true); - parser.parse(query, selectMethodNames); + parser.parse(mYSqlSelectQueryBlock(sql)); List sqlSelectItems = parser.selectItemList(); List columnNodes = parser.getColumnNodes(); @@ -306,28 +275,7 @@ public void aggregationWithNestedShouldThrowException() { + "FROM t " + "GROUP BY nested(projects.name.keyword, 'projects')"; SQLAggregationParser parser = new SQLAggregationParser(new ColumnTypeProvider()); - MySqlSelectQueryBlock query = mYSqlSelectQueryBlock(sql); - List selectMethodNames = prepareRawMethodNames(query, true); - parser.parse(query, selectMethodNames); - } - - private List prepareRawMethodNames(MySqlSelectQueryBlock query, Boolean fillWithNull) { - List selectItems = query.getSelectList(); - if (fillWithNull) { - return Arrays.asList(new String[selectItems.size()]); - } - List selectMethodNames = new ArrayList<>(); - for (SQLSelectItem selectItem: selectItems){ - SQLExpr selectItemExpr = selectItem.getExpr(); - if (selectItemExpr instanceof SQLMethodInvokeExpr) { - selectMethodNames.add(((SQLMethodInvokeExpr) selectItemExpr).getMethodName()); - } else if (selectItemExpr instanceof SQLAggregateExpr) { - selectMethodNames.add(((SQLAggregateExpr) selectItemExpr).getMethodName()); - } else { - selectMethodNames.add(null); - } - } - return selectMethodNames; + parser.parse(mYSqlSelectQueryBlock(sql)); } private MySqlSelectQueryBlock mYSqlSelectQueryBlock(String sql) { From 5b138beb811c02599cb88e4fd72b4e90b07bfa4c Mon Sep 17 00:00:00 2001 From: George Chen Date: Wed, 18 Mar 2020 10:47:43 -0500 Subject: [PATCH 3/5] RMV: remove unused function --- .../planner/converter/SQLToOperatorConverter.java | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLToOperatorConverter.java b/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLToOperatorConverter.java index 58ac4cd5de..fc665157be 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLToOperatorConverter.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLToOperatorConverter.java @@ -76,19 +76,6 @@ public List getColumnNodes() { return aggregationParser.getColumnNodes(); } - public List extractSelectFunctionNames(List selectItems) { - List methodNames = new ArrayList<>(); - for (SQLSelectItem selectItem: selectItems){ - SQLExpr selectItemExpr = selectItem.getExpr(); - if (selectItemExpr instanceof SQLMethodInvokeExpr) { - methodNames.add(((SQLMethodInvokeExpr) selectItemExpr).getMethodName()); - } else { - methodNames.add(null); - } - } - return methodNames; - } - private PhysicalOperator project(PhysicalOperator input) { return new PhysicalProject(input, aggregationParser.getColumnNodes()); } From ae275fd14a56487f6296c3868e5773c126ff0344 Mon Sep 17 00:00:00 2001 From: George Chen Date: Wed, 18 Mar 2020 10:49:57 -0500 Subject: [PATCH 4/5] STY: unused imports --- .../sql/query/planner/converter/SQLToOperatorConverter.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLToOperatorConverter.java b/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLToOperatorConverter.java index fc665157be..a780b9f24c 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLToOperatorConverter.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLToOperatorConverter.java @@ -15,9 +15,6 @@ package com.amazon.opendistroforelasticsearch.sql.query.planner.converter; -import com.alibaba.druid.sql.ast.SQLExpr; -import com.alibaba.druid.sql.ast.expr.SQLMethodInvokeExpr; -import com.alibaba.druid.sql.ast.statement.SQLSelectItem; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock; import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlASTVisitorAdapter; import com.amazon.opendistroforelasticsearch.sql.domain.ColumnTypeProvider; From c7aaf016d7cc28a84dc41d088de9bf3300b3ed7a Mon Sep 17 00:00:00 2001 From: George Chen Date: Wed, 18 Mar 2020 10:51:51 -0500 Subject: [PATCH 5/5] STY: unused import --- .../sql/query/planner/converter/SQLToOperatorConverter.java | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLToOperatorConverter.java b/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLToOperatorConverter.java index a780b9f24c..4312e2bc09 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLToOperatorConverter.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/sql/query/planner/converter/SQLToOperatorConverter.java @@ -32,7 +32,6 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.client.Client; -import java.util.ArrayList; import java.util.List; /**