Skip to content

Commit

Permalink
support case statement in Calcite Aggregation Extractor (#2330)
Browse files Browse the repository at this point in the history
Co-authored-by: Chandrasekar Rajasekar <[email protected]>
Co-authored-by: Aaron Klish <[email protected]>
  • Loading branch information
3 people authored Oct 8, 2021
1 parent 64325a0 commit 9ad0d58
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ public List<List<String>> visit(SqlCall call) {
}

for (SqlNode node : call.getOperandList()) {
if (node == null) {
continue;
}
List<List<String>> operandResults = node.accept(this);
if (operandResults != null) {
result.addAll(operandResults);
Expand All @@ -60,6 +63,9 @@ public List<List<String>> visit(SqlCall call) {
public List<List<String>> visit(SqlNodeList nodeList) {
List<List<String>> result = new ArrayList<>();
for (SqlNode node : nodeList) {
if (node == null) {
continue;
}
List<List<String>> inner = node.accept(this);
if (inner != null) {
result.addAll(inner);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.util.SqlBasicVisitor;

import java.util.ArrayDeque;
import java.util.List;
import java.util.Queue;
import java.util.stream.Collectors;

/**
* Parses a column expression and rewrites the post aggregation expression AST to reference
Expand Down Expand Up @@ -58,15 +60,19 @@ public SqlNode visit(SqlCall call) {
}
for (int idx = 0; idx < call.getOperandList().size(); idx++) {
SqlNode operand = call.getOperandList().get(idx);
call.setOperand(idx, operand.accept(this));
call.setOperand(idx, operand == null ? null : operand.accept(this));
}

return call;
}

@Override
public SqlNode visit(SqlNodeList nodeList) {
return nodeList;

return SqlNodeList.of(SqlParserPos.ZERO,
nodeList.getList().stream()
.map(sqlNode -> sqlNode == null ? null : sqlNode.accept(this))
.collect(Collectors.toList()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,26 @@ public void testExpressionParsing() throws Exception {
assertEquals("SUM(`blah`)", aggregations.get(1).get(0));
}

@Test
public void testCaseStmtParsing() throws Exception {
String sql = " CASE\n"
+ " WHEN SUM(blah) = 0 THEN 1\n"
+ " ELSE SUM('number_of_lectures') / SUM(blah)\n"
+ " END";
SqlParser sqlParser = SqlParser.create(sql, CalciteUtils.constructParserConfig(dialect));
SqlNode node = sqlParser.parseExpression();
CalciteInnerAggregationExtractor extractor = new CalciteInnerAggregationExtractor(dialect);
List<List<String>> aggregations = node.accept(extractor);

assertEquals(3, aggregations.size());
assertEquals(1, aggregations.get(0).size());
assertEquals(1, aggregations.get(1).size());
assertEquals(1, aggregations.get(2).size());
assertEquals("SUM(`blah`)", aggregations.get(0).get(0));
assertEquals("SUM('number_of_lectures')", aggregations.get(1).get(0));
assertEquals("SUM(`blah`)", aggregations.get(2).get(0));
}

@Test
public void testInvalidAggregationFunction() throws Exception {
String sql = "CUSTOM_SUM(blah)";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,24 @@ public void testExpressionParsing() throws Exception {
assertEquals(expected, actual);
}

@Test
public void testCaseStmtParsing() throws Exception {
String sql = " CASE\n"
+ " WHEN SUM(blah) = 0 THEN 1\n"
+ " ELSE SUM('number_of_lectures') / SUM(blah)\n"
+ " END";

SqlParser sqlParser = SqlParser.create(sql, CalciteUtils.constructParserConfig(dialect));
SqlNode node = sqlParser.parseExpression();

List<List<String>> substitutions = Arrays.asList(Arrays.asList("SUB1"), Arrays.asList("SUB2"), Arrays.asList("SUB1"));
CalciteOuterAggregationExtractor extractor = new CalciteOuterAggregationExtractor(dialect, substitutions);
String actual = node.accept(extractor).toSqlString(dialect.getCalciteDialect()).getSql();
String expected = "CASE WHEN SUM(`SUB1`) = 0 THEN 1 ELSE SUM(`SUB2`) / SUM(`SUB1`) END";

assertEquals(expected, actual);
}

@Test
public void testCustomAggregationFunction() throws Exception {
String sql = "CUSTOM_SUM(blah)";
Expand Down

0 comments on commit 9ad0d58

Please sign in to comment.