Skip to content

Commit

Permalink
Support Spark ColumnarBatch to ArrowArray
Browse files Browse the repository at this point in the history
  • Loading branch information
boneanxs committed Feb 29, 2024
1 parent 69ffb4c commit fcfd758
Show file tree
Hide file tree
Showing 4 changed files with 359 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.glutenproject.execution

import io.glutenproject.backendsapi.velox.ValidatorApiImpl
import io.glutenproject.columnarbatch.ColumnarBatches
import io.glutenproject.exec.Runtimes
import io.glutenproject.extension.GlutenPlan
import io.glutenproject.memory.arrowalloc.ArrowBufferAllocators
import io.glutenproject.memory.nmm.NativeMemoryManagers
import io.glutenproject.utils.{ArrowAbiUtil, Iterators}
import io.glutenproject.vectorized.VanillaColumnarToNativeColumnarJniWrapper
import org.apache.arrow.c.{ArrowArray, ArrowSchema, Data}
import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.arrow.ArrowColumnarBatchConverter
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.utils.SparkArrowUtil
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.TaskResources

case class VanillaColumnarToVeloxColumnarExec(child: SparkPlan) extends GlutenPlan with UnaryExecNode {

override def supportsColumnar: Boolean = true

override protected def doExecute(): RDD[InternalRow] = {
throw new UnsupportedOperationException(s"This operator doesn't support doExecute().")
}

override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
new ValidatorApiImpl().doSchemaValidate(schema).foreach {
reason =>
throw new UnsupportedOperationException(
s"Input schema contains unsupported type when convert columnar to columnar for $schema " +
s"due to $reason")
}

val numInputBatches = longMetric("numInputBatches")
val numOutputBatches = longMetric("numOutputBatches")
val convertTime = longMetric("convertTime")
// Instead of creating a new config we are reusing columnBatchSize. In the future if we do
// combine with some of the Arrow conversion tools we will need to unify some of the configs.
val numRows = conf.columnBatchSize
// This avoids calling `schema` in the RDD closure, so that we don't need to include the entire
// plan (this) in the closure.
val localSchema = schema
child.execute().mapPartitions {
rowIterator =>
VanillaColumnarToVeloxColumnarExec.toColumnarBatchIterator(
rowIterator.asInstanceOf[Iterator[ColumnarBatch]],
localSchema,
numInputBatches,
numOutputBatches,
convertTime,
TaskContext.get())
}
}

override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = {
copy(child = newChild)
}

override def output: Seq[Attribute] = child.output
}

object VanillaColumnarToVeloxColumnarExec {

def toColumnarBatchIterator(it: Iterator[ColumnarBatch],
schema: StructType,
numInputBatches: SQLMetric,
numOutputBatches: SQLMetric,
convertTime: SQLMetric,
taskContext: TaskContext): Iterator[ColumnarBatch] = {
if (it.isEmpty) {
return Iterator.empty
}

val arrowSchema =
SparkArrowUtil.toArrowSchema(schema, SQLConf.get.sessionLocalTimeZone)
val jniWrapper = VanillaColumnarToNativeColumnarJniWrapper.create()
val allocator = ArrowBufferAllocators.contextInstance()
val cSchema = ArrowSchema.allocateNew(allocator)
val c2cHandle =
try {
ArrowAbiUtil.exportSchema(allocator, arrowSchema, cSchema)
jniWrapper.init(
cSchema.memoryAddress(),
NativeMemoryManagers
.contextInstance("ColumnarToColumnar")
.getNativeInstanceHandle)
} finally {
cSchema.close()
}

val converter = ArrowColumnarBatchConverter.create(arrowSchema, allocator)

val res: Iterator[ColumnarBatch] = new Iterator[ColumnarBatch] {

override def hasNext: Boolean = {
it.hasNext
}

def nativeConvert(cb: ColumnarBatch): ColumnarBatch = {
var arrowArray: ArrowArray = null
TaskResources.addRecycler("ColumnarToColumnar_arrowArray", 100) {
// Remind, remove isOpen here
if (arrowArray != null) {
arrowArray.close()
}
}

numInputBatches += 1
try {
arrowArray = ArrowArray.allocateNew(allocator)
converter.write(cb)
converter.finish()
Data.exportVectorSchemaRoot(allocator, converter.root, null, arrowArray)
val handle = jniWrapper
.nativeConvertVanillaColumnarToColumnar(c2cHandle, arrowArray.memoryAddress())
ColumnarBatches.create(Runtimes.contextInstance(), handle)
} finally {
converter.reset()
arrowArray.close()
arrowArray = null
}
}

override def next(): ColumnarBatch = {
val currentBatch = it.next()
val start = System.currentTimeMillis()
val cb = nativeConvert(currentBatch)
numOutputBatches += 1
convertTime += System.currentTimeMillis() - start
cb
}
}

if (taskContext != null) {
taskContext.addTaskCompletionListener[Unit] { _ =>
jniWrapper.close(c2cHandle)
allocator.close()
converter.close()
}
}

Iterators
.wrap(res)
.recycleIterator {
jniWrapper.close(c2cHandle)
allocator.close()
converter.close()
}
.recyclePayload(_.close())
.create()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla
.set("spark.unsafe.exceptionOnMemoryLeak", "true")
.set("spark.sql.autoBroadcastJoinThreshold", "-1")
.set("spark.sql.sources.useV1SourceList", "avro")
.set("spark.gluten.sql.columnar.batchscan", "false")
.set("spark.sql.columnVector.offheap.enabled", "true")
}

test("simple_select") {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package io.glutenproject.vectorized;

import io.glutenproject.exec.Runtime;
import io.glutenproject.exec.RuntimeAware;
import io.glutenproject.exec.Runtimes;

public class VanillaColumnarToNativeColumnarJniWrapper implements RuntimeAware {
private final Runtime runtime;

private VanillaColumnarToNativeColumnarJniWrapper(Runtime runtime) {
this.runtime = runtime;
}

public static VanillaColumnarToNativeColumnarJniWrapper create() {
return new VanillaColumnarToNativeColumnarJniWrapper(Runtimes.contextInstance());
}

@Override
public long handle() {
return runtime.getHandle();
}

public native long init(long cSchema, long memoryManagerHandle);

public native long nativeConvertVanillaColumnarToColumnar(
long c2cHandle, long bufferAddress);

public native void close(long c2cHandle);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.arrow

import org.apache.arrow.memory.BufferAllocator

import scala.collection.JavaConverters._
import org.apache.arrow.vector._
import org.apache.arrow.vector.complex._
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.sql.vectorized.ColumnVector

object ArrowColumnarBatchConverter {

def create(arrowSchema: Schema, bufferAllocator: BufferAllocator): ArrowColumnarBatchConverter = {
val root = VectorSchemaRoot.create(arrowSchema, bufferAllocator)
create(root)
}

def create(root: VectorSchemaRoot): ArrowColumnarBatchConverter = {
val children = root.getFieldVectors.asScala.map { vector =>
vector.allocateNew()
createFieldWriter(vector)
}
new ArrowColumnarBatchConverter(root, children.toArray)
}

private def createFieldWriter(vector: ValueVector): ArrowFieldWriter = {
(ArrowUtils.fromArrowField(vector.getField), vector) match {
case (BooleanType, vector: BitVector) => new BooleanWriter(vector)
case (ByteType, vector: TinyIntVector) => new ByteWriter(vector)
case (ShortType, vector: SmallIntVector) => new ShortWriter(vector)
case (IntegerType, vector: IntVector) => new IntegerWriter(vector)
case (LongType, vector: BigIntVector) => new LongWriter(vector)
case (FloatType, vector: Float4Vector) => new FloatWriter(vector)
case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector)
case (DecimalType.Fixed(precision, scale), vector: DecimalVector) =>
new DecimalWriter(vector, precision, scale)
case (StringType, vector: VarCharVector) => new StringWriter(vector)
case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector)
case (DateType, vector: DateDayVector) => new DateWriter(vector)
case (TimestampType, vector: TimeStampMicroTZVector) => new TimestampWriter(vector)
case (ArrayType(_, _), vector: ListVector) =>
val elementVector = createFieldWriter(vector.getDataVector)
new ArrayWriter(vector, elementVector)
case (MapType(_, _, _), vector: MapVector) =>
val structVector = vector.getDataVector.asInstanceOf[StructVector]
val keyWriter = createFieldWriter(structVector.getChild(MapVector.KEY_NAME))
val valueWriter = createFieldWriter(structVector.getChild(MapVector.VALUE_NAME))
new MapWriter(vector, structVector, keyWriter, valueWriter)
case (StructType(_), vector: StructVector) =>
val children = (0 until vector.size()).map { ordinal =>
createFieldWriter(vector.getChildByOrdinal(ordinal))
}
new StructWriter(vector, children.toArray)
case (NullType, vector: NullVector) => new NullWriter(vector)
case (_: YearMonthIntervalType, vector: IntervalYearVector) => new IntervalYearWriter(vector)
case (_: DayTimeIntervalType, vector: IntervalDayVector) => new IntervalDayWriter(vector)
case (dt, _) =>
throw QueryExecutionErrors.unsupportedDataTypeError(dt.catalogString)
}
}
}

case class ColumnarSpecializedGetters(columnVector: ColumnVector) extends SpecializedGetters {

override def isNullAt(rowId: Int): Boolean = columnVector.isNullAt(rowId)

override def getBoolean(rowId: Int): Boolean = columnVector.getBoolean(rowId)

override def getByte(rowId: Int): Byte = columnVector.getByte(rowId)

override def getShort(rowId: Int): Short = columnVector.getShort(rowId)

override def getInt(rowId: Int): Int = columnVector.getInt(rowId)

override def getLong(rowId: Int): Long = columnVector.getLong(rowId)

override def getFloat(rowId: Int): Float = columnVector.getFloat(rowId)

override def getDouble(rowId: Int): Double = columnVector.getDouble(rowId)

override def getDecimal(rowId: Int, precision: Int, scale: Int): Decimal = columnVector.getDecimal(rowId, precision, scale)

override def getUTF8String(rowId: Int): UTF8String = columnVector.getUTF8String(rowId)

override def getBinary(rowId: Int): Array[Byte] = columnVector.getBinary(rowId)

override def getInterval(rowId: Int): CalendarInterval = columnVector.getInterval(rowId)

override def getStruct(rowId: Int, numFields: Int): InternalRow = columnVector.getStruct(rowId)

override def getArray(rowId: Int): ArrayData = columnVector.getArray(rowId)

override def getMap(rowId: Int): MapData = columnVector.getMap(rowId)

override def get(rowId: Int, dataType: DataType): AnyRef = {
throw new UnsupportedOperationException("Not implemented yet")
}
}

class ArrowColumnarBatchConverter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) {

private var count: Int = 0

def write(columnarBatch: ColumnarBatch): Unit = {
fields.zipWithIndex.foreach { case (field, ordinal) =>
val columnVector = ColumnarSpecializedGetters(columnarBatch.column(ordinal))
for (rowId <- 0 until columnarBatch.numRows()) {
field.write(columnVector, rowId)
}
}
count += columnarBatch.numRows()
}

def finish(): Unit = {
root.setRowCount(count)
fields.foreach(_.finish())
}

def reset(): Unit = {
root.setRowCount(0)
count = 0
fields.foreach(_.reset())
}

def close(): Unit = {
root.close()
}
}

0 comments on commit fcfd758

Please sign in to comment.