-
Notifications
You must be signed in to change notification settings - Fork 434
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support Spark ColumnarBatch to ArrowArray
- Loading branch information
Showing
4 changed files
with
359 additions
and
0 deletions.
There are no files selected for viewing
175 changes: 175 additions & 0 deletions
175
...-velox/src/main/scala/io/glutenproject/execution/VanillaColumnarToVeloxColumnarExec.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package io.glutenproject.execution | ||
|
||
import io.glutenproject.backendsapi.velox.ValidatorApiImpl | ||
import io.glutenproject.columnarbatch.ColumnarBatches | ||
import io.glutenproject.exec.Runtimes | ||
import io.glutenproject.extension.GlutenPlan | ||
import io.glutenproject.memory.arrowalloc.ArrowBufferAllocators | ||
import io.glutenproject.memory.nmm.NativeMemoryManagers | ||
import io.glutenproject.utils.{ArrowAbiUtil, Iterators} | ||
import io.glutenproject.vectorized.VanillaColumnarToNativeColumnarJniWrapper | ||
import org.apache.arrow.c.{ArrowArray, ArrowSchema, Data} | ||
import org.apache.spark.TaskContext | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.expressions.Attribute | ||
import org.apache.spark.sql.execution.arrow.ArrowColumnarBatchConverter | ||
import org.apache.spark.sql.execution.metric.SQLMetric | ||
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} | ||
import org.apache.spark.sql.internal.SQLConf | ||
import org.apache.spark.sql.types.StructType | ||
import org.apache.spark.sql.utils.SparkArrowUtil | ||
import org.apache.spark.sql.vectorized.ColumnarBatch | ||
import org.apache.spark.util.TaskResources | ||
|
||
case class VanillaColumnarToVeloxColumnarExec(child: SparkPlan) extends GlutenPlan with UnaryExecNode { | ||
|
||
override def supportsColumnar: Boolean = true | ||
|
||
override protected def doExecute(): RDD[InternalRow] = { | ||
throw new UnsupportedOperationException(s"This operator doesn't support doExecute().") | ||
} | ||
|
||
override protected def doExecuteColumnar(): RDD[ColumnarBatch] = { | ||
new ValidatorApiImpl().doSchemaValidate(schema).foreach { | ||
reason => | ||
throw new UnsupportedOperationException( | ||
s"Input schema contains unsupported type when convert columnar to columnar for $schema " + | ||
s"due to $reason") | ||
} | ||
|
||
val numInputBatches = longMetric("numInputBatches") | ||
val numOutputBatches = longMetric("numOutputBatches") | ||
val convertTime = longMetric("convertTime") | ||
// Instead of creating a new config we are reusing columnBatchSize. In the future if we do | ||
// combine with some of the Arrow conversion tools we will need to unify some of the configs. | ||
val numRows = conf.columnBatchSize | ||
// This avoids calling `schema` in the RDD closure, so that we don't need to include the entire | ||
// plan (this) in the closure. | ||
val localSchema = schema | ||
child.execute().mapPartitions { | ||
rowIterator => | ||
VanillaColumnarToVeloxColumnarExec.toColumnarBatchIterator( | ||
rowIterator.asInstanceOf[Iterator[ColumnarBatch]], | ||
localSchema, | ||
numInputBatches, | ||
numOutputBatches, | ||
convertTime, | ||
TaskContext.get()) | ||
} | ||
} | ||
|
||
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = { | ||
copy(child = newChild) | ||
} | ||
|
||
override def output: Seq[Attribute] = child.output | ||
} | ||
|
||
object VanillaColumnarToVeloxColumnarExec { | ||
|
||
def toColumnarBatchIterator(it: Iterator[ColumnarBatch], | ||
schema: StructType, | ||
numInputBatches: SQLMetric, | ||
numOutputBatches: SQLMetric, | ||
convertTime: SQLMetric, | ||
taskContext: TaskContext): Iterator[ColumnarBatch] = { | ||
if (it.isEmpty) { | ||
return Iterator.empty | ||
} | ||
|
||
val arrowSchema = | ||
SparkArrowUtil.toArrowSchema(schema, SQLConf.get.sessionLocalTimeZone) | ||
val jniWrapper = VanillaColumnarToNativeColumnarJniWrapper.create() | ||
val allocator = ArrowBufferAllocators.contextInstance() | ||
val cSchema = ArrowSchema.allocateNew(allocator) | ||
val c2cHandle = | ||
try { | ||
ArrowAbiUtil.exportSchema(allocator, arrowSchema, cSchema) | ||
jniWrapper.init( | ||
cSchema.memoryAddress(), | ||
NativeMemoryManagers | ||
.contextInstance("ColumnarToColumnar") | ||
.getNativeInstanceHandle) | ||
} finally { | ||
cSchema.close() | ||
} | ||
|
||
val converter = ArrowColumnarBatchConverter.create(arrowSchema, allocator) | ||
|
||
val res: Iterator[ColumnarBatch] = new Iterator[ColumnarBatch] { | ||
|
||
override def hasNext: Boolean = { | ||
it.hasNext | ||
} | ||
|
||
def nativeConvert(cb: ColumnarBatch): ColumnarBatch = { | ||
var arrowArray: ArrowArray = null | ||
TaskResources.addRecycler("ColumnarToColumnar_arrowArray", 100) { | ||
// Remind, remove isOpen here | ||
if (arrowArray != null) { | ||
arrowArray.close() | ||
} | ||
} | ||
|
||
numInputBatches += 1 | ||
try { | ||
arrowArray = ArrowArray.allocateNew(allocator) | ||
converter.write(cb) | ||
converter.finish() | ||
Data.exportVectorSchemaRoot(allocator, converter.root, null, arrowArray) | ||
val handle = jniWrapper | ||
.nativeConvertVanillaColumnarToColumnar(c2cHandle, arrowArray.memoryAddress()) | ||
ColumnarBatches.create(Runtimes.contextInstance(), handle) | ||
} finally { | ||
converter.reset() | ||
arrowArray.close() | ||
arrowArray = null | ||
} | ||
} | ||
|
||
override def next(): ColumnarBatch = { | ||
val currentBatch = it.next() | ||
val start = System.currentTimeMillis() | ||
val cb = nativeConvert(currentBatch) | ||
numOutputBatches += 1 | ||
convertTime += System.currentTimeMillis() - start | ||
cb | ||
} | ||
} | ||
|
||
if (taskContext != null) { | ||
taskContext.addTaskCompletionListener[Unit] { _ => | ||
jniWrapper.close(c2cHandle) | ||
allocator.close() | ||
converter.close() | ||
} | ||
} | ||
|
||
Iterators | ||
.wrap(res) | ||
.recycleIterator { | ||
jniWrapper.close(c2cHandle) | ||
allocator.close() | ||
converter.close() | ||
} | ||
.recyclePayload(_.close()) | ||
.create() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
29 changes: 29 additions & 0 deletions
29
.../src/main/java/io/glutenproject/vectorized/VanillaColumnarToNativeColumnarJniWrapper.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
package io.glutenproject.vectorized; | ||
|
||
import io.glutenproject.exec.Runtime; | ||
import io.glutenproject.exec.RuntimeAware; | ||
import io.glutenproject.exec.Runtimes; | ||
|
||
public class VanillaColumnarToNativeColumnarJniWrapper implements RuntimeAware { | ||
private final Runtime runtime; | ||
|
||
private VanillaColumnarToNativeColumnarJniWrapper(Runtime runtime) { | ||
this.runtime = runtime; | ||
} | ||
|
||
public static VanillaColumnarToNativeColumnarJniWrapper create() { | ||
return new VanillaColumnarToNativeColumnarJniWrapper(Runtimes.contextInstance()); | ||
} | ||
|
||
@Override | ||
public long handle() { | ||
return runtime.getHandle(); | ||
} | ||
|
||
public native long init(long cSchema, long memoryManagerHandle); | ||
|
||
public native long nativeConvertVanillaColumnarToColumnar( | ||
long c2cHandle, long bufferAddress); | ||
|
||
public native void close(long c2cHandle); | ||
} |
153 changes: 153 additions & 0 deletions
153
...ata/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowColumnarBatchConverter.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.spark.sql.execution.arrow | ||
|
||
import org.apache.arrow.memory.BufferAllocator | ||
|
||
import scala.collection.JavaConverters._ | ||
import org.apache.arrow.vector._ | ||
import org.apache.arrow.vector.complex._ | ||
import org.apache.arrow.vector.types.pojo.Schema | ||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.types._ | ||
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters | ||
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} | ||
import org.apache.spark.sql.errors.QueryExecutionErrors | ||
import org.apache.spark.sql.types.StructType | ||
import org.apache.spark.sql.util.ArrowUtils | ||
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} | ||
import org.apache.spark.sql.vectorized.ColumnarBatch | ||
import org.apache.spark.sql.vectorized.ColumnVector | ||
|
||
object ArrowColumnarBatchConverter { | ||
|
||
def create(arrowSchema: Schema, bufferAllocator: BufferAllocator): ArrowColumnarBatchConverter = { | ||
val root = VectorSchemaRoot.create(arrowSchema, bufferAllocator) | ||
create(root) | ||
} | ||
|
||
def create(root: VectorSchemaRoot): ArrowColumnarBatchConverter = { | ||
val children = root.getFieldVectors.asScala.map { vector => | ||
vector.allocateNew() | ||
createFieldWriter(vector) | ||
} | ||
new ArrowColumnarBatchConverter(root, children.toArray) | ||
} | ||
|
||
private def createFieldWriter(vector: ValueVector): ArrowFieldWriter = { | ||
(ArrowUtils.fromArrowField(vector.getField), vector) match { | ||
case (BooleanType, vector: BitVector) => new BooleanWriter(vector) | ||
case (ByteType, vector: TinyIntVector) => new ByteWriter(vector) | ||
case (ShortType, vector: SmallIntVector) => new ShortWriter(vector) | ||
case (IntegerType, vector: IntVector) => new IntegerWriter(vector) | ||
case (LongType, vector: BigIntVector) => new LongWriter(vector) | ||
case (FloatType, vector: Float4Vector) => new FloatWriter(vector) | ||
case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector) | ||
case (DecimalType.Fixed(precision, scale), vector: DecimalVector) => | ||
new DecimalWriter(vector, precision, scale) | ||
case (StringType, vector: VarCharVector) => new StringWriter(vector) | ||
case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector) | ||
case (DateType, vector: DateDayVector) => new DateWriter(vector) | ||
case (TimestampType, vector: TimeStampMicroTZVector) => new TimestampWriter(vector) | ||
case (ArrayType(_, _), vector: ListVector) => | ||
val elementVector = createFieldWriter(vector.getDataVector) | ||
new ArrayWriter(vector, elementVector) | ||
case (MapType(_, _, _), vector: MapVector) => | ||
val structVector = vector.getDataVector.asInstanceOf[StructVector] | ||
val keyWriter = createFieldWriter(structVector.getChild(MapVector.KEY_NAME)) | ||
val valueWriter = createFieldWriter(structVector.getChild(MapVector.VALUE_NAME)) | ||
new MapWriter(vector, structVector, keyWriter, valueWriter) | ||
case (StructType(_), vector: StructVector) => | ||
val children = (0 until vector.size()).map { ordinal => | ||
createFieldWriter(vector.getChildByOrdinal(ordinal)) | ||
} | ||
new StructWriter(vector, children.toArray) | ||
case (NullType, vector: NullVector) => new NullWriter(vector) | ||
case (_: YearMonthIntervalType, vector: IntervalYearVector) => new IntervalYearWriter(vector) | ||
case (_: DayTimeIntervalType, vector: IntervalDayVector) => new IntervalDayWriter(vector) | ||
case (dt, _) => | ||
throw QueryExecutionErrors.unsupportedDataTypeError(dt.catalogString) | ||
} | ||
} | ||
} | ||
|
||
case class ColumnarSpecializedGetters(columnVector: ColumnVector) extends SpecializedGetters { | ||
|
||
override def isNullAt(rowId: Int): Boolean = columnVector.isNullAt(rowId) | ||
|
||
override def getBoolean(rowId: Int): Boolean = columnVector.getBoolean(rowId) | ||
|
||
override def getByte(rowId: Int): Byte = columnVector.getByte(rowId) | ||
|
||
override def getShort(rowId: Int): Short = columnVector.getShort(rowId) | ||
|
||
override def getInt(rowId: Int): Int = columnVector.getInt(rowId) | ||
|
||
override def getLong(rowId: Int): Long = columnVector.getLong(rowId) | ||
|
||
override def getFloat(rowId: Int): Float = columnVector.getFloat(rowId) | ||
|
||
override def getDouble(rowId: Int): Double = columnVector.getDouble(rowId) | ||
|
||
override def getDecimal(rowId: Int, precision: Int, scale: Int): Decimal = columnVector.getDecimal(rowId, precision, scale) | ||
|
||
override def getUTF8String(rowId: Int): UTF8String = columnVector.getUTF8String(rowId) | ||
|
||
override def getBinary(rowId: Int): Array[Byte] = columnVector.getBinary(rowId) | ||
|
||
override def getInterval(rowId: Int): CalendarInterval = columnVector.getInterval(rowId) | ||
|
||
override def getStruct(rowId: Int, numFields: Int): InternalRow = columnVector.getStruct(rowId) | ||
|
||
override def getArray(rowId: Int): ArrayData = columnVector.getArray(rowId) | ||
|
||
override def getMap(rowId: Int): MapData = columnVector.getMap(rowId) | ||
|
||
override def get(rowId: Int, dataType: DataType): AnyRef = { | ||
throw new UnsupportedOperationException("Not implemented yet") | ||
} | ||
} | ||
|
||
class ArrowColumnarBatchConverter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) { | ||
|
||
private var count: Int = 0 | ||
|
||
def write(columnarBatch: ColumnarBatch): Unit = { | ||
fields.zipWithIndex.foreach { case (field, ordinal) => | ||
val columnVector = ColumnarSpecializedGetters(columnarBatch.column(ordinal)) | ||
for (rowId <- 0 until columnarBatch.numRows()) { | ||
field.write(columnVector, rowId) | ||
} | ||
} | ||
count += columnarBatch.numRows() | ||
} | ||
|
||
def finish(): Unit = { | ||
root.setRowCount(count) | ||
fields.foreach(_.finish()) | ||
} | ||
|
||
def reset(): Unit = { | ||
root.setRowCount(0) | ||
count = 0 | ||
fields.foreach(_.reset()) | ||
} | ||
|
||
def close(): Unit = { | ||
root.close() | ||
} | ||
} |