Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-223] Add Parquet write support to Arrow data source #324

Merged
merged 7 commits into from
May 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,27 @@
* limitations under the License.
*/

package com.intel.oap.execution

import com.intel.oap.vectorized._
package com.intel.oap.sql.execution

import java.util.concurrent.TimeUnit._
import scala.collection.JavaConverters._

import org.apache.spark.{broadcast, TaskContext}
import com.intel.oap.vectorized._

import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder, SpecializedGetters, UnsafeProjection}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.RowToColumnarExec
import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector, WritableColumnVector}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils.UnsafeItr
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector
import org.apache.spark.sql.execution.vectorized.WritableColumnVector
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.sql.vectorized.ColumnarBatch

class RowToColumnConverter(schema: StructType) extends Serializable {
private val converters = schema.fields.map {
Expand Down Expand Up @@ -279,7 +278,7 @@ case class RowToArrowColumnarExec(child: SparkPlan) extends UnaryExecNode {
}

override def next(): ColumnarBatch = {
val vectors: Seq[WritableColumnVector] =
val vectors: Seq[WritableColumnVector] =
ArrowWritableColumnVector.allocateColumns(numRows, schema)
var rowCount = 0
while (rowCount < numRows && rowIterator.hasNext) {
Expand All @@ -297,7 +296,7 @@ case class RowToArrowColumnarExec(child: SparkPlan) extends UnaryExecNode {
last_cb
}
}
new CloseableColumnBatchIterator(res)
new UnsafeItr(res)
} else {
Iterator.empty
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,4 +226,41 @@ object SparkMemoryUtils {
val list = new util.ArrayList[NativeMemoryPool](leakedMemoryPools)
list.asScala.toList
}

class UnsafeItr[T <: AutoCloseable](delegate: Iterator[T])
extends Iterator[T] {
val holder = new GenericRetainer[T]()

SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]((_: TaskContext) => {
holder.release()
})

override def hasNext: Boolean = {
holder.release()
val hasNext = delegate.hasNext
hasNext
}

override def next(): T = {
val b = delegate.next()
holder.retain(b)
b
}
}

class GenericRetainer[T <: AutoCloseable] {
private var retained: Option[T] = None

def retain(batch: T): Unit = {
if (retained.isDefined) {
throw new IllegalStateException
}
retained = Some(batch)
}

def release(): Unit = {
retained.foreach(b => b.close())
retained = None
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.v2.arrow

import scala.collection.JavaConverters._

import com.intel.oap.vectorized.ArrowWritableColumnVector
import org.apache.arrow.memory.ArrowBuf
import org.apache.arrow.vector.FieldVector
import org.apache.arrow.vector.TypeLayout
import org.apache.arrow.vector.ValueVector
import org.apache.arrow.vector.ipc.message.ArrowFieldNode
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch

import org.apache.spark.sql.vectorized.ColumnarBatch

object SparkVectorUtils {

def estimateSize(columnarBatch: ColumnarBatch): Long = {
val cols = (0 until columnarBatch.numCols).toList.map(i =>
columnarBatch.column(i).asInstanceOf[ArrowWritableColumnVector].getValueVector())
val nodes = new java.util.ArrayList[ArrowFieldNode]()
val buffers = new java.util.ArrayList[ArrowBuf]()
cols.foreach(vector => {
appendNodes(vector.asInstanceOf[FieldVector], nodes, buffers);
})
buffers.asScala.map(_.getPossibleMemoryConsumed()).sum
}

def toArrowRecordBatch(columnarBatch: ColumnarBatch): ArrowRecordBatch = {
val numRowsInBatch = columnarBatch.numRows()
val cols = (0 until columnarBatch.numCols).toList.map(i =>
columnarBatch.column(i).asInstanceOf[ArrowWritableColumnVector].getValueVector)
toArrowRecordBatch(numRowsInBatch, cols)
}

def toArrowRecordBatch(numRows: Int,
cols: List[ValueVector]): ArrowRecordBatch = {
val nodes = new java.util.ArrayList[ArrowFieldNode]()
val buffers = new java.util.ArrayList[ArrowBuf]()
cols.foreach(vector => {
appendNodes(vector.asInstanceOf[FieldVector], nodes, buffers);
})
new ArrowRecordBatch(numRows, nodes, buffers);
}

private def appendNodes(
vector: FieldVector,
nodes: java.util.List[ArrowFieldNode],
buffers: java.util.List[ArrowBuf]): Unit = {
nodes.add(new ArrowFieldNode(vector.getValueCount, vector.getNullCount))
val fieldBuffers = vector.getFieldBuffers
val expectedBufferCount = TypeLayout.getTypeBufferCount(vector.getField.getType)
if (fieldBuffers.size != expectedBufferCount) {
throw new IllegalArgumentException(
s"Wrong number of buffers for field ${vector.getField} in vector " +
s"${vector.getClass.getSimpleName}. found: ${fieldBuffers}")
}
buffers.addAll(fieldBuffers)
vector.getChildrenFromFields.asScala.foreach(child => appendNodes(child, nodes, buffers))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.intel.oap.spark.sql

import com.intel.oap.spark.sql.ArrowWriteExtension.ArrowWritePostRule
import com.intel.oap.spark.sql.ArrowWriteExtension.DummyRule
import com.intel.oap.spark.sql.ArrowWriteExtension.SimpleColumnarRule
import com.intel.oap.spark.sql.execution.datasources.arrow.ArrowFileFormat
import com.intel.oap.sql.execution.RowToArrowColumnarExec

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.SparkSessionExtensions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.catalyst.util.MapData
import org.apache.spark.sql.execution.CodegenSupport
import org.apache.spark.sql.execution.ColumnarRule
import org.apache.spark.sql.execution.ColumnarToRowExec
import org.apache.spark.sql.execution.ColumnarToRowTransition
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.Decimal
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.unsafe.types.UTF8String

class ArrowWriteExtension extends (SparkSessionExtensions => Unit) {
def apply(e: SparkSessionExtensions): Unit = {
e.injectColumnar(session => SimpleColumnarRule(DummyRule, ArrowWritePostRule(session)))
}
}

object ArrowWriteExtension {
private object DummyRule extends Rule[SparkPlan] {
def apply(p: SparkPlan): SparkPlan = p
}

private case class SimpleColumnarRule(pre: Rule[SparkPlan], post: Rule[SparkPlan])
extends ColumnarRule {
override def preColumnarTransitions: Rule[SparkPlan] = pre
override def postColumnarTransitions: Rule[SparkPlan] = post
}

case class ArrowWritePostRule(session: SparkSession) extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = plan match {
case rc @ DataWritingCommandExec(cmd, ColumnarToRowExec(child)) =>
cmd match {
case command: InsertIntoHadoopFsRelationCommand =>
if (command.fileFormat
.isInstanceOf[ArrowFileFormat]) {
rc.withNewChildren(Array(ColumnarToFakeRowAdaptor(child)))
} else {
plan.withNewChildren(plan.children.map(apply))
}
case _ => plan.withNewChildren(plan.children.map(apply))
}
case rc @ DataWritingCommandExec(cmd, child) =>
cmd match {
case command: InsertIntoHadoopFsRelationCommand =>
if (command.fileFormat
.isInstanceOf[ArrowFileFormat]) {
rc.withNewChildren(Array(ColumnarToFakeRowAdaptor(RowToArrowColumnarExec(child))))
} else {
plan.withNewChildren(plan.children.map(apply))
}
case _ => plan.withNewChildren(plan.children.map(apply))
}
case plan: SparkPlan => plan.withNewChildren(plan.children.map(apply))
}
}

private case class ColumnarToFakeRowAdaptor(child: SparkPlan) extends ColumnarToRowTransition {
assert(child.supportsColumnar)

override protected def doExecute(): RDD[InternalRow] = {
child.executeColumnar().map { cb =>
new FakeRow(cb)
}
}

override def output: Seq[Attribute] = child.output
}

class FakeRow(val batch: ColumnarBatch) extends InternalRow {
override def numFields: Int = throw new UnsupportedOperationException()
override def setNullAt(i: Int): Unit = throw new UnsupportedOperationException()
override def update(i: Int, value: Any): Unit = throw new UnsupportedOperationException()
override def copy(): InternalRow = throw new UnsupportedOperationException()
override def isNullAt(ordinal: Int): Boolean = throw new UnsupportedOperationException()
override def getBoolean(ordinal: Int): Boolean = throw new UnsupportedOperationException()
override def getByte(ordinal: Int): Byte = throw new UnsupportedOperationException()
override def getShort(ordinal: Int): Short = throw new UnsupportedOperationException()
override def getInt(ordinal: Int): Int = throw new UnsupportedOperationException()
override def getLong(ordinal: Int): Long = throw new UnsupportedOperationException()
override def getFloat(ordinal: Int): Float = throw new UnsupportedOperationException()
override def getDouble(ordinal: Int): Double = throw new UnsupportedOperationException()
override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal =
throw new UnsupportedOperationException()
override def getUTF8String(ordinal: Int): UTF8String = throw new UnsupportedOperationException()
override def getBinary(ordinal: Int): Array[Byte] = throw new UnsupportedOperationException()
override def getInterval(ordinal: Int): CalendarInterval =
throw new UnsupportedOperationException()
override def getStruct(ordinal: Int, numFields: Int): InternalRow =
throw new UnsupportedOperationException()
override def getArray(ordinal: Int): ArrayData = throw new UnsupportedOperationException()
override def getMap(ordinal: Int): MapData = throw new UnsupportedOperationException()
override def get(ordinal: Int, dataType: DataType): AnyRef =
throw new UnsupportedOperationException()
}
}
Loading