Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CALCITE-6690] Refactor the Arrow adapter type system #4052

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,17 @@

import java.math.BigDecimal;
import java.sql.Date;
import java.util.List;

import static java.util.Objects.requireNonNull;

/**
* Arrow field type.
*/
enum ArrowFieldType {
INT(Primitive.INT),
BOOLEAN(Primitive.BOOLEAN),
STRING(String.class),
FLOAT(Primitive.FLOAT),
DOUBLE(Primitive.DOUBLE),
DATE(Date.class),
LIST(List.class),
DECIMAL(BigDecimal.class),
LONG(Primitive.LONG),
BYTE(Primitive.BYTE),
SHORT(Primitive.SHORT);
class ArrowFieldType {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think enums are preferable, since they are typechecked by the compiler.
A Class can be anything.

Copy link
Member Author

@caicancai caicancai Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Later, I may dynamically pass some parameters according to the data type of arrow, which is difficult to achieve through enumeration.
e.g.

new ArrowFieldType(BigDecimal.class,
          ((ArrowType.Decimal) arrowType).getPrecision(),
          ((ArrowType.Decimal) arrowType).getScale())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why later if you are refactoring the code now?
Do it correctly the first time.
I am suggesting to have 3 arguments for the enum constructor: Class, precision, scale.

Copy link
Member Author

@caicancai caicancai Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mihaibudiu Sorry for the late reply, you are right. I looked at the arrow types and it seems that Class, precision, and scale are sufficient.

I will finish this work on the weekend

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mihaibudiu Hi. I have a question, how to dynamically pass parameters in enumeration class? Precision and scale are not constants for decimal, timestamp and other types.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I understand the question.
There are lots of enum classes with parameters in the codebase.
See https://github.com/apache/calcite/blob/main/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Primitive.java for example


private final Class<?> clazz;
private int precision;
private int scale;

ArrowFieldType(Primitive primitive) {
this(requireNonNull(primitive.boxClass, "boxClass"));
Expand All @@ -55,9 +45,20 @@ enum ArrowFieldType {
this.clazz = clazz;
}

ArrowFieldType(Class<?> clazz, int precision, int scale) {
this.clazz = clazz;
this.precision = precision;
this.scale = scale;
}

public RelDataType toType(JavaTypeFactory typeFactory) {
RelDataType javaType = typeFactory.createJavaType(clazz);
RelDataType sqlType = typeFactory.createSqlType(javaType.getSqlTypeName());
RelDataType sqlType = null;
if (javaType.getSqlTypeName().getName().equals("DECIMAL")) {
sqlType = typeFactory.createSqlType(javaType.getSqlTypeName(), precision, scale);
} else {
sqlType = typeFactory.createSqlType(javaType.getSqlTypeName());
}
return typeFactory.createTypeWithNullability(sqlType, true);
}

Expand All @@ -67,34 +68,36 @@ public static ArrowFieldType of(ArrowType arrowType) {
int bitWidth = ((ArrowType.Int) arrowType).getBitWidth();
switch (bitWidth) {
case 64:
return LONG;
return new ArrowFieldType(Primitive.LONG);
case 32:
return INT;
return new ArrowFieldType(Primitive.INT);
case 16:
return SHORT;
return new ArrowFieldType(Primitive.SHORT);
case 8:
return BYTE;
return new ArrowFieldType(Primitive.BYTE);
default:
throw new IllegalArgumentException("Unsupported Int bit width: " + bitWidth);
}
case Bool:
return BOOLEAN;
return new ArrowFieldType(Primitive.BOOLEAN);
case Utf8:
return STRING;
return new ArrowFieldType(String.class);
case FloatingPoint:
FloatingPointPrecision precision = ((ArrowType.FloatingPoint) arrowType).getPrecision();
switch (precision) {
case SINGLE:
return FLOAT;
return new ArrowFieldType(Primitive.FLOAT);
case DOUBLE:
return DOUBLE;
return new ArrowFieldType(Primitive.DOUBLE);
default:
throw new IllegalArgumentException("Unsupported Floating point precision: " + precision);
}
case Date:
return DATE;
return new ArrowFieldType(Date.class);
case Decimal:
return DECIMAL;
return new ArrowFieldType(BigDecimal.class,
((ArrowType.Decimal) arrowType).getPrecision(),
((ArrowType.Decimal) arrowType).getScale());
default:
throw new IllegalArgumentException("Unsupported type: " + arrowType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String sql = "select \"tinyIntField\" from arrowdatatype";
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(tinyIntField=[$0])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "tinyIntField=0\ntinyIntField=1\n";
CalciteAssert.that()
.with(arrow)
Expand All @@ -82,7 +82,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String sql = "select \"smallIntField\" from arrowdatatype";
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(smallIntField=[$1])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "smallIntField=0\nsmallIntField=1\n";
CalciteAssert.that()
.with(arrow)
Expand All @@ -96,7 +96,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String sql = "select \"intField\" from arrowdatatype";
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(intField=[$2])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "intField=0\nintField=1\n";
CalciteAssert.that()
.with(arrow)
Expand All @@ -110,7 +110,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String sql = "select \"longField\" from arrowdatatype";
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(longField=[$5])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "longField=0\nlongField=1\n";
CalciteAssert.that()
.with(arrow)
Expand All @@ -124,7 +124,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String sql = "select \"floatField\" from arrowdatatype";
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(floatField=[$4])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "floatField=0.0\nfloatField=1.0\n";
CalciteAssert.that()
.with(arrow)
Expand All @@ -138,7 +138,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String sql = "select \"doubleField\" from arrowdatatype";
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(doubleField=[$6])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "doubleField=0.0\ndoubleField=1.0\n";
CalciteAssert.that()
.with(arrow)
Expand All @@ -152,7 +152,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String sql = "select \"decimalField\" from arrowdatatype";
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(decimalField=[$8])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "decimalField=0.00\ndecimalField=1.00\n";
CalciteAssert.that()
.with(arrow)
Expand All @@ -162,11 +162,26 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
.explainContains(plan);
}


@Test void testDecimalProject2() {
String sql = "select \"decimalField2\" from arrowdatatype";
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(decimalField2=[$10])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "decimalField2=20.000\ndecimalField2=21.000\n";
CalciteAssert.that()
.with(arrow)
.query(sql)
.limit(2)
.returns(result)
.explainContains(plan);
}

@Test void testDateProject() {
String sql = "select \"dateField\" from arrowdatatype";
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(dateField=[$9])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "dateField=1970-01-01\n"
+ "dateField=1970-01-02\n";
CalciteAssert.that()
Expand All @@ -181,7 +196,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String sql = "select \"booleanField\" from arrowdatatype";
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(booleanField=[$7])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "booleanField=null\nbooleanField=true\nbooleanField=false\n";
CalciteAssert.that()
.with(arrow)
Expand All @@ -190,5 +205,4 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
.returns(result)
.explainContains(plan);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -732,8 +732,8 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
@Test void testFilteredAgg() {
String sql = "select SUM(SAL) FILTER (WHERE COMM > 400) as SALESSUM from EMP";
String plan = "PLAN=EnumerableAggregate(group=[{}], SALESSUM=[SUM($0) FILTER $1])\n"
+ " EnumerableCalc(expr#0..7=[{inputs}], expr#8=[400:DECIMAL(19, 0)], expr#9=[>($t6, $t8)], "
+ "expr#10=[IS TRUE($t9)], SAL=[$t5], $f1=[$t10])\n"
+ " EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t6):DECIMAL(10, 2)], expr#9=[400.00:DECIMAL(10, 2)], "
+ "expr#10=[>($t8, $t9)], expr#11=[IS TRUE($t10)], SAL=[$t5], $f1=[$t11])\n"
+ " ArrowToEnumerableConverter\n"
+ " ArrowTableScan(table=[[ARROW, EMP]], fields=[[0, 1, 2, 3, 4, 5, 6, 7]])\n\n";
String result = "SALESSUM=2500.00\n";
Expand All @@ -750,8 +750,8 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String sql = "select SUM(SAL) FILTER (WHERE COMM > 400) as SALESSUM from EMP group by EMPNO";
String plan = "PLAN=EnumerableCalc(expr#0..1=[{inputs}], SALESSUM=[$t1])\n"
+ " EnumerableAggregate(group=[{0}], SALESSUM=[SUM($1) FILTER $2])\n"
+ " EnumerableCalc(expr#0..7=[{inputs}], expr#8=[400:DECIMAL(19, 0)], expr#9=[>($t6, $t8)], "
+ "expr#10=[IS TRUE($t9)], EMPNO=[$t0], SAL=[$t5], $f2=[$t10])\n"
+ " EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t6):DECIMAL(10, 2)], expr#9=[400.00:DECIMAL(10, 2)], "
+ "expr#10=[>($t8, $t9)], expr#11=[IS TRUE($t10)], EMPNO=[$t0], SAL=[$t5], $f2=[$t11])\n"
+ " ArrowToEnumerableConverter\n"
+ " ArrowTableScan(table=[[ARROW, EMP]], fields=[[0, 1, 2, 3, 4, 5, 6, 7]])\n\n";
String result = "SALESSUM=1250.00\nSALESSUM=null\n";
Expand Down Expand Up @@ -860,7 +860,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(booleanField=[$7])\n"
+ " ArrowFilter(condition=[$7])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "booleanField=true\nbooleanField=true\n";

CalciteAssert.that()
Expand All @@ -878,7 +878,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(intField=[$2])\n"
+ " ArrowFilter(condition=[>($2, 10)])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "intField=11\nintField=12\n";

CalciteAssert.that()
Expand All @@ -896,7 +896,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(booleanField=[$7])\n"
+ " ArrowFilter(condition=[NOT($7)])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "booleanField=false\nbooleanField=false\n";

CalciteAssert.that()
Expand All @@ -915,7 +915,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(booleanField=[$7])\n"
+ " ArrowFilter(condition=[IS NOT TRUE($7)])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "booleanField=null\nbooleanField=false\n";

CalciteAssert.that()
Expand All @@ -933,7 +933,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(booleanField=[$7])\n"
+ " ArrowFilter(condition=[IS NOT FALSE($7)])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n\n";
String result = "booleanField=null\nbooleanField=true\n";

CalciteAssert.that()
Expand Down
40 changes: 28 additions & 12 deletions arrow/src/test/java/org/apache/calcite/adapter/arrow/ArrowData.java
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ private Schema makeArrowDateTypeSchema() {
FieldType doubleType =
FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE));
FieldType booleanType = FieldType.nullable(new ArrowType.Bool());
FieldType decimalType = FieldType.nullable(new ArrowType.Decimal(10, 2, 128));
FieldType decimalType = FieldType.nullable(new ArrowType.Decimal(12, 2, 128));
FieldType decimalType2 = FieldType.nullable(new ArrowType.Decimal(12, 3, 128));
FieldType dateType = FieldType.nullable(new ArrowType.Date(DateUnit.DAY));

childrenBuilder.add(new Field("tinyIntField", tinyIntType, null));
Expand All @@ -116,6 +117,7 @@ private Schema makeArrowDateTypeSchema() {
childrenBuilder.add(new Field("booleanField", booleanType, null));
childrenBuilder.add(new Field("decimalField", decimalType, null));
childrenBuilder.add(new Field("dateField", dateType, null));
childrenBuilder.add(new Field("decimalField2", decimalType2, null));

return new Schema(childrenBuilder.build(), null);
}
Expand Down Expand Up @@ -237,35 +239,38 @@ public void writeArrowDataType(File file) throws IOException {
vectorSchemaRoot.setRowCount(numRows);
for (Field field : vectorSchemaRoot.getSchema().getFields()) {
FieldVector vector = vectorSchemaRoot.getVector(field.getName());
switch (vector.getMinorType()) {
case TINYINT:
switch (field.getName()) {
case "tinyIntField":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a necessary change? What is the difference between decimalField2 and decimalField?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Required, the scale of decimalfield and decimalfiled2 are different, one is 2, the other is 3, the main purpose is to capture the accuracy of whether it is effective

tinyIntField(vector, numRows);
break;
case SMALLINT:
case "smallIntField":
smallIntFiled(vector, numRows);
break;
case INT:
case "intField":
intField(vector, numRows);
break;
case FLOAT4:
case "floatField":
floatField(vector, numRows);
break;
case VARCHAR:
case "stringField":
varCharField(vector, numRows);
break;
case BIGINT:
case "longField":
longField(vector, numRows);
break;
case FLOAT8:
case "doubleField":
doubleField(vector, numRows);
break;
case BIT:
case "booleanField":
booleanField(vector, numRows);
break;
case DECIMAL:
case "decimalField":
decimalField(vector, numRows);
break;
case DATEDAY:
case "decimalField2":
decimalField2(vector, numRows);
break;
case "dateField":
dateField(vector, numRows);
break;
default:
Expand Down Expand Up @@ -386,6 +391,17 @@ private void decimalField(FieldVector fieldVector, int rowCount) {
fieldVector.setValueCount(rowCount);
}

private void decimalField2(FieldVector fieldVector, int rowCount) {
DecimalVector decimalVector = (DecimalVector) fieldVector;
decimalVector.setInitialCapacity(rowCount);
decimalVector.allocateNew();
for (int i = 0; i < rowCount; i++) {
decimalVector.set(i, this.decimalValue.setScale(3));
this.decimalValue = this.decimalValue.add(BigDecimal.ONE);
}
fieldVector.setValueCount(rowCount);
}

private void dateField(FieldVector fieldVector, int rowCount) {
DateDayVector dateDayVector = (DateDayVector) fieldVector;
dateDayVector.setInitialCapacity(rowCount);
Expand Down
Loading