diff --git a/arrow/src/main/java/org/apache/calcite/adapter/arrow/ArrowFieldType.java b/arrow/src/main/java/org/apache/calcite/adapter/arrow/ArrowFieldType.java index 9a0fa3033d4..f36e495a2a5 100644 --- a/arrow/src/main/java/org/apache/calcite/adapter/arrow/ArrowFieldType.java +++ b/arrow/src/main/java/org/apache/calcite/adapter/arrow/ArrowFieldType.java @@ -25,27 +25,17 @@ import java.math.BigDecimal; import java.sql.Date; -import java.util.List; import static java.util.Objects.requireNonNull; /** * Arrow field type. */ -enum ArrowFieldType { - INT(Primitive.INT), - BOOLEAN(Primitive.BOOLEAN), - STRING(String.class), - FLOAT(Primitive.FLOAT), - DOUBLE(Primitive.DOUBLE), - DATE(Date.class), - LIST(List.class), - DECIMAL(BigDecimal.class), - LONG(Primitive.LONG), - BYTE(Primitive.BYTE), - SHORT(Primitive.SHORT); +class ArrowFieldType { private final Class clazz; + private int precision; + private int scale; ArrowFieldType(Primitive primitive) { this(requireNonNull(primitive.boxClass, "boxClass")); @@ -55,9 +45,20 @@ enum ArrowFieldType { this.clazz = clazz; } + ArrowFieldType(Class clazz, int precision, int scale) { + this.clazz = clazz; + this.precision = precision; + this.scale = scale; + } + public RelDataType toType(JavaTypeFactory typeFactory) { RelDataType javaType = typeFactory.createJavaType(clazz); - RelDataType sqlType = typeFactory.createSqlType(javaType.getSqlTypeName()); + RelDataType sqlType = null; + if (javaType.getSqlTypeName().getName().equals("DECIMAL")) { + sqlType = typeFactory.createSqlType(javaType.getSqlTypeName(), precision, scale); + } else { + sqlType = typeFactory.createSqlType(javaType.getSqlTypeName()); + } return typeFactory.createTypeWithNullability(sqlType, true); } @@ -67,34 +68,36 @@ public static ArrowFieldType of(ArrowType arrowType) { int bitWidth = ((ArrowType.Int) arrowType).getBitWidth(); switch (bitWidth) { case 64: - return LONG; + return new ArrowFieldType(Primitive.LONG); case 32: - return INT; + return new ArrowFieldType(Primitive.INT); case 16: - return SHORT; + return new ArrowFieldType(Primitive.SHORT); case 8: - return BYTE; + return new ArrowFieldType(Primitive.BYTE); default: throw new IllegalArgumentException("Unsupported Int bit width: " + bitWidth); } case Bool: - return BOOLEAN; + return new ArrowFieldType(Primitive.BOOLEAN); case Utf8: - return STRING; + return new ArrowFieldType(String.class); case FloatingPoint: FloatingPointPrecision precision = ((ArrowType.FloatingPoint) arrowType).getPrecision(); switch (precision) { case SINGLE: - return FLOAT; + return new ArrowFieldType(Primitive.FLOAT); case DOUBLE: - return DOUBLE; + return new ArrowFieldType(Primitive.DOUBLE); default: throw new IllegalArgumentException("Unsupported Floating point precision: " + precision); } case Date: - return DATE; + return new ArrowFieldType(Date.class); case Decimal: - return DECIMAL; + return new ArrowFieldType(BigDecimal.class, + ((ArrowType.Decimal) arrowType).getPrecision(), + ((ArrowType.Decimal) arrowType).getScale()); default: throw new IllegalArgumentException("Unsupported type: " + arrowType); } diff --git a/arrow/src/test/java/org/apache/calcite/adapter/arrow/ArrowAdapterDataTypesTest.java b/arrow/src/test/java/org/apache/calcite/adapter/arrow/ArrowAdapterDataTypesTest.java index 87dff81d4e0..23d4386dff7 100644 --- a/arrow/src/test/java/org/apache/calcite/adapter/arrow/ArrowAdapterDataTypesTest.java +++ b/arrow/src/test/java/org/apache/calcite/adapter/arrow/ArrowAdapterDataTypesTest.java @@ -68,7 +68,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir) String sql = "select \"tinyIntField\" from arrowdatatype"; String plan = "PLAN=ArrowToEnumerableConverter\n" + " ArrowProject(tinyIntField=[$0])\n" - + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n"; + + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n"; String result = "tinyIntField=0\ntinyIntField=1\n"; CalciteAssert.that() .with(arrow) @@ -82,7 +82,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir) String sql = "select \"smallIntField\" from arrowdatatype"; String plan = "PLAN=ArrowToEnumerableConverter\n" + " ArrowProject(smallIntField=[$1])\n" - + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n"; + + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n"; String result = "smallIntField=0\nsmallIntField=1\n"; CalciteAssert.that() .with(arrow) @@ -96,7 +96,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir) String sql = "select \"intField\" from arrowdatatype"; String plan = "PLAN=ArrowToEnumerableConverter\n" + " ArrowProject(intField=[$2])\n" - + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n"; + + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n"; String result = "intField=0\nintField=1\n"; CalciteAssert.that() .with(arrow) @@ -110,7 +110,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir) String sql = "select \"longField\" from arrowdatatype"; String plan = "PLAN=ArrowToEnumerableConverter\n" + " ArrowProject(longField=[$5])\n" - + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n"; + + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n"; String result = "longField=0\nlongField=1\n"; CalciteAssert.that() .with(arrow) @@ -124,7 +124,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir) String sql = "select \"floatField\" from arrowdatatype"; String plan = "PLAN=ArrowToEnumerableConverter\n" + " ArrowProject(floatField=[$4])\n" - + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n"; + + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n"; String result = "floatField=0.0\nfloatField=1.0\n"; CalciteAssert.that() .with(arrow) @@ -138,7 +138,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir) String sql = "select \"doubleField\" from arrowdatatype"; String plan = "PLAN=ArrowToEnumerableConverter\n" + " ArrowProject(doubleField=[$6])\n" - + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n"; + + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n"; String result = "doubleField=0.0\ndoubleField=1.0\n"; CalciteAssert.that() .with(arrow) @@ -152,7 +152,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir) String sql = "select \"decimalField\" from arrowdatatype"; String plan = "PLAN=ArrowToEnumerableConverter\n" + " ArrowProject(decimalField=[$8])\n" - + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n"; + + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n"; String result = "decimalField=0.00\ndecimalField=1.00\n"; CalciteAssert.that() .with(arrow) @@ -162,11 +162,26 @@ static void initializeArrowState(@TempDir Path sharedTempDir) .explainContains(plan); } + + @Test void testDecimalProject2() { + String sql = "select \"decimalField2\" from arrowdatatype"; + String plan = "PLAN=ArrowToEnumerableConverter\n" + + " ArrowProject(decimalField2=[$10])\n" + + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n"; + String result = "decimalField2=20.000\ndecimalField2=21.000\n"; + CalciteAssert.that() + .with(arrow) + .query(sql) + .limit(2) + .returns(result) + .explainContains(plan); + } + @Test void testDateProject() { String sql = "select \"dateField\" from arrowdatatype"; String plan = "PLAN=ArrowToEnumerableConverter\n" + " ArrowProject(dateField=[$9])\n" - + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n"; + + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n"; String result = "dateField=1970-01-01\n" + "dateField=1970-01-02\n"; CalciteAssert.that() @@ -181,7 +196,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir) String sql = "select \"booleanField\" from arrowdatatype"; String plan = "PLAN=ArrowToEnumerableConverter\n" + " ArrowProject(booleanField=[$7])\n" - + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n"; + + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n"; String result = "booleanField=null\nbooleanField=true\nbooleanField=false\n"; CalciteAssert.that() .with(arrow) @@ -190,5 +205,4 @@ static void initializeArrowState(@TempDir Path sharedTempDir) .returns(result) .explainContains(plan); } - } diff --git a/arrow/src/test/java/org/apache/calcite/adapter/arrow/ArrowAdapterTest.java b/arrow/src/test/java/org/apache/calcite/adapter/arrow/ArrowAdapterTest.java index 9e60bca2a4d..f139857c290 100644 --- a/arrow/src/test/java/org/apache/calcite/adapter/arrow/ArrowAdapterTest.java +++ b/arrow/src/test/java/org/apache/calcite/adapter/arrow/ArrowAdapterTest.java @@ -732,8 +732,8 @@ static void initializeArrowState(@TempDir Path sharedTempDir) @Test void testFilteredAgg() { String sql = "select SUM(SAL) FILTER (WHERE COMM > 400) as SALESSUM from EMP"; String plan = "PLAN=EnumerableAggregate(group=[{}], SALESSUM=[SUM($0) FILTER $1])\n" - + " EnumerableCalc(expr#0..7=[{inputs}], expr#8=[400:DECIMAL(19, 0)], expr#9=[>($t6, $t8)], " - + "expr#10=[IS TRUE($t9)], SAL=[$t5], $f1=[$t10])\n" + + " EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t6):DECIMAL(10, 2)], expr#9=[400.00:DECIMAL(10, 2)], " + + "expr#10=[>($t8, $t9)], expr#11=[IS TRUE($t10)], SAL=[$t5], $f1=[$t11])\n" + " ArrowToEnumerableConverter\n" + " ArrowTableScan(table=[[ARROW, EMP]], fields=[[0, 1, 2, 3, 4, 5, 6, 7]])\n\n"; String result = "SALESSUM=2500.00\n"; @@ -750,8 +750,8 @@ static void initializeArrowState(@TempDir Path sharedTempDir) String sql = "select SUM(SAL) FILTER (WHERE COMM > 400) as SALESSUM from EMP group by EMPNO"; String plan = "PLAN=EnumerableCalc(expr#0..1=[{inputs}], SALESSUM=[$t1])\n" + " EnumerableAggregate(group=[{0}], SALESSUM=[SUM($1) FILTER $2])\n" - + " EnumerableCalc(expr#0..7=[{inputs}], expr#8=[400:DECIMAL(19, 0)], expr#9=[>($t6, $t8)], " - + "expr#10=[IS TRUE($t9)], EMPNO=[$t0], SAL=[$t5], $f2=[$t10])\n" + + " EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t6):DECIMAL(10, 2)], expr#9=[400.00:DECIMAL(10, 2)], " + + "expr#10=[>($t8, $t9)], expr#11=[IS TRUE($t10)], EMPNO=[$t0], SAL=[$t5], $f2=[$t11])\n" + " ArrowToEnumerableConverter\n" + " ArrowTableScan(table=[[ARROW, EMP]], fields=[[0, 1, 2, 3, 4, 5, 6, 7]])\n\n"; String result = "SALESSUM=1250.00\nSALESSUM=null\n"; @@ -860,7 +860,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir) String plan = "PLAN=ArrowToEnumerableConverter\n" + " ArrowProject(booleanField=[$7])\n" + " ArrowFilter(condition=[$7])\n" - + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n"; + + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n"; String result = "booleanField=true\nbooleanField=true\n"; CalciteAssert.that() @@ -878,7 +878,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir) String plan = "PLAN=ArrowToEnumerableConverter\n" + " ArrowProject(intField=[$2])\n" + " ArrowFilter(condition=[>($2, 10)])\n" - + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n"; + + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n"; String result = "intField=11\nintField=12\n"; CalciteAssert.that() @@ -896,7 +896,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir) String plan = "PLAN=ArrowToEnumerableConverter\n" + " ArrowProject(booleanField=[$7])\n" + " ArrowFilter(condition=[NOT($7)])\n" - + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n"; + + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n"; String result = "booleanField=false\nbooleanField=false\n"; CalciteAssert.that() @@ -915,7 +915,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir) String plan = "PLAN=ArrowToEnumerableConverter\n" + " ArrowProject(booleanField=[$7])\n" + " ArrowFilter(condition=[IS NOT TRUE($7)])\n" - + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n"; + + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n"; String result = "booleanField=null\nbooleanField=false\n"; CalciteAssert.that() @@ -933,7 +933,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir) String plan = "PLAN=ArrowToEnumerableConverter\n" + " ArrowProject(booleanField=[$7])\n" + " ArrowFilter(condition=[IS NOT FALSE($7)])\n" - + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n"; + + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n"; String result = "booleanField=null\nbooleanField=true\n"; CalciteAssert.that() @@ -951,7 +951,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir) String plan = "PLAN=ArrowToEnumerableConverter\n" + " ArrowProject(booleanField=[$7])\n" + " ArrowFilter(condition=[IS NULL($7)])\n" - + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n"; + + " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n"; String result = "booleanField=null\n"; CalciteAssert.that() diff --git a/arrow/src/test/java/org/apache/calcite/adapter/arrow/ArrowData.java b/arrow/src/test/java/org/apache/calcite/adapter/arrow/ArrowData.java index 482a15e902f..0d70aa64003 100644 --- a/arrow/src/test/java/org/apache/calcite/adapter/arrow/ArrowData.java +++ b/arrow/src/test/java/org/apache/calcite/adapter/arrow/ArrowData.java @@ -103,7 +103,8 @@ private Schema makeArrowDateTypeSchema() { FieldType doubleType = FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)); FieldType booleanType = FieldType.nullable(new ArrowType.Bool()); - FieldType decimalType = FieldType.nullable(new ArrowType.Decimal(10, 2, 128)); + FieldType decimalType = FieldType.nullable(new ArrowType.Decimal(12, 2, 128)); + FieldType decimalType2 = FieldType.nullable(new ArrowType.Decimal(12, 3, 128)); FieldType dateType = FieldType.nullable(new ArrowType.Date(DateUnit.DAY)); childrenBuilder.add(new Field("tinyIntField", tinyIntType, null)); @@ -116,6 +117,7 @@ private Schema makeArrowDateTypeSchema() { childrenBuilder.add(new Field("booleanField", booleanType, null)); childrenBuilder.add(new Field("decimalField", decimalType, null)); childrenBuilder.add(new Field("dateField", dateType, null)); + childrenBuilder.add(new Field("decimalField2", decimalType2, null)); return new Schema(childrenBuilder.build(), null); } @@ -237,35 +239,38 @@ public void writeArrowDataType(File file) throws IOException { vectorSchemaRoot.setRowCount(numRows); for (Field field : vectorSchemaRoot.getSchema().getFields()) { FieldVector vector = vectorSchemaRoot.getVector(field.getName()); - switch (vector.getMinorType()) { - case TINYINT: + switch (field.getName()) { + case "tinyIntField": tinyIntField(vector, numRows); break; - case SMALLINT: + case "smallIntField": smallIntFiled(vector, numRows); break; - case INT: + case "intField": intField(vector, numRows); break; - case FLOAT4: + case "floatField": floatField(vector, numRows); break; - case VARCHAR: + case "stringField": varCharField(vector, numRows); break; - case BIGINT: + case "longField": longField(vector, numRows); break; - case FLOAT8: + case "doubleField": doubleField(vector, numRows); break; - case BIT: + case "booleanField": booleanField(vector, numRows); break; - case DECIMAL: + case "decimalField": decimalField(vector, numRows); break; - case DATEDAY: + case "decimalField2": + decimalField2(vector, numRows); + break; + case "dateField": dateField(vector, numRows); break; default: @@ -386,6 +391,17 @@ private void decimalField(FieldVector fieldVector, int rowCount) { fieldVector.setValueCount(rowCount); } + private void decimalField2(FieldVector fieldVector, int rowCount) { + DecimalVector decimalVector = (DecimalVector) fieldVector; + decimalVector.setInitialCapacity(rowCount); + decimalVector.allocateNew(); + for (int i = 0; i < rowCount; i++) { + decimalVector.set(i, this.decimalValue.setScale(3)); + this.decimalValue = this.decimalValue.add(BigDecimal.ONE); + } + fieldVector.setValueCount(rowCount); + } + private void dateField(FieldVector fieldVector, int rowCount) { DateDayVector dateDayVector = (DateDayVector) fieldVector; dateDayVector.setInitialCapacity(rowCount);