From 9cb54101d62c9f09501c147829fb516dcbe0443a Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Wed, 10 Nov 2021 16:35:58 +0800 Subject: [PATCH] wip --- .../vectorized/ArrowWritableColumnVector.java | 75 +++++++++++++++---- .../arrow/ArrowDataSourceTest.scala | 44 ++++++++--- 2 files changed, 95 insertions(+), 24 deletions(-) diff --git a/arrow-data-source/common/src/main/java/com/intel/oap/vectorized/ArrowWritableColumnVector.java b/arrow-data-source/common/src/main/java/com/intel/oap/vectorized/ArrowWritableColumnVector.java index 4ec2ec6c2..2936f2f19 100644 --- a/arrow-data-source/common/src/main/java/com/intel/oap/vectorized/ArrowWritableColumnVector.java +++ b/arrow-data-source/common/src/main/java/com/intel/oap/vectorized/ArrowWritableColumnVector.java @@ -58,7 +58,6 @@ public final class ArrowWritableColumnVector extends WritableColumnVector { private ArrowVectorAccessor accessor; private ArrowVectorWriter writer; - private int ordinal; private ValueVector vector; private ValueVector dictionaryVector; private static BufferAllocator OffRecordAllocator = SparkMemoryUtils.globalAllocator(); @@ -89,7 +88,7 @@ public static ArrowWritableColumnVector[] allocateColumns( ArrowWritableColumnVector[] vectors = new ArrowWritableColumnVector[fieldVectors.size()]; for (int i = 0; i < fieldVectors.size(); i++) { - vectors[i] = new ArrowWritableColumnVector(fieldVectors.get(i), i, capacity, true); + vectors[i] = new ArrowWritableColumnVector(fieldVectors.get(i), capacity, true); } // LOG.info("allocateColumns allocator is " + allocator); return vectors; @@ -107,7 +106,7 @@ public static ArrowWritableColumnVector[] loadColumns( new ArrowWritableColumnVector[fieldVectors.size()]; for (int i = 0; i < fieldVectors.size(); i++) { vectors[i] = new ArrowWritableColumnVector( - fieldVectors.get(i), dictionaryVectors.get(i), i, capacity, false); + fieldVectors.get(i), dictionaryVectors.get(i), capacity, false); } return vectors; } @@ -117,7 +116,7 @@ public static ArrowWritableColumnVector[] loadColumns( ArrowWritableColumnVector[] vectors = new ArrowWritableColumnVector[fieldVectors.size()]; for (int i = 0; i < fieldVectors.size(); i++) { - vectors[i] = new ArrowWritableColumnVector(fieldVectors.get(i), i, capacity, false); + vectors[i] = new ArrowWritableColumnVector(fieldVectors.get(i), capacity, false); } return vectors; } @@ -140,17 +139,16 @@ public static ArrowWritableColumnVector[] loadColumns(int capacity, Schema arrow @Deprecated public ArrowWritableColumnVector( - ValueVector vector, int ordinal, int capacity, boolean init) { - this(vector, null, ordinal, capacity, init); + ValueVector vector, int capacity, boolean init) { + this(vector, null, capacity, init); } - public ArrowWritableColumnVector(ValueVector vector, ValueVector dicionaryVector, - int ordinal, int capacity, boolean init) { + public ArrowWritableColumnVector(ValueVector vector, ValueVector dicionaryVector, int capacity, + boolean init) { super(capacity, ArrowUtils.fromArrowField(vector.getField())); vectorCount.getAndIncrement(); refCnt.getAndIncrement(); - this.ordinal = ordinal; this.vector = vector; this.dictionaryVector = dicionaryVector; if (init) { @@ -231,21 +229,23 @@ private void createVectorAccessor(ValueVector vector, ValueVector dictionary) { } else if (vector instanceof TimeStampMicroVector || vector instanceof TimeStampMicroTZVector) { accessor = new TimestampMicroAccessor((TimeStampVector) vector); + } else if (vector instanceof MapVector) { + MapVector mapVector = (MapVector) vector; + accessor = new MapAccessor(mapVector); } else if (vector instanceof ListVector) { ListVector listVector = (ListVector) vector; accessor = new ArrayAccessor(listVector); childColumns = new ArrowWritableColumnVector[1]; childColumns[0] = new ArrowWritableColumnVector( - listVector.getDataVector(), 0, listVector.size(), false); + listVector.getDataVector(), listVector.size(), false); } else if (vector instanceof StructVector) { - throw new UnsupportedOperationException(); - /*StructVector structVector = (StructVector) vector; + StructVector structVector = (StructVector) vector; accessor = new StructAccessor(structVector); childColumns = new ArrowWritableColumnVector[structVector.size()]; for (int i = 0; i < childColumns.length; ++i) { - childColumns[i] = new ArrowWritableColumnVector(structVector.getVectorById(i)); - }*/ + childColumns[i] = new ArrowWritableColumnVector(structVector.getVectorById(i), null, structVector.size(), false); + } } else { throw new UnsupportedOperationException(); } @@ -277,6 +277,9 @@ private ArrowVectorWriter createVectorWriter(ValueVector vector) { } else if (vector instanceof TimeStampMicroVector || vector instanceof TimeStampMicroTZVector) { return new TimestampMicroWriter((TimeStampVector) vector); + } else if (vector instanceof MapVector) { + MapVector mapVector = (MapVector) vector; + return new MapWriter(mapVector); } else if (vector instanceof ListVector) { ListVector listVector = (ListVector) vector; ArrowVectorWriter elementVector = createVectorWriter(listVector.getDataVector()); @@ -893,6 +896,10 @@ int getArrayLength(int rowId) { int getArrayOffset(int rowId) { throw new UnsupportedOperationException(); } + + ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException(); + } } private static class BooleanAccessor extends ArrowVectorAccessor { @@ -1224,6 +1231,40 @@ private static class StructAccessor extends ArrowVectorAccessor { } } + private static class MapAccessor extends ArrowVectorAccessor { + private final MapVector accessor; + private final ArrowColumnVector keys; + private final ArrowColumnVector values; + + MapAccessor(MapVector vector) { + super(vector); + this.accessor = vector; + StructVector entries = (StructVector) vector.getDataVector(); + this.keys = new ArrowColumnVector(entries.getChild(MapVector.KEY_NAME)); + this.values = new ArrowColumnVector(entries.getChild(MapVector.VALUE_NAME)); + } + + @Override + final ColumnarMap getMap(int rowId) { + int index = rowId * MapVector.OFFSET_WIDTH; + int offset = accessor.getOffsetBuffer().getInt(index); + int length = accessor.getInnerValueCountAt(rowId); + return new ColumnarMap(keys, values, offset, length); + } + + @Override + int getArrayOffset(int rowId) { + int index = rowId * MapVector.OFFSET_WIDTH; + return accessor.getOffsetBuffer().getInt(index); + } + + @Override + int getArrayLength(int rowId) { + return accessor.getInnerValueCountAt(rowId); + } + } + + /* Arrow Vector Writer */ private abstract static class ArrowVectorWriter { private final ValueVector vector; @@ -1885,4 +1926,10 @@ private static class StructWriter extends ArrowVectorWriter { super(vector); } } + + private static class MapWriter extends ArrowVectorWriter { + MapWriter(ValueVector vector) { + super(vector); + } + } } diff --git a/arrow-data-source/standard/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala b/arrow-data-source/standard/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala index 536004fe4..ad15b3515 100644 --- a/arrow-data-source/standard/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala +++ b/arrow-data-source/standard/src/test/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowDataSourceTest.scala @@ -26,6 +26,7 @@ import com.intel.oap.spark.sql.DataFrameWriterImplicits._ import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowOptions import com.sun.management.UnixOperatingSystemMXBean import org.apache.commons.io.FileUtils + import org.apache.spark.SparkConf import org.apache.spark.sql.SaveMode import org.apache.spark.sql.{DataFrame, QueryTest, Row} @@ -34,14 +35,18 @@ import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.SPARK_SESSION_EXTENSIONS import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType} class ArrowDataSourceTest extends QueryTest with SharedSparkSession { + import testImplicits._ + private val parquetFile1 = "parquet-1.parquet" private val parquetFile2 = "parquet-2.parquet" private val parquetFile3 = "parquet-3.parquet" private val parquetFile4 = "parquet-4.parquet" private val parquetFile5 = "parquet-5.parquet" + private val parquetFile6 = "parquet-6.parquet" override protected def sparkConf: SparkConf = { val conf = super.sparkConf @@ -95,6 +100,14 @@ class ArrowDataSourceTest extends QueryTest with SharedSparkSession { .mode("overwrite") .parquet(ArrowDataSourceTest.locateResourcePath(parquetFile5)) + spark.range(100) + .map(i => Tuple1((i, Seq(s"val1_$i", s"val2_$i"), Map((s"ka_$i", s"va_$i"), + (s"kb_$i", s"vb_$i"))))) + .write + .format("parquet") + .mode("overwrite") + .parquet(ArrowDataSourceTest.locateResourcePath(parquetFile6)) + } override def afterAll(): Unit = { @@ -296,22 +309,33 @@ class ArrowDataSourceTest extends QueryTest with SharedSparkSession { assert(fdGrowth < 100) } + test("parquet reader on data type: struct, array, map") { + val path = ArrowDataSourceTest.locateResourcePath(parquetFile6) + val frame = spark.read + .option(ArrowOptions.KEY_ORIGINAL_FORMAT, "parquet") + .arrow(path) + frame.createOrReplaceTempView("ptab3") + val df = spark.sql("select * from ptab3") + df.explain() + df.show() + } + private val orcFile = "people.orc" test("read orc file") { val path = ArrowDataSourceTest.locateResourcePath(orcFile) verifyFrame( spark.read - .format("arrow") - .option(ArrowOptions.KEY_ORIGINAL_FORMAT, "orc") - .load(path), 2, 3) + .format("arrow") + .option(ArrowOptions.KEY_ORIGINAL_FORMAT, "orc") + .load(path), 2, 3) } test("read orc file - programmatic API ") { val path = ArrowDataSourceTest.locateResourcePath(orcFile) verifyFrame( spark.read - .option(ArrowOptions.KEY_ORIGINAL_FORMAT, "orc") - .arrow(path), 2, 3) + .option(ArrowOptions.KEY_ORIGINAL_FORMAT, "orc") + .arrow(path), 2, 3) } test("create catalog table for orc") { @@ -326,14 +350,14 @@ class ArrowDataSourceTest extends QueryTest with SharedSparkSession { test("simple SQL query on orc file ") { val path = ArrowDataSourceTest.locateResourcePath(orcFile) val frame = spark.read - .option(ArrowOptions.KEY_ORIGINAL_FORMAT, "orc") - .arrow(path) + .option(ArrowOptions.KEY_ORIGINAL_FORMAT, "orc") + .arrow(path) frame.createOrReplaceTempView("people") val sqlFrame = spark.sql("select * from people") assert( sqlFrame.schema === - StructType(Seq(StructField("name", StringType), - StructField("age", IntegerType), StructField("job", StringType)))) + StructType(Seq(StructField("name", StringType), + StructField("age", IntegerType), StructField("job", StringType)))) val rows = sqlFrame.collect() assert(rows(0).get(0) == "Jorge") assert(rows(0).get(1) == 30)