From 6437e9005024c1f8cf419073922e9d64bfb196c7 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Tue, 12 Sep 2023 16:05:45 -0700 Subject: [PATCH] feat(calcite): support reading in list and map literals (#177) * feat: inject ExpressionRexConverter into SubstraitRelNodeConverter * feat: allow subclasses of ExpressionRexConverter to re-use fields --- .../isthmus/SubstraitRelNodeConverter.java | 23 +++++++++++++--- .../expression/ExpressionRexConverter.java | 26 +++++++++++++++---- .../isthmus/ExpressionConvertabilityTest.java | 11 ++++++++ 3 files changed, 52 insertions(+), 8 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index ba1dc3f9e..036e30d84 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -84,15 +84,32 @@ public SubstraitRelNodeConverter( AggregateFunctionConverter aggregateFunctionConverter, WindowFunctionConverter windowFunctionConverter, TypeConverter typeConverter) { + this( + typeFactory, + relBuilder, + scalarFunctionConverter, + aggregateFunctionConverter, + windowFunctionConverter, + typeConverter, + new ExpressionRexConverter( + typeFactory, scalarFunctionConverter, windowFunctionConverter, typeConverter)); + } + + public SubstraitRelNodeConverter( + RelDataTypeFactory typeFactory, + RelBuilder relBuilder, + ScalarFunctionConverter scalarFunctionConverter, + AggregateFunctionConverter aggregateFunctionConverter, + WindowFunctionConverter windowFunctionConverter, + TypeConverter typeConverter, + ExpressionRexConverter expressionRexConverter) { this.typeFactory = typeFactory; this.typeConverter = typeConverter; this.relBuilder = relBuilder; this.rexBuilder = new RexBuilder(typeFactory); this.scalarFunctionConverter = scalarFunctionConverter; this.aggregateFunctionConverter = aggregateFunctionConverter; - this.expressionRexConverter = - new ExpressionRexConverter( - typeFactory, scalarFunctionConverter, windowFunctionConverter, typeConverter); + this.expressionRexConverter = expressionRexConverter; } public static RelNode convert( diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index 84f879ee0..7e2180991 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -44,11 +44,11 @@ */ public class ExpressionRexConverter extends AbstractExpressionVisitor implements FunctionArg.FuncArgVisitor { - private final RelDataTypeFactory typeFactory; - private final TypeConverter typeConverter; - private final RexBuilder rexBuilder; - private final ScalarFunctionConverter scalarFunctionConverter; - private final WindowFunctionConverter windowFunctionConverter; + protected final RelDataTypeFactory typeFactory; + protected final TypeConverter typeConverter; + protected final RexBuilder rexBuilder; + protected final ScalarFunctionConverter scalarFunctionConverter; + protected final WindowFunctionConverter windowFunctionConverter; private static final SqlIntervalQualifier YEAR_MONTH_INTERVAL = new SqlIntervalQualifier( @@ -218,6 +218,22 @@ public RexNode visit(Expression.DecimalLiteral expr) throws RuntimeException { return rexBuilder.makeLiteral(decimal, typeConverter.toCalcite(typeFactory, expr.getType())); } + @Override + public RexNode visit(Expression.ListLiteral expr) throws RuntimeException { + List args = + expr.values().stream().map(l -> l.accept(this)).collect(Collectors.toList()); + return rexBuilder.makeCall(SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, args); + } + + @Override + public RexNode visit(Expression.MapLiteral expr) throws RuntimeException { + var args = + expr.values().entrySet().stream() + .flatMap(entry -> Stream.of(entry.getKey().accept(this), entry.getValue().accept(this))) + .collect(Collectors.toList()); + return rexBuilder.makeCall(SqlStdOperatorTable.MAP_VALUE_CONSTRUCTOR, args); + } + @Override public RexNode visit(Expression.IfThen expr) throws RuntimeException { // In Calcite, the arguments to the CASE operator are given as: diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java b/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java index 55328ede1..455569a6b 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java @@ -16,6 +16,7 @@ import java.io.IOException; import java.util.List; import org.apache.calcite.rel.core.Filter; +import org.apache.calcite.sql.parser.SqlParseException; import org.junit.jupiter.api.Test; /** Tests which test that an expression can be converted to and from Calcite expressions. */ @@ -32,6 +33,16 @@ public class ExpressionConvertabilityTest extends PlanTestBase { final SubstraitToCalcite converter = new SubstraitToCalcite(extensions, typeFactory); + @Test + public void listLiteral() throws IOException, SqlParseException { + assertFullRoundTrip("select ARRAY[1,2,3] from ORDERS"); + } + + @Test + public void mapLiteral() throws IOException, SqlParseException { + assertFullRoundTrip("select MAP[1, 'hello'] from ORDERS"); + } + @Test public void singleOrList() throws IOException { Plan.Root root =