Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
Support ListArray in ColumnarShuffle and RowToArrowColumnar operator (#…
Browse files Browse the repository at this point in the history
…496)

Signed-off-by: Chendi Xue <[email protected]>
  • Loading branch information
xuechendi authored Sep 10, 2021
1 parent 1a0d654 commit 6f3041e
Show file tree
Hide file tree
Showing 21 changed files with 949 additions and 271 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,15 @@ public Decimal getDecimal(int rowId, int precision, int scale) {
public UTF8String getUTF8String(int rowId) {
if (isNullAt(rowId))
return null;
return accessor.getUTF8String(rowId);
if (dataType() instanceof ArrayType) {
UTF8String ret_0 = accessor.getUTF8String(rowId);
for (int i = 0; i < ((ArrayAccessor) accessor).getArrayLength(rowId); i++) {
ret_0 = UTF8String.concat(ret_0, getArray(rowId).getUTF8String(i));
}
return ret_0;
} else {
return accessor.getUTF8String(rowId);
}
}

@Override
Expand Down Expand Up @@ -1165,13 +1173,10 @@ final long getLong(int rowId) {

private static class ArrayAccessor extends ArrowVectorAccessor {
private final ListVector accessor;
ArrowWritableColumnVector arrayData;

ArrayAccessor(ListVector vector) {
super(vector);
this.accessor = vector;
arrayData =
new ArrowWritableColumnVector(vector.getDataVector(), 0, vector.size(), false);
}

@Override
Expand All @@ -1197,6 +1202,12 @@ public int getArrayOffset(int rowId) {
int index = rowId * ListVector.OFFSET_WIDTH;
return accessor.getOffsetBuffer().getInt(index);
}

@Override
final UTF8String getUTF8String(int rowId) {
return UTF8String.fromString(
"Array[" + getArrayOffset(rowId) + "-" + getArrayLength(rowId) + "]");
}
}

/**
Expand Down Expand Up @@ -1849,12 +1860,24 @@ final void setNulls(int rowId, int count) {

private static class ArrayWriter extends ArrowVectorWriter {
private final ListVector writer;
// private final ArrowWritableColumnVector arrayData;

ArrayWriter(ListVector vector, ArrowVectorWriter elementVector) {
super(vector);
this.writer = vector;
}

@Override
void setArray(int rowId, int offset, int length) {
int index = rowId * ListVector.OFFSET_WIDTH;
writer.getOffsetBuffer().setInt(index, offset);
writer.getOffsetBuffer().setInt(index + ListVector.OFFSET_WIDTH, offset + length);
writer.setNotNull(rowId);
}

@Override
final void setNull(int rowId) {
writer.setNull(rowId);
}
}

private static class StructWriter extends ArrowVectorWriter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,17 @@ import scala.collection.JavaConverters._

import com.intel.oap.vectorized.ArrowWritableColumnVector
import org.apache.arrow.memory.ArrowBuf
import org.apache.arrow.vector.FieldVector
import org.apache.arrow.vector.TypeLayout
import org.apache.arrow.vector.ValueVector
import org.apache.arrow.vector.ipc.message.ArrowFieldNode
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch
import org.apache.arrow.vector.{
BaseFixedWidthVector,
BaseVariableWidthVector,
FieldVector,
TypeLayout,
VectorLoader,
ValueVector,
VectorSchemaRoot
}

import org.apache.spark.sql.vectorized.ColumnarBatch

Expand All @@ -49,8 +55,7 @@ object SparkVectorUtils {
toArrowRecordBatch(numRowsInBatch, cols)
}

def toArrowRecordBatch(numRows: Int,
cols: List[ValueVector]): ArrowRecordBatch = {
def toArrowRecordBatch(numRows: Int, cols: List[ValueVector]): ArrowRecordBatch = {
val nodes = new java.util.ArrayList[ArrowFieldNode]()
val buffers = new java.util.ArrayList[ArrowBuf]()
cols.foreach(vector => {
Expand All @@ -59,19 +64,48 @@ object SparkVectorUtils {
new ArrowRecordBatch(numRows, nodes, buffers);
}

private def appendNodes(
def getArrowBuffers(vector: FieldVector): Array[ArrowBuf] = {
try {
vector.getFieldBuffers.asScala.toArray
} catch {
case _ : Throwable =>
vector match {
case fixed: BaseFixedWidthVector =>
Array(fixed.getValidityBuffer, fixed.getDataBuffer)
case variable: BaseVariableWidthVector =>
Array(variable.getValidityBuffer, variable.getOffsetBuffer, variable.getDataBuffer)
case _ =>
throw new UnsupportedOperationException(
s"Could not decompress vector of class ${vector.getClass}")
}
}
}

def appendNodes(
vector: FieldVector,
nodes: java.util.List[ArrowFieldNode],
buffers: java.util.List[ArrowBuf]): Unit = {
nodes.add(new ArrowFieldNode(vector.getValueCount, vector.getNullCount))
val fieldBuffers = vector.getFieldBuffers
buffers: java.util.List[ArrowBuf],
bits: java.util.List[Boolean] = null): Unit = {
if (nodes != null) {
nodes.add(new ArrowFieldNode(vector.getValueCount, vector.getNullCount))
}
val fieldBuffers = getArrowBuffers(vector)
val expectedBufferCount = TypeLayout.getTypeBufferCount(vector.getField.getType)
if (fieldBuffers.size != expectedBufferCount) {
throw new IllegalArgumentException(
s"Wrong number of buffers for field ${vector.getField} in vector " +
s"${vector.getClass.getSimpleName}. found: ${fieldBuffers}")
s"${vector.getClass.getSimpleName}. found: ${fieldBuffers}")
}
import collection.JavaConversions._
buffers.addAll(fieldBuffers.toSeq)
if (bits != null) {
val bits_tmp = Array.fill[Boolean](expectedBufferCount)(false)
bits_tmp(0) = true
bits.addAll(bits_tmp.toSeq)
vector.getChildrenFromFields.asScala.foreach(child =>
appendNodes(child, nodes, buffers, bits))
} else {
vector.getChildrenFromFields.asScala.foreach(child => appendNodes(child, nodes, buffers))
}
buffers.addAll(fieldBuffers)
vector.getChildrenFromFields.asScala.foreach(child => appendNodes(child, nodes, buffers))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ package com.intel.oap.execution
import com.intel.oap.expression.ConverterUtils
import com.intel.oap.vectorized.ArrowWritableColumnVector
import com.intel.oap.vectorized.CloseableColumnBatchIterator

import org.apache.arrow.vector.util.VectorBatchAppender
import org.apache.arrow.memory.{BufferAllocator, RootAllocator}
import org.apache.arrow.vector.types.pojo.Schema;

import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -32,8 +35,10 @@ import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.sql.types.{StructType, StructField}
import org.apache.spark.sql.util.ArrowUtils;

import scala.collection.mutable.ListBuffer
import scala.collection.JavaConverters._

case class CoalesceBatchesExec(child: SparkPlan) extends UnaryExecNode {

Expand Down Expand Up @@ -75,9 +80,6 @@ case class CoalesceBatchesExec(child: SparkPlan) extends UnaryExecNode {
new Iterator[ColumnarBatch] {
var numBatchesTotal: Long = _
var numRowsTotal: Long = _
val resultStructType =
StructType(output.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))

SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit] { _ =>
if (numBatchesTotal > 0) {
avgCoalescedNumRows.set(numRowsTotal.toDouble / numBatchesTotal)
Expand Down Expand Up @@ -107,6 +109,16 @@ case class CoalesceBatchesExec(child: SparkPlan) extends UnaryExecNode {
batchesToAppend += delta
}

// chendi: We need make sure target FieldTypes are exactly the same as src
val expected_output_arrow_fields = if (batchesToAppend.size > 0) {
(0 until batchesToAppend(0).numCols).map(i => {
batchesToAppend(0).column(i).asInstanceOf[ArrowWritableColumnVector].getValueVector.getField
})
} else {
Nil
}

val resultStructType = ArrowUtils.fromArrowSchema(new Schema(expected_output_arrow_fields.asJava))
val beforeConcat = System.nanoTime
val resultColumnVectors =
ArrowWritableColumnVector.allocateColumns(rowCount, resultStructType).toArray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,27 @@ import scala.collection.mutable.ListBuffer

class ColumnarBoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
extends BoundReference(ordinal, dataType, nullable)
with ColumnarExpression with Logging {
with ColumnarExpression
with Logging {

buildCheck()

def buildCheck(): Unit = {
try {
ConverterUtils.checkIfTypeSupported(dataType)
dataType match {
case at: ArrayType =>
case _ =>
ConverterUtils.checkIfTypeSupported(dataType)
}
} catch {
case e : UnsupportedOperationException =>
case e: UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${dataType} is not supported in ColumnarBoundReference.")
}
}
override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
val resultType = CodeGeneration.getResultType(dataType)
val field = Field.nullable(s"c_$ordinal", resultType)
val field = ConverterUtils.createArrowField(s"c_$ordinal", dataType)
val fieldTypes = args.asInstanceOf[java.util.List[Field]]
fieldTypes.add(field)
(TreeBuilder.makeField(field), resultType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,26 +85,31 @@ class ColumnarProjection (
}
}

val (ordinalList, arrowSchema) = if (projPrepareList.size > 0 &&
(s"${projPrepareList.map(_.toProtobuf)}".contains("fnNode") || projPrepareList.size != inputList.size)) {
val inputFieldList = inputList.asScala.toList.distinct
val schema = new Schema(inputFieldList.asJava)
projector = Projector.make(schema, projPrepareList.toList.asJava)
(inputFieldList.map(field => {
field.getName.replace("c_", "").toInt
}),
schema)
} else {
val inputFieldList = inputList.asScala.toList
(inputFieldList.map(field => {
field.getName.replace("c_", "").toInt
}),
new Schema(inputFieldList.asJava))
val (ordinalList, arrowSchema) = {
var protoBufTest: String = null
try {
protoBufTest = s"${projPrepareList.map(_.toProtobuf)}"
} catch {
case _ => protoBufTest = null
}
if (protoBufTest != null && projPrepareList.size > 0 &&
(protoBufTest.contains("fnNode") || projPrepareList.size != inputList.size)) {
val inputFieldList = inputList.asScala.toList.distinct
val schema = new Schema(inputFieldList.asJava)
projector = Projector.make(schema, projPrepareList.toList.asJava)
(inputFieldList.map(field => {
field.getName.replace("c_", "").toInt
}), schema)
} else {
val inputFieldList = inputList.asScala.toList
(inputFieldList.map(field => {
field.getName.replace("c_", "").toInt
}), new Schema(inputFieldList.asJava))
}
}
//System.out.println(s"Project input ordinal is ${ordinalList}, Schema is ${arrowSchema}")
val outputArrowSchema = new Schema(resultAttributes.map(attr => {
Field.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType))
}).asJava)
val outputArrowSchema = new Schema(
resultAttributes.map(attr => ConverterUtils.createArrowField(attr)).asJava)
val outputSchema = ArrowUtils.fromArrowSchema(outputArrowSchema)

def output(): List[AttributeReference] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,7 @@ import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector._
import org.apache.arrow.vector.ipc.{ArrowStreamReader, ReadChannel, WriteChannel}
import org.apache.arrow.vector.ipc.message.{ArrowFieldNode, ArrowRecordBatch, IpcOption, MessageChannelReader, MessageResult, MessageSerializer}
import org.apache.arrow.vector.ipc.message.{ArrowFieldNode, ArrowRecordBatch, IpcOption, MessageChannelReader, MessageResult, MessageSerializer}
import org.apache.arrow.vector.types.pojo.ArrowType
import org.apache.arrow.vector.types.pojo.Field
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
import org.apache.arrow.gandiva.expression._
import org.apache.arrow.gandiva.evaluator._

Expand Down Expand Up @@ -507,6 +504,23 @@ object ConverterUtils extends Logging {
throw new UnsupportedOperationException(s"Unsupported data type: $dt")
}

def createArrowField(name: String, dt: DataType): Field = dt match {
case at: ArrayType =>
new Field(
name,
FieldType.nullable(ArrowType.List.INSTANCE),
Lists.newArrayList(createArrowField(s"${name}_${dt}", at.elementType)))
case mt: MapType =>
throw new UnsupportedOperationException(s"${dt} is not supported yet")
case st: StructType =>
throw new UnsupportedOperationException(s"${dt} is not supported yet")
case _ =>
Field.nullable(name, CodeGeneration.getResultType(dt))
}

def createArrowField(attr: Attribute): Field =
createArrowField(s"${attr.name}#${attr.exprId.id}", attr.dataType)

private def asTimestampType(inType: ArrowType): ArrowType.Timestamp = {
if (inType.getTypeID != ArrowTypeID.Timestamp) {
throw new IllegalArgumentException(s"Value type to convert must be timestamp")
Expand Down
Loading

0 comments on commit 6f3041e

Please sign in to comment.